diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/entities | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/entities')
320 files changed, 66328 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/__init__.py new file mode 100644 index 00000000..508dea7c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/__init__.py @@ -0,0 +1,631 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Contains entities and SDK objects for Azure Machine Learning SDKv2. + +Main areas include managing compute targets, creating/managing workspaces and jobs, and submitting/accessing model, runs +and run output/logging etc. +""" +# pylint: disable=naming-mismatch +__path__ = __import__("pkgutil").extend_path(__path__, __name__) + +import logging +from typing import Any, Optional + +from azure.ai.ml._restclient.v2022_10_01.models import CreatedByType +from azure.ai.ml._restclient.v2022_10_01_preview.models import UsageUnit + +from ._assets._artifacts._package.base_environment_source import BaseEnvironment +from ._assets._artifacts._package.inferencing_server import ( + AzureMLBatchInferencingServer, + AzureMLOnlineInferencingServer, + CustomInferencingServer, + Route, + TritonInferencingServer, +) +from ._assets._artifacts._package.model_configuration import ModelConfiguration +from ._assets._artifacts._package.model_package import ( + ModelPackage, + ModelPackageInput, + PackageInputPathId, + PackageInputPathUrl, + PackageInputPathVersion, +) +from ._assets._artifacts.data import Data +from ._assets._artifacts.feature_set import FeatureSet +from ._assets._artifacts.index import Index +from ._assets._artifacts.model import Model +from ._assets.asset import Asset +from ._assets.environment import BuildContext, Environment +from ._assets.intellectual_property import IntellectualProperty +from ._assets.workspace_asset_reference import ( + WorkspaceAssetReference as WorkspaceModelReference, +) +from ._autogen_entities.models import ( + AzureOpenAIDeployment, + MarketplacePlan, + MarketplaceSubscription, + ServerlessEndpoint, +) +from ._builders import Command, Parallel, Pipeline, Spark, Sweep +from ._component.command_component import CommandComponent +from ._component.component import Component +from ._component.parallel_component import ParallelComponent +from ._component.pipeline_component import PipelineComponent +from ._component.spark_component import SparkComponent +from ._compute._aml_compute_node_info import AmlComputeNodeInfo +from ._compute._custom_applications import ( + CustomApplications, + EndpointsSettings, + ImageSettings, + VolumeSettings, +) +from ._compute._image_metadata import ImageMetadata +from ._compute._schedule import ( + ComputePowerAction, + ComputeSchedules, + ComputeStartStopSchedule, + ScheduleState, +) +from ._compute._setup_scripts import ScriptReference, SetupScripts +from ._compute._usage import Usage, UsageName +from ._compute._vm_size import VmSize +from ._compute.aml_compute import AmlCompute, AmlComputeSshSettings +from ._compute.compute import Compute, NetworkSettings +from ._compute.compute_instance import ( + AssignedUserConfiguration, + ComputeInstance, + ComputeInstanceSshSettings, +) +from ._compute.kubernetes_compute import KubernetesCompute +from ._compute.synapsespark_compute import ( + AutoPauseSettings, + AutoScaleSettings, + SynapseSparkCompute, +) +from ._compute.unsupported_compute import UnsupportedCompute +from ._compute.virtual_machine_compute import ( + VirtualMachineCompute, + VirtualMachineSshSettings, +) +from ._credentials import ( + AadCredentialConfiguration, + AccessKeyConfiguration, + AccountKeyConfiguration, + AmlTokenConfiguration, + ApiKeyConfiguration, + CertificateConfiguration, + IdentityConfiguration, + ManagedIdentityConfiguration, + NoneCredentialConfiguration, + PatTokenConfiguration, + SasTokenConfiguration, + ServicePrincipalConfiguration, + UserIdentityConfiguration, + UsernamePasswordConfiguration, +) +from ._data_import.data_import import DataImport +from ._data_import.schedule import ImportDataSchedule +from ._datastore.adls_gen1 import AzureDataLakeGen1Datastore +from ._datastore.azure_storage import ( + AzureBlobDatastore, + AzureDataLakeGen2Datastore, + AzureFileDatastore, +) +from ._datastore.datastore import Datastore +from ._datastore.one_lake import OneLakeArtifact, OneLakeDatastore +from ._deployment.batch_deployment import BatchDeployment +from ._deployment.batch_job import BatchJob +from ._deployment.code_configuration import CodeConfiguration +from ._deployment.container_resource_settings import ResourceSettings +from ._deployment.data_asset import DataAsset +from ._deployment.data_collector import DataCollector +from ._deployment.deployment_collection import DeploymentCollection +from ._deployment.deployment_settings import ( + BatchRetrySettings, + OnlineRequestSettings, + ProbeSettings, +) +from ._deployment.model_batch_deployment import ModelBatchDeployment +from ._deployment.model_batch_deployment_settings import ModelBatchDeploymentSettings +from ._deployment.online_deployment import ( + Deployment, + KubernetesOnlineDeployment, + ManagedOnlineDeployment, + OnlineDeployment, +) +from ._deployment.pipeline_component_batch_deployment import ( + PipelineComponentBatchDeployment, +) +from ._deployment.request_logging import RequestLogging +from ._deployment.resource_requirements_settings import ResourceRequirementsSettings +from ._deployment.scale_settings import ( + DefaultScaleSettings, + OnlineScaleSettings, + TargetUtilizationScaleSettings, +) +from ._endpoint.batch_endpoint import BatchEndpoint +from ._endpoint.endpoint import Endpoint +from ._endpoint.online_endpoint import ( + EndpointAadToken, + EndpointAuthKeys, + EndpointAuthToken, + KubernetesOnlineEndpoint, + ManagedOnlineEndpoint, + OnlineEndpoint, +) +from ._feature_set.data_availability_status import DataAvailabilityStatus +from ._feature_set.feature import Feature +from ._feature_set.feature_set_backfill_metadata import FeatureSetBackfillMetadata +from ._feature_set.feature_set_backfill_request import FeatureSetBackfillRequest +from ._feature_set.feature_set_materialization_metadata import ( + FeatureSetMaterializationMetadata, +) +from ._feature_set.feature_set_specification import FeatureSetSpecification +from ._feature_set.feature_window import FeatureWindow +from ._feature_set.materialization_compute_resource import ( + MaterializationComputeResource, +) +from ._feature_set.materialization_settings import MaterializationSettings +from ._feature_set.materialization_type import MaterializationType +from ._feature_store.feature_store import FeatureStore +from ._feature_store.materialization_store import MaterializationStore +from ._feature_store_entity.data_column import DataColumn +from ._feature_store_entity.data_column_type import DataColumnType +from ._feature_store_entity.feature_store_entity import FeatureStoreEntity +from ._indexes import AzureAISearchConfig, GitSource, IndexDataSource, LocalSource +from ._indexes import ModelConfiguration as IndexModelConfiguration +from ._job.command_job import CommandJob +from ._job.compute_configuration import ComputeConfiguration +from ._job.finetuning.custom_model_finetuning_job import CustomModelFineTuningJob +from ._job.input_port import InputPort +from ._job.job import Job +from ._job.job_limits import CommandJobLimits +from ._job.job_resources import JobResources +from ._job.job_resource_configuration import JobResourceConfiguration +from ._job.job_service import ( + JobService, + JupyterLabJobService, + SshJobService, + TensorBoardJobService, + VsCodeJobService, +) +from ._job.parallel.parallel_task import ParallelTask +from ._job.parallel.retry_settings import RetrySettings +from ._job.parameterized_command import ParameterizedCommand + +# Pipeline related entities goes behind component since it depends on component +from ._job.pipeline.pipeline_job import PipelineJob, PipelineJobSettings +from ._job.queue_settings import QueueSettings +from ._job.resource_configuration import ResourceConfiguration +from ._job.service_instance import ServiceInstance +from ._job.spark_job import SparkJob +from ._job.spark_job_entry import SparkJobEntry, SparkJobEntryType +from ._job.spark_resource_configuration import SparkResourceConfiguration +from ._monitoring.alert_notification import AlertNotification +from ._monitoring.compute import ServerlessSparkCompute +from ._monitoring.definition import MonitorDefinition +from ._monitoring.input_data import ( + FixedInputData, + MonitorInputData, + StaticInputData, + TrailingInputData, +) +from ._monitoring.schedule import MonitorSchedule +from ._monitoring.signals import ( + BaselineDataRange, + CustomMonitoringSignal, + DataDriftSignal, + DataQualitySignal, + DataSegment, + FADProductionData, + FeatureAttributionDriftSignal, + GenerationSafetyQualitySignal, + GenerationTokenStatisticsSignal, + LlmData, + ModelPerformanceSignal, + MonitorFeatureFilter, + PredictionDriftSignal, + ProductionData, + ReferenceData, +) +from ._monitoring.target import MonitoringTarget +from ._monitoring.thresholds import ( + CategoricalDriftMetrics, + CustomMonitoringMetricThreshold, + DataDriftMetricThreshold, + DataQualityMetricsCategorical, + DataQualityMetricsNumerical, + DataQualityMetricThreshold, + FeatureAttributionDriftMetricThreshold, + GenerationSafetyQualityMonitoringMetricThreshold, + GenerationTokenStatisticsMonitorMetricThreshold, + ModelPerformanceClassificationThresholds, + ModelPerformanceMetricThreshold, + ModelPerformanceRegressionThresholds, + NumericalDriftMetrics, + PredictionDriftMetricThreshold, +) +from ._notification.notification import Notification +from ._registry.registry import Registry +from ._registry.registry_support_classes import ( + RegistryRegionDetails, + SystemCreatedAcrAccount, + SystemCreatedStorageAccount, +) +from ._resource import Resource +from ._schedule.schedule import JobSchedule, Schedule, ScheduleTriggerResult +from ._schedule.trigger import CronTrigger, RecurrencePattern, RecurrenceTrigger +from ._system_data import SystemData +from ._validation import ValidationResult +from ._workspace._ai_workspaces.hub import Hub +from ._workspace._ai_workspaces.project import Project +from ._workspace.compute_runtime import ComputeRuntime +from ._workspace.connections.connection_subtypes import ( + APIKeyConnection, + AzureAISearchConnection, + AzureAIServicesConnection, + AzureBlobStoreConnection, + AzureContentSafetyConnection, + AzureOpenAIConnection, + AzureSpeechServicesConnection, + MicrosoftOneLakeConnection, + OpenAIConnection, + SerpConnection, + ServerlessConnection, +) +from ._workspace.connections.one_lake_artifacts import OneLakeConnectionArtifact +from ._workspace.connections.workspace_connection import WorkspaceConnection +from ._workspace.customer_managed_key import CustomerManagedKey +from ._workspace.diagnose import ( + DiagnoseRequestProperties, + DiagnoseResponseResult, + DiagnoseResponseResultValue, + DiagnoseResult, + DiagnoseWorkspaceParameters, +) +from ._workspace.feature_store_settings import FeatureStoreSettings +from ._workspace.network_acls import DefaultActionType, IPRule, NetworkAcls +from ._workspace.networking import ( + FqdnDestination, + IsolationMode, + ManagedNetwork, + ManagedNetworkProvisionStatus, + OutboundRule, + PrivateEndpointDestination, + ServiceTagDestination, +) +from ._workspace.private_endpoint import EndpointConnection, PrivateEndpoint +from ._workspace.serverless_compute import ServerlessComputeSettings +from ._workspace.workspace import Workspace +from ._workspace._ai_workspaces.capability_host import ( + CapabilityHost, + CapabilityHostKind, +) +from ._workspace.workspace_keys import ( + ContainerRegistryCredential, + NotebookAccessKeys, + WorkspaceKeys, +) + +__all__ = [ + "Resource", + "Job", + "CommandJob", + "PipelineJob", + "ServiceInstance", + "SystemData", + "SparkJob", + "SparkJobEntry", + "SparkJobEntryType", + "CommandJobLimits", + "ComputeConfiguration", + "CustomModelFineTuningJob", + "CreatedByType", + "ResourceConfiguration", + "JobResources", + "JobResourceConfiguration", + "QueueSettings", + "JobService", + "SshJobService", + "TensorBoardJobService", + "VsCodeJobService", + "JupyterLabJobService", + "SparkResourceConfiguration", + "ParameterizedCommand", + "InputPort", + "BatchEndpoint", + "OnlineEndpoint", + "Deployment", + "BatchDeployment", + "BatchJob", + "CodeConfiguration", + "Endpoint", + "OnlineDeployment", + "Data", + "KubernetesOnlineEndpoint", + "ManagedOnlineEndpoint", + "KubernetesOnlineDeployment", + "ManagedOnlineDeployment", + "OnlineRequestSettings", + "OnlineScaleSettings", + "ProbeSettings", + "BatchRetrySettings", + "RetrySettings", + "ParallelTask", + "DefaultScaleSettings", + "TargetUtilizationScaleSettings", + "Asset", + "Environment", + "BuildContext", + "Model", + "ModelBatchDeployment", + "ModelBatchDeploymentSettings", + "IPRule", + "DefaultActionType", + "NetworkAcls", + "Workspace", + "WorkspaceKeys", + "WorkspaceConnection", + "AzureBlobStoreConnection", + "MicrosoftOneLakeConnection", + "AzureOpenAIConnection", + "AzureAIServicesConnection", + "AzureAISearchConnection", + "AzureContentSafetyConnection", + "AzureSpeechServicesConnection", + "APIKeyConnection", + "OpenAIConnection", + "SerpConnection", + "ServerlessConnection", + "DiagnoseRequestProperties", + "DiagnoseResult", + "DiagnoseResponseResult", + "DiagnoseResponseResultValue", + "DiagnoseWorkspaceParameters", + "PrivateEndpoint", + "OutboundRule", + "ManagedNetwork", + "FqdnDestination", + "ServiceTagDestination", + "PrivateEndpointDestination", + "IsolationMode", + "ManagedNetworkProvisionStatus", + "EndpointConnection", + "CustomerManagedKey", + "DataImport", + "Datastore", + "AzureDataLakeGen1Datastore", + "AzureBlobDatastore", + "AzureDataLakeGen2Datastore", + "AzureFileDatastore", + "OneLakeDatastore", + "OneLakeArtifact", + "OneLakeConnectionArtifact", + "Compute", + "VirtualMachineCompute", + "AmlCompute", + "ComputeInstance", + "UnsupportedCompute", + "KubernetesCompute", + "NetworkSettings", + "Component", + "PipelineJobSettings", + "PipelineComponentBatchDeployment", + "ParallelComponent", + "CommandComponent", + "SparkComponent", + "ResourceRequirementsSettings", + "ResourceSettings", + "AssignedUserConfiguration", + "ComputeInstanceSshSettings", + "VmSize", + "Usage", + "UsageName", + "UsageUnit", + "CronTrigger", + "RecurrenceTrigger", + "RecurrencePattern", + "JobSchedule", + "ImportDataSchedule", + "Schedule", + "ScheduleTriggerResult", + "ComputePowerAction", + "ComputeSchedules", + "ComputeStartStopSchedule", + "ScheduleState", + "PipelineComponent", + "VirtualMachineSshSettings", + "AmlComputeSshSettings", + "AmlComputeNodeInfo", + "ImageMetadata", + "CustomApplications", + "ImageSettings", + "EndpointsSettings", + "VolumeSettings", + "SetupScripts", + "ScriptReference", + "SystemCreatedAcrAccount", + "SystemCreatedStorageAccount", + "ValidationResult", + "RegistryRegionDetails", + "Registry", + "SynapseSparkCompute", + "AutoScaleSettings", + "AutoPauseSettings", + "WorkspaceModelReference", + "Hub", + "Project", + "CapabilityHost", + "CapabilityHostKind", + "Feature", + "FeatureSet", + "FeatureSetBackfillRequest", + "ComputeRuntime", + "FeatureStoreSettings", + "FeatureStoreEntity", + "DataColumn", + "DataColumnType", + "FeatureSetSpecification", + "MaterializationComputeResource", + "FeatureWindow", + "MaterializationSettings", + "MaterializationType", + "FeatureStore", + "MaterializationStore", + "Notification", + "FeatureSetBackfillMetadata", + "DataAvailabilityStatus", + "FeatureSetMaterializationMetadata", + "ServerlessComputeSettings", + # builders + "Command", + "Parallel", + "Sweep", + "Spark", + "Pipeline", + "PatTokenConfiguration", + "SasTokenConfiguration", + "ManagedIdentityConfiguration", + "AccountKeyConfiguration", + "ServicePrincipalConfiguration", + "CertificateConfiguration", + "UsernamePasswordConfiguration", + "UserIdentityConfiguration", + "AmlTokenConfiguration", + "IdentityConfiguration", + "NotebookAccessKeys", + "ContainerRegistryCredential", + "EndpointAuthKeys", + "EndpointAuthToken", + "EndpointAadToken", + "ModelPackage", + "ModelPackageInput", + "AzureMLOnlineInferencingServer", + "AzureMLBatchInferencingServer", + "TritonInferencingServer", + "CustomInferencingServer", + "ModelConfiguration", + "BaseEnvironment", + "PackageInputPathId", + "PackageInputPathUrl", + "PackageInputPathVersion", + "Route", + "AccessKeyConfiguration", + "AlertNotification", + "ServerlessSparkCompute", + "ApiKeyConfiguration", + "MonitorDefinition", + "MonitorInputData", + "MonitorSchedule", + "DataDriftSignal", + "DataQualitySignal", + "PredictionDriftSignal", + "FeatureAttributionDriftSignal", + "CustomMonitoringSignal", + "GenerationSafetyQualitySignal", + "GenerationTokenStatisticsSignal", + "ModelPerformanceSignal", + "MonitorFeatureFilter", + "DataSegment", + "FADProductionData", + "LlmData", + "ProductionData", + "ReferenceData", + "BaselineDataRange", + "MonitoringTarget", + "FixedInputData", + "StaticInputData", + "TrailingInputData", + "DataDriftMetricThreshold", + "DataQualityMetricThreshold", + "PredictionDriftMetricThreshold", + "FeatureAttributionDriftMetricThreshold", + "CustomMonitoringMetricThreshold", + "GenerationSafetyQualityMonitoringMetricThreshold", + "GenerationTokenStatisticsMonitorMetricThreshold", + "CategoricalDriftMetrics", + "NumericalDriftMetrics", + "DataQualityMetricsNumerical", + "DataQualityMetricsCategorical", + "ModelPerformanceMetricThreshold", + "ModelPerformanceClassificationThresholds", + "ModelPerformanceRegressionThresholds", + "DataCollector", + "IntellectualProperty", + "DataAsset", + "DeploymentCollection", + "RequestLogging", + "NoneCredentialConfiguration", + "MarketplacePlan", + "MarketplaceSubscription", + "ServerlessEndpoint", + "AccountKeyConfiguration", + "AadCredentialConfiguration", + "Index", + "AzureOpenAIDeployment", + "AzureAISearchConfig", + "IndexDataSource", + "GitSource", + "LocalSource", + "IndexModelConfiguration", +] + +# Allow importing these types for backwards compatibility + + +def __getattr__(name: str): + requested: Optional[Any] = None + + if name == "Choice": + from ..sweep import Choice + + requested = Choice + if name == "LogNormal": + from ..sweep import LogNormal + + requested = LogNormal + if name == "LogUniform": + from ..sweep import LogUniform + + requested = LogUniform + if name == "Normal": + from ..sweep import Normal + + requested = Normal + if name == "QLogNormal": + from ..sweep import QLogNormal + + requested = QLogNormal + if name == "QLogUniform": + from ..sweep import QLogUniform + + requested = QLogUniform + if name == "QNormal": + from ..sweep import QNormal + + requested = QNormal + if name == "QUniform": + from ..sweep import QUniform + + requested = QUniform + if name == "Randint": + from ..sweep import Randint + + requested = Randint + if name == "Uniform": + from ..sweep import Uniform + + requested = Uniform + + if requested: + if not getattr(__getattr__, "warning_issued", False): + logging.warning( + " %s will be removed from the azure.ai.ml.entities namespace in a future release." + " Please import from the azure.ai.ml.sweep namespace instead.", + name, + ) + __getattr__.warning_issued = True # type: ignore[attr-defined] + return requested + + raise AttributeError(f"module 'azure.ai.ml.entities' has no attribute {name}") diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/__init__.py new file mode 100644 index 00000000..5ee0f971 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/__init__.py @@ -0,0 +1,17 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) + + +from ._artifacts.artifact import Artifact +from ._artifacts.code import Code +from ._artifacts.data import Data +from ._artifacts.index import Index +from ._artifacts.model import Model +from .environment import Environment +from ._artifacts._package.model_package import ModelPackage +from .workspace_asset_reference import WorkspaceAssetReference + +__all__ = ["Artifact", "Model", "Code", "Data", "Index", "Environment", "WorkspaceAssetReference", "ModelPackage"] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/_package/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/_package/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/_package/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/_package/base_environment_source.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/_package/base_environment_source.py new file mode 100644 index 00000000..1be67144 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/_package/base_environment_source.py @@ -0,0 +1,48 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=redefined-builtin + +from typing import Dict, Optional + +from azure.ai.ml._restclient.v2023_08_01_preview.models import BaseEnvironmentId as RestBaseEnvironmentId +from azure.ai.ml._schema.assets.package.base_environment_source import BaseEnvironmentSourceSchema +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY + + +@experimental +class BaseEnvironment: + """Base environment type. + + All required parameters must be populated in order to send to Azure. + + :param type: The type of the base environment. + :type type: str + :param resource_id: The resource id of the base environment. e.g. azureml:name:version + :type resource_id: str + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START base_env_entity_create] + :end-before: [END base_env_entity_create] + :language: python + :dedent: 8 + :caption: Create a Base Environment object. + """ + + def __init__(self, type: str, resource_id: Optional[str] = None): + self.type = type + self.resource_id = resource_id + + @classmethod + def _from_rest_object(cls, rest_obj: RestBaseEnvironmentId) -> "RestBaseEnvironmentId": + return BaseEnvironment(type=rest_obj.base_environment_source_type, resource_id=rest_obj.resource_id) + + def _to_dict(self) -> Dict: + return dict(BaseEnvironmentSourceSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)) + + def _to_rest_object(self) -> RestBaseEnvironmentId: + return RestBaseEnvironmentId(base_environment_source_type=self.type, resource_id=self.resource_id) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/_package/inferencing_server.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/_package/inferencing_server.py new file mode 100644 index 00000000..6e685244 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/_package/inferencing_server.py @@ -0,0 +1,216 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access,unused-argument + +from typing import Any, Optional + +from azure.ai.ml._restclient.v2023_02_01_preview.models import ( + AzureMLOnlineInferencingServer as RestAzureMLOnlineInferencingServer, +) +from azure.ai.ml._restclient.v2023_02_01_preview.models import CustomInferencingServer as RestCustomInferencingServer +from azure.ai.ml._restclient.v2023_02_01_preview.models import ( + OnlineInferenceConfiguration as RestOnlineInferenceConfiguration, +) +from azure.ai.ml._restclient.v2023_02_01_preview.models import Route as RestRoute +from azure.ai.ml._restclient.v2023_02_01_preview.models import TritonInferencingServer as RestTritonInferencingServer +from azure.ai.ml._restclient.v2023_08_01_preview.models import ( + AzureMLBatchInferencingServer as RestAzureMLBatchInferencingServer, +) +from azure.ai.ml._restclient.v2023_08_01_preview.models import ( + AzureMLOnlineInferencingServer as RestAzureMLOnlineInferencingServer, +) +from azure.ai.ml._utils._experimental import experimental + +from ...._deployment.code_configuration import CodeConfiguration + + +@experimental +class AzureMLOnlineInferencingServer: + """Azure ML online inferencing configurations. + + :param code_configuration: The code configuration of the inferencing server. + :type code_configuration: str + :ivar type: The type of the inferencing server. + """ + + def __init__(self, *, code_configuration: Optional[CodeConfiguration] = None, **kwargs: Any): + self.type = "azureml_online" + self.code_configuration = code_configuration + + @classmethod + def _from_rest_object(cls, rest_obj: RestAzureMLOnlineInferencingServer) -> "RestAzureMLOnlineInferencingServer": + return AzureMLOnlineInferencingServer(type=rest_obj.server_type, code_configuration=rest_obj.code_configuration) + + def _to_rest_object(self) -> RestAzureMLOnlineInferencingServer: + return RestAzureMLOnlineInferencingServer(server_type=self.type, code_configuration=self.code_configuration) + + +@experimental +class AzureMLBatchInferencingServer: + """Azure ML batch inferencing configurations. + + :param code_configuration: The code configuration of the inferencing server. + :type code_configuration: azure.ai.ml.entities.CodeConfiguration + :ivar type: The type of the inferencing server. + """ + + def __init__(self, *, code_configuration: Optional[CodeConfiguration] = None, **kwargs: Any): + self.type = "azureml_batch" + self.code_configuration = code_configuration + + @classmethod + def _from_rest_object(cls, rest_obj: RestAzureMLBatchInferencingServer) -> "RestAzureMLBatchInferencingServer": + return AzureMLBatchInferencingServer(code_configuration=rest_obj.code_configuration) + + def _to_rest_object(self) -> RestAzureMLBatchInferencingServer: + return RestAzureMLBatchInferencingServer(server_type=self.type, code_configuration=self.code_configuration) + + +@experimental +class TritonInferencingServer: + """Azure ML triton inferencing configurations. + + :param inference_configuration: The inference configuration of the inferencing server. + :type inference_configuration: azure.ai.ml.entities.CodeConfiguration + :ivar type: The type of the inferencing server. + """ + + def __init__(self, *, inference_configuration: Optional[CodeConfiguration] = None, **kwargs: Any): + self.type = "triton" + self.inference_configuration = inference_configuration + + @classmethod + def _from_rest_object(cls, rest_obj: RestTritonInferencingServer) -> "RestTritonInferencingServer": + return CustomInferencingServer( + type=rest_obj.server_type, inference_configuration=rest_obj.inference_configuration + ) + + def _to_rest_object(self) -> RestTritonInferencingServer: + return RestCustomInferencingServer(server_type=self.type, inference_configuration=self.inference_configuration) + + +@experimental +class Route: + """Route. + + :param port: The port of the route. + :type port: str + :param path: The path of the route. + :type path: str + """ + + def __init__(self, *, port: Optional[str] = None, path: Optional[str] = None): + self.port = port + self.path = path + + @classmethod + def _from_rest_object(cls, rest_obj: RestRoute) -> "RestRoute": + return Route(port=rest_obj.port, path=rest_obj.path) + + def _to_rest_object(self) -> Optional[RestRoute]: + return RestRoute(port=self.port, path=self.path) + + +@experimental +class OnlineInferenceConfiguration: + """Online inference configurations. + + :param liveness_route: The liveness route of the online inference configuration. + :type liveness_route: Route + :param readiness_route: The readiness route of the online inference configuration. + :type readiness_route: Route + :param scoring_route: The scoring route of the online inference configuration. + :type scoring_route: Route + :param entry_script: The entry script of the online inference configuration. + :type entry_script: str + :param configuration: The configuration of the online inference configuration. + :type configuration: dict + """ + + def __init__( + self, + liveness_route: Optional[Route] = None, + readiness_route: Optional[Route] = None, + scoring_route: Optional[Route] = None, + entry_script: Optional[str] = None, + configuration: Optional[dict] = None, + ): + self.liveness_route = liveness_route + self.readiness_route = readiness_route + self.scoring_route = scoring_route + self.entry_script = entry_script + self.configuration = configuration + + @classmethod + def _from_rest_object(cls, rest_obj: RestOnlineInferenceConfiguration) -> "RestOnlineInferenceConfiguration": + return OnlineInferenceConfiguration( + liveness_route=Route._from_rest_object(rest_obj.liveness_route), + readiness_route=Route._from_rest_object(rest_obj.readiness_route), + scoring_route=Route._from_rest_object(rest_obj.scoring_route), + entry_script=rest_obj.entry_script, + configuration=rest_obj.configuration, + ) + + def _to_rest_object(self) -> RestOnlineInferenceConfiguration: + if self.liveness_route is not None and self.readiness_route is not None and self.scoring_route is not None: + return RestOnlineInferenceConfiguration( + liveness_route=self.liveness_route._to_rest_object(), + readiness_route=self.readiness_route._to_rest_object(), + scoring_route=self.scoring_route._to_rest_object(), + entry_script=self.entry_script, + configuration=self.configuration, + ) + + if self.liveness_route is None: + return RestOnlineInferenceConfiguration( + readiness_route=self.readiness_route._to_rest_object() if self.readiness_route is not None else None, + scoring_route=self.scoring_route._to_rest_object() if self.scoring_route is not None else None, + entry_script=self.entry_script, + configuration=self.configuration, + ) + + if self.readiness_route is None: + return RestOnlineInferenceConfiguration( + liveness_route=self.liveness_route._to_rest_object(), + scoring_route=self.scoring_route._to_rest_object() if self.scoring_route is not None else None, + entry_script=self.entry_script, + configuration=self.configuration, + ) + + if self.scoring_route is None: + return RestOnlineInferenceConfiguration( + liveness_route=self.liveness_route._to_rest_object(), + readiness_route=self.readiness_route._to_rest_object(), + entry_script=self.entry_script, + configuration=self.configuration, + ) + + return RestOnlineInferenceConfiguration( + entry_script=self.entry_script, + configuration=self.configuration, + ) + + +@experimental +class CustomInferencingServer: + """Custom inferencing configurations. + + :param inference_configuration: The inference configuration of the inferencing server. + :type inference_configuration: OnlineInferenceConfiguration + :ivar type: The type of the inferencing server. + """ + + def __init__(self, *, inference_configuration: Optional[OnlineInferenceConfiguration] = None, **kwargs: Any): + self.type = "custom" + self.inference_configuration = inference_configuration + + @classmethod + def _from_rest_object(cls, rest_obj: RestCustomInferencingServer) -> "RestCustomInferencingServer": + return CustomInferencingServer( + type=rest_obj.server_type, inference_configuration=rest_obj.inference_configuration + ) + + def _to_rest_object(self) -> RestCustomInferencingServer: + return RestCustomInferencingServer(server_type=self.type, inference_configuration=self.inference_configuration) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/_package/model_configuration.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/_package/model_configuration.py new file mode 100644 index 00000000..73c777cf --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/_package/model_configuration.py @@ -0,0 +1,55 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# ---------------------------------------------------------- + + +from typing import Optional + +from azure.ai.ml._exception_helper import log_and_raise_error +from azure.ai.ml._restclient.v2023_04_01_preview.models import ModelConfiguration as RestModelConfiguration +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + + +@experimental +class ModelConfiguration: + """ModelConfiguration. + + :param mode: The mode of the model. Possible values include: "Copy", "Download". + :type mode: str + :param mount_path: The mount path of the model. + :type mount_path: str + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START model_configuration_entity_create] + :end-before: [END model_configuration_entity_create] + :language: python + :dedent: 8 + :caption: Creating a Model Configuration object. + """ + + def __init__(self, *, mode: Optional[str] = None, mount_path: Optional[str] = None): + self.mode = mode + self.mount_path = mount_path + + @classmethod + def _from_rest_object(cls, rest_obj: RestModelConfiguration) -> "ModelConfiguration": + return ModelConfiguration(mode=rest_obj.mode, mount_path=rest_obj.mount_path) + + def _to_rest_object(self) -> RestModelConfiguration: + self._validate() + return RestModelConfiguration(mode=self.mode, mount_path=self.mount_path) + + def _validate(self) -> None: + if self.mode is not None and self.mode.lower() not in ["copy", "download"]: + msg = "Mode must be either 'Copy' or 'Download'" + err = ValidationException( + message=msg, + target=ErrorTarget.MODEL, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + log_and_raise_error(err) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/_package/model_package.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/_package/model_package.py new file mode 100644 index 00000000..c4797c20 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/_package/model_package.py @@ -0,0 +1,338 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access, redefined-builtin + +import re +from os import PathLike +from pathlib import Path +from typing import IO, Any, AnyStr, Dict, List, Optional, Union + +from azure.ai.ml._restclient.v2023_08_01_preview.models import CodeConfiguration +from azure.ai.ml._restclient.v2023_08_01_preview.models import ModelPackageInput as RestModelPackageInput +from azure.ai.ml._restclient.v2023_08_01_preview.models import PackageInputPathId as RestPackageInputPathId +from azure.ai.ml._restclient.v2023_08_01_preview.models import PackageInputPathUrl as RestPackageInputPathUrl +from azure.ai.ml._restclient.v2023_08_01_preview.models import PackageInputPathVersion as RestPackageInputPathVersion +from azure.ai.ml._restclient.v2023_08_01_preview.models import PackageRequest, PackageResponse +from azure.ai.ml._schema.assets.package.model_package import ModelPackageSchema +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml._utils.utils import dump_yaml_to_file, snake_to_pascal +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY +from azure.ai.ml.entities._resource import Resource +from azure.ai.ml.entities._util import load_from_dict + +from .base_environment_source import BaseEnvironment +from .inferencing_server import AzureMLBatchInferencingServer, AzureMLOnlineInferencingServer +from .model_configuration import ModelConfiguration + + +@experimental +class PackageInputPathId: + """Package input path specified with a resource ID. + + :param input_path_type: The type of the input path. Accepted values are "Url", "PathId", and "PathVersion". + :type input_path_type: Optional[str] + :param resource_id: The resource ID of the input path. e.g. "azureml://subscriptions/<>/resourceGroups/ + <>/providers/Microsoft.MachineLearningServices/workspaces/<>/data/<>/versions/<>". + :type resource_id: Optional[str] + """ + + def __init__( + self, + *, + input_path_type: Optional[str] = None, + resource_id: Optional[str] = None, + ) -> None: + self.input_path_type = input_path_type + self.resource_id = resource_id + + def _to_rest_object(self) -> RestPackageInputPathId: + return RestPackageInputPathId( + input_path_type=self.input_path_type, + resource_id=self.resource_id, + ) + + @classmethod + def _from_rest_object(cls, package_input_path_id_rest_object: RestPackageInputPathId) -> "PackageInputPathId": + return PackageInputPathId( + input_path_type=package_input_path_id_rest_object.input_path_type, + resource_id=package_input_path_id_rest_object.resource_id, + ) + + +@experimental +class PackageInputPathVersion: + """Package input path specified with a resource name and version. + + :param input_path_type: The type of the input path. Accepted values are "Url", "PathId", and "PathVersion". + :type input_path_type: Optional[str] + :param resource_name: The resource name of the input path. + :type resource_name: Optional[str] + :param resource_version: The resource version of the input path. + :type resource_version: Optional[str] + """ + + def __init__( + self, + *, + input_path_type: Optional[str] = None, + resource_name: Optional[str] = None, + resource_version: Optional[str] = None, + ) -> None: + self.input_path_type = input_path_type + self.resource_name = resource_name + self.resource_version = resource_version + + def _to_rest_object(self) -> RestPackageInputPathVersion: + return RestPackageInputPathVersion( + input_path_type=self.input_path_type, + resource_name=self.resource_name, + resource_version=self.resource_version, + ) + + @classmethod + def _from_rest_object( + cls, package_input_path_version_rest_object: RestPackageInputPathVersion + ) -> "PackageInputPathVersion": + return PackageInputPathVersion( + input_path_type=package_input_path_version_rest_object.input_path_type, + resource_name=package_input_path_version_rest_object.resource_name, + resource_version=package_input_path_version_rest_object.resource_version, + ) + + +@experimental +class PackageInputPathUrl: + """Package input path specified with a url. + + :param input_path_type: The type of the input path. Accepted values are "Url", "PathId", and "PathVersion". + :type input_path_type: Optional[str] + :param url: The url of the input path. e.g. "azureml://subscriptions/<>/resourceGroups/ + <>/providers/Microsoft.MachineLearningServices/workspaces/data/<>/versions/<>". + :type url: Optional[str] + """ + + def __init__(self, *, input_path_type: Optional[str] = None, url: Optional[str] = None) -> None: + self.input_path_type = input_path_type + self.url = url + + def _to_rest_object(self) -> RestPackageInputPathUrl: + return RestPackageInputPathUrl( + input_path_type=self.input_path_type, + url=self.url, + ) + + @classmethod + def _from_rest_object(cls, package_input_path_url_rest_object: RestPackageInputPathUrl) -> "PackageInputPathUrl": + return PackageInputPathUrl( + input_path_type=package_input_path_url_rest_object.input_path_type, + url=package_input_path_url_rest_object.url, + ) + + +@experimental +class ModelPackageInput: + """Model package input. + + :param type: The type of the input. + :type type: Optional[str] + :param path: The path of the input. + :type path: Optional[Union[~azure.ai.ml.entities.PackageInputPathId, ~azure.ai.ml.entities.PackageInputPathUrl, + ~azure.ai.ml.entities.PackageInputPathVersion]] + :param mode: The input mode. + :type mode: Optional[str] + :param mount_path: The mount path for the input. + :type mount_path: Optional[str] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START model_package_input_entity_create] + :end-before: [END model_package_input_entity_create] + :language: python + :dedent: 8 + :caption: Create a Model Package Input object. + """ + + def __init__( + self, + *, + type: Optional[str] = None, + path: Optional[Union[PackageInputPathId, PackageInputPathUrl, PackageInputPathVersion]] = None, + mode: Optional[str] = None, + mount_path: Optional[str] = None, + ) -> None: + self.type = type + self.path = path + self.mode = mode + self.mount_path = mount_path + + def _to_rest_object(self) -> RestModelPackageInput: + if self.path is None: + return RestModelPackageInput( + input_type=snake_to_pascal(self.type), + path=None, + mode=snake_to_pascal(self.mode), + mount_path=self.mount_path, + ) + return RestModelPackageInput( + input_type=snake_to_pascal(self.type), + path=self.path._to_rest_object(), + mode=snake_to_pascal(self.mode), + mount_path=self.mount_path, + ) + + @classmethod + def _from_rest_object(cls, model_package_input_rest_object: RestModelPackageInput) -> "ModelPackageInput": + return ModelPackageInput( + type=model_package_input_rest_object.input_type, + path=model_package_input_rest_object.path._from_rest_object(), + mode=model_package_input_rest_object.mode, + mount_path=model_package_input_rest_object.mount_path, + ) + + +@experimental +class ModelPackage(Resource, PackageRequest): + """Model package. + + :param target_environment_name: The target environment name for the model package. + :type target_environment_name: str + :param inferencing_server: The inferencing server of the model package. + :type inferencing_server: Union[~azure.ai.ml.entities.AzureMLOnlineInferencingServer, + ~azure.ai.ml.entities.AzureMLBatchInferencingServer] + :param base_environment_source: The base environment source of the model package. + :type base_environment_source: Optional[~azure.ai.ml.entities.BaseEnvironment] + :param target_environment_version: The version of the model package. + :type target_environment_version: Optional[str] + :param environment_variables: The environment variables of the model package. + :type environment_variables: Optional[dict[str, str]] + :param inputs: The inputs of the model package. + :type inputs: Optional[list[~azure.ai.ml.entities.ModelPackageInput]] + :param model_configuration: The model configuration. + :type model_configuration: Optional[~azure.ai.ml.entities.ModelConfiguration] + :param tags: The tags of the model package. + :type tags: Optional[dict[str, str]] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START model_package_entity_create] + :end-before: [END model_package_entity_create] + :language: python + :dedent: 8 + :caption: Create a Model Package object. + """ + + def __init__( + self, + *, + target_environment: Union[str, Dict[str, str]], + inferencing_server: Union[AzureMLOnlineInferencingServer, AzureMLBatchInferencingServer], + base_environment_source: Optional[BaseEnvironment] = None, + environment_variables: Optional[Dict[str, str]] = None, + inputs: Optional[List[ModelPackageInput]] = None, + model_configuration: Optional[ModelConfiguration] = None, + tags: Optional[Dict[str, str]] = None, + **kwargs: Any, + ): + if isinstance(target_environment, dict): + target_environment = target_environment["name"] + env_version = None + else: + parse_id = re.match(r"azureml:(\w+):(\d+)$", target_environment) + + if parse_id: + target_environment = parse_id.group(1) + env_version = parse_id.group(2) + else: + env_version = None + + super().__init__( + name=target_environment, + target_environment_id=target_environment, + base_environment_source=base_environment_source, + inferencing_server=inferencing_server, + model_configuration=model_configuration, + inputs=inputs, + tags=tags, + environment_variables=environment_variables, + ) + self.environment_version = env_version + + @classmethod + def _load( + cls, + data: Optional[Dict] = None, + yaml_path: Optional[Union[PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "ModelPackage": + params_override = params_override or [] + data = data or {} + context = { + BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"), + PARAMS_OVERRIDE_KEY: params_override, + } + res: ModelPackage = load_from_dict(ModelPackageSchema, data, context, **kwargs) + return res + + def dump( + self, + dest: Union[str, PathLike, IO[AnyStr]], + **kwargs: Any, + ) -> None: + """Dumps the job content into a file in YAML format. + + :param dest: The local path or file stream to write the YAML content to. + If dest is a file path, a new file will be created. + If dest is an open file, the file will be written to directly. + :type dest: Union[PathLike, str, IO[AnyStr]] + :raises FileExistsError: Raised if dest is a file path and the file already exists. + :raises IOError: Raised if dest is an open file and the file is not writable. + """ + yaml_serialized = self._to_dict() + dump_yaml_to_file(dest, yaml_serialized, default_flow_style=False) + + def _to_dict(self) -> Dict: + return dict(ModelPackageSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)) + + @classmethod + def _from_rest_object(cls, model_package_rest_object: PackageResponse) -> Any: + target_environment_id = model_package_rest_object.target_environment_id + return target_environment_id + + def _to_rest_object(self) -> PackageRequest: + code = None + + if ( + self.inferencing_server + and hasattr(self.inferencing_server, "code_configuration") + and self.inferencing_server.code_configuration + ): + self.inferencing_server.code_configuration._validate() + code_id = ( + self.inferencing_server.code_configuration.code + if isinstance(self.inferencing_server.code_configuration.code, str) + else self.inferencing_server.code_configuration.code.id + ) + code = CodeConfiguration( + code_id=code_id, + scoring_script=self.inferencing_server.code_configuration.scoring_script, + ) + self.inferencing_server.code_configuration = code + + package_request = PackageRequest( + target_environment_id=self.target_environment_id, + base_environment_source=( + self.base_environment_source._to_rest_object() if self.base_environment_source else None + ), + inferencing_server=self.inferencing_server._to_rest_object() if self.inferencing_server else None, + model_configuration=self.model_configuration._to_rest_object() if self.model_configuration else None, + inputs=[input._to_rest_object() for input in self.inputs] if self.inputs else None, + tags=self.tags, + environment_variables=self.environment_variables, + ) + + return package_request diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/artifact.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/artifact.py new file mode 100644 index 00000000..f82e2aa0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/artifact.py @@ -0,0 +1,131 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from abc import abstractmethod +from os import PathLike +from pathlib import Path, PurePosixPath +from typing import Any, Dict, Optional, Union +from urllib.parse import urljoin + +from azure.ai.ml._utils.utils import is_mlflow_uri, is_url +from azure.ai.ml.entities._assets.asset import Asset + + +class ArtifactStorageInfo: + def __init__( + self, + name: str, + version: str, + relative_path: str, + datastore_arm_id: Optional[str], + container_name: str, + storage_account_url: Optional[str] = None, + is_file: Optional[bool] = None, + indicator_file: Optional[str] = None, + ): + self.name = name + self.version = version + self.relative_path = relative_path + self.datastore_arm_id = datastore_arm_id + self.container_name = container_name + self.storage_account_url = storage_account_url + self.is_file = is_file + self.indicator_file = indicator_file + + @property + def full_storage_path(self) -> Optional[str]: + if self.storage_account_url is None: + return f"{self.container_name}/{self.relative_path}" + return urljoin(self.storage_account_url, f"{self.container_name}/{self.relative_path}") + + @property + def subdir_path(self) -> Optional[str]: + if self.is_file: + path = PurePosixPath(self.relative_path).parent + if self.storage_account_url is None: + return f"{self.container_name}/{path}" + return urljoin(self.storage_account_url, f"{self.container_name}/{path}") + return self.full_storage_path + + +class Artifact(Asset): + """Base class for artifact, can't be instantiated directly. + + :param name: Name of the resource. + :type name: str + :param version: Version of the resource. + :type version: str + :param path: The local or remote path to the asset. + :type path: Union[str, os.PathLike] + :param description: Description of the resource. + :type description: str + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: dict[str, str] + :param properties: The asset property dictionary. + :type properties: dict[str, str] + :param datastore: The datastore to upload the local artifact to. + :type datastore: str + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + """ + + def __init__( + self, + name: Optional[str] = None, + version: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[Dict] = None, + properties: Optional[Dict] = None, + path: Optional[Union[str, PathLike]] = None, + datastore: Optional[str] = None, + **kwargs: Any, + ): + super().__init__( + name=name, + version=version, + description=description, + tags=tags, + properties=properties, + **kwargs, + ) + self.path = path + self.datastore = datastore + + @property + def path(self) -> Optional[Union[str, PathLike]]: + return self._path + + @path.setter + def path(self, value: Optional[Union[str, PathLike]]) -> None: + if not value or is_url(value) or Path(value).is_absolute() or is_mlflow_uri(value): + self._path = value + else: + self._path = Path(self.base_path, value).resolve() + + @abstractmethod + def _to_dict(self) -> Dict: + pass + + def __eq__(self, other: Any) -> bool: + return ( + type(self) == type(other) # pylint: disable = unidiomatic-typecheck + and self.name == other.name + and self.id == other.id + and self.version == other.version + and self.description == other.description + and self.tags == other.tags + and self.properties == other.properties + and self.base_path == other.base_path + and self._is_anonymous == other._is_anonymous + ) + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) + + @abstractmethod + def _update_path(self, asset_artifact: ArtifactStorageInfo) -> None: + """Updates an an artifact with the remote path of a local upload. + + :param asset_artifact: The asset storage info of the artifact + :type asset_artifact: ArtifactStorageInfo + """ diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/code.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/code.py new file mode 100644 index 00000000..b08149ab --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/code.py @@ -0,0 +1,142 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import os +from os import PathLike +from pathlib import Path +from typing import Any, Dict, Optional, Union + +from azure.ai.ml._restclient.v2022_05_01.models import CodeVersionData, CodeVersionDetails +from azure.ai.ml._schema import CodeAssetSchema +from azure.ai.ml._utils._arm_id_utils import AMLVersionedArmId +from azure.ai.ml._utils._asset_utils import IgnoreFile, get_content_hash, get_content_hash_version, get_ignore_file +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY, ArmConstants +from azure.ai.ml.entities._assets import Artifact +from azure.ai.ml.entities._system_data import SystemData +from azure.ai.ml.entities._util import load_from_dict + +from .artifact import ArtifactStorageInfo + + +class Code(Artifact): + """Code for training and scoring. + + :param name: Name of the resource. + :type name: str + :param version: Version of the resource. + :type version: str + :param path: A local path or a remote uri. A datastore remote uri example is like, + "azureml://subscriptions/{}/resourcegroups/{}/workspaces/{}/datastores/{}/paths/path_on_datastore/" + :type path: str + :param description: Description of the resource. + :type description: str + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: dict[str, str] + :param properties: The asset property dictionary. + :type properties: dict[str, str] + :param ignore_file: Ignore file for the resource. + :type ignore_file: IgnoreFile + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + """ + + def __init__( + self, + *, + name: Optional[str] = None, + version: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[Dict] = None, + properties: Optional[Dict] = None, + path: Optional[Union[str, PathLike]] = None, + ignore_file: Optional[IgnoreFile] = None, + **kwargs: Any, + ): + super().__init__( + name=name, + version=version, + description=description, + tags=tags, + properties=properties, + path=path, + **kwargs, + ) + self._arm_type = ArmConstants.CODE_VERSION_TYPE + if self.path and os.path.isabs(self.path): + # Only calculate hash for local files + self._ignore_file = get_ignore_file(self.path) if ignore_file is None else ignore_file + self._hash_sha256 = get_content_hash(self.path, self._ignore_file) + + @classmethod + def _load( + cls, + data: Optional[Dict] = None, + yaml_path: Optional[Union[PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "Code": + data = data or {} + params_override = params_override or [] + context = { + BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"), + PARAMS_OVERRIDE_KEY: params_override, + } + res: Code = load_from_dict(CodeAssetSchema, data, context, **kwargs) + return res + + def _to_dict(self) -> Dict: + res: dict = CodeAssetSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + @classmethod + def _from_rest_object(cls, code_rest_object: CodeVersionData) -> "Code": + rest_code_version: CodeVersionDetails = code_rest_object.properties + arm_id = AMLVersionedArmId(arm_id=code_rest_object.id) + code = Code( + id=code_rest_object.id, + name=arm_id.asset_name, + version=arm_id.asset_version, + path=rest_code_version.code_uri, + description=rest_code_version.description, + tags=rest_code_version.tags, + properties=rest_code_version.properties, + # pylint: disable=protected-access + creation_context=SystemData._from_rest_object(code_rest_object.system_data), + is_anonymous=rest_code_version.is_anonymous, + ) + return code + + def _to_rest_object(self) -> CodeVersionData: + properties = {} + if hasattr(self, "_hash_sha256"): + properties["hash_sha256"] = self._hash_sha256 + properties["hash_version"] = get_content_hash_version() + code_version = CodeVersionDetails(code_uri=self.path, is_anonymous=self._is_anonymous, properties=properties) + code_version_resource = CodeVersionData(properties=code_version) + + return code_version_resource + + def _update_path(self, asset_artifact: ArtifactStorageInfo) -> None: + """Update an artifact with the remote path of a local upload. + + :param asset_artifact: The asset storage info of the artifact + :type asset_artifact: ArtifactStorageInfo + """ + if asset_artifact.is_file: + # Code paths cannot be pointers to single files. It must be a pointer to a container + # Skipping the setter to avoid being resolved as a local path + self._path = asset_artifact.subdir_path # pylint: disable=attribute-defined-outside-init + else: + self._path = asset_artifact.full_storage_path # pylint: disable=attribute-defined-outside-init + + # pylint: disable=unused-argument + def _to_arm_resource_param(self, **kwargs: Any) -> Dict: + properties = self._to_rest_object().properties + + return { + self._arm_type: { + ArmConstants.NAME: self.name, + ArmConstants.VERSION: self.version, + ArmConstants.PROPERTIES_PARAMETER_NAME: self._serialize.body(properties, "CodeVersionDetails"), + } + } diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/data.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/data.py new file mode 100644 index 00000000..710e959a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/data.py @@ -0,0 +1,237 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import os +import re +from os import PathLike +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +from azure.ai.ml._exception_helper import log_and_raise_error +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + DataContainer, + DataContainerProperties, + DataType, + DataVersionBase, + DataVersionBaseProperties, + MLTableData, + UriFileDataVersion, + UriFolderDataVersion, +) +from azure.ai.ml._schema import DataSchema +from azure.ai.ml._utils._arm_id_utils import get_arm_id_object_from_id +from azure.ai.ml._utils.utils import is_url +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY, SHORT_URI_FORMAT, AssetTypes +from azure.ai.ml.entities._assets import Artifact +from azure.ai.ml.entities._system_data import SystemData +from azure.ai.ml.entities._util import load_from_dict +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + +from .artifact import ArtifactStorageInfo + +DataAssetTypeModelMap: Dict[str, Type[DataVersionBaseProperties]] = { + AssetTypes.URI_FILE: UriFileDataVersion, + AssetTypes.URI_FOLDER: UriFolderDataVersion, + AssetTypes.MLTABLE: MLTableData, +} + + +def getModelForDataAssetType(data_asset_type: str) -> Optional[Type[DataVersionBaseProperties]]: + model = DataAssetTypeModelMap.get(data_asset_type) + if model is None: + msg = "Unknown DataType {}".format(data_asset_type) + err = ValidationException( + message=msg, + no_personal_data_message=msg, + error_type=ValidationErrorType.INVALID_VALUE, + target=ErrorTarget.DATA, + error_category=ErrorCategory.USER_ERROR, + ) + log_and_raise_error(err) + return model + + +DataTypeMap: Dict[DataType, str] = { + DataType.URI_FILE: AssetTypes.URI_FILE, + DataType.URI_FOLDER: AssetTypes.URI_FOLDER, + DataType.MLTABLE: AssetTypes.MLTABLE, +} + + +def getDataAssetType(data_type: DataType) -> str: + return DataTypeMap.get(data_type, data_type) # pass through value if no match found + + +class Data(Artifact): + """Data for training and scoring. + + :param name: Name of the resource. + :type name: str + :param version: Version of the resource. + :type version: str + :param description: Description of the resource. + :type description: str + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: dict[str, str] + :param properties: The asset property dictionary. + :type properties: dict[str, str] + :param path: The path to the asset on the datastore. This can be local or remote + :type path: str + :param type: The type of the asset. Valid values are uri_file, uri_folder, mltable. Defaults to uri_folder. + :type type: Literal[AssetTypes.URI_FILE, AssetTypes.URI_FOLDER, AssetTypes.MLTABLE] + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + """ + + def __init__( + self, + *, + name: Optional[str] = None, + version: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[Dict] = None, + properties: Optional[Dict] = None, + path: Optional[str] = None, # if type is mltable, the path has to be a folder. + type: str = AssetTypes.URI_FOLDER, # pylint: disable=redefined-builtin + **kwargs: Any, + ): + self._path: Optional[Union[Path, str, PathLike]] = None + + self._skip_validation = kwargs.pop("skip_validation", False) + self._mltable_schema_url = kwargs.pop("mltable_schema_url", None) + self._referenced_uris = kwargs.pop("referenced_uris", None) + self.type = type + super().__init__( + name=name, + version=version, + path=path, + description=description, + tags=tags, + properties=properties, + **kwargs, + ) + self.path = path + + @property + def path(self) -> Optional[Union[Path, str, PathLike]]: + return self._path + + @path.setter + def path(self, value: str) -> None: + # Call the parent setter to resolve the path with base_path if it was a local path + # TODO: Bug Item number: 2883424 + super(Data, type(self)).path.fset(self, value) # type: ignore + if self.type == AssetTypes.URI_FOLDER and self._path is not None and not is_url(self._path): + self._path = Path(os.path.join(self._path, "")) + + @classmethod + def _load( + cls, + data: Optional[Dict] = None, + yaml_path: Optional[Union[PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "Data": + data = data or {} + params_override = params_override or [] + context = { + BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"), + PARAMS_OVERRIDE_KEY: params_override, + } + data_asset = Data._load_from_dict(yaml_data=data, context=context, **kwargs) + + return data_asset + + @classmethod + def _load_from_dict(cls, yaml_data: Dict, context: Dict, **kwargs: Any) -> "Data": + return Data(**load_from_dict(DataSchema, yaml_data, context, **kwargs)) + + def _to_dict(self) -> Dict: + res: dict = DataSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + def _to_container_rest_object(self) -> DataContainer: + VersionDetailsClass = getModelForDataAssetType(self.type) + return DataContainer( + properties=DataContainerProperties( + properties=self.properties, + tags=self.tags, + is_archived=False, + data_type=VersionDetailsClass.data_type if VersionDetailsClass is not None else None, + ) + ) + + def _to_rest_object(self) -> Optional[DataVersionBase]: + VersionDetailsClass = getModelForDataAssetType(self.type) + if VersionDetailsClass is not None: + data_version_details = VersionDetailsClass( + description=self.description, + is_anonymous=self._is_anonymous, + tags=self.tags, + is_archived=False, + properties=self.properties, + data_uri=self.path, + auto_delete_setting=self.auto_delete_setting, + ) + if VersionDetailsClass._attribute_map.get("referenced_uris") is not None: + data_version_details.referenced_uris = self._referenced_uris + return DataVersionBase(properties=data_version_details) + + return None + + @classmethod + def _from_container_rest_object(cls, data_container_rest_object: DataContainer) -> "Data": + data_rest_object_details: DataContainerProperties = data_container_rest_object.properties + data = Data( + name=data_container_rest_object.name, + creation_context=SystemData._from_rest_object(data_container_rest_object.system_data), + tags=data_rest_object_details.tags, + properties=data_rest_object_details.properties, + type=getDataAssetType(data_rest_object_details.data_type), + ) + data.latest_version = data_rest_object_details.latest_version + return data + + @classmethod + def _from_rest_object(cls, data_rest_object: DataVersionBase) -> "Data": + data_rest_object_details: DataVersionBaseProperties = data_rest_object.properties + arm_id_object = get_arm_id_object_from_id(data_rest_object.id) + path = data_rest_object_details.data_uri + data = Data( + id=data_rest_object.id, + name=arm_id_object.asset_name, + version=arm_id_object.asset_version, + path=path, + type=getDataAssetType(data_rest_object_details.data_type), + description=data_rest_object_details.description, + tags=data_rest_object_details.tags, + properties=data_rest_object_details.properties, + creation_context=SystemData._from_rest_object(data_rest_object.system_data), + is_anonymous=data_rest_object_details.is_anonymous, + referenced_uris=getattr(data_rest_object_details, "referenced_uris", None), + auto_delete_setting=getattr(data_rest_object_details, "auto_delete_setting", None), + ) + return data + + @classmethod + def _resolve_cls_and_type(cls, data: Dict, params_override: Optional[List[Dict]] = None) -> Tuple: + from azure.ai.ml.entities._data_import.data_import import DataImport + + if "source" in data: + return DataImport, None + + return cls, None + + def _update_path(self, asset_artifact: ArtifactStorageInfo) -> None: + regex = r"datastores\/(.+)" + # datastore_arm_id is null for registry scenario, so capture the full_storage_path + if not asset_artifact.datastore_arm_id and asset_artifact.full_storage_path: + self.path = asset_artifact.full_storage_path + else: + groups = re.search(regex, asset_artifact.datastore_arm_id) # type: ignore + if groups: + datastore_name = groups.group(1) + self.path = SHORT_URI_FORMAT.format(datastore_name, asset_artifact.relative_path) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/feature_set.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/feature_set.py new file mode 100644 index 00000000..a5bb73fe --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/feature_set.py @@ -0,0 +1,220 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from os import PathLike +from pathlib import Path +from typing import IO, Any, AnyStr, Dict, List, Optional, Union, cast + +from azure.ai.ml._restclient.v2023_10_01.models import ( + FeaturesetContainer, + FeaturesetContainerProperties, + FeaturesetVersion, + FeaturesetVersionProperties, +) +from azure.ai.ml._schema._feature_set.feature_set_schema import FeatureSetSchema +from azure.ai.ml._utils._arm_id_utils import AMLNamedArmId, get_arm_id_object_from_id +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, LONG_URI_FORMAT, PARAMS_OVERRIDE_KEY +from azure.ai.ml.entities._assets import Artifact +from azure.ai.ml.entities._feature_set.feature_set_specification import FeatureSetSpecification +from azure.ai.ml.entities._feature_set.materialization_settings import MaterializationSettings +from azure.ai.ml.entities._util import load_from_dict +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + +from .artifact import ArtifactStorageInfo + + +class FeatureSet(Artifact): + """Feature Set + + :param name: The name of the Feature Set resource. + :type name: str + :param version: The version of the Feature Set resource. + :type version: str + :param entities: Specifies list of entities. + :type entities: list[str] + :param specification: Specifies the feature set spec details. + :type specification: ~azure.ai.ml.entities.FeatureSetSpecification + :param stage: Feature set stage. Allowed values: Development, Production, Archived. Defatuls to Development. + :type stage: Optional[str] + :param description: The description of the Feature Set resource. Defaults to None. + :type description: Optional[str] + :param tags: Tag dictionary. Tags can be added, removed, and updated. Defaults to None. + :type tags: Optional[dict[str, str]] + :param materialization_settings: Specifies the materialization settings. Defaults to None. + :type materialization_settings: Optional[~azure.ai.ml.entities.MaterializationSettings] + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + :raises ValidationException: Raised if stage is specified and is not valid. + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_featurestore.py + :start-after: [START configure_feature_set] + :end-before: [END configure_feature_set] + :language: Python + :dedent: 8 + :caption: Instantiating a Feature Set object + """ + + def __init__( + self, + *, + name: str, + version: str, + entities: List[str], + specification: Optional[FeatureSetSpecification], + stage: Optional[str] = "Development", + description: Optional[str] = None, + materialization_settings: Optional[MaterializationSettings] = None, + tags: Optional[Dict] = None, + **kwargs: Any, + ) -> None: + super().__init__( + name=name, + version=version, + description=description, + tags=tags, + path=specification.path if specification is not None else None, + **kwargs, + ) + if stage and stage not in ["Development", "Production", "Archived"]: + msg = f"Stage must be Development, Production, or Archived, found {stage}" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + error_type=ValidationErrorType.INVALID_VALUE, + target=ErrorTarget.FEATURE_SET, + error_category=ErrorCategory.USER_ERROR, + ) + self.entities = entities + self.specification = specification + self.stage = stage + self.materialization_settings = materialization_settings + self.latest_version = None + + def _to_rest_object(self) -> FeaturesetVersion: + featureset_version_properties = FeaturesetVersionProperties( + description=self.description, + properties=self.properties, + tags=self.tags, + entities=self.entities, + materialization_settings=( + self.materialization_settings._to_rest_object() if self.materialization_settings else None + ), + specification=self.specification._to_rest_object() if self.specification is not None else None, + stage=self.stage, + ) + return FeaturesetVersion(name=self.name, properties=featureset_version_properties) + + @classmethod + def _from_rest_object(cls, featureset_rest_object: FeaturesetVersion) -> Optional["FeatureSet"]: + if not featureset_rest_object: + return None + featureset_rest_object_details: FeaturesetVersionProperties = featureset_rest_object.properties + arm_id_object = get_arm_id_object_from_id(featureset_rest_object.id) + featureset = FeatureSet( + id=featureset_rest_object.id, + name=arm_id_object.asset_name, + version=arm_id_object.asset_version, + description=featureset_rest_object_details.description, + tags=featureset_rest_object_details.tags, + entities=featureset_rest_object_details.entities, + materialization_settings=MaterializationSettings._from_rest_object( + featureset_rest_object_details.materialization_settings + ), + specification=FeatureSetSpecification._from_rest_object(featureset_rest_object_details.specification), + stage=featureset_rest_object_details.stage, + properties=featureset_rest_object_details.properties, + ) + return featureset + + @classmethod + def _from_container_rest_object(cls, rest_obj: FeaturesetContainer) -> "FeatureSet": + rest_object_details: FeaturesetContainerProperties = rest_obj.properties + arm_id_object = get_arm_id_object_from_id(rest_obj.id) + featureset = FeatureSet( + name=arm_id_object.asset_name, + description=rest_object_details.description, + tags=rest_object_details.tags, + entities=[], + specification=FeatureSetSpecification(), + version="", + ) + featureset.latest_version = rest_object_details.latest_version + return featureset + + @classmethod + def _load( + cls, + data: Optional[Dict] = None, + yaml_path: Optional[Union[PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "FeatureSet": + data = data or {} + params_override = params_override or [] + base_path = Path(yaml_path).parent if yaml_path else Path("./") + context = { + BASE_PATH_CONTEXT_KEY: base_path, + PARAMS_OVERRIDE_KEY: params_override, + } + loaded_schema = load_from_dict(FeatureSetSchema, data, context, **kwargs) + feature_set = FeatureSet(base_path=base_path, **loaded_schema) + return feature_set + + def _to_dict(self) -> Dict: + return dict(FeatureSetSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)) + + def _update_path(self, asset_artifact: ArtifactStorageInfo) -> None: + # if datastore_arm_id is null, capture the full_storage_path + if not asset_artifact.datastore_arm_id and asset_artifact.full_storage_path: + self.path = asset_artifact.full_storage_path + else: + aml_datastore_id = AMLNamedArmId(asset_artifact.datastore_arm_id) + self.path = LONG_URI_FORMAT.format( + aml_datastore_id.subscription_id, + aml_datastore_id.resource_group_name, + aml_datastore_id.workspace_name, + aml_datastore_id.asset_name, + asset_artifact.relative_path, + ) + + if self.specification is not None: + self.specification.path = self.path + + def dump(self, dest: Union[str, PathLike, IO[AnyStr]], **kwargs: Any) -> None: + """Dump the asset content into a file in YAML format. + + :param dest: The local path or file stream to write the YAML content to. + If dest is a file path, a new file will be created. + If dest is an open file, the file will be written to directly. + :type dest: Union[PathLike, str, IO[AnyStr]] + :raises FileExistsError: Raised if dest is a file path and the file already exists. + :raises IOError: Raised if dest is an open file and the file is not writable. + """ + + import os + import shutil + + from azure.ai.ml._utils.utils import is_url + + origin_spec_path = self.specification.path if self.specification is not None else None + if isinstance(dest, (PathLike, str)) and self.specification is not None and not is_url(self.specification.path): + if os.path.exists(dest): + raise FileExistsError(f"File {dest} already exists.") + relative_path = os.path.basename(cast(PathLike, self.specification.path)) + src_spec_path = ( + str(Path(self._base_path, self.specification.path)) if self.specification.path is not None else "" + ) + dest_spec_path = str(Path(os.path.dirname(dest), relative_path)) + if os.path.exists(dest_spec_path): + shutil.rmtree(dest_spec_path) + shutil.copytree(src=src_spec_path, dst=dest_spec_path) + self.specification.path = str(Path("./", relative_path)) + super().dump(dest=dest, **kwargs) + + if self.specification is not None: + self.specification.path = origin_spec_path diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/index.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/index.py new file mode 100644 index 00000000..35f671d3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/index.py @@ -0,0 +1,137 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from os import PathLike +from pathlib import Path +from typing import Any, Dict, Optional, Union, cast + +# cspell:disable-next-line +from azure.ai.ml._restclient.azure_ai_assets_v2024_04_01.azureaiassetsv20240401.models import Index as RestIndex +from azure.ai.ml._schema import IndexAssetSchema +from azure.ai.ml._utils._arm_id_utils import AMLAssetId, AMLNamedArmId +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, LONG_URI_FORMAT, PARAMS_OVERRIDE_KEY +from azure.ai.ml.entities._assets import Artifact +from azure.ai.ml.entities._assets._artifacts.artifact import ArtifactStorageInfo +from azure.ai.ml.entities._system_data import RestSystemData, SystemData +from azure.ai.ml.entities._util import load_from_dict + + +@experimental +class Index(Artifact): + """Index asset. + + :ivar name: Name of the resource. + :vartype name: str + :ivar version: Version of the resource. + :vartype version: str + :ivar id: Fully qualified resource Id: + azureml://workspace/{workspaceName}/indexes/{name}/versions/{version} of the index. Required. + :vartype id: str + :ivar stage: Update stage to 'Archive' for soft delete. Default is Development, which means the + asset is under development. Required. + :vartype stage: str + :ivar description: Description information of the asset. + :vartype description: Optional[str] + :ivar tags: Asset's tags. + :vartype tags: Optional[dict[str, str]] + :ivar properties: Asset's properties. + :vartype properties: Optional[dict[str, str]] + :ivar path: The local or remote path to the asset. + :vartype path: Optional[Union[str, os.PathLike]] + """ + + def __init__( + self, + *, + name: str, + version: Optional[str] = None, + stage: str = "Development", + description: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + properties: Optional[Dict[str, str]] = None, + path: Optional[Union[str, PathLike]] = None, + datastore: Optional[str] = None, + **kwargs: Any, + ): + self.stage = stage + super().__init__( + name=name, + version=version, + description=description, + tags=tags, + properties=properties, + path=path, + datastore=datastore, + **kwargs, + ) + + @classmethod + def _from_rest_object(cls, index_rest_object: RestIndex) -> "Index": + """Convert the response from the Index API into a Index + + :param RestIndex index_rest_object: + :return: An Index Asset + :rtype: Index + """ + asset_id = AMLAssetId(asset_id=index_rest_object.id) + + return Index( + id=index_rest_object.id, + name=asset_id.asset_name, + version=asset_id.asset_version, + description=index_rest_object.description, + tags=index_rest_object.tags, + properties=index_rest_object.properties, + stage=index_rest_object.stage, + path=index_rest_object.storage_uri, + # pylint: disable-next=protected-access + creation_context=SystemData._from_rest_object( + RestSystemData.from_dict(index_rest_object.system_data.as_dict()) + ), + ) + + def _to_rest_object(self) -> RestIndex: + # Note: Index.name and Index.version get dropped going to RestIndex, since both are encoded in the id + # (when present) + return RestIndex( + stage=self.stage, + storage_uri=self.path, + description=self.description, + tags=self.tags, + properties=self.properties, + id=self.id, + ) + + @classmethod + def _load( + cls, + data: Optional[Dict] = None, + yaml_path: Optional[Union[PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "Index": + data = data or {} + params_override = params_override or [] + context = { + BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"), + PARAMS_OVERRIDE_KEY: params_override, + } + return cast(Index, load_from_dict(IndexAssetSchema, data, context, **kwargs)) + + def _to_dict(self) -> Dict: + return cast(dict, IndexAssetSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)) + + def _update_path(self, asset_artifact: ArtifactStorageInfo) -> None: + """Updates an an artifact with the remote path of a local upload. + + :param ArtifactStorageInfo asset_artifact: The asset storage info of the artifact + """ + aml_datastore_id = AMLNamedArmId(asset_artifact.datastore_arm_id) + self.path = LONG_URI_FORMAT.format( + aml_datastore_id.subscription_id, + aml_datastore_id.resource_group_name, + aml_datastore_id.workspace_name, + aml_datastore_id.asset_name, + asset_artifact.relative_path, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/model.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/model.py new file mode 100644 index 00000000..8e65bd3e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/model.py @@ -0,0 +1,219 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from os import PathLike +from pathlib import Path +from typing import Any, Dict, Optional, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + FlavorData, + ModelContainer, + ModelVersion, + ModelVersionProperties, +) +from azure.ai.ml._schema import ModelSchema +from azure.ai.ml._utils._arm_id_utils import AMLNamedArmId, AMLVersionedArmId +from azure.ai.ml._utils._asset_utils import get_ignore_file, get_object_hash +from azure.ai.ml.constants._common import ( + BASE_PATH_CONTEXT_KEY, + LONG_URI_FORMAT, + PARAMS_OVERRIDE_KEY, + ArmConstants, + AssetTypes, +) +from azure.ai.ml.entities._assets import Artifact +from azure.ai.ml.entities._assets.intellectual_property import IntellectualProperty +from azure.ai.ml.entities._system_data import SystemData +from azure.ai.ml.entities._util import get_md5_string, load_from_dict + +from .artifact import ArtifactStorageInfo + + +class Model(Artifact): # pylint: disable=too-many-instance-attributes + """Model for training and scoring. + + :param name: The name of the model. Defaults to a random GUID. + :type name: Optional[str] + :param version: The version of the model. Defaults to "1" if either no name or an unregistered name is provided. + Otherwise, defaults to autoincrement from the last registered version of the model with that name. + :type version: Optional[str] + :param type: The storage format for this entity, used for NCD (Novel Class Discovery). Accepted values are + "custom_model", "mlflow_model", or "triton_model". Defaults to "custom_model". + :type type: Optional[str] + :param utc_time_created: The date and time when the model was created, in + UTC ISO 8601 format. (e.g. '2020-10-19 17:44:02.096572'). + :type utc_time_created: Optional[str] + :param flavors: The flavors in which the model can be interpreted. Defaults to None. + :type flavors: Optional[dict[str, Any]] + :param path: A remote uri or a local path pointing to a model. Defaults to None. + :type path: Optional[str] + :param description: The description of the resource. Defaults to None + :type description: Optional[str] + :param tags: Tag dictionary. Tags can be added, removed, and updated. Defaults to None. + :type tags: Optional[dict[str, str]] + :param properties: The asset property dictionary. Defaults to None. + :type properties: Optional[dict[str, str]] + :param stage: The stage of the resource. Defaults to None. + :type stage: Optional[str] + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: Optional[dict] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START model_entity_create] + :end-before: [END model_entity_create] + :language: python + :dedent: 8 + :caption: Creating a Model object. + """ + + def __init__( + self, + *, + name: Optional[str] = None, + version: Optional[str] = None, + type: Optional[str] = None, # pylint: disable=redefined-builtin + path: Optional[Union[str, PathLike]] = None, + utc_time_created: Optional[str] = None, + flavors: Optional[Dict[str, Dict[str, Any]]] = None, + description: Optional[str] = None, + tags: Optional[Dict] = None, + properties: Optional[Dict] = None, + stage: Optional[str] = None, + **kwargs: Any, + ) -> None: + self.job_name = kwargs.pop("job_name", None) + self._intellectual_property = kwargs.pop("intellectual_property", None) + self._system_metadata = kwargs.pop("system_metadata", None) + super().__init__( + name=name, + version=version, + path=path, + description=description, + tags=tags, + properties=properties, + **kwargs, + ) + self.utc_time_created = utc_time_created + self.flavors = dict(flavors) if flavors else None + self._arm_type = ArmConstants.MODEL_VERSION_TYPE + self.type = type or AssetTypes.CUSTOM_MODEL + self.stage = stage + if self._is_anonymous and self.path: + _ignore_file = get_ignore_file(self.path) + _upload_hash = get_object_hash(self.path, _ignore_file) + self.name = get_md5_string(_upload_hash) + + @classmethod + def _load( + cls, + data: Optional[Dict] = None, + yaml_path: Optional[Union[PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "Model": + params_override = params_override or [] + data = data or {} + context = { + BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"), + PARAMS_OVERRIDE_KEY: params_override, + } + res: Model = load_from_dict(ModelSchema, data, context, **kwargs) + return res + + def _to_dict(self) -> Dict: + return dict(ModelSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)) + + @classmethod + def _from_rest_object(cls, model_rest_object: ModelVersion) -> "Model": + rest_model_version: ModelVersionProperties = model_rest_object.properties + arm_id = AMLVersionedArmId(arm_id=model_rest_object.id) + model_stage = rest_model_version.stage if hasattr(rest_model_version, "stage") else None + model_system_metadata = ( + rest_model_version.system_metadata if hasattr(rest_model_version, "system_metadata") else None + ) + if hasattr(rest_model_version, "flavors"): + flavors = {key: flavor.data for key, flavor in rest_model_version.flavors.items()} + model = Model( + id=model_rest_object.id, + name=arm_id.asset_name, + version=arm_id.asset_version, + path=rest_model_version.model_uri, + description=rest_model_version.description, + tags=rest_model_version.tags, + flavors=flavors, # pylint: disable=possibly-used-before-assignment + properties=rest_model_version.properties, + stage=model_stage, + # pylint: disable=protected-access + creation_context=SystemData._from_rest_object(model_rest_object.system_data), + type=rest_model_version.model_type, + job_name=rest_model_version.job_name, + intellectual_property=( + IntellectualProperty._from_rest_object(rest_model_version.intellectual_property) + if rest_model_version.intellectual_property + else None + ), + system_metadata=model_system_metadata, + ) + return model + + @classmethod + def _from_container_rest_object(cls, model_container_rest_object: ModelContainer) -> "Model": + model = Model( + name=model_container_rest_object.name, + version="1", + id=model_container_rest_object.id, + # pylint: disable=protected-access + creation_context=SystemData._from_rest_object(model_container_rest_object.system_data), + ) + model.latest_version = model_container_rest_object.properties.latest_version + + # Setting version to None since if version is not provided it is defaulted to "1". + # This should go away once container concept is finalized. + model.version = None + return model + + def _to_rest_object(self) -> ModelVersion: + model_version = ModelVersionProperties( + description=self.description, + tags=self.tags, + properties=self.properties, + flavors=( + {key: FlavorData(data=dict(value)) for key, value in self.flavors.items()} if self.flavors else None + ), # flatten OrderedDict to dict + model_type=self.type, + model_uri=self.path, + stage=self.stage, + is_anonymous=self._is_anonymous, + ) + model_version.system_metadata = self._system_metadata if hasattr(self, "_system_metadata") else None + + model_version_resource = ModelVersion(properties=model_version) + + return model_version_resource + + def _update_path(self, asset_artifact: ArtifactStorageInfo) -> None: + # datastore_arm_id is null for registry scenario, so capture the full_storage_path + if not asset_artifact.datastore_arm_id and asset_artifact.full_storage_path: + self.path = asset_artifact.full_storage_path + else: + aml_datastore_id = AMLNamedArmId(asset_artifact.datastore_arm_id) + self.path = LONG_URI_FORMAT.format( + aml_datastore_id.subscription_id, + aml_datastore_id.resource_group_name, + aml_datastore_id.workspace_name, + aml_datastore_id.asset_name, + asset_artifact.relative_path, + ) + + def _to_arm_resource_param(self, **kwargs: Any) -> Dict: # pylint: disable=unused-argument + properties = self._to_rest_object().properties + + return { + self._arm_type: { + ArmConstants.NAME: self.name, + ArmConstants.VERSION: self.version, + ArmConstants.PROPERTIES_PARAMETER_NAME: self._serialize.body(properties, "ModelVersionProperties"), + } + } diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/asset.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/asset.py new file mode 100644 index 00000000..b6ee2b55 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/asset.py @@ -0,0 +1,145 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import uuid +from abc import abstractmethod +from os import PathLike +from typing import IO, Any, AnyStr, Dict, Optional, Union + +from azure.ai.ml._exception_helper import log_and_raise_error +from azure.ai.ml._utils.utils import dump_yaml_to_file +from azure.ai.ml.entities._resource import Resource +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + + +class Asset(Resource): + """Base class for asset. + + This class should not be instantiated directly. Instead, use one of its subclasses. + + :param name: The name of the asset. Defaults to a random GUID. + :type name: Optional[str]] + :param version: The version of the asset. Defaults to "1" if no name is provided, otherwise defaults to + autoincrement from the last registered version of the asset with that name. For a model name that has + never been registered, a default version will be assigned. + :type version: Optional[str] + :param description: The description of the resource. Defaults to None. + :type description: Optional[str] + :param tags: Tag dictionary. Tags can be added, removed, and updated. Defaults to None. + :type tags: Optional[dict[str, str]] + :param properties: The asset property dictionary. Defaults to None. + :type properties: Optional[dict[str, str]] + :keyword kwargs: A dictionary of additional configuration parameters. + :paramtype kwargs: Optional[dict] + """ + + def __init__( + self, + name: Optional[str] = None, + version: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[Dict] = None, + properties: Optional[Dict] = None, + **kwargs: Any, + ) -> None: + self._is_anonymous = kwargs.pop("is_anonymous", False) + self._auto_increment_version = kwargs.pop("auto_increment_version", False) + self.auto_delete_setting = kwargs.pop("auto_delete_setting", None) + + if not name and version is None: + name = _get_random_name() + version = "1" + self._is_anonymous = True + elif version is not None and not name: + msg = "If version is specified, name must be specified also." + err = ValidationException( + message=msg, + target=ErrorTarget.ASSET, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.MISSING_FIELD, + ) + log_and_raise_error(err) + + super().__init__( + name=name, + description=description, + tags=tags, + properties=properties, + **kwargs, + ) + + self.version = version + self.latest_version = None + + @abstractmethod + def _to_dict(self) -> Dict: + """Dump the artifact content into a pure dict object.""" + + @property + def version(self) -> Optional[str]: + """The asset version. + + :return: The asset version. + :rtype: str + """ + return self._version + + @version.setter + def version(self, value: str) -> None: + """Sets the asset version. + + :param value: The asset version. + :type value: str + :raises ValidationException: Raised if value is not a string. + """ + if value: + if not isinstance(value, str): + msg = f"Asset version must be a string, not type {type(value)}." + err = ValidationException( + message=msg, + target=ErrorTarget.ASSET, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + log_and_raise_error(err) + + self._version = value + self._auto_increment_version = self.name and not self._version + + def dump(self, dest: Union[str, PathLike, IO[AnyStr]], **kwargs: Any) -> None: + """Dump the asset content into a file in YAML format. + + :param dest: The local path or file stream to write the YAML content to. + If dest is a file path, a new file will be created. + If dest is an open file, the file will be written to directly. + :type dest: Union[PathLike, str, IO[AnyStr]] + :raises FileExistsError: Raised if dest is a file path and the file already exists. + :raises IOError: Raised if dest is an open file and the file is not writable. + """ + path = kwargs.pop("path", None) + yaml_serialized = self._to_dict() + dump_yaml_to_file(dest, yaml_serialized, default_flow_style=False, path=path, **kwargs) + + def __eq__(self, other: Any) -> bool: + return bool( + self.name == other.name + and self.id == other.id + and self.version == other.version + and self.description == other.description + and self.tags == other.tags + and self.properties == other.properties + and self.base_path == other.base_path + and self._is_anonymous == other._is_anonymous + and self._auto_increment_version == other._auto_increment_version + and self.auto_delete_setting == other.auto_delete_setting + ) + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) + + +def _get_random_name() -> str: + return str(uuid.uuid4()) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/auto_delete_setting.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/auto_delete_setting.py new file mode 100644 index 00000000..ea6bf9e8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/auto_delete_setting.py @@ -0,0 +1,42 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Any, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import AutoDeleteSetting as RestAutoDeleteSetting +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.constants._common import AutoDeleteCondition +from azure.ai.ml.entities._mixins import DictMixin + + +@experimental +class AutoDeleteSetting(DictMixin): + """Class which defines the auto delete setting. + :param condition: When to check if an asset is expired. + Possible values include: "CreatedGreaterThan", "LastAccessedGreaterThan". + :type condition: AutoDeleteCondition + :param value: Expiration condition value. + :type value: str + """ + + def __init__( + self, + *, + condition: AutoDeleteCondition = AutoDeleteCondition.CREATED_GREATER_THAN, + value: Union[str, None] = None + ): + self.condition = condition + self.value = value + + def _to_rest_object(self) -> RestAutoDeleteSetting: + return RestAutoDeleteSetting(condition=self.condition, value=self.value) + + @classmethod + def _from_rest_object(cls, obj: RestAutoDeleteSetting) -> "AutoDeleteSetting": + return cls(condition=obj.condition, value=obj.value) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, AutoDeleteSetting): + return NotImplemented + return self.condition == other.condition and self.value == other.value diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/environment.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/environment.py new file mode 100644 index 00000000..865273fb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/environment.py @@ -0,0 +1,478 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access, too-many-instance-attributes + +import os +from pathlib import Path +from typing import Any, Dict, Optional, Union + +import yaml # type: ignore[import] + +from azure.ai.ml._exception_helper import log_and_raise_error +from azure.ai.ml._restclient.v2023_04_01_preview.models import BuildContext as RestBuildContext +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + EnvironmentContainer, + EnvironmentVersion, + EnvironmentVersionProperties, +) +from azure.ai.ml._schema import EnvironmentSchema +from azure.ai.ml._utils._arm_id_utils import AMLVersionedArmId +from azure.ai.ml._utils._asset_utils import get_ignore_file, get_object_hash +from azure.ai.ml._utils.utils import dump_yaml, is_url, load_file, load_yaml +from azure.ai.ml.constants._common import ANONYMOUS_ENV_NAME, BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY, ArmConstants +from azure.ai.ml.entities._assets.asset import Asset +from azure.ai.ml.entities._assets.intellectual_property import IntellectualProperty +from azure.ai.ml.entities._mixins import LocalizableMixin +from azure.ai.ml.entities._system_data import SystemData +from azure.ai.ml.entities._util import get_md5_string, load_from_dict +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + + +class BuildContext: + """Docker build context for Environment. + + :param path: The local or remote path to the the docker build context directory. + :type path: Union[str, os.PathLike] + :param dockerfile_path: The path to the dockerfile relative to root of docker build context directory. + :type dockerfile_path: str + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START build_context_entity_create] + :end-before: [END build_context_entity_create] + :language: python + :dedent: 8 + :caption: Create a Build Context object. + """ + + def __init__( + self, + *, + dockerfile_path: Optional[str] = None, + path: Optional[Union[str, os.PathLike]] = None, + ): + self.dockerfile_path = dockerfile_path + self.path = path + + def _to_rest_object(self) -> RestBuildContext: + return RestBuildContext(context_uri=self.path, dockerfile_path=self.dockerfile_path) + + @classmethod + def _from_rest_object(cls, rest_obj: RestBuildContext) -> "BuildContext": + return BuildContext( + path=rest_obj.context_uri, + dockerfile_path=rest_obj.dockerfile_path, + ) + + def __eq__(self, other: Any) -> bool: + res: bool = self.dockerfile_path == other.dockerfile_path and self.path == other.path + return res + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) + + +class Environment(Asset, LocalizableMixin): + """Environment for training. + + :param name: Name of the resource. + :type name: str + :param version: Version of the asset. + :type version: str + :param description: Description of the resource. + :type description: str + :param image: URI of a custom base image. + :type image: str + :param build: Docker build context to create the environment. Mutually exclusive with "image" + :type build: ~azure.ai.ml.entities._assets.environment.BuildContext + :param conda_file: Path to configuration file listing conda packages to install. + :type conda_file: typing.Union[str, os.PathLike] + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: dict[str, str] + :param properties: The asset property dictionary. + :type properties: dict[str, str] + :param datastore: The datastore to upload the local artifact to. + :type datastore: str + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START env_entity_create] + :end-before: [END env_entity_create] + :language: python + :dedent: 8 + :caption: Create a Environment object. + """ + + def __init__( + self, + *, + name: Optional[str] = None, + version: Optional[str] = None, + description: Optional[str] = None, + image: Optional[str] = None, + build: Optional[BuildContext] = None, + conda_file: Optional[Union[str, os.PathLike, Dict]] = None, + tags: Optional[Dict] = None, + properties: Optional[Dict] = None, + datastore: Optional[str] = None, + **kwargs: Any, + ): + self._arm_type: str = "" + self.latest_version: str = "" # type: ignore[assignment] + self.image: Optional[str] = None + inference_config = kwargs.pop("inference_config", None) + os_type = kwargs.pop("os_type", None) + self._intellectual_property = kwargs.pop("intellectual_property", None) + + super().__init__( + name=name, + version=version, + description=description, + tags=tags, + properties=properties, + **kwargs, + ) + + self.conda_file = conda_file + self.image = image + self.build = build + self.inference_config = inference_config + self.os_type = os_type + self._arm_type = ArmConstants.ENVIRONMENT_VERSION_TYPE + self._conda_file_path = ( + _resolve_path(base_path=self.base_path, input=conda_file) + if isinstance(conda_file, (os.PathLike, str)) + else None + ) + self.path = None + self.datastore = datastore + self._upload_hash = None + + self._translated_conda_file = None + if self.conda_file: + self._translated_conda_file = dump_yaml(self.conda_file, sort_keys=True) # service needs str representation + + if self.build and self.build.path and not is_url(self.build.path): + path = Path(self.build.path) + if not path.is_absolute(): + path = Path(self.base_path, path).resolve() + self.path = path + + if self._is_anonymous: + if self.path: + self._ignore_file = get_ignore_file(path) + self._upload_hash = get_object_hash(path, self._ignore_file) + self._generate_anonymous_name_version(source="build") + elif self.image: + self._generate_anonymous_name_version( + source="image", conda_file=self._translated_conda_file, inference_config=self.inference_config + ) + + @property + def conda_file(self) -> Optional[Union[str, os.PathLike, Dict]]: + """Conda environment specification. + + :return: Conda dependencies loaded from `conda_file` param. + :rtype: Optional[Union[str, os.PathLike]] + """ + return self._conda_file + + @conda_file.setter + def conda_file(self, value: Optional[Union[str, os.PathLike, Dict]]) -> None: + """Set conda environment specification. + + :param value: A path to a local conda dependencies yaml file or a loaded yaml dictionary of dependencies. + :type value: Union[str, os.PathLike, Dict] + :return: None + """ + if not isinstance(value, Dict): + value = _deserialize(self.base_path, value, is_conda=True) + self._conda_file = value + + @classmethod + def _load( + cls, + data: Optional[dict] = None, + yaml_path: Optional[Union[os.PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "Environment": + params_override = params_override or [] + data = data or {} + context = { + BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"), + PARAMS_OVERRIDE_KEY: params_override, + } + res: Environment = load_from_dict(EnvironmentSchema, data, context, **kwargs) + return res + + def _to_rest_object(self) -> EnvironmentVersion: + self.validate() + environment_version = EnvironmentVersionProperties() + if self.conda_file: + environment_version.conda_file = self._translated_conda_file + if self.image: + environment_version.image = self.image + if self.build: + environment_version.build = self.build._to_rest_object() + if self.os_type: + environment_version.os_type = self.os_type + if self.tags: + environment_version.tags = self.tags + if self._is_anonymous: + environment_version.is_anonymous = self._is_anonymous + if self.inference_config: + environment_version.inference_config = self.inference_config + if self.description: + environment_version.description = self.description + if self.properties: + environment_version.properties = self.properties + + environment_version_resource = EnvironmentVersion(properties=environment_version) + + return environment_version_resource + + @classmethod + def _from_rest_object(cls, env_rest_object: EnvironmentVersion) -> "Environment": + rest_env_version = env_rest_object.properties + arm_id = AMLVersionedArmId(arm_id=env_rest_object.id) + + environment = Environment( + id=env_rest_object.id, + name=arm_id.asset_name, + version=arm_id.asset_version, + description=rest_env_version.description, + tags=rest_env_version.tags, + creation_context=( + SystemData._from_rest_object(env_rest_object.system_data) if env_rest_object.system_data else None + ), + is_anonymous=rest_env_version.is_anonymous, + image=rest_env_version.image, + os_type=rest_env_version.os_type, + inference_config=rest_env_version.inference_config, + build=BuildContext._from_rest_object(rest_env_version.build) if rest_env_version.build else None, + properties=rest_env_version.properties, + intellectual_property=( + IntellectualProperty._from_rest_object(rest_env_version.intellectual_property) + if rest_env_version.intellectual_property + else None + ), + ) + + if rest_env_version.conda_file: + translated_conda_file = yaml.safe_load(rest_env_version.conda_file) + environment.conda_file = translated_conda_file + environment._translated_conda_file = rest_env_version.conda_file + + return environment + + @classmethod + def _from_container_rest_object(cls, env_container_rest_object: EnvironmentContainer) -> "Environment": + env = Environment( + name=env_container_rest_object.name, + version="1", + id=env_container_rest_object.id, + creation_context=SystemData._from_rest_object(env_container_rest_object.system_data), + ) + env.latest_version = env_container_rest_object.properties.latest_version + + # Setting version to None since if version is not provided it is defaulted to "1". + # This should go away once container concept is finalized. + env.version = None + return env + + def _to_arm_resource_param(self, **kwargs: Any) -> Dict: # pylint: disable=unused-argument + properties = self._to_rest_object().properties + + return { + self._arm_type: { + ArmConstants.NAME: self.name, + ArmConstants.VERSION: self.version, + ArmConstants.PROPERTIES_PARAMETER_NAME: self._serialize.body(properties, "EnvironmentVersion"), + } + } + + def _to_dict(self) -> Dict: + res: dict = EnvironmentSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + def validate(self) -> None: + """Validate the environment by checking its name, image and build + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START env_entities_validate] + :end-before: [END env_entities_validate] + :language: python + :dedent: 8 + :caption: Validate environment example. + """ + + if self.name is None: + msg = "Environment name is required" + err = ValidationException( + message=msg, + target=ErrorTarget.ENVIRONMENT, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.MISSING_FIELD, + ) + log_and_raise_error(err) + if self.image is None and self.build is None: + msg = "Docker image or Dockerfile is required for environments" + err = ValidationException( + message=msg, + target=ErrorTarget.ENVIRONMENT, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.MISSING_FIELD, + ) + log_and_raise_error(err) + if self.image and self.build: + msg = "Docker image or Dockerfile should be provided not both" + err = ValidationException( + message=msg, + target=ErrorTarget.ENVIRONMENT, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + log_and_raise_error(err) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Environment): + return NotImplemented + return ( + self.name == other.name + and self.id == other.id + and self.version == other.version + and self.description == other.description + and self.tags == other.tags + and self.properties == other.properties + and self.base_path == other.base_path + and self.image == other.image + and self.build == other.build + and self.conda_file == other.conda_file + and self.inference_config == other.inference_config + and self._is_anonymous == other._is_anonymous + and self.os_type == other.os_type + and self._intellectual_property == other._intellectual_property + ) + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + def _generate_anonymous_name_version( + self, source: str, conda_file: Optional[str] = None, inference_config: Optional[Dict] = None + ) -> None: + hash_str = "" + if source == "image": + hash_str = hash_str.join(get_md5_string(self.image)) + if inference_config: + hash_str = hash_str.join(get_md5_string(yaml.dump(inference_config, sort_keys=True))) + if conda_file: + hash_str = hash_str.join(get_md5_string(conda_file)) + if source == "build": + if self.build is not None and not self.build.dockerfile_path: + hash_str = hash_str.join(get_md5_string(self._upload_hash)) + else: + if self.build is not None: + hash_str = hash_str.join(get_md5_string(self._upload_hash)).join( + get_md5_string(self.build.dockerfile_path) + ) + version_hash = get_md5_string(hash_str) + self.version = version_hash + self.name = ANONYMOUS_ENV_NAME + + def _localize(self, base_path: str) -> None: + """Called on an asset got from service to clean up remote attributes like id, creation_context, etc. and update + base_path. + + :param base_path: The base path + :type base_path: str + """ + if not getattr(self, "id", None): + raise ValueError("Only remote asset can be localize but got a {} without id.".format(type(self))) + self._id = None + self._creation_context = None + self._base_path = base_path + if self._is_anonymous: + self.name, self.version = None, None + + +# TODO: Remove _DockerBuild and _DockerConfiguration classes once local endpoint moves to using updated env +class _DockerBuild: + """Helper class to encapsulate Docker build info for Environment.""" + + def __init__( + self, + base_path: Optional[Union[str, os.PathLike]] = None, + dockerfile: Optional[str] = None, + ): + self.dockerfile = _deserialize(base_path, dockerfile) + + @classmethod + def _to_rest_object(cls) -> None: + return None + + def _from_rest_object(self, rest_obj: Any) -> None: + self.dockerfile = rest_obj.dockerfile + + def __eq__(self, other: Any) -> bool: + res: bool = self.dockerfile == other.dockerfile + return res + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) + + +def _deserialize( + base_path: Optional[Union[str, os.PathLike]], + input: Optional[Union[str, os.PathLike, Dict]], # pylint: disable=redefined-builtin + is_conda: bool = False, +) -> Optional[Union[str, os.PathLike, Dict]]: + """Deserialize user input files for conda and docker. + + :param base_path: The base path for all files supplied by user. + :type base_path: Union[str, os.PathLike] + :param input: Input to be deserialized. Will be either dictionary of file contents or path to file. + :type input: Union[str, os.PathLike, Dict[str, str]] + :param is_conda: If file is conda file, it will be returned as dictionary + :type is_conda: bool + :return: The deserialized data + :rtype: Union[str, Dict] + """ + + if input: + path = _resolve_path(base_path=base_path, input=input) + data: Union[str, Dict] = "" + if is_conda: + data = load_yaml(path) + else: + data = load_file(path) + return data + return input + + +def _resolve_path(base_path: Any, input: Any) -> Path: # pylint: disable=redefined-builtin + """Deserialize user input files for conda and docker. + + :param base_path: The base path for all files supplied by user. + :type base_path: Union[str, os.PathLike] + :param input: Input to be deserialized. Will be either dictionary of file contents or path to file. + :type input: Union[str, os.PathLike, Dict[str, str]] + :return: The resolved path + :rtype: Path + """ + + path = Path(input) + if not path.is_absolute(): + path = Path(base_path, path).resolve() + return path diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/federated_learning_silo.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/federated_learning_silo.py new file mode 100644 index 00000000..8255f887 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/federated_learning_silo.py @@ -0,0 +1,123 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# TODO determine where this file should live. + +from os import PathLike +from typing import IO, Any, AnyStr, Dict, List, Optional, Union + +from azure.ai.ml import Input +from azure.ai.ml._utils.utils import dump_yaml_to_file, load_yaml +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY + + +# Entity representation of a federated learning silo. +# Used by Federated Learning DSL nodes as inputs for creating +# FL subgraphs in pipelines. +# The functionality of this entity is limited, and it exists mostly +# To simplify the process of loading and validating these objects from YAML. +class FederatedLearningSilo: + def __init__( + self, + *, + compute: str, + datastore: str, + inputs: Dict[str, Input], + ): + """ + A pseudo-entity that represents a federated learning silo, which is an isolated compute with its own + datastore and input targets. This is meant to be used in conjunction with the + Federated Learning DSL node to create federated learning pipelines. This does NOT represent any specific + AML resource, and is instead merely meant to simply client-side experiences with managing FL data distribution. + Standard usage involves the "load_list" classmethod to load a list of these objects from YAML, which serves + as a necessary input for FL processes. + + + :param compute: The resource id of a compute. + :type compute: str + :param datastore: The resource id of a datastore. + :type datastore: str + :param inputs: A dictionary of input entities that exist in the previously specified datastore. + The keys of this dictionary are the keyword names that these inputs should be entered into. + :type inputs: dict[str, Input] + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + """ + self.compute = compute + self.datastore = datastore + self.inputs = inputs + + def dump( + self, + dest: Union[str, PathLike, IO[AnyStr]], + # pylint: disable=unused-argument + **kwargs: Any, + ) -> None: + """Dump the Federated Learning Silo spec into a file in yaml format. + + :param dest: Either + * A path to a local file + * A writeable file-like object + :type dest: Union[str, PathLike, IO[AnyStr]] + """ + yaml_serialized = self._to_dict() + dump_yaml_to_file(dest, yaml_serialized, default_flow_style=False) + + def _to_dict(self) -> Dict: + # JIT import to avoid experimental warnings on unrelated calls + from azure.ai.ml._schema.assets.federated_learning_silo import FederatedLearningSiloSchema + + schema = FederatedLearningSiloSchema(context={BASE_PATH_CONTEXT_KEY: "./"}) + + return Dict(schema.dump(self)) + + @classmethod + def _load_from_dict(cls, silo_dict: dict) -> "FederatedLearningSilo": + data_input = silo_dict.get("inputs", {}) + return FederatedLearningSilo(compute=silo_dict["compute"], datastore=silo_dict["datastore"], inputs=data_input) + + # simple load based off mltable metadata loading style + @classmethod + def _load( + cls, + yaml_path: Optional[Union[PathLike, str]] = None, + ) -> "FederatedLearningSilo": + yaml_dict = load_yaml(yaml_path) + return FederatedLearningSilo._load_from_dict(silo_dict=yaml_dict) + + @classmethod + def load_list( + cls, + *, + yaml_path: Optional[Union[PathLike, str]], + list_arg: str, + ) -> List["FederatedLearningSilo"]: + """ + Loads a list of federated learning silos from YAML. This is the expected entry point + for this class; load a list of these, then supply them to the federated learning DSL + package node in order to produce an FL pipeline. + + The structure of the supplied YAML file is assumed to be a list of FL silos under the + name specified by the list_arg input, as shown below. + + list_arg: + - silo 1 ... + - silo 2 ... + + :keyword yaml_path: A path leading to a local YAML file which contains a list of + FederatedLearningSilo objects. + :paramtype yaml_path: Optional[Union[PathLike, str]] + :keyword list_arg: A string that names the top-level value which contains the list + of FL silos. + :paramtype list_arg: str + :return: The list of federated learning silos + :rtype: List[FederatedLearningSilo] + """ + yaml_dict = load_yaml(yaml_path) + return [ + FederatedLearningSilo._load_from_dict(silo_dict=silo_yaml_dict) for silo_yaml_dict in yaml_dict[list_arg] + ] + + # There are no to/from rest object functions because this object has no + # rest object equivalent. Any conversions should be done as part of the + # to/from rest object functions of OTHER entity objects. diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/intellectual_property.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/intellectual_property.py new file mode 100644 index 00000000..58b96a1b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/intellectual_property.py @@ -0,0 +1,49 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Any, Optional + +from azure.ai.ml._restclient.v2023_04_01_preview.models import IntellectualProperty as RestIntellectualProperty +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.constants._assets import IPProtectionLevel +from azure.ai.ml.entities._mixins import RestTranslatableMixin + + +@experimental +class IntellectualProperty(RestTranslatableMixin): + """Intellectual property settings definition. + + :keyword publisher: The publisher's name. + :paramtype publisher: Optional[str] + :keyword protection_level: Asset Protection Level. Accepted values are IPProtectionLevel.ALL ("all") and + IPProtectionLevel.NONE ("none"). Defaults to IPProtectionLevel.ALL ("all"). + :paramtype protection_level: Optional[Union[str, ~azure.ai.ml.constants.IPProtectionLevel]] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START intellectual_property_configuration] + :end-before: [END intellectual_property_configuration] + :language: python + :dedent: 8 + :caption: Configuring intellectual property settings on a CommandComponent. + """ + + def __init__( + self, *, publisher: Optional[str] = None, protection_level: IPProtectionLevel = IPProtectionLevel.ALL + ) -> None: + self.publisher = publisher + self.protection_level = protection_level + + def _to_rest_object(self) -> RestIntellectualProperty: + return RestIntellectualProperty(publisher=self.publisher, protection_level=self.protection_level) + + @classmethod + def _from_rest_object(cls, obj: RestIntellectualProperty) -> "IntellectualProperty": + return cls(publisher=obj.publisher, protection_level=obj.protection_level) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, IntellectualProperty): + return NotImplemented + return self.publisher == other.publisher and self.protection_level == other.protection_level diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/workspace_asset_reference.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/workspace_asset_reference.py new file mode 100644 index 00000000..1e7d1ba2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/workspace_asset_reference.py @@ -0,0 +1,87 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import os +from pathlib import Path +from typing import Any, Dict, Optional, Union + +from azure.ai.ml._restclient.v2021_10_01_dataplanepreview.models import ( + ResourceManagementAssetReferenceData, + ResourceManagementAssetReferenceDetails, +) +from azure.ai.ml._schema import WorkspaceAssetReferenceSchema +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY +from azure.ai.ml.entities._assets.asset import Asset +from azure.ai.ml.entities._util import load_from_dict + + +class WorkspaceAssetReference(Asset): + """Workspace Model Reference. + + This is for SDK internal use only, might be deprecated in the future. + :param name: Model name + :type name: str + :param version: Model version + :type version: str + :param asset_id: Model asset id + :type version: str + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + """ + + def __init__( + self, + *, + name: Optional[str] = None, + version: Optional[str] = None, + asset_id: Optional[str] = None, + properties: Optional[Dict] = None, + **kwargs: Any, + ): + super().__init__( + name=name, + version=version, + properties=properties, + **kwargs, + ) + self.asset_id = asset_id + + @classmethod + def _load( + cls: Any, + data: Optional[dict] = None, + yaml_path: Optional[Union[os.PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "WorkspaceAssetReference": + data = data or {} + params_override = params_override or [] + context = { + BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"), + PARAMS_OVERRIDE_KEY: params_override, + } + res: WorkspaceAssetReference = load_from_dict(WorkspaceAssetReferenceSchema, data, context, **kwargs) + return res + + def _to_rest_object(self) -> ResourceManagementAssetReferenceData: + resource_management_details = ResourceManagementAssetReferenceDetails( + destination_name=self.name, + destination_version=self.version, + source_asset_id=self.asset_id, + ) + resource_management = ResourceManagementAssetReferenceData(properties=resource_management_details) + return resource_management + + @classmethod + def _from_rest_object(cls, resource_object: ResourceManagementAssetReferenceData) -> "WorkspaceAssetReference": + resource_management = WorkspaceAssetReference( + name=resource_object.properties.destination_name, + version=resource_object.properties.destination_version, + asset_id=resource_object.properties.source_asset_id, + ) + + return resource_management + + def _to_dict(self) -> Dict: + return dict(WorkspaceAssetReferenceSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/__init__.py new file mode 100644 index 00000000..8dfc61b2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/__init__.py @@ -0,0 +1,20 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- + + +try: + from ._patch import __all__ as _patch_all + from ._patch import * # pylint: disable=unused-wildcard-import +except ImportError: + _patch_all = [] +from ._patch import patch_sdk as _patch_sdk + +__all__ = [] +__all__.extend([p for p in _patch_all if p not in __all__]) + +_patch_sdk() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/_model_base.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/_model_base.py new file mode 100644 index 00000000..5bf680b4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/_model_base.py @@ -0,0 +1,881 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +# pylint: disable=protected-access, broad-except + +import calendar +import decimal +import functools +import sys +import logging +import base64 +import re +import copy +import typing +import enum +import email.utils +from datetime import datetime, date, time, timedelta, timezone +from json import JSONEncoder +from typing_extensions import Self +import isodate +from azure.core.exceptions import DeserializationError +from azure.core import CaseInsensitiveEnumMeta +from azure.core.pipeline import PipelineResponse +from azure.core.serialization import NULL + +if sys.version_info >= (3, 9): + from collections.abc import MutableMapping +else: + from typing import MutableMapping + +_LOGGER = logging.getLogger(__name__) + +__all__ = ["SdkJSONEncoder", "Model", "rest_field", "rest_discriminator"] + +TZ_UTC = timezone.utc +_T = typing.TypeVar("_T") + + +def _timedelta_as_isostr(td: timedelta) -> str: + """Converts a datetime.timedelta object into an ISO 8601 formatted string, e.g. 'P4DT12H30M05S' + + Function adapted from the Tin Can Python project: https://github.com/RusticiSoftware/TinCanPython + + :param timedelta td: The timedelta to convert + :rtype: str + :return: ISO8601 version of this timedelta + """ + + # Split seconds to larger units + seconds = td.total_seconds() + minutes, seconds = divmod(seconds, 60) + hours, minutes = divmod(minutes, 60) + days, hours = divmod(hours, 24) + + days, hours, minutes = list(map(int, (days, hours, minutes))) + seconds = round(seconds, 6) + + # Build date + date_str = "" + if days: + date_str = "%sD" % days + + if hours or minutes or seconds: + # Build time + time_str = "T" + + # Hours + bigger_exists = date_str or hours + if bigger_exists: + time_str += "{:02}H".format(hours) + + # Minutes + bigger_exists = bigger_exists or minutes + if bigger_exists: + time_str += "{:02}M".format(minutes) + + # Seconds + try: + if seconds.is_integer(): + seconds_string = "{:02}".format(int(seconds)) + else: + # 9 chars long w/ leading 0, 6 digits after decimal + seconds_string = "%09.6f" % seconds + # Remove trailing zeros + seconds_string = seconds_string.rstrip("0") + except AttributeError: # int.is_integer() raises + seconds_string = "{:02}".format(seconds) + + time_str += "{}S".format(seconds_string) + else: + time_str = "" + + return "P" + date_str + time_str + + +def _serialize_bytes(o, format: typing.Optional[str] = None) -> str: + encoded = base64.b64encode(o).decode() + if format == "base64url": + return encoded.strip("=").replace("+", "-").replace("/", "_") + return encoded + + +def _serialize_datetime(o, format: typing.Optional[str] = None): + if hasattr(o, "year") and hasattr(o, "hour"): + if format == "rfc7231": + return email.utils.format_datetime(o, usegmt=True) + if format == "unix-timestamp": + return int(calendar.timegm(o.utctimetuple())) + + # astimezone() fails for naive times in Python 2.7, so make make sure o is aware (tzinfo is set) + if not o.tzinfo: + iso_formatted = o.replace(tzinfo=TZ_UTC).isoformat() + else: + iso_formatted = o.astimezone(TZ_UTC).isoformat() + # Replace the trailing "+00:00" UTC offset with "Z" (RFC 3339: https://www.ietf.org/rfc/rfc3339.txt) + return iso_formatted.replace("+00:00", "Z") + # Next try datetime.date or datetime.time + return o.isoformat() + + +def _is_readonly(p): + try: + return p._visibility == ["read"] # pylint: disable=protected-access + except AttributeError: + return False + + +class SdkJSONEncoder(JSONEncoder): + """A JSON encoder that's capable of serializing datetime objects and bytes.""" + + def __init__(self, *args, exclude_readonly: bool = False, format: typing.Optional[str] = None, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + self.format = format + + def default(self, o): # pylint: disable=too-many-return-statements + if _is_model(o): + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) + try: + return super(SdkJSONEncoder, self).default(o) + except TypeError: + if isinstance(o, type(NULL)): + return None + if isinstance(o, decimal.Decimal): + return float(o) + if isinstance(o, (bytes, bytearray)): + return _serialize_bytes(o, self.format) + try: + # First try datetime.datetime + return _serialize_datetime(o, self.format) + except AttributeError: + pass + # Last, try datetime.timedelta + try: + return _timedelta_as_isostr(o) + except AttributeError: + # This will be raised when it hits value.total_seconds in the method above + pass + return super(SdkJSONEncoder, self).default(o) + + +_VALID_DATE = re.compile(r"\d{4}[-]\d{2}[-]\d{2}T\d{2}:\d{2}:\d{2}" + r"\.?\d*Z?[-+]?[\d{2}]?:?[\d{2}]?") +_VALID_RFC7231 = re.compile( + r"(Mon|Tue|Wed|Thu|Fri|Sat|Sun),\s\d{2}\s" + r"(Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)\s\d{4}\s\d{2}:\d{2}:\d{2}\sGMT" +) + + +def _deserialize_datetime(attr: typing.Union[str, datetime]) -> datetime: + """Deserialize ISO-8601 formatted string into Datetime object. + + :param str attr: response string to be deserialized. + :rtype: ~datetime.datetime + :returns: The datetime object from that input + """ + if isinstance(attr, datetime): + # i'm already deserialized + return attr + attr = attr.upper() + match = _VALID_DATE.match(attr) + if not match: + raise ValueError("Invalid datetime string: " + attr) + + check_decimal = attr.split(".") + if len(check_decimal) > 1: + decimal_str = "" + for digit in check_decimal[1]: + if digit.isdigit(): + decimal_str += digit + else: + break + if len(decimal_str) > 6: + attr = attr.replace(decimal_str, decimal_str[0:6]) + + date_obj = isodate.parse_datetime(attr) + test_utc = date_obj.utctimetuple() + if test_utc.tm_year > 9999 or test_utc.tm_year < 1: + raise OverflowError("Hit max or min date") + return date_obj + + +def _deserialize_datetime_rfc7231(attr: typing.Union[str, datetime]) -> datetime: + """Deserialize RFC7231 formatted string into Datetime object. + + :param str attr: response string to be deserialized. + :rtype: ~datetime.datetime + :returns: The datetime object from that input + """ + if isinstance(attr, datetime): + # i'm already deserialized + return attr + match = _VALID_RFC7231.match(attr) + if not match: + raise ValueError("Invalid datetime string: " + attr) + + return email.utils.parsedate_to_datetime(attr) + + +def _deserialize_datetime_unix_timestamp(attr: typing.Union[float, datetime]) -> datetime: + """Deserialize unix timestamp into Datetime object. + + :param str attr: response string to be deserialized. + :rtype: ~datetime.datetime + :returns: The datetime object from that input + """ + if isinstance(attr, datetime): + # i'm already deserialized + return attr + return datetime.fromtimestamp(attr, TZ_UTC) + + +def _deserialize_date(attr: typing.Union[str, date]) -> date: + """Deserialize ISO-8601 formatted string into Date object. + :param str attr: response string to be deserialized. + :rtype: date + :returns: The date object from that input + """ + # This must NOT use defaultmonth/defaultday. Using None ensure this raises an exception. + if isinstance(attr, date): + return attr + return isodate.parse_date(attr, defaultmonth=None, defaultday=None) # type: ignore + + +def _deserialize_time(attr: typing.Union[str, time]) -> time: + """Deserialize ISO-8601 formatted string into time object. + + :param str attr: response string to be deserialized. + :rtype: datetime.time + :returns: The time object from that input + """ + if isinstance(attr, time): + return attr + return isodate.parse_time(attr) + + +def _deserialize_bytes(attr): + if isinstance(attr, (bytes, bytearray)): + return attr + return bytes(base64.b64decode(attr)) + + +def _deserialize_bytes_base64(attr): + if isinstance(attr, (bytes, bytearray)): + return attr + padding = "=" * (3 - (len(attr) + 3) % 4) # type: ignore + attr = attr + padding # type: ignore + encoded = attr.replace("-", "+").replace("_", "/") + return bytes(base64.b64decode(encoded)) + + +def _deserialize_duration(attr): + if isinstance(attr, timedelta): + return attr + return isodate.parse_duration(attr) + + +def _deserialize_decimal(attr): + if isinstance(attr, decimal.Decimal): + return attr + return decimal.Decimal(str(attr)) + + +_DESERIALIZE_MAPPING = { + datetime: _deserialize_datetime, + date: _deserialize_date, + time: _deserialize_time, + bytes: _deserialize_bytes, + bytearray: _deserialize_bytes, + timedelta: _deserialize_duration, + typing.Any: lambda x: x, + decimal.Decimal: _deserialize_decimal, +} + +_DESERIALIZE_MAPPING_WITHFORMAT = { + "rfc3339": _deserialize_datetime, + "rfc7231": _deserialize_datetime_rfc7231, + "unix-timestamp": _deserialize_datetime_unix_timestamp, + "base64": _deserialize_bytes, + "base64url": _deserialize_bytes_base64, +} + + +def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = None): + if rf and rf._format: + return _DESERIALIZE_MAPPING_WITHFORMAT.get(rf._format) + return _DESERIALIZE_MAPPING.get(annotation) + + +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + +def _get_model(module_name: str, model_name: str): + models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} + module_end = module_name.rsplit(".", 1)[0] + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) + if isinstance(model_name, str): + model_name = model_name.split(".")[-1] + if model_name not in models: + return model_name + return models[model_name] + + +_UNSET = object() + + +class _MyMutableMapping(MutableMapping[str, typing.Any]): # pylint: disable=unsubscriptable-object + def __init__(self, data: typing.Dict[str, typing.Any]) -> None: + self._data = copy.deepcopy(data) + + def __contains__(self, key: typing.Any) -> bool: + return key in self._data + + def __getitem__(self, key: str) -> typing.Any: + return self._data.__getitem__(key) + + def __setitem__(self, key: str, value: typing.Any) -> None: + self._data.__setitem__(key, value) + + def __delitem__(self, key: str) -> None: + self._data.__delitem__(key) + + def __iter__(self) -> typing.Iterator[typing.Any]: + return self._data.__iter__() + + def __len__(self) -> int: + return self._data.__len__() + + def __ne__(self, other: typing.Any) -> bool: + return not self.__eq__(other) + + def keys(self) -> typing.KeysView[str]: + return self._data.keys() + + def values(self) -> typing.ValuesView[typing.Any]: + return self._data.values() + + def items(self) -> typing.ItemsView[str, typing.Any]: + return self._data.items() + + def get(self, key: str, default: typing.Any = None) -> typing.Any: + try: + return self[key] + except KeyError: + return default + + @typing.overload + def pop(self, key: str) -> typing.Any: ... + + @typing.overload + def pop(self, key: str, default: _T) -> _T: ... + + @typing.overload + def pop(self, key: str, default: typing.Any) -> typing.Any: ... + + def pop(self, key: str, default: typing.Any = _UNSET) -> typing.Any: + if default is _UNSET: + return self._data.pop(key) + return self._data.pop(key, default) + + def popitem(self) -> typing.Tuple[str, typing.Any]: + return self._data.popitem() + + def clear(self) -> None: + self._data.clear() + + def update(self, *args: typing.Any, **kwargs: typing.Any) -> None: + self._data.update(*args, **kwargs) + + @typing.overload + def setdefault(self, key: str, default: None = None) -> None: ... + + @typing.overload + def setdefault(self, key: str, default: typing.Any) -> typing.Any: ... + + def setdefault(self, key: str, default: typing.Any = _UNSET) -> typing.Any: + if default is _UNSET: + return self._data.setdefault(key) + return self._data.setdefault(key, default) + + def __eq__(self, other: typing.Any) -> bool: + try: + other_model = self.__class__(other) + except Exception: + return False + return self._data == other_model._data + + def __repr__(self) -> str: + return str(self._data) + + +def _is_model(obj: typing.Any) -> bool: + return getattr(obj, "_is_model", False) + + +def _serialize(o, format: typing.Optional[str] = None): # pylint: disable=too-many-return-statements + if isinstance(o, list): + return [_serialize(x, format) for x in o] + if isinstance(o, dict): + return {k: _serialize(v, format) for k, v in o.items()} + if isinstance(o, set): + return {_serialize(x, format) for x in o} + if isinstance(o, tuple): + return tuple(_serialize(x, format) for x in o) + if isinstance(o, (bytes, bytearray)): + return _serialize_bytes(o, format) + if isinstance(o, decimal.Decimal): + return float(o) + if isinstance(o, enum.Enum): + return o.value + try: + # First try datetime.datetime + return _serialize_datetime(o, format) + except AttributeError: + pass + # Last, try datetime.timedelta + try: + return _timedelta_as_isostr(o) + except AttributeError: + # This will be raised when it hits value.total_seconds in the method above + pass + return o + + +def _get_rest_field( + attr_to_rest_field: typing.Dict[str, "_RestField"], rest_name: str +) -> typing.Optional["_RestField"]: + try: + return next(rf for rf in attr_to_rest_field.values() if rf._rest_name == rest_name) + except StopIteration: + return None + + +def _create_value(rf: typing.Optional["_RestField"], value: typing.Any) -> typing.Any: + if not rf: + return _serialize(value, None) + if rf._is_multipart_file_input: + return value + if rf._is_model: + return _deserialize(rf._type, value) + return _serialize(value, rf._format) + + +class Model(_MyMutableMapping): + _is_model = True + + def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: + class_name = self.__class__.__name__ + if len(args) > 1: + raise TypeError(f"{class_name}.__init__() takes 2 positional arguments but {len(args) + 1} were given") + dict_to_pass = { + rest_field._rest_name: rest_field._default + for rest_field in self._attr_to_rest_field.values() + if rest_field._default is not _UNSET + } + if args: + dict_to_pass.update( + {k: _create_value(_get_rest_field(self._attr_to_rest_field, k), v) for k, v in args[0].items()} + ) + else: + non_attr_kwargs = [k for k in kwargs if k not in self._attr_to_rest_field] + if non_attr_kwargs: + # actual type errors only throw the first wrong keyword arg they see, so following that. + raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") + dict_to_pass.update( + { + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) + for k, v in kwargs.items() + if v is not None + } + ) + super().__init__(dict_to_pass) + + def copy(self) -> "Model": + return Model(self.__dict__) + + def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> Self: + # we know the last three classes in mro are going to be 'Model', 'dict', and 'object' + mros = cls.__mro__[:-3][::-1] # ignore model, dict, and object parents, and reverse the mro order + attr_to_rest_field: typing.Dict[str, _RestField] = { # map attribute name to rest_field property + k: v for mro_class in mros for k, v in mro_class.__dict__.items() if k[0] != "_" and hasattr(v, "_type") + } + annotations = { + k: v + for mro_class in mros + if hasattr(mro_class, "__annotations__") + for k, v in mro_class.__annotations__.items() + } + for attr, rf in attr_to_rest_field.items(): + rf._module = cls.__module__ + if not rf._type: + rf._type = rf._get_deserialize_callable_from_annotation(annotations.get(attr, None)) + if not rf._rest_name_input: + rf._rest_name_input = attr + cls._attr_to_rest_field: typing.Dict[str, _RestField] = dict(attr_to_rest_field.items()) + + return super().__new__(cls) # pylint: disable=no-value-for-parameter + + def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: + for base in cls.__bases__: + if hasattr(base, "__mapping__"): + base.__mapping__[discriminator or cls.__name__] = cls # type: ignore + + @classmethod + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: + for v in cls.__dict__.values(): + if isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators: + return v._rest_name # pylint: disable=protected-access + return None + + @classmethod + def _deserialize(cls, data, exist_discriminators): + if not hasattr(cls, "__mapping__"): + return cls(data) + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) + mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pyright: ignore + if mapped_cls == cls: + return cls(data) + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if ( + exclude_readonly + and k in readonly_props # pyright: ignore # pylint: disable=possibly-used-before-assignment + ): + continue + is_multipart_file_input = False + try: + is_multipart_file_input = next( + rf for rf in self._attr_to_rest_field.values() if rf._rest_name == k + )._is_multipart_file_input + except StopIteration: + pass + result[k] = v if is_multipart_file_input else Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, type(NULL)): + return None + if isinstance(v, (list, tuple, set)): + return type(v)(Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v) + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v + + +def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj): + if _is_model(obj): + return obj + return _deserialize(model_deserializer, obj) + + +def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Callable], obj): + if obj is None: + return obj + return _deserialize_with_callable(if_obj_deserializer, obj) + + +def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + +def _deserialize_dict( + value_deserializer: typing.Optional[typing.Callable], + module: typing.Optional[str], + obj: typing.Dict[typing.Any, typing.Any], +): + if obj is None: + return obj + return {k: _deserialize(value_deserializer, v, module) for k, v in obj.items()} + + +def _deserialize_multiple_sequence( + entry_deserializers: typing.List[typing.Optional[typing.Callable]], + module: typing.Optional[str], + obj, +): + if obj is None: + return obj + return type(obj)(_deserialize(deserializer, entry, module) for entry, deserializer in zip(obj, entry_deserializers)) + + +def _deserialize_sequence( + deserializer: typing.Optional[typing.Callable], + module: typing.Optional[str], + obj, +): + if obj is None: + return obj + return type(obj)(_deserialize(deserializer, entry, module) for entry in obj) + + +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911 + annotation: typing.Any, + module: typing.Optional[str], + rf: typing.Optional["_RestField"] = None, +) -> typing.Optional[typing.Callable[[typing.Any], typing.Any]]: + if not annotation or annotation in [int, float]: + return None + + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + + try: + if module and _is_model(annotation): + if rf: + rf._is_model = True + + return functools.partial(_deserialize_model, annotation) # pyright: ignore + except Exception: + pass + + # is it a literal? + try: + if annotation.__origin__ is typing.Literal: # pyright: ignore + return None + except AttributeError: + pass + + # is it optional? + try: + if any(a for a in annotation.__args__ if a == type(None)): # pyright: ignore + if_obj_deserializer = _get_deserialize_callable_from_annotation( + next(a for a in annotation.__args__ if a != type(None)), module, rf # pyright: ignore + ) + + return functools.partial(_deserialize_with_optional, if_obj_deserializer) + except AttributeError: + pass + + if getattr(annotation, "__origin__", None) is typing.Union: + # initial ordering is we make `string` the last deserialization option, because it is often them most generic + deserializers = [ + _get_deserialize_callable_from_annotation(arg, module, rf) + for arg in sorted( + annotation.__args__, key=lambda x: hasattr(x, "__name__") and x.__name__ == "str" # pyright: ignore + ) + ] + + return functools.partial(_deserialize_with_union, deserializers) + + try: + if annotation._name == "Dict": # pyright: ignore + value_deserializer = _get_deserialize_callable_from_annotation( + annotation.__args__[1], module, rf # pyright: ignore + ) + + return functools.partial( + _deserialize_dict, + value_deserializer, + module, + ) + except (AttributeError, IndexError): + pass + try: + if annotation._name in ["List", "Set", "Tuple", "Sequence"]: # pyright: ignore + if len(annotation.__args__) > 1: # pyright: ignore + + entry_deserializers = [ + _get_deserialize_callable_from_annotation(dt, module, rf) + for dt in annotation.__args__ # pyright: ignore + ] + return functools.partial(_deserialize_multiple_sequence, entry_deserializers, module) + deserializer = _get_deserialize_callable_from_annotation( + annotation.__args__[0], module, rf # pyright: ignore + ) + + return functools.partial(_deserialize_sequence, deserializer, module) + except (TypeError, IndexError, AttributeError, SyntaxError): + pass + + def _deserialize_default( + deserializer, + obj, + ): + if obj is None: + return obj + try: + return _deserialize_with_callable(deserializer, obj) + except Exception: + pass + return obj + + if get_deserializer(annotation, rf): + return functools.partial(_deserialize_default, get_deserializer(annotation, rf)) + + return functools.partial(_deserialize_default, annotation) + + +def _deserialize_with_callable( + deserializer: typing.Optional[typing.Callable[[typing.Any], typing.Any]], + value: typing.Any, +): + try: + if value is None or isinstance(value, type(NULL)): + return None + if deserializer is None: + return value + if isinstance(deserializer, CaseInsensitiveEnumMeta): + try: + return deserializer(value) + except ValueError: + # for unknown value, return raw value + return value + if isinstance(deserializer, type) and issubclass(deserializer, Model): + return deserializer._deserialize(value, []) + return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) + except Exception as e: + raise DeserializationError() from e + + +def _deserialize( + deserializer: typing.Any, + value: typing.Any, + module: typing.Optional[str] = None, + rf: typing.Optional["_RestField"] = None, + format: typing.Optional[str] = None, +) -> typing.Any: + if isinstance(value, PipelineResponse): + value = value.http_response.json() + if rf is None and format: + rf = _RestField(format=format) + if not isinstance(deserializer, functools.partial): + deserializer = _get_deserialize_callable_from_annotation(deserializer, module, rf) + return _deserialize_with_callable(deserializer, value) + + +class _RestField: + def __init__( + self, + *, + name: typing.Optional[str] = None, + type: typing.Optional[typing.Callable] = None, # pylint: disable=redefined-builtin + is_discriminator: bool = False, + visibility: typing.Optional[typing.List[str]] = None, + default: typing.Any = _UNSET, + format: typing.Optional[str] = None, + is_multipart_file_input: bool = False, + is_required: bool = False, + ): + self._type = type + self._rest_name_input = name + self._module: typing.Optional[str] = None + self._is_discriminator = is_discriminator + self._visibility = visibility + self._is_model = False + self._default = default + self._format = format + self._is_multipart_file_input = is_multipart_file_input + self._is_required = is_required + + @property + def _class_type(self) -> typing.Any: + return getattr(self._type, "args", [None])[0] + + @property + def _rest_name(self) -> str: + if self._rest_name_input is None: + raise ValueError("Rest name was never set") + return self._rest_name_input + + def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin + # by this point, type and rest_name will have a value bc we default + # them in __new__ of the Model class + item = obj.get(self._rest_name) + if item is None: + return item + if self._is_model: + return item + return _deserialize(self._type, _serialize(item, self._format), rf=self) + + def __set__(self, obj: Model, value) -> None: + if value is None: + # we want to wipe out entries if users set attr to None + try: + obj.__delitem__(self._rest_name) + except KeyError: + pass + return + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return + obj.__setitem__(self._rest_name, _serialize(value, self._format)) + + def _get_deserialize_callable_from_annotation( + self, annotation: typing.Any + ) -> typing.Optional[typing.Callable[[typing.Any], typing.Any]]: + return _get_deserialize_callable_from_annotation(annotation, self._module, self) + + +def rest_field( + *, + name: typing.Optional[str] = None, + type: typing.Optional[typing.Callable] = None, # pylint: disable=redefined-builtin + visibility: typing.Optional[typing.List[str]] = None, + default: typing.Any = _UNSET, + format: typing.Optional[str] = None, + is_multipart_file_input: bool = False, + is_required: bool = False, +) -> typing.Any: + return _RestField( + name=name, + type=type, + visibility=visibility, + default=default, + format=format, + is_multipart_file_input=is_multipart_file_input, + is_required=is_required, + ) + + +def rest_discriminator( + *, + name: typing.Optional[str] = None, + type: typing.Optional[typing.Callable] = None, # pylint: disable=redefined-builtin +) -> typing.Any: + return _RestField(name=name, type=type, is_discriminator=True) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/_patch.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/_patch.py new file mode 100644 index 00000000..f7dd3251 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/_patch.py @@ -0,0 +1,20 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +"""Customize generated code here. + +Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize +""" +from typing import List + +__all__: List[str] = [] # Add all objects you want publicly available to users at this package level + + +def patch_sdk(): + """Do not remove from this file. + + `patch_sdk` is a last resort escape hatch that allows you to do customizations + you can't accomplish using the techniques described in + https://aka.ms/azsdk/python/dpcodegen/python/customize + """ diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/_serialization.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/_serialization.py new file mode 100644 index 00000000..2f781d74 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/_serialization.py @@ -0,0 +1,1998 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- + +# pylint: skip-file +# pyright: reportUnnecessaryTypeIgnoreComment=false + +from base64 import b64decode, b64encode +import calendar +import datetime +import decimal +import email +from enum import Enum +import json +import logging +import re +import sys +import codecs +from typing import ( + Dict, + Any, + cast, + Optional, + Union, + AnyStr, + IO, + Mapping, + Callable, + TypeVar, + MutableMapping, + Type, + List, + Mapping, +) + +try: + from urllib import quote # type: ignore +except ImportError: + from urllib.parse import quote +import xml.etree.ElementTree as ET + +import isodate # type: ignore + +from azure.core.exceptions import DeserializationError, SerializationError +from azure.core.serialization import NULL as CoreNull + +_BOM = codecs.BOM_UTF8.decode(encoding="utf-8") + +ModelType = TypeVar("ModelType", bound="Model") +JSON = MutableMapping[str, Any] + + +class RawDeserializer: + + # Accept "text" because we're open minded people... + JSON_REGEXP = re.compile(r"^(application|text)/([a-z+.]+\+)?json$") + + # Name used in context + CONTEXT_NAME = "deserialized_data" + + @classmethod + def deserialize_from_text(cls, data: Optional[Union[AnyStr, IO]], content_type: Optional[str] = None) -> Any: + """Decode data according to content-type. + + Accept a stream of data as well, but will be load at once in memory for now. + + If no content-type, will return the string version (not bytes, not stream) + + :param data: Input, could be bytes or stream (will be decoded with UTF8) or text + :type data: str or bytes or IO + :param str content_type: The content type. + """ + if hasattr(data, "read"): + # Assume a stream + data = cast(IO, data).read() + + if isinstance(data, bytes): + data_as_str = data.decode(encoding="utf-8-sig") + else: + # Explain to mypy the correct type. + data_as_str = cast(str, data) + + # Remove Byte Order Mark if present in string + data_as_str = data_as_str.lstrip(_BOM) + + if content_type is None: + return data + + if cls.JSON_REGEXP.match(content_type): + try: + return json.loads(data_as_str) + except ValueError as err: + raise DeserializationError("JSON is invalid: {}".format(err), err) + elif "xml" in (content_type or []): + try: + + try: + if isinstance(data, unicode): # type: ignore + # If I'm Python 2.7 and unicode XML will scream if I try a "fromstring" on unicode string + data_as_str = data_as_str.encode(encoding="utf-8") # type: ignore + except NameError: + pass + + return ET.fromstring(data_as_str) # nosec + except ET.ParseError as err: + # It might be because the server has an issue, and returned JSON with + # content-type XML.... + # So let's try a JSON load, and if it's still broken + # let's flow the initial exception + def _json_attemp(data): + try: + return True, json.loads(data) + except ValueError: + return False, None # Don't care about this one + + success, json_result = _json_attemp(data) + if success: + return json_result + # If i'm here, it's not JSON, it's not XML, let's scream + # and raise the last context in this block (the XML exception) + # The function hack is because Py2.7 messes up with exception + # context otherwise. + _LOGGER.critical("Wasn't XML not JSON, failing") + raise DeserializationError("XML is invalid") from err + raise DeserializationError("Cannot deserialize content-type: {}".format(content_type)) + + @classmethod + def deserialize_from_http_generics(cls, body_bytes: Optional[Union[AnyStr, IO]], headers: Mapping) -> Any: + """Deserialize from HTTP response. + + Use bytes and headers to NOT use any requests/aiohttp or whatever + specific implementation. + Headers will tested for "content-type" + """ + # Try to use content-type from headers if available + content_type = None + if "content-type" in headers: + content_type = headers["content-type"].split(";")[0].strip().lower() + # Ouch, this server did not declare what it sent... + # Let's guess it's JSON... + # Also, since Autorest was considering that an empty body was a valid JSON, + # need that test as well.... + else: + content_type = "application/json" + + if body_bytes: + return cls.deserialize_from_text(body_bytes, content_type) + return None + + +_LOGGER = logging.getLogger(__name__) + +try: + _long_type = long # type: ignore +except NameError: + _long_type = int + + +class UTC(datetime.tzinfo): + """Time Zone info for handling UTC""" + + def utcoffset(self, dt): + """UTF offset for UTC is 0.""" + return datetime.timedelta(0) + + def tzname(self, dt): + """Timestamp representation.""" + return "Z" + + def dst(self, dt): + """No daylight saving for UTC.""" + return datetime.timedelta(hours=1) + + +try: + from datetime import timezone as _FixedOffset # type: ignore +except ImportError: # Python 2.7 + + class _FixedOffset(datetime.tzinfo): # type: ignore + """Fixed offset in minutes east from UTC. + Copy/pasted from Python doc + :param datetime.timedelta offset: offset in timedelta format + """ + + def __init__(self, offset): + self.__offset = offset + + def utcoffset(self, dt): + return self.__offset + + def tzname(self, dt): + return str(self.__offset.total_seconds() / 3600) + + def __repr__(self): + return "<FixedOffset {}>".format(self.tzname(None)) + + def dst(self, dt): + return datetime.timedelta(0) + + def __getinitargs__(self): + return (self.__offset,) + + +try: + from datetime import timezone + + TZ_UTC = timezone.utc +except ImportError: + TZ_UTC = UTC() # type: ignore + +_FLATTEN = re.compile(r"(?<!\\)\.") + + +def attribute_transformer(key, attr_desc, value): + """A key transformer that returns the Python attribute. + + :param str key: The attribute name + :param dict attr_desc: The attribute metadata + :param object value: The value + :returns: A key using attribute name + """ + return (key, value) + + +def full_restapi_key_transformer(key, attr_desc, value): + """A key transformer that returns the full RestAPI key path. + + :param str _: The attribute name + :param dict attr_desc: The attribute metadata + :param object value: The value + :returns: A list of keys using RestAPI syntax. + """ + keys = _FLATTEN.split(attr_desc["key"]) + return ([_decode_attribute_map_key(k) for k in keys], value) + + +def last_restapi_key_transformer(key, attr_desc, value): + """A key transformer that returns the last RestAPI key. + + :param str key: The attribute name + :param dict attr_desc: The attribute metadata + :param object value: The value + :returns: The last RestAPI key. + """ + key, value = full_restapi_key_transformer(key, attr_desc, value) + return (key[-1], value) + + +def _create_xml_node(tag, prefix=None, ns=None): + """Create a XML node.""" + if prefix and ns: + ET.register_namespace(prefix, ns) + if ns: + return ET.Element("{" + ns + "}" + tag) + else: + return ET.Element(tag) + + +class Model(object): + """Mixin for all client request body/response body models to support + serialization and deserialization. + """ + + _subtype_map: Dict[str, Dict[str, Any]] = {} + _attribute_map: Dict[str, Dict[str, Any]] = {} + _validation: Dict[str, Dict[str, Any]] = {} + + def __init__(self, **kwargs: Any) -> None: + self.additional_properties: Optional[Dict[str, Any]] = {} + for k in kwargs: + if k not in self._attribute_map: + _LOGGER.warning("%s is not a known attribute of class %s and will be ignored", k, self.__class__) + elif k in self._validation and self._validation[k].get("readonly", False): + _LOGGER.warning("Readonly attribute %s will be ignored in class %s", k, self.__class__) + else: + setattr(self, k, kwargs[k]) + + def __eq__(self, other: Any) -> bool: + """Compare objects by comparing all attributes.""" + if isinstance(other, self.__class__): + return self.__dict__ == other.__dict__ + return False + + def __ne__(self, other: Any) -> bool: + """Compare objects by comparing all attributes.""" + return not self.__eq__(other) + + def __str__(self) -> str: + return str(self.__dict__) + + @classmethod + def enable_additional_properties_sending(cls) -> None: + cls._attribute_map["additional_properties"] = {"key": "", "type": "{object}"} + + @classmethod + def is_xml_model(cls) -> bool: + try: + cls._xml_map # type: ignore + except AttributeError: + return False + return True + + @classmethod + def _create_xml_node(cls): + """Create XML node.""" + try: + xml_map = cls._xml_map # type: ignore + except AttributeError: + xml_map = {} + + return _create_xml_node(xml_map.get("name", cls.__name__), xml_map.get("prefix", None), xml_map.get("ns", None)) + + def serialize(self, keep_readonly: bool = False, **kwargs: Any) -> JSON: + """Return the JSON that would be sent to server from this model. + + This is an alias to `as_dict(full_restapi_key_transformer, keep_readonly=False)`. + + If you want XML serialization, you can pass the kwargs is_xml=True. + + :param bool keep_readonly: If you want to serialize the readonly attributes + :returns: A dict JSON compatible object + :rtype: dict + """ + serializer = Serializer(self._infer_class_models()) + return serializer._serialize(self, keep_readonly=keep_readonly, **kwargs) # type: ignore + + def as_dict( + self, + keep_readonly: bool = True, + key_transformer: Callable[[str, Dict[str, Any], Any], Any] = attribute_transformer, + **kwargs: Any + ) -> JSON: + """Return a dict that can be serialized using json.dump. + + Advanced usage might optionally use a callback as parameter: + + .. code::python + + def my_key_transformer(key, attr_desc, value): + return key + + Key is the attribute name used in Python. Attr_desc + is a dict of metadata. Currently contains 'type' with the + msrest type and 'key' with the RestAPI encoded key. + Value is the current value in this object. + + The string returned will be used to serialize the key. + If the return type is a list, this is considered hierarchical + result dict. + + See the three examples in this file: + + - attribute_transformer + - full_restapi_key_transformer + - last_restapi_key_transformer + + If you want XML serialization, you can pass the kwargs is_xml=True. + + :param function key_transformer: A key transformer function. + :returns: A dict JSON compatible object + :rtype: dict + """ + serializer = Serializer(self._infer_class_models()) + return serializer._serialize(self, key_transformer=key_transformer, keep_readonly=keep_readonly, **kwargs) # type: ignore + + @classmethod + def _infer_class_models(cls): + try: + str_models = cls.__module__.rsplit(".", 1)[0] + models = sys.modules[str_models] + client_models = {k: v for k, v in models.__dict__.items() if isinstance(v, type)} + if cls.__name__ not in client_models: + raise ValueError("Not Autorest generated code") + except Exception: + # Assume it's not Autorest generated (tests?). Add ourselves as dependencies. + client_models = {cls.__name__: cls} + return client_models + + @classmethod + def deserialize(cls: Type[ModelType], data: Any, content_type: Optional[str] = None) -> ModelType: + """Parse a str using the RestAPI syntax and return a model. + + :param str data: A str using RestAPI structure. JSON by default. + :param str content_type: JSON by default, set application/xml if XML. + :returns: An instance of this model + :raises: DeserializationError if something went wrong + """ + deserializer = Deserializer(cls._infer_class_models()) + return deserializer(cls.__name__, data, content_type=content_type) # type: ignore + + @classmethod + def from_dict( + cls: Type[ModelType], + data: Any, + key_extractors: Optional[Callable[[str, Dict[str, Any], Any], Any]] = None, + content_type: Optional[str] = None, + ) -> ModelType: + """Parse a dict using given key extractor return a model. + + By default consider key + extractors (rest_key_case_insensitive_extractor, attribute_key_case_insensitive_extractor + and last_rest_key_case_insensitive_extractor) + + :param dict data: A dict using RestAPI structure + :param str content_type: JSON by default, set application/xml if XML. + :returns: An instance of this model + :raises: DeserializationError if something went wrong + """ + deserializer = Deserializer(cls._infer_class_models()) + deserializer.key_extractors = ( # type: ignore + [ # type: ignore + attribute_key_case_insensitive_extractor, + rest_key_case_insensitive_extractor, + last_rest_key_case_insensitive_extractor, + ] + if key_extractors is None + else key_extractors + ) + return deserializer(cls.__name__, data, content_type=content_type) # type: ignore + + @classmethod + def _flatten_subtype(cls, key, objects): + if "_subtype_map" not in cls.__dict__: + return {} + result = dict(cls._subtype_map[key]) + for valuetype in cls._subtype_map[key].values(): + result.update(objects[valuetype]._flatten_subtype(key, objects)) + return result + + @classmethod + def _classify(cls, response, objects): + """Check the class _subtype_map for any child classes. + We want to ignore any inherited _subtype_maps. + Remove the polymorphic key from the initial data. + """ + for subtype_key in cls.__dict__.get("_subtype_map", {}).keys(): + subtype_value = None + + if not isinstance(response, ET.Element): + rest_api_response_key = cls._get_rest_key_parts(subtype_key)[-1] + subtype_value = response.pop(rest_api_response_key, None) or response.pop(subtype_key, None) + else: + subtype_value = xml_key_extractor(subtype_key, cls._attribute_map[subtype_key], response) + if subtype_value: + # Try to match base class. Can be class name only + # (bug to fix in Autorest to support x-ms-discriminator-name) + if cls.__name__ == subtype_value: + return cls + flatten_mapping_type = cls._flatten_subtype(subtype_key, objects) + try: + return objects[flatten_mapping_type[subtype_value]] # type: ignore + except KeyError: + _LOGGER.warning( + "Subtype value %s has no mapping, use base class %s.", + subtype_value, + cls.__name__, + ) + break + else: + _LOGGER.warning("Discriminator %s is absent or null, use base class %s.", subtype_key, cls.__name__) + break + return cls + + @classmethod + def _get_rest_key_parts(cls, attr_key): + """Get the RestAPI key of this attr, split it and decode part + :param str attr_key: Attribute key must be in attribute_map. + :returns: A list of RestAPI part + :rtype: list + """ + rest_split_key = _FLATTEN.split(cls._attribute_map[attr_key]["key"]) + return [_decode_attribute_map_key(key_part) for key_part in rest_split_key] + + +def _decode_attribute_map_key(key): + """This decode a key in an _attribute_map to the actual key we want to look at + inside the received data. + + :param str key: A key string from the generated code + """ + return key.replace("\\.", ".") + + +class Serializer(object): + """Request object model serializer.""" + + basic_types = {str: "str", int: "int", bool: "bool", float: "float"} + + _xml_basic_types_serializers = {"bool": lambda x: str(x).lower()} + days = {0: "Mon", 1: "Tue", 2: "Wed", 3: "Thu", 4: "Fri", 5: "Sat", 6: "Sun"} + months = { + 1: "Jan", + 2: "Feb", + 3: "Mar", + 4: "Apr", + 5: "May", + 6: "Jun", + 7: "Jul", + 8: "Aug", + 9: "Sep", + 10: "Oct", + 11: "Nov", + 12: "Dec", + } + validation = { + "min_length": lambda x, y: len(x) < y, + "max_length": lambda x, y: len(x) > y, + "minimum": lambda x, y: x < y, + "maximum": lambda x, y: x > y, + "minimum_ex": lambda x, y: x <= y, + "maximum_ex": lambda x, y: x >= y, + "min_items": lambda x, y: len(x) < y, + "max_items": lambda x, y: len(x) > y, + "pattern": lambda x, y: not re.match(y, x, re.UNICODE), + "unique": lambda x, y: len(x) != len(set(x)), + "multiple": lambda x, y: x % y != 0, + } + + def __init__(self, classes: Optional[Mapping[str, type]] = None): + self.serialize_type = { + "iso-8601": Serializer.serialize_iso, + "rfc-1123": Serializer.serialize_rfc, + "unix-time": Serializer.serialize_unix, + "duration": Serializer.serialize_duration, + "date": Serializer.serialize_date, + "time": Serializer.serialize_time, + "decimal": Serializer.serialize_decimal, + "long": Serializer.serialize_long, + "bytearray": Serializer.serialize_bytearray, + "base64": Serializer.serialize_base64, + "object": self.serialize_object, + "[]": self.serialize_iter, + "{}": self.serialize_dict, + } + self.dependencies: Dict[str, type] = dict(classes) if classes else {} + self.key_transformer = full_restapi_key_transformer + self.client_side_validation = True + + def _serialize(self, target_obj, data_type=None, **kwargs): + """Serialize data into a string according to type. + + :param target_obj: The data to be serialized. + :param str data_type: The type to be serialized from. + :rtype: str, dict + :raises: SerializationError if serialization fails. + """ + key_transformer = kwargs.get("key_transformer", self.key_transformer) + keep_readonly = kwargs.get("keep_readonly", False) + if target_obj is None: + return None + + attr_name = None + class_name = target_obj.__class__.__name__ + + if data_type: + return self.serialize_data(target_obj, data_type, **kwargs) + + if not hasattr(target_obj, "_attribute_map"): + data_type = type(target_obj).__name__ + if data_type in self.basic_types.values(): + return self.serialize_data(target_obj, data_type, **kwargs) + + # Force "is_xml" kwargs if we detect a XML model + try: + is_xml_model_serialization = kwargs["is_xml"] + except KeyError: + is_xml_model_serialization = kwargs.setdefault("is_xml", target_obj.is_xml_model()) + + serialized = {} + if is_xml_model_serialization: + serialized = target_obj._create_xml_node() + try: + attributes = target_obj._attribute_map + for attr, attr_desc in attributes.items(): + attr_name = attr + if not keep_readonly and target_obj._validation.get(attr_name, {}).get("readonly", False): + continue + + if attr_name == "additional_properties" and attr_desc["key"] == "": + if target_obj.additional_properties is not None: + serialized.update(target_obj.additional_properties) + continue + try: + + orig_attr = getattr(target_obj, attr) + if is_xml_model_serialization: + pass # Don't provide "transformer" for XML for now. Keep "orig_attr" + else: # JSON + keys, orig_attr = key_transformer(attr, attr_desc.copy(), orig_attr) + keys = keys if isinstance(keys, list) else [keys] + + kwargs["serialization_ctxt"] = attr_desc + new_attr = self.serialize_data(orig_attr, attr_desc["type"], **kwargs) + + if is_xml_model_serialization: + xml_desc = attr_desc.get("xml", {}) + xml_name = xml_desc.get("name", attr_desc["key"]) + xml_prefix = xml_desc.get("prefix", None) + xml_ns = xml_desc.get("ns", None) + if xml_desc.get("attr", False): + if xml_ns: + ET.register_namespace(xml_prefix, xml_ns) + xml_name = "{{{}}}{}".format(xml_ns, xml_name) + serialized.set(xml_name, new_attr) # type: ignore + continue + if xml_desc.get("text", False): + serialized.text = new_attr # type: ignore + continue + if isinstance(new_attr, list): + serialized.extend(new_attr) # type: ignore + elif isinstance(new_attr, ET.Element): + # If the down XML has no XML/Name, we MUST replace the tag with the local tag. But keeping the namespaces. + if "name" not in getattr(orig_attr, "_xml_map", {}): + splitted_tag = new_attr.tag.split("}") + if len(splitted_tag) == 2: # Namespace + new_attr.tag = "}".join([splitted_tag[0], xml_name]) + else: + new_attr.tag = xml_name + serialized.append(new_attr) # type: ignore + else: # That's a basic type + # Integrate namespace if necessary + local_node = _create_xml_node(xml_name, xml_prefix, xml_ns) + local_node.text = str(new_attr) + serialized.append(local_node) # type: ignore + else: # JSON + for k in reversed(keys): # type: ignore + new_attr = {k: new_attr} + + _new_attr = new_attr + _serialized = serialized + for k in keys: # type: ignore + if k not in _serialized: + _serialized.update(_new_attr) # type: ignore + _new_attr = _new_attr[k] # type: ignore + _serialized = _serialized[k] + except ValueError as err: + if isinstance(err, SerializationError): + raise + + except (AttributeError, KeyError, TypeError) as err: + msg = "Attribute {} in object {} cannot be serialized.\n{}".format(attr_name, class_name, str(target_obj)) + raise SerializationError(msg) from err + else: + return serialized + + def body(self, data, data_type, **kwargs): + """Serialize data intended for a request body. + + :param data: The data to be serialized. + :param str data_type: The type to be serialized from. + :rtype: dict + :raises: SerializationError if serialization fails. + :raises: ValueError if data is None + """ + + # Just in case this is a dict + internal_data_type_str = data_type.strip("[]{}") + internal_data_type = self.dependencies.get(internal_data_type_str, None) + try: + is_xml_model_serialization = kwargs["is_xml"] + except KeyError: + if internal_data_type and issubclass(internal_data_type, Model): + is_xml_model_serialization = kwargs.setdefault("is_xml", internal_data_type.is_xml_model()) + else: + is_xml_model_serialization = False + if internal_data_type and not isinstance(internal_data_type, Enum): + try: + deserializer = Deserializer(self.dependencies) + # Since it's on serialization, it's almost sure that format is not JSON REST + # We're not able to deal with additional properties for now. + deserializer.additional_properties_detection = False + if is_xml_model_serialization: + deserializer.key_extractors = [ # type: ignore + attribute_key_case_insensitive_extractor, + ] + else: + deserializer.key_extractors = [ + rest_key_case_insensitive_extractor, + attribute_key_case_insensitive_extractor, + last_rest_key_case_insensitive_extractor, + ] + data = deserializer._deserialize(data_type, data) + except DeserializationError as err: + raise SerializationError("Unable to build a model: " + str(err)) from err + + return self._serialize(data, data_type, **kwargs) + + def url(self, name, data, data_type, **kwargs): + """Serialize data intended for a URL path. + + :param data: The data to be serialized. + :param str data_type: The type to be serialized from. + :rtype: str + :raises: TypeError if serialization fails. + :raises: ValueError if data is None + """ + try: + output = self.serialize_data(data, data_type, **kwargs) + if data_type == "bool": + output = json.dumps(output) + + if kwargs.get("skip_quote") is True: + output = str(output) + output = output.replace("{", quote("{")).replace("}", quote("}")) + else: + output = quote(str(output), safe="") + except SerializationError: + raise TypeError("{} must be type {}.".format(name, data_type)) + else: + return output + + def query(self, name, data, data_type, **kwargs): + """Serialize data intended for a URL query. + + :param data: The data to be serialized. + :param str data_type: The type to be serialized from. + :keyword bool skip_quote: Whether to skip quote the serialized result. + Defaults to False. + :rtype: str, list + :raises: TypeError if serialization fails. + :raises: ValueError if data is None + """ + try: + # Treat the list aside, since we don't want to encode the div separator + if data_type.startswith("["): + internal_data_type = data_type[1:-1] + do_quote = not kwargs.get("skip_quote", False) + return self.serialize_iter(data, internal_data_type, do_quote=do_quote, **kwargs) + + # Not a list, regular serialization + output = self.serialize_data(data, data_type, **kwargs) + if data_type == "bool": + output = json.dumps(output) + if kwargs.get("skip_quote") is True: + output = str(output) + else: + output = quote(str(output), safe="") + except SerializationError: + raise TypeError("{} must be type {}.".format(name, data_type)) + else: + return str(output) + + def header(self, name, data, data_type, **kwargs): + """Serialize data intended for a request header. + + :param data: The data to be serialized. + :param str data_type: The type to be serialized from. + :rtype: str + :raises: TypeError if serialization fails. + :raises: ValueError if data is None + """ + try: + if data_type in ["[str]"]: + data = ["" if d is None else d for d in data] + + output = self.serialize_data(data, data_type, **kwargs) + if data_type == "bool": + output = json.dumps(output) + except SerializationError: + raise TypeError("{} must be type {}.".format(name, data_type)) + else: + return str(output) + + def serialize_data(self, data, data_type, **kwargs): + """Serialize generic data according to supplied data type. + + :param data: The data to be serialized. + :param str data_type: The type to be serialized from. + :param bool required: Whether it's essential that the data not be + empty or None + :raises: AttributeError if required data is None. + :raises: ValueError if data is None + :raises: SerializationError if serialization fails. + """ + if data is None: + raise ValueError("No value for given attribute") + + try: + if data is CoreNull: + return None + if data_type in self.basic_types.values(): + return self.serialize_basic(data, data_type, **kwargs) + + elif data_type in self.serialize_type: + return self.serialize_type[data_type](data, **kwargs) + + # If dependencies is empty, try with current data class + # It has to be a subclass of Enum anyway + enum_type = self.dependencies.get(data_type, data.__class__) + if issubclass(enum_type, Enum): + return Serializer.serialize_enum(data, enum_obj=enum_type) + + iter_type = data_type[0] + data_type[-1] + if iter_type in self.serialize_type: + return self.serialize_type[iter_type](data, data_type[1:-1], **kwargs) + + except (ValueError, TypeError) as err: + msg = "Unable to serialize value: {!r} as type: {!r}." + raise SerializationError(msg.format(data, data_type)) from err + else: + return self._serialize(data, **kwargs) + + @classmethod + def _get_custom_serializers(cls, data_type, **kwargs): + custom_serializer = kwargs.get("basic_types_serializers", {}).get(data_type) + if custom_serializer: + return custom_serializer + if kwargs.get("is_xml", False): + return cls._xml_basic_types_serializers.get(data_type) + + @classmethod + def serialize_basic(cls, data, data_type, **kwargs): + """Serialize basic builting data type. + Serializes objects to str, int, float or bool. + + Possible kwargs: + - basic_types_serializers dict[str, callable] : If set, use the callable as serializer + - is_xml bool : If set, use xml_basic_types_serializers + + :param data: Object to be serialized. + :param str data_type: Type of object in the iterable. + """ + custom_serializer = cls._get_custom_serializers(data_type, **kwargs) + if custom_serializer: + return custom_serializer(data) + if data_type == "str": + return cls.serialize_unicode(data) + return eval(data_type)(data) # nosec + + @classmethod + def serialize_unicode(cls, data): + """Special handling for serializing unicode strings in Py2. + Encode to UTF-8 if unicode, otherwise handle as a str. + + :param data: Object to be serialized. + :rtype: str + """ + try: # If I received an enum, return its value + return data.value + except AttributeError: + pass + + try: + if isinstance(data, unicode): # type: ignore + # Don't change it, JSON and XML ElementTree are totally able + # to serialize correctly u'' strings + return data + except NameError: + return str(data) + else: + return str(data) + + def serialize_iter(self, data, iter_type, div=None, **kwargs): + """Serialize iterable. + + Supported kwargs: + - serialization_ctxt dict : The current entry of _attribute_map, or same format. + serialization_ctxt['type'] should be same as data_type. + - is_xml bool : If set, serialize as XML + + :param list attr: Object to be serialized. + :param str iter_type: Type of object in the iterable. + :param bool required: Whether the objects in the iterable must + not be None or empty. + :param str div: If set, this str will be used to combine the elements + in the iterable into a combined string. Default is 'None'. + :keyword bool do_quote: Whether to quote the serialized result of each iterable element. + Defaults to False. + :rtype: list, str + """ + if isinstance(data, str): + raise SerializationError("Refuse str type as a valid iter type.") + + serialization_ctxt = kwargs.get("serialization_ctxt", {}) + is_xml = kwargs.get("is_xml", False) + + serialized = [] + for d in data: + try: + serialized.append(self.serialize_data(d, iter_type, **kwargs)) + except ValueError as err: + if isinstance(err, SerializationError): + raise + serialized.append(None) + + if kwargs.get("do_quote", False): + serialized = ["" if s is None else quote(str(s), safe="") for s in serialized] + + if div: + serialized = ["" if s is None else str(s) for s in serialized] + serialized = div.join(serialized) + + if "xml" in serialization_ctxt or is_xml: + # XML serialization is more complicated + xml_desc = serialization_ctxt.get("xml", {}) + xml_name = xml_desc.get("name") + if not xml_name: + xml_name = serialization_ctxt["key"] + + # Create a wrap node if necessary (use the fact that Element and list have "append") + is_wrapped = xml_desc.get("wrapped", False) + node_name = xml_desc.get("itemsName", xml_name) + if is_wrapped: + final_result = _create_xml_node(xml_name, xml_desc.get("prefix", None), xml_desc.get("ns", None)) + else: + final_result = [] + # All list elements to "local_node" + for el in serialized: + if isinstance(el, ET.Element): + el_node = el + else: + el_node = _create_xml_node(node_name, xml_desc.get("prefix", None), xml_desc.get("ns", None)) + if el is not None: # Otherwise it writes "None" :-p + el_node.text = str(el) + final_result.append(el_node) + return final_result + return serialized + + def serialize_dict(self, attr, dict_type, **kwargs): + """Serialize a dictionary of objects. + + :param dict attr: Object to be serialized. + :param str dict_type: Type of object in the dictionary. + :param bool required: Whether the objects in the dictionary must + not be None or empty. + :rtype: dict + """ + serialization_ctxt = kwargs.get("serialization_ctxt", {}) + serialized = {} + for key, value in attr.items(): + try: + serialized[self.serialize_unicode(key)] = self.serialize_data(value, dict_type, **kwargs) + except ValueError as err: + if isinstance(err, SerializationError): + raise + serialized[self.serialize_unicode(key)] = None + + if "xml" in serialization_ctxt: + # XML serialization is more complicated + xml_desc = serialization_ctxt["xml"] + xml_name = xml_desc["name"] + + final_result = _create_xml_node(xml_name, xml_desc.get("prefix", None), xml_desc.get("ns", None)) + for key, value in serialized.items(): + ET.SubElement(final_result, key).text = value + return final_result + + return serialized + + def serialize_object(self, attr, **kwargs): + """Serialize a generic object. + This will be handled as a dictionary. If object passed in is not + a basic type (str, int, float, dict, list) it will simply be + cast to str. + + :param dict attr: Object to be serialized. + :rtype: dict or str + """ + if attr is None: + return None + if isinstance(attr, ET.Element): + return attr + obj_type = type(attr) + if obj_type in self.basic_types: + return self.serialize_basic(attr, self.basic_types[obj_type], **kwargs) + if obj_type is _long_type: + return self.serialize_long(attr) + if obj_type is str: + return self.serialize_unicode(attr) + if obj_type is datetime.datetime: + return self.serialize_iso(attr) + if obj_type is datetime.date: + return self.serialize_date(attr) + if obj_type is datetime.time: + return self.serialize_time(attr) + if obj_type is datetime.timedelta: + return self.serialize_duration(attr) + if obj_type is decimal.Decimal: + return self.serialize_decimal(attr) + + # If it's a model or I know this dependency, serialize as a Model + elif obj_type in self.dependencies.values() or isinstance(attr, Model): + return self._serialize(attr) + + if obj_type == dict: + serialized = {} + for key, value in attr.items(): + try: + serialized[self.serialize_unicode(key)] = self.serialize_object(value, **kwargs) + except ValueError: + serialized[self.serialize_unicode(key)] = None + return serialized + + if obj_type == list: + serialized = [] + for obj in attr: + try: + serialized.append(self.serialize_object(obj, **kwargs)) + except ValueError: + pass + return serialized + return str(attr) + + @staticmethod + def serialize_enum(attr, enum_obj=None): + try: + result = attr.value + except AttributeError: + result = attr + try: + enum_obj(result) # type: ignore + return result + except ValueError: + for enum_value in enum_obj: # type: ignore + if enum_value.value.lower() == str(attr).lower(): + return enum_value.value + error = "{!r} is not valid value for enum {!r}" + raise SerializationError(error.format(attr, enum_obj)) + + @staticmethod + def serialize_bytearray(attr, **kwargs): + """Serialize bytearray into base-64 string. + + :param attr: Object to be serialized. + :rtype: str + """ + return b64encode(attr).decode() + + @staticmethod + def serialize_base64(attr, **kwargs): + """Serialize str into base-64 string. + + :param attr: Object to be serialized. + :rtype: str + """ + encoded = b64encode(attr).decode("ascii") + return encoded.strip("=").replace("+", "-").replace("/", "_") + + @staticmethod + def serialize_decimal(attr, **kwargs): + """Serialize Decimal object to float. + + :param attr: Object to be serialized. + :rtype: float + """ + return float(attr) + + @staticmethod + def serialize_long(attr, **kwargs): + """Serialize long (Py2) or int (Py3). + + :param attr: Object to be serialized. + :rtype: int/long + """ + return _long_type(attr) + + @staticmethod + def serialize_date(attr, **kwargs): + """Serialize Date object into ISO-8601 formatted string. + + :param Date attr: Object to be serialized. + :rtype: str + """ + if isinstance(attr, str): + attr = isodate.parse_date(attr) + t = "{:04}-{:02}-{:02}".format(attr.year, attr.month, attr.day) + return t + + @staticmethod + def serialize_time(attr, **kwargs): + """Serialize Time object into ISO-8601 formatted string. + + :param datetime.time attr: Object to be serialized. + :rtype: str + """ + if isinstance(attr, str): + attr = isodate.parse_time(attr) + t = "{:02}:{:02}:{:02}".format(attr.hour, attr.minute, attr.second) + if attr.microsecond: + t += ".{:02}".format(attr.microsecond) + return t + + @staticmethod + def serialize_duration(attr, **kwargs): + """Serialize TimeDelta object into ISO-8601 formatted string. + + :param TimeDelta attr: Object to be serialized. + :rtype: str + """ + if isinstance(attr, str): + attr = isodate.parse_duration(attr) + return isodate.duration_isoformat(attr) + + @staticmethod + def serialize_rfc(attr, **kwargs): + """Serialize Datetime object into RFC-1123 formatted string. + + :param Datetime attr: Object to be serialized. + :rtype: str + :raises: TypeError if format invalid. + """ + try: + if not attr.tzinfo: + _LOGGER.warning("Datetime with no tzinfo will be considered UTC.") + utc = attr.utctimetuple() + except AttributeError: + raise TypeError("RFC1123 object must be valid Datetime object.") + + return "{}, {:02} {} {:04} {:02}:{:02}:{:02} GMT".format( + Serializer.days[utc.tm_wday], + utc.tm_mday, + Serializer.months[utc.tm_mon], + utc.tm_year, + utc.tm_hour, + utc.tm_min, + utc.tm_sec, + ) + + @staticmethod + def serialize_iso(attr, **kwargs): + """Serialize Datetime object into ISO-8601 formatted string. + + :param Datetime attr: Object to be serialized. + :rtype: str + :raises: SerializationError if format invalid. + """ + if isinstance(attr, str): + attr = isodate.parse_datetime(attr) + try: + if not attr.tzinfo: + _LOGGER.warning("Datetime with no tzinfo will be considered UTC.") + utc = attr.utctimetuple() + if utc.tm_year > 9999 or utc.tm_year < 1: + raise OverflowError("Hit max or min date") + + microseconds = str(attr.microsecond).rjust(6, "0").rstrip("0").ljust(3, "0") + if microseconds: + microseconds = "." + microseconds + date = "{:04}-{:02}-{:02}T{:02}:{:02}:{:02}".format( + utc.tm_year, utc.tm_mon, utc.tm_mday, utc.tm_hour, utc.tm_min, utc.tm_sec + ) + return date + microseconds + "Z" + except (ValueError, OverflowError) as err: + msg = "Unable to serialize datetime object." + raise SerializationError(msg) from err + except AttributeError as err: + msg = "ISO-8601 object must be valid Datetime object." + raise TypeError(msg) from err + + @staticmethod + def serialize_unix(attr, **kwargs): + """Serialize Datetime object into IntTime format. + This is represented as seconds. + + :param Datetime attr: Object to be serialized. + :rtype: int + :raises: SerializationError if format invalid + """ + if isinstance(attr, int): + return attr + try: + if not attr.tzinfo: + _LOGGER.warning("Datetime with no tzinfo will be considered UTC.") + return int(calendar.timegm(attr.utctimetuple())) + except AttributeError: + raise TypeError("Unix time object must be valid Datetime object.") + + +def rest_key_extractor(attr, attr_desc, data): + key = attr_desc["key"] + working_data = data + + while "." in key: + # Need the cast, as for some reasons "split" is typed as list[str | Any] + dict_keys = cast(List[str], _FLATTEN.split(key)) + if len(dict_keys) == 1: + key = _decode_attribute_map_key(dict_keys[0]) + break + working_key = _decode_attribute_map_key(dict_keys[0]) + working_data = working_data.get(working_key, data) + if working_data is None: + # If at any point while following flatten JSON path see None, it means + # that all properties under are None as well + return None + key = ".".join(dict_keys[1:]) + + return working_data.get(key) + + +def rest_key_case_insensitive_extractor(attr, attr_desc, data): + key = attr_desc["key"] + working_data = data + + while "." in key: + dict_keys = _FLATTEN.split(key) + if len(dict_keys) == 1: + key = _decode_attribute_map_key(dict_keys[0]) + break + working_key = _decode_attribute_map_key(dict_keys[0]) + working_data = attribute_key_case_insensitive_extractor(working_key, None, working_data) + if working_data is None: + # If at any point while following flatten JSON path see None, it means + # that all properties under are None as well + return None + key = ".".join(dict_keys[1:]) + + if working_data: + return attribute_key_case_insensitive_extractor(key, None, working_data) + + +def last_rest_key_extractor(attr, attr_desc, data): + """Extract the attribute in "data" based on the last part of the JSON path key.""" + key = attr_desc["key"] + dict_keys = _FLATTEN.split(key) + return attribute_key_extractor(dict_keys[-1], None, data) + + +def last_rest_key_case_insensitive_extractor(attr, attr_desc, data): + """Extract the attribute in "data" based on the last part of the JSON path key. + + This is the case insensitive version of "last_rest_key_extractor" + """ + key = attr_desc["key"] + dict_keys = _FLATTEN.split(key) + return attribute_key_case_insensitive_extractor(dict_keys[-1], None, data) + + +def attribute_key_extractor(attr, _, data): + return data.get(attr) + + +def attribute_key_case_insensitive_extractor(attr, _, data): + found_key = None + lower_attr = attr.lower() + for key in data: + if lower_attr == key.lower(): + found_key = key + break + + return data.get(found_key) + + +def _extract_name_from_internal_type(internal_type): + """Given an internal type XML description, extract correct XML name with namespace. + + :param dict internal_type: An model type + :rtype: tuple + :returns: A tuple XML name + namespace dict + """ + internal_type_xml_map = getattr(internal_type, "_xml_map", {}) + xml_name = internal_type_xml_map.get("name", internal_type.__name__) + xml_ns = internal_type_xml_map.get("ns", None) + if xml_ns: + xml_name = "{{{}}}{}".format(xml_ns, xml_name) + return xml_name + + +def xml_key_extractor(attr, attr_desc, data): + if isinstance(data, dict): + return None + + # Test if this model is XML ready first + if not isinstance(data, ET.Element): + return None + + xml_desc = attr_desc.get("xml", {}) + xml_name = xml_desc.get("name", attr_desc["key"]) + + # Look for a children + is_iter_type = attr_desc["type"].startswith("[") + is_wrapped = xml_desc.get("wrapped", False) + internal_type = attr_desc.get("internalType", None) + internal_type_xml_map = getattr(internal_type, "_xml_map", {}) + + # Integrate namespace if necessary + xml_ns = xml_desc.get("ns", internal_type_xml_map.get("ns", None)) + if xml_ns: + xml_name = "{{{}}}{}".format(xml_ns, xml_name) + + # If it's an attribute, that's simple + if xml_desc.get("attr", False): + return data.get(xml_name) + + # If it's x-ms-text, that's simple too + if xml_desc.get("text", False): + return data.text + + # Scenario where I take the local name: + # - Wrapped node + # - Internal type is an enum (considered basic types) + # - Internal type has no XML/Name node + if is_wrapped or (internal_type and (issubclass(internal_type, Enum) or "name" not in internal_type_xml_map)): + children = data.findall(xml_name) + # If internal type has a local name and it's not a list, I use that name + elif not is_iter_type and internal_type and "name" in internal_type_xml_map: + xml_name = _extract_name_from_internal_type(internal_type) + children = data.findall(xml_name) + # That's an array + else: + if internal_type: # Complex type, ignore itemsName and use the complex type name + items_name = _extract_name_from_internal_type(internal_type) + else: + items_name = xml_desc.get("itemsName", xml_name) + children = data.findall(items_name) + + if len(children) == 0: + if is_iter_type: + if is_wrapped: + return None # is_wrapped no node, we want None + else: + return [] # not wrapped, assume empty list + return None # Assume it's not there, maybe an optional node. + + # If is_iter_type and not wrapped, return all found children + if is_iter_type: + if not is_wrapped: + return children + else: # Iter and wrapped, should have found one node only (the wrap one) + if len(children) != 1: + raise DeserializationError( + "Tried to deserialize an array not wrapped, and found several nodes '{}'. Maybe you should declare this array as wrapped?".format( + xml_name + ) + ) + return list(children[0]) # Might be empty list and that's ok. + + # Here it's not a itertype, we should have found one element only or empty + if len(children) > 1: + raise DeserializationError("Find several XML '{}' where it was not expected".format(xml_name)) + return children[0] + + +class Deserializer(object): + """Response object model deserializer. + + :param dict classes: Class type dictionary for deserializing complex types. + :ivar list key_extractors: Ordered list of extractors to be used by this deserializer. + """ + + basic_types = {str: "str", int: "int", bool: "bool", float: "float"} + + valid_date = re.compile(r"\d{4}[-]\d{2}[-]\d{2}T\d{2}:\d{2}:\d{2}" r"\.?\d*Z?[-+]?[\d{2}]?:?[\d{2}]?") + + def __init__(self, classes: Optional[Mapping[str, type]] = None): + self.deserialize_type = { + "iso-8601": Deserializer.deserialize_iso, + "rfc-1123": Deserializer.deserialize_rfc, + "unix-time": Deserializer.deserialize_unix, + "duration": Deserializer.deserialize_duration, + "date": Deserializer.deserialize_date, + "time": Deserializer.deserialize_time, + "decimal": Deserializer.deserialize_decimal, + "long": Deserializer.deserialize_long, + "bytearray": Deserializer.deserialize_bytearray, + "base64": Deserializer.deserialize_base64, + "object": self.deserialize_object, + "[]": self.deserialize_iter, + "{}": self.deserialize_dict, + } + self.deserialize_expected_types = { + "duration": (isodate.Duration, datetime.timedelta), + "iso-8601": (datetime.datetime), + } + self.dependencies: Dict[str, type] = dict(classes) if classes else {} + self.key_extractors = [rest_key_extractor, xml_key_extractor] + # Additional properties only works if the "rest_key_extractor" is used to + # extract the keys. Making it to work whatever the key extractor is too much + # complicated, with no real scenario for now. + # So adding a flag to disable additional properties detection. This flag should be + # used if your expect the deserialization to NOT come from a JSON REST syntax. + # Otherwise, result are unexpected + self.additional_properties_detection = True + + def __call__(self, target_obj, response_data, content_type=None): + """Call the deserializer to process a REST response. + + :param str target_obj: Target data type to deserialize to. + :param requests.Response response_data: REST response object. + :param str content_type: Swagger "produces" if available. + :raises: DeserializationError if deserialization fails. + :return: Deserialized object. + """ + data = self._unpack_content(response_data, content_type) + return self._deserialize(target_obj, data) + + def _deserialize(self, target_obj, data): + """Call the deserializer on a model. + + Data needs to be already deserialized as JSON or XML ElementTree + + :param str target_obj: Target data type to deserialize to. + :param object data: Object to deserialize. + :raises: DeserializationError if deserialization fails. + :return: Deserialized object. + """ + # This is already a model, go recursive just in case + if hasattr(data, "_attribute_map"): + constants = [name for name, config in getattr(data, "_validation", {}).items() if config.get("constant")] + try: + for attr, mapconfig in data._attribute_map.items(): + if attr in constants: + continue + value = getattr(data, attr) + if value is None: + continue + local_type = mapconfig["type"] + internal_data_type = local_type.strip("[]{}") + if internal_data_type not in self.dependencies or isinstance(internal_data_type, Enum): + continue + setattr(data, attr, self._deserialize(local_type, value)) + return data + except AttributeError: + return + + response, class_name = self._classify_target(target_obj, data) + + if isinstance(response, str): + return self.deserialize_data(data, response) + elif isinstance(response, type) and issubclass(response, Enum): + return self.deserialize_enum(data, response) + + if data is None: + return data + try: + attributes = response._attribute_map # type: ignore + d_attrs = {} + for attr, attr_desc in attributes.items(): + # Check empty string. If it's not empty, someone has a real "additionalProperties"... + if attr == "additional_properties" and attr_desc["key"] == "": + continue + raw_value = None + # Enhance attr_desc with some dynamic data + attr_desc = attr_desc.copy() # Do a copy, do not change the real one + internal_data_type = attr_desc["type"].strip("[]{}") + if internal_data_type in self.dependencies: + attr_desc["internalType"] = self.dependencies[internal_data_type] + + for key_extractor in self.key_extractors: + found_value = key_extractor(attr, attr_desc, data) + if found_value is not None: + if raw_value is not None and raw_value != found_value: + msg = ( + "Ignoring extracted value '%s' from %s for key '%s'" + " (duplicate extraction, follow extractors order)" + ) + _LOGGER.warning(msg, found_value, key_extractor, attr) + continue + raw_value = found_value + + value = self.deserialize_data(raw_value, attr_desc["type"]) + d_attrs[attr] = value + except (AttributeError, TypeError, KeyError) as err: + msg = "Unable to deserialize to object: " + class_name # type: ignore + raise DeserializationError(msg) from err + else: + additional_properties = self._build_additional_properties(attributes, data) + return self._instantiate_model(response, d_attrs, additional_properties) + + def _build_additional_properties(self, attribute_map, data): + if not self.additional_properties_detection: + return None + if "additional_properties" in attribute_map and attribute_map.get("additional_properties", {}).get("key") != "": + # Check empty string. If it's not empty, someone has a real "additionalProperties" + return None + if isinstance(data, ET.Element): + data = {el.tag: el.text for el in data} + + known_keys = { + _decode_attribute_map_key(_FLATTEN.split(desc["key"])[0]) + for desc in attribute_map.values() + if desc["key"] != "" + } + present_keys = set(data.keys()) + missing_keys = present_keys - known_keys + return {key: data[key] for key in missing_keys} + + def _classify_target(self, target, data): + """Check to see whether the deserialization target object can + be classified into a subclass. + Once classification has been determined, initialize object. + + :param str target: The target object type to deserialize to. + :param str/dict data: The response data to deserialize. + """ + if target is None: + return None, None + + if isinstance(target, str): + try: + target = self.dependencies[target] + except KeyError: + return target, target + + try: + target = target._classify(data, self.dependencies) # type: ignore + except AttributeError: + pass # Target is not a Model, no classify + return target, target.__class__.__name__ # type: ignore + + def failsafe_deserialize(self, target_obj, data, content_type=None): + """Ignores any errors encountered in deserialization, + and falls back to not deserializing the object. Recommended + for use in error deserialization, as we want to return the + HttpResponseError to users, and not have them deal with + a deserialization error. + + :param str target_obj: The target object type to deserialize to. + :param str/dict data: The response data to deserialize. + :param str content_type: Swagger "produces" if available. + """ + try: + return self(target_obj, data, content_type=content_type) + except: + _LOGGER.debug( + "Ran into a deserialization error. Ignoring since this is failsafe deserialization", exc_info=True + ) + return None + + @staticmethod + def _unpack_content(raw_data, content_type=None): + """Extract the correct structure for deserialization. + + If raw_data is a PipelineResponse, try to extract the result of RawDeserializer. + if we can't, raise. Your Pipeline should have a RawDeserializer. + + If not a pipeline response and raw_data is bytes or string, use content-type + to decode it. If no content-type, try JSON. + + If raw_data is something else, bypass all logic and return it directly. + + :param raw_data: Data to be processed. + :param content_type: How to parse if raw_data is a string/bytes. + :raises JSONDecodeError: If JSON is requested and parsing is impossible. + :raises UnicodeDecodeError: If bytes is not UTF8 + """ + # Assume this is enough to detect a Pipeline Response without importing it + context = getattr(raw_data, "context", {}) + if context: + if RawDeserializer.CONTEXT_NAME in context: + return context[RawDeserializer.CONTEXT_NAME] + raise ValueError("This pipeline didn't have the RawDeserializer policy; can't deserialize") + + # Assume this is enough to recognize universal_http.ClientResponse without importing it + if hasattr(raw_data, "body"): + return RawDeserializer.deserialize_from_http_generics(raw_data.text(), raw_data.headers) + + # Assume this enough to recognize requests.Response without importing it. + if hasattr(raw_data, "_content_consumed"): + return RawDeserializer.deserialize_from_http_generics(raw_data.text, raw_data.headers) + + if isinstance(raw_data, (str, bytes)) or hasattr(raw_data, "read"): + return RawDeserializer.deserialize_from_text(raw_data, content_type) # type: ignore + return raw_data + + def _instantiate_model(self, response, attrs, additional_properties=None): + """Instantiate a response model passing in deserialized args. + + :param response: The response model class. + :param d_attrs: The deserialized response attributes. + """ + if callable(response): + subtype = getattr(response, "_subtype_map", {}) + try: + readonly = [k for k, v in response._validation.items() if v.get("readonly")] + const = [k for k, v in response._validation.items() if v.get("constant")] + kwargs = {k: v for k, v in attrs.items() if k not in subtype and k not in readonly + const} + response_obj = response(**kwargs) + for attr in readonly: + setattr(response_obj, attr, attrs.get(attr)) + if additional_properties: + response_obj.additional_properties = additional_properties + return response_obj + except TypeError as err: + msg = "Unable to deserialize {} into model {}. ".format(kwargs, response) # type: ignore + raise DeserializationError(msg + str(err)) + else: + try: + for attr, value in attrs.items(): + setattr(response, attr, value) + return response + except Exception as exp: + msg = "Unable to populate response model. " + msg += "Type: {}, Error: {}".format(type(response), exp) + raise DeserializationError(msg) + + def deserialize_data(self, data, data_type): + """Process data for deserialization according to data type. + + :param str data: The response string to be deserialized. + :param str data_type: The type to deserialize to. + :raises: DeserializationError if deserialization fails. + :return: Deserialized object. + """ + if data is None: + return data + + try: + if not data_type: + return data + if data_type in self.basic_types.values(): + return self.deserialize_basic(data, data_type) + if data_type in self.deserialize_type: + if isinstance(data, self.deserialize_expected_types.get(data_type, tuple())): + return data + + is_a_text_parsing_type = lambda x: x not in ["object", "[]", r"{}"] + if isinstance(data, ET.Element) and is_a_text_parsing_type(data_type) and not data.text: + return None + data_val = self.deserialize_type[data_type](data) + return data_val + + iter_type = data_type[0] + data_type[-1] + if iter_type in self.deserialize_type: + return self.deserialize_type[iter_type](data, data_type[1:-1]) + + obj_type = self.dependencies[data_type] + if issubclass(obj_type, Enum): + if isinstance(data, ET.Element): + data = data.text + return self.deserialize_enum(data, obj_type) + + except (ValueError, TypeError, AttributeError) as err: + msg = "Unable to deserialize response data." + msg += " Data: {}, {}".format(data, data_type) + raise DeserializationError(msg) from err + else: + return self._deserialize(obj_type, data) + + def deserialize_iter(self, attr, iter_type): + """Deserialize an iterable. + + :param list attr: Iterable to be deserialized. + :param str iter_type: The type of object in the iterable. + :rtype: list + """ + if attr is None: + return None + if isinstance(attr, ET.Element): # If I receive an element here, get the children + attr = list(attr) + if not isinstance(attr, (list, set)): + raise DeserializationError("Cannot deserialize as [{}] an object of type {}".format(iter_type, type(attr))) + return [self.deserialize_data(a, iter_type) for a in attr] + + def deserialize_dict(self, attr, dict_type): + """Deserialize a dictionary. + + :param dict/list attr: Dictionary to be deserialized. Also accepts + a list of key, value pairs. + :param str dict_type: The object type of the items in the dictionary. + :rtype: dict + """ + if isinstance(attr, list): + return {x["key"]: self.deserialize_data(x["value"], dict_type) for x in attr} + + if isinstance(attr, ET.Element): + # Transform <Key>value</Key> into {"Key": "value"} + attr = {el.tag: el.text for el in attr} + return {k: self.deserialize_data(v, dict_type) for k, v in attr.items()} + + def deserialize_object(self, attr, **kwargs): + """Deserialize a generic object. + This will be handled as a dictionary. + + :param dict attr: Dictionary to be deserialized. + :rtype: dict + :raises: TypeError if non-builtin datatype encountered. + """ + if attr is None: + return None + if isinstance(attr, ET.Element): + # Do no recurse on XML, just return the tree as-is + return attr + if isinstance(attr, str): + return self.deserialize_basic(attr, "str") + obj_type = type(attr) + if obj_type in self.basic_types: + return self.deserialize_basic(attr, self.basic_types[obj_type]) + if obj_type is _long_type: + return self.deserialize_long(attr) + + if obj_type == dict: + deserialized = {} + for key, value in attr.items(): + try: + deserialized[key] = self.deserialize_object(value, **kwargs) + except ValueError: + deserialized[key] = None + return deserialized + + if obj_type == list: + deserialized = [] + for obj in attr: + try: + deserialized.append(self.deserialize_object(obj, **kwargs)) + except ValueError: + pass + return deserialized + + else: + error = "Cannot deserialize generic object with type: " + raise TypeError(error + str(obj_type)) + + def deserialize_basic(self, attr, data_type): + """Deserialize basic builtin data type from string. + Will attempt to convert to str, int, float and bool. + This function will also accept '1', '0', 'true' and 'false' as + valid bool values. + + :param str attr: response string to be deserialized. + :param str data_type: deserialization data type. + :rtype: str, int, float or bool + :raises: TypeError if string format is not valid. + """ + # If we're here, data is supposed to be a basic type. + # If it's still an XML node, take the text + if isinstance(attr, ET.Element): + attr = attr.text + if not attr: + if data_type == "str": + # None or '', node <a/> is empty string. + return "" + else: + # None or '', node <a/> with a strong type is None. + # Don't try to model "empty bool" or "empty int" + return None + + if data_type == "bool": + if attr in [True, False, 1, 0]: + return bool(attr) + elif isinstance(attr, str): + if attr.lower() in ["true", "1"]: + return True + elif attr.lower() in ["false", "0"]: + return False + raise TypeError("Invalid boolean value: {}".format(attr)) + + if data_type == "str": + return self.deserialize_unicode(attr) + return eval(data_type)(attr) # nosec + + @staticmethod + def deserialize_unicode(data): + """Preserve unicode objects in Python 2, otherwise return data + as a string. + + :param str data: response string to be deserialized. + :rtype: str or unicode + """ + # We might be here because we have an enum modeled as string, + # and we try to deserialize a partial dict with enum inside + if isinstance(data, Enum): + return data + + # Consider this is real string + try: + if isinstance(data, unicode): # type: ignore + return data + except NameError: + return str(data) + else: + return str(data) + + @staticmethod + def deserialize_enum(data, enum_obj): + """Deserialize string into enum object. + + If the string is not a valid enum value it will be returned as-is + and a warning will be logged. + + :param str data: Response string to be deserialized. If this value is + None or invalid it will be returned as-is. + :param Enum enum_obj: Enum object to deserialize to. + :rtype: Enum + """ + if isinstance(data, enum_obj) or data is None: + return data + if isinstance(data, Enum): + data = data.value + if isinstance(data, int): + # Workaround. We might consider remove it in the future. + try: + return list(enum_obj.__members__.values())[data] + except IndexError: + error = "{!r} is not a valid index for enum {!r}" + raise DeserializationError(error.format(data, enum_obj)) + try: + return enum_obj(str(data)) + except ValueError: + for enum_value in enum_obj: + if enum_value.value.lower() == str(data).lower(): + return enum_value + # We don't fail anymore for unknown value, we deserialize as a string + _LOGGER.warning("Deserializer is not able to find %s as valid enum in %s", data, enum_obj) + return Deserializer.deserialize_unicode(data) + + @staticmethod + def deserialize_bytearray(attr): + """Deserialize string into bytearray. + + :param str attr: response string to be deserialized. + :rtype: bytearray + :raises: TypeError if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + return bytearray(b64decode(attr)) # type: ignore + + @staticmethod + def deserialize_base64(attr): + """Deserialize base64 encoded string into string. + + :param str attr: response string to be deserialized. + :rtype: bytearray + :raises: TypeError if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + padding = "=" * (3 - (len(attr) + 3) % 4) # type: ignore + attr = attr + padding # type: ignore + encoded = attr.replace("-", "+").replace("_", "/") + return b64decode(encoded) + + @staticmethod + def deserialize_decimal(attr): + """Deserialize string into Decimal object. + + :param str attr: response string to be deserialized. + :rtype: Decimal + :raises: DeserializationError if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + try: + return decimal.Decimal(str(attr)) # type: ignore + except decimal.DecimalException as err: + msg = "Invalid decimal {}".format(attr) + raise DeserializationError(msg) from err + + @staticmethod + def deserialize_long(attr): + """Deserialize string into long (Py2) or int (Py3). + + :param str attr: response string to be deserialized. + :rtype: long or int + :raises: ValueError if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + return _long_type(attr) # type: ignore + + @staticmethod + def deserialize_duration(attr): + """Deserialize ISO-8601 formatted string into TimeDelta object. + + :param str attr: response string to be deserialized. + :rtype: TimeDelta + :raises: DeserializationError if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + try: + duration = isodate.parse_duration(attr) + except (ValueError, OverflowError, AttributeError) as err: + msg = "Cannot deserialize duration object." + raise DeserializationError(msg) from err + else: + return duration + + @staticmethod + def deserialize_date(attr): + """Deserialize ISO-8601 formatted string into Date object. + + :param str attr: response string to be deserialized. + :rtype: Date + :raises: DeserializationError if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + if re.search(r"[^\W\d_]", attr, re.I + re.U): # type: ignore + raise DeserializationError("Date must have only digits and -. Received: %s" % attr) + # This must NOT use defaultmonth/defaultday. Using None ensure this raises an exception. + return isodate.parse_date(attr, defaultmonth=0, defaultday=0) + + @staticmethod + def deserialize_time(attr): + """Deserialize ISO-8601 formatted string into time object. + + :param str attr: response string to be deserialized. + :rtype: datetime.time + :raises: DeserializationError if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + if re.search(r"[^\W\d_]", attr, re.I + re.U): # type: ignore + raise DeserializationError("Date must have only digits and -. Received: %s" % attr) + return isodate.parse_time(attr) + + @staticmethod + def deserialize_rfc(attr): + """Deserialize RFC-1123 formatted string into Datetime object. + + :param str attr: response string to be deserialized. + :rtype: Datetime + :raises: DeserializationError if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + try: + parsed_date = email.utils.parsedate_tz(attr) # type: ignore + date_obj = datetime.datetime( + *parsed_date[:6], tzinfo=_FixedOffset(datetime.timedelta(minutes=(parsed_date[9] or 0) / 60)) + ) + if not date_obj.tzinfo: + date_obj = date_obj.astimezone(tz=TZ_UTC) + except ValueError as err: + msg = "Cannot deserialize to rfc datetime object." + raise DeserializationError(msg) from err + else: + return date_obj + + @staticmethod + def deserialize_iso(attr): + """Deserialize ISO-8601 formatted string into Datetime object. + + :param str attr: response string to be deserialized. + :rtype: Datetime + :raises: DeserializationError if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + try: + attr = attr.upper() # type: ignore + match = Deserializer.valid_date.match(attr) + if not match: + raise ValueError("Invalid datetime string: " + attr) + + check_decimal = attr.split(".") + if len(check_decimal) > 1: + decimal_str = "" + for digit in check_decimal[1]: + if digit.isdigit(): + decimal_str += digit + else: + break + if len(decimal_str) > 6: + attr = attr.replace(decimal_str, decimal_str[0:6]) + + date_obj = isodate.parse_datetime(attr) + test_utc = date_obj.utctimetuple() + if test_utc.tm_year > 9999 or test_utc.tm_year < 1: + raise OverflowError("Hit max or min date") + except (ValueError, OverflowError, AttributeError) as err: + msg = "Cannot deserialize datetime object." + raise DeserializationError(msg) from err + else: + return date_obj + + @staticmethod + def deserialize_unix(attr): + """Serialize Datetime object into IntTime format. + This is represented as seconds. + + :param int attr: Object to be serialized. + :rtype: Datetime + :raises: DeserializationError if format invalid + """ + if isinstance(attr, ET.Element): + attr = int(attr.text) # type: ignore + try: + attr = int(attr) + date_obj = datetime.datetime.fromtimestamp(attr, TZ_UTC) + except ValueError as err: + msg = "Cannot deserialize to unix datetime object." + raise DeserializationError(msg) from err + else: + return date_obj diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/models/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/models/__init__.py new file mode 100644 index 00000000..cda7689a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/models/__init__.py @@ -0,0 +1,18 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- + +from ._models import AzureOpenAIDeployment +from ._models import ServerlessEndpoint +from ._models import MarketplaceSubscription +from ._patch import __all__ as _patch_all +from ._patch import * # pylint: disable=unused-wildcard-import +from ._patch import patch_sdk as _patch_sdk + +__all__ = ["AzureOpenAIDeployment", "ServerlessEndpoint", "MarketplaceSubscription"] +__all__.extend([p for p in _patch_all if p not in __all__]) +_patch_sdk() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/models/_models.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/models/_models.py new file mode 100644 index 00000000..3b12203d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/models/_models.py @@ -0,0 +1,214 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- + +from typing import Any, Dict, Mapping, Optional, TYPE_CHECKING, overload + +from .. import _model_base +from .._model_base import rest_field + +if TYPE_CHECKING: + from .. import models as _models + + +class AzureOpenAIDeployment(_model_base.Model): + """Azure OpenAI Deployment Information. + + Readonly variables are only populated by the server, and will be ignored when sending a request. + + :ivar name: The deployment name. + :vartype name: str + :ivar model_name: The name of the model to deploy. + :vartype model_name: str + :ivar model_version: The model version to deploy. + :vartype model_version: str + :ivar connection_name: The name of the connection to deploy to. + :vartype connection_name: str + :ivar target_url: The target URL of the AOAI resource for the deployment. + :vartype target_url: str + :ivar id: The ARM resource id of the deployment. + :vartype id: str + :ivar properties: Properties of the deployment. + :vartype properties: dict[str, str] + :ivar tags: Tags of the deployment. + :vartype tags: dict[str, str] + """ + + name: Optional[str] = rest_field(visibility=["read"]) + """The deployment name.""" + model_name: Optional[str] = rest_field(visibility=["read"]) + """The name of the model to deploy.""" + model_version: Optional[str] = rest_field(visibility=["read"]) + """The model version to deploy.""" + connection_name: Optional[str] = rest_field(visibility=["read"]) + """The name of the connection to deploy to.""" + target_url: Optional[str] = rest_field(visibility=["read"]) + """The target URL of the AOAI resource for the deployment.""" + id: Optional[str] = rest_field(visibility=["read"]) + """The ARM resource id of the deployment.""" + + +class MarketplacePlan(_model_base.Model): + """Marketplace Subscription Definition. + + Readonly variables are only populated by the server, and will be ignored when sending a request. + + :ivar publisher_id: The id of the publisher. + :vartype publisher_id: str + :ivar offer_id: The id of the offering associated with the plan. + :vartype offer_id: str + :ivar plan_id: The id of the plan. + :vartype plan_id: str + :ivar term_id: The term id. + :vartype term_id: str + """ + + publisher_id: Optional[str] = rest_field(visibility=["read"]) + """The id of the publisher.""" + offer_id: Optional[str] = rest_field(visibility=["read"]) + """The id of the offering associated with the plan.""" + plan_id: Optional[str] = rest_field(visibility=["read"]) + """The id of the plan.""" + term_id: Optional[str] = rest_field(visibility=["read"]) + """The term id.""" + + +class MarketplaceSubscription(_model_base.Model): + """Marketplace Subscription Definition. + + Readonly variables are only populated by the server, and will be ignored when sending a request. + + All required parameters must be populated in order to send to server. + + :ivar name: The marketplace subscription name. Required. + :vartype name: str + :ivar model_id: Model id for which to create marketplace subscription. Required. + :vartype model_id: str + :ivar marketplace_plan: The plan associated with the marketplace subscription. + :vartype marketplace_plan: ~azure.ai.ml.entities.models.MarketplacePlan + :ivar status: Status of the marketplace subscription. Possible values are: + "pending_fulfillment_start", "subscribed", "unsubscribed", "suspended". + :vartype status: str + :ivar provisioning_state: Provisioning state of the marketplace subscription. Possible values + are: "creating", "deleting", "succeeded", "failed", "updating", and "canceled". + :vartype provisioning_state: str + :ivar id: ARM resource id of the marketplace subscription. + :vartype id: str + """ + + name: str = rest_field() + """The marketplace subscription name. Required.""" + model_id: str = rest_field() + """Model id for which to create marketplace subscription. Required.""" + marketplace_plan: Optional["_models.MarketplacePlan"] = rest_field(visibility=["read"]) + """The plan associated with the marketplace subscription.""" + status: Optional[str] = rest_field(visibility=["read"]) + """Status of the marketplace subscription. Possible values are: \"pending_fulfillment_start\", + \"subscribed\", \"unsubscribed\", \"suspended\".""" + provisioning_state: Optional[str] = rest_field(visibility=["read"]) + """Provisioning state of the marketplace subscription. Possible values are: \"creating\", + \"deleting\", \"succeeded\", \"failed\", \"updating\", and \"canceled\".""" + id: Optional[str] = rest_field(visibility=["read"]) + """ARM resource id of the marketplace subscription.""" + + @overload + def __init__( + self, + *, + name: str, + model_id: str, + ): ... + + @overload + def __init__(self, mapping: Mapping[str, Any]): + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + super().__init__(*args, **kwargs) + + +class ServerlessEndpoint(_model_base.Model): + """Serverless Endpoint Definition. + + Readonly variables are only populated by the server, and will be ignored when sending a request. + + All required parameters must be populated in order to send to server. + + :ivar name: The deployment name. Required. + :vartype name: str + :ivar auth_mode: Authentication mode of the endpoint. + :vartype auth_mode: str + :ivar model_id: The id of the model to deploy. Required. + :vartype model_id: str + :ivar location: Location in which to create endpoint. + :vartype location: str + :ivar provisioning_state: Provisioning state of the endpoint. Possible values are: "creating", + "deleting", "succeeded", "failed", "updating", and "canceled". + :vartype provisioning_state: str + :ivar tags: Tags for the endpoint. + :vartype tags: dict[str, str] + :ivar properties: Properties of the endpoint. + :vartype properties: dict[str, str] + :ivar description: Descripton of the endpoint. + :vartype description: str + :ivar scoring_uri: Scoring uri of the endpoint. + :vartype scoring_uri: str + :ivar id: ARM resource id of the endpoint. + :vartype id: str + :ivar headers: Headers required to hit the endpoint. + :vartype id: dict[str, str] + """ + + name: str = rest_field() + """The deployment name. Required.""" + auth_mode: Optional[str] = rest_field() + """Authentication mode of the endpoint. Possible values are: \"key\", \"aad\". + Defaults to \"key\" if not given.""" + model_id: str = rest_field() + """The id of the model to deploy. Required.""" + location: Optional[str] = rest_field(visibility=["read"]) + """Location in which to create endpoint.""" + provisioning_state: Optional[str] = rest_field(visibility=["read"]) + """Provisioning state of the endpoint. Possible values are: \"creating\", \"deleting\", + \"succeeded\", \"failed\", \"updating\", and \"canceled\".""" + tags: Optional[Dict[str, str]] = rest_field() + """Tags for the endpoint.""" + properties: Optional[Dict[str, str]] = rest_field() + """Properties of the endpoint.""" + description: Optional[str] = rest_field() + """Descripton of the endpoint.""" + scoring_uri: Optional[str] = rest_field(visibility=["read"]) + """Scoring uri of the endpoint.""" + id: Optional[str] = rest_field(visibility=["read"]) + """ARM resource id of the endpoint.""" + headers: Optional[Dict[str, str]] = rest_field(visibility=["read"]) + """Headers required to hit the endpoint.""" + + @overload + def __init__( + self, + *, + name: str, + model_id: str, + auth_mode: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + properties: Optional[Dict[str, str]] = None, + description: Optional[str] = None, + ): ... + + @overload + def __init__(self, mapping: Mapping[str, Any]): + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + super().__init__(*args, **kwargs) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/models/_patch.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/models/_patch.py new file mode 100644 index 00000000..da29aeb3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/models/_patch.py @@ -0,0 +1,223 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ + +# pylint: disable=protected-access + +"""Customize generated code here. + +Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize +""" +import json +from typing import Any, Dict, List, Optional + +from azure.ai.ml._restclient.v2024_01_01_preview.models import MarketplaceSubscription as RestMarketplaceSubscription +from azure.ai.ml._restclient.v2024_01_01_preview.models import ( + MarketplaceSubscriptionProperties as RestMarketplaceSubscriptionProperties, +) +from azure.ai.ml._restclient.v2024_01_01_preview.models import ModelSettings as RestModelSettings +from azure.ai.ml._restclient.v2024_01_01_preview.models import ServerlessEndpoint as RestServerlessEndpoint +from azure.ai.ml._restclient.v2024_01_01_preview.models import ( + ServerlessEndpointProperties as RestServerlessEndpointProperties, +) +from azure.ai.ml._restclient.v2024_01_01_preview.models import Sku as RestSku +from azure.ai.ml._restclient.v2024_04_01_preview.models import ( + EndpointDeploymentResourcePropertiesBasicResource, + OpenAIEndpointDeploymentResourceProperties, +) +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.entities._system_data import SystemData + +from .._model_base import rest_field +from ._models import AzureOpenAIDeployment as _AzureOpenAIDeployment +from ._models import MarketplacePlan as _MarketplacePlan +from ._models import MarketplaceSubscription as _MarketplaceSubscription +from ._models import ServerlessEndpoint as _ServerlessEndpoint + +__all__: List[str] = [ + "AzureOpenAIDeployment", + "ServerlessEndpoint", + "MarketplaceSubscription", + "MarketplacePlan", +] # Add all objects you want publicly available to users at this package level + +_NULL = object() + + +func_to_attr_type = { + "_deserialize_dict": dict, + "_deserialize_sequence": list, +} + + +def _get_rest_field_type(field): + if hasattr(field, "_type"): + if field._type.func.__name__ == "_deserialize_default": + return field._type.args[0] + if func_to_attr_type.get(field._type.func.__name__): + return func_to_attr_type[field._type.func.__name__] + return _get_rest_field_type(field._type.args[0]) + if hasattr(field, "func") and func_to_attr_type.get(field.func.__name__): + return func_to_attr_type[field.func.__name__] + if hasattr(field, "args"): + return _get_rest_field_type(field.args[0]) + return field + + +class ValidationMixin: + def _validate(self) -> None: + # verify types + for attr, field in self._attr_to_rest_field.items(): # type: ignore + try: + attr_value = self.__getitem__(attr) # type: ignore + attr_type = type(attr_value) + except KeyError as exc: + if field._visibility and "read" in field._visibility: + # read-only field, no need to validate + continue + if field._type.func.__name__ != "_deserialize_with_optional": + # i'm required + raise ValueError(f"attr {attr} is a required property for {self.__class__.__name__}") from exc + else: + if getattr(attr_value, "_is_model", False): + attr_value._validate() + rest_field_type = _get_rest_field_type(field) + if attr_type != rest_field_type: + raise ValueError(f"Type of attr {attr} is of type {attr_type}, not {rest_field_type}") + + +@experimental +class AzureOpenAIDeployment(_AzureOpenAIDeployment): + + system_data: Optional[SystemData] = rest_field(visibility=["read"]) + """System data of the deployment.""" + + @classmethod + def _from_rest_object(cls, obj: EndpointDeploymentResourcePropertiesBasicResource) -> "AzureOpenAIDeployment": + properties: OpenAIEndpointDeploymentResourceProperties = obj.properties + return cls( + name=obj.name, + model_name=properties.model.name, + model_version=properties.model.version, + id=obj.id, + system_data=SystemData._from_rest_object(obj.system_data), + ) + + def as_dict(self, *, exclude_readonly: bool = False) -> Dict[str, Any]: + d = super().as_dict(exclude_readonly=exclude_readonly) + d["system_data"] = json.loads(json.dumps(self.system_data._to_dict())) # type: ignore + return d + + +AzureOpenAIDeployment.__doc__ += ( + _AzureOpenAIDeployment.__doc__.strip() # type: ignore + + """ + :ivar system_data: System data of the deployment. + :vartype system_data: ~azure.ai.ml.entities.SystemData +""" +) + + +@experimental +class MarketplacePlan(_MarketplacePlan): + pass + + +@experimental +class ServerlessEndpoint(_ServerlessEndpoint, ValidationMixin): + + system_data: Optional[SystemData] = rest_field(visibility=["read"]) + """System data of the endpoint.""" + + def _to_rest_object(self) -> RestServerlessEndpoint: + return RestServerlessEndpoint( + properties=RestServerlessEndpointProperties( + model_settings=RestModelSettings(model_id=self.model_id), + ), + auth_mode="key", # only key is supported for now + tags=self.tags, + sku=RestSku(name="Consumption"), + location=self.location, + ) + + @classmethod + def _from_rest_object(cls, obj: RestServerlessEndpoint) -> "ServerlessEndpoint": + return cls( # type: ignore + name=obj.name, + id=obj.id, + tags=obj.tags, + location=obj.location, + auth_mode=obj.properties.auth_mode, + provisioning_state=camel_to_snake(obj.properties.provisioning_state), + model_id=obj.properties.model_settings.model_id if obj.properties.model_settings else None, + scoring_uri=obj.properties.inference_endpoint.uri if obj.properties.inference_endpoint else None, + system_data=SystemData._from_rest_object(obj.system_data) if obj.system_data else None, + headers=obj.properties.inference_endpoint.headers if obj.properties.inference_endpoint else None, + ) + + def as_dict(self, *, exclude_readonly: bool = False) -> Dict[str, Any]: + d = super().as_dict(exclude_readonly=exclude_readonly) + d["system_data"] = json.loads(json.dumps(self.system_data._to_dict())) # type: ignore + return d + + +ServerlessEndpoint.__doc__ += ( + _ServerlessEndpoint.__doc__.strip() # type: ignore + + """ + :ivar system_data: System data of the endpoint. + :vartype system_data: ~azure.ai.ml.entities.SystemData +""" +) + + +@experimental +class MarketplaceSubscription(_MarketplaceSubscription, ValidationMixin): + + system_data: Optional[SystemData] = rest_field(visibility=["read"]) + """System data of the endpoint.""" + + def _to_rest_object(self) -> RestMarketplaceSubscription: + return RestMarketplaceSubscription(properties=RestMarketplaceSubscriptionProperties(model_id=self.model_id)) + + @classmethod + def _from_rest_object(cls, obj: RestMarketplaceSubscription) -> "MarketplaceSubscription": + properties = obj.properties + return cls( # type: ignore + name=obj.name, + id=obj.id, + model_id=properties.model_id, + marketplace_plan=MarketplacePlan( + publisher_id=properties.marketplace_plan.publisher_id, + offer_id=properties.marketplace_plan.offer_id, + plan_id=properties.marketplace_plan.plan_id, + ), + status=camel_to_snake(properties.marketplace_subscription_status), + provisioning_state=camel_to_snake(properties.provisioning_state), + system_data=SystemData._from_rest_object(obj.system_data) if obj.system_data else None, + ) + + def as_dict(self, *, exclude_readonly: bool = False) -> Dict[str, Any]: + d = super().as_dict(exclude_readonly=exclude_readonly) + if self.system_data: + d["system_data"] = json.loads(json.dumps(self.system_data._to_dict())) + return d + + +MarketplaceSubscription.__doc__ = ( + _MarketplaceSubscription.__doc__.strip() # type: ignore + + """ + :ivar system_data: System data of the marketplace subscription. + :vartype system_data: ~azure.ai.ml.entities.SystemData +""" +) + + +def patch_sdk(): + """Do not remove from this file. + + `patch_sdk` is a last resort escape hatch that allows you to do customizations + you can't accomplish using the techniques described in + https://aka.ms/azsdk/python/dpcodegen/python/customize + """ diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/__init__.py new file mode 100644 index 00000000..95dfca0a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/__init__.py @@ -0,0 +1,28 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from .base_node import BaseNode, parse_inputs_outputs +from .command import Command +from .do_while import DoWhile +from .import_node import Import +from .parallel import Parallel +from .pipeline import Pipeline +from .spark import Spark +from .sweep import Sweep +from .data_transfer import DataTransfer, DataTransferCopy, DataTransferImport, DataTransferExport + +__all__ = [ + "BaseNode", + "Sweep", + "Parallel", + "Command", + "Import", + "Spark", + "Pipeline", + "parse_inputs_outputs", + "DoWhile", + "DataTransfer", + "DataTransferCopy", + "DataTransferImport", + "DataTransferExport", +] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/base_node.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/base_node.py new file mode 100644 index 00000000..98eba6a5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/base_node.py @@ -0,0 +1,568 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint: disable=protected-access + +import logging +import os +import uuid +from abc import abstractmethod +from enum import Enum +from functools import wraps +from typing import Any, Dict, List, Optional, Union + +from azure.ai.ml._utils._arm_id_utils import get_resource_name_from_arm_id_safe +from azure.ai.ml.constants import JobType +from azure.ai.ml.constants._common import CommonYamlFields +from azure.ai.ml.constants._component import NodeType +from azure.ai.ml.entities import Data, Model +from azure.ai.ml.entities._component.component import Component +from azure.ai.ml.entities._inputs_outputs import Input, Output +from azure.ai.ml.entities._job._input_output_helpers import build_input_output +from azure.ai.ml.entities._job.job import Job +from azure.ai.ml.entities._job.pipeline._attr_dict import _AttrDict +from azure.ai.ml.entities._job.pipeline._io import NodeOutput, PipelineInput +from azure.ai.ml.entities._job.pipeline._io.mixin import NodeWithGroupInputMixin +from azure.ai.ml.entities._job.pipeline._pipeline_expression import PipelineExpression +from azure.ai.ml.entities._job.sweep.search_space import SweepDistribution +from azure.ai.ml.entities._mixins import YamlTranslatableMixin +from azure.ai.ml.entities._util import convert_ordered_dict_to_dict, resolve_pipeline_parameters +from azure.ai.ml.entities._validation import MutableValidationResult, PathAwareSchemaValidatableMixin +from azure.ai.ml.exceptions import ErrorTarget, ValidationException + +module_logger = logging.getLogger(__name__) + + +def parse_inputs_outputs(data: dict) -> dict: + """Parse inputs and outputs from data. If data is a list, parse each item in the list. + + :param data: A dict that may contain "inputs" or "outputs" keys + :type data: dict + :return: Dict with parsed "inputs" and "outputs" keys + :rtype: Dict + """ + + if "inputs" in data: + data["inputs"] = {key: build_input_output(val) for key, val in data["inputs"].items()} + if "outputs" in data: + data["outputs"] = {key: build_input_output(val, inputs=False) for key, val in data["outputs"].items()} + return data + + +def pipeline_node_decorator(func: Any) -> Any: + """Wrap a function and add its return value to the current DSL pipeline. + + :param func: The function to be wrapped. + :type func: callable + :return: The wrapped function. + :rtype: callable + """ + + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + automl_job = func(*args, **kwargs) + from azure.ai.ml.dsl._pipeline_component_builder import ( + _add_component_to_current_definition_builder, + _is_inside_dsl_pipeline_func, + ) + + if _is_inside_dsl_pipeline_func(): + # Build automl job to automl node if it's defined inside DSL pipeline func. + automl_job._instance_id = str(uuid.uuid4()) + _add_component_to_current_definition_builder(automl_job) + return automl_job + + return wrapper + + +# pylint: disable=too-many-instance-attributes +class BaseNode(Job, YamlTranslatableMixin, _AttrDict, PathAwareSchemaValidatableMixin, NodeWithGroupInputMixin): + """Base class for node in pipeline, used for component version consumption. Can't be instantiated directly. + + You should not instantiate this class directly. Instead, you should + create from a builder function. + + :param type: Type of pipeline node. Defaults to JobType.COMPONENT. + :type type: str + :param component: Id or instance of the component version to be run for the step + :type component: Component + :param inputs: The inputs for the node. + :type inputs: Optional[Dict[str, Union[ + ~azure.ai.ml.entities._job.pipeline._io.PipelineInput, + ~azure.ai.ml.entities._job.pipeline._io.NodeOutput, + ~azure.ai.ml.entities.Input, + str, + bool, + int, + float, + Enum, + 'Input']]] + :param outputs: Mapping of output data bindings used in the job. + :type outputs: Optional[Dict[str, Union[str, ~azure.ai.ml.entities.Output, 'Output']]] + :param name: The name of the node. + :type name: Optional[str] + :param display_name: The display name of the node. + :type display_name: Optional[str] + :param description: The description of the node. + :type description: Optional[str] + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: Optional[Dict] + :param properties: The properties of the job. + :type properties: Optional[Dict] + :param comment: Comment of the pipeline node, which will be shown in designer canvas. + :type comment: Optional[str] + :param compute: Compute definition containing the compute information for the step. + :type compute: Optional[str] + :param experiment_name: Name of the experiment the job will be created under, + if None is provided, default will be set to current directory name. + Will be ignored as a pipeline step. + :type experiment_name: Optional[str] + :param kwargs: Additional keyword arguments for future compatibility. + """ + + def __init__( + self, + *, + type: str = JobType.COMPONENT, # pylint: disable=redefined-builtin + component: Any, + inputs: Optional[Dict] = None, + outputs: Optional[Dict] = None, + name: Optional[str] = None, + display_name: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[Dict] = None, + properties: Optional[Dict] = None, + comment: Optional[str] = None, + compute: Optional[str] = None, + experiment_name: Optional[str] = None, + **kwargs: Any, + ) -> None: + self._init = True + # property _source can't be set + source = kwargs.pop("_source", None) + _from_component_func = kwargs.pop("_from_component_func", False) + self._name: Optional[str] = None + super(BaseNode, self).__init__( + type=type, + name=name, + display_name=display_name, + description=description, + tags=tags, + properties=properties, + compute=compute, + experiment_name=experiment_name, + **kwargs, + ) + self.comment = comment + + # initialize io + inputs = resolve_pipeline_parameters(inputs) + inputs, outputs = inputs or {}, outputs or {} + # parse empty dict to None so we won't pass default mode, type to backend + # add `isinstance` to avoid converting to expression + for k, v in inputs.items(): + if isinstance(v, dict) and v == {}: + inputs[k] = None + + # TODO: get rid of self._job_inputs, self._job_outputs once we have unified Input + self._job_inputs, self._job_outputs = inputs, outputs + if isinstance(component, Component): + # Build the inputs from component input definition and given inputs, unfilled inputs will be None + self._inputs = self._build_inputs_dict(inputs or {}, input_definition_dict=component.inputs) + # Build the outputs from component output definition and given outputs, unfilled outputs will be None + self._outputs = self._build_outputs_dict(outputs or {}, output_definition_dict=component.outputs) + else: + # Build inputs/outputs dict without meta when definition not available + self._inputs = self._build_inputs_dict(inputs or {}) + self._outputs = self._build_outputs_dict(outputs or {}) + + self._component = component + self._referenced_control_flow_node_instance_id: Optional[str] = None + self.kwargs = kwargs + + # Generate an id for every instance + self._instance_id = str(uuid.uuid4()) + if _from_component_func: + # add current component in pipeline stack for dsl scenario + self._register_in_current_pipeline_component_builder() + + if source is None: + if isinstance(component, Component): + source = self._component._source + else: + source = Component._resolve_component_source_from_id(id=self._component) + self._source = source + self._validate_required_input_not_provided = True + self._init = False + + @property + def name(self) -> Optional[str]: + """Get the name of the node. + + :return: The name of the node. + :rtype: str + """ + return self._name + + @name.setter + def name(self, value: str) -> None: + """Set the name of the node. + + :param value: The name to set for the node. + :type value: str + :return: None + """ + # when name is not lower case, lower it to make sure it's a valid node name + if value and value != value.lower(): + module_logger.warning( + "Changing node name %s to lower case: %s since upper case is not allowed node name.", + value, + value.lower(), + ) + value = value.lower() + self._name = value + + @classmethod + def _get_supported_inputs_types(cls) -> Any: + """Get the supported input types for node input. + + :param cls: The class (or instance) to retrieve supported input types for. + :type cls: object + + :return: A tuple of supported input types. + :rtype: tuple + """ + # supported input types for node input + return ( + PipelineInput, + NodeOutput, + Input, + Data, + Model, + str, + bool, + int, + float, + Enum, + PipelineExpression, + ) + + @property + def _skip_required_compute_missing_validation(self) -> bool: + return False + + def _initializing(self) -> bool: + # use this to indicate ongoing init process so all attributes set during init process won't be set as + # arbitrary attribute in _AttrDict + # TODO: replace this hack + return self._init + + def _set_base_path(self, base_path: Optional[Union[str, os.PathLike]]) -> None: + """Set the base path for the node. + + Will be used for schema validation. If not set, will use Path.cwd() as the base path + (default logic defined in SchemaValidatableMixin._base_path_for_validation). + + :param base_path: The new base path + :type base_path: Union[str, os.PathLike] + """ + self._base_path = base_path + + def _set_referenced_control_flow_node_instance_id(self, instance_id: str) -> None: + """Set the referenced control flow node instance id. + + If this node is referenced to a control flow node, the instance_id will not be modified. + + :param instance_id: The new instance id + :type instance_id: str + """ + if not self._referenced_control_flow_node_instance_id: + self._referenced_control_flow_node_instance_id = instance_id + + def _get_component_id(self) -> Union[str, Component]: + """Return component id if possible. + + :return: The component id + :rtype: Union[str, Component] + """ + if isinstance(self._component, Component) and self._component.id: + # If component is remote, return it's asset id + return self._component.id + # Otherwise, return the component version or arm id. + res: Union[str, Component] = self._component + return res + + def _get_component_name(self) -> Optional[str]: + # first use component version/job's display name or name as component name + # make it unique when pipeline build finished. + if self._component is None: + return None + if isinstance(self._component, str): + return self._component + return str(self._component.name) + + def _to_dict(self) -> Dict: + return dict(convert_ordered_dict_to_dict(self._dump_for_validation())) + + @classmethod + def _create_validation_error(cls, message: str, no_personal_data_message: str) -> ValidationException: + return ValidationException( + message=message, + no_personal_data_message=no_personal_data_message, + target=ErrorTarget.PIPELINE, + ) + + def _validate_inputs(self) -> MutableValidationResult: + validation_result = self._create_empty_validation_result() + if self._validate_required_input_not_provided: + # validate required inputs not provided + if isinstance(self._component, Component): + for key, meta in self._component.inputs.items(): + # raise error when required input with no default value not set + if ( + not self._is_input_set(input_name=key) # input not provided + and meta.optional is not True # and it's required + and meta.default is None # and it does not have default + ): + validation_result.append_error( + yaml_path=f"inputs.{key}", + message=f"Required input {key!r} for component {self.name!r} not provided.", + ) + + inputs = self._build_inputs() + for input_name, input_obj in inputs.items(): + if isinstance(input_obj, SweepDistribution): + validation_result.append_error( + yaml_path=f"inputs.{input_name}", + message=f"Input of command {self.name} is a SweepDistribution, " + f"please use command.sweep to transform the command into a sweep node.", + ) + return validation_result + + def _customized_validate(self) -> MutableValidationResult: + """Validate the resource with customized logic. + + Override this method to add customized validation logic. + + :return: The validation result + :rtype: MutableValidationResult + """ + validate_result = self._validate_inputs() + return validate_result + + @classmethod + def _get_skip_fields_in_schema_validation(cls) -> List[str]: + return [ + "inputs", # processed separately + "outputs", # processed separately + "name", + "display_name", + "experiment_name", # name is not part of schema but may be set in dsl/yml file + "kwargs", + ] + + @classmethod + def _get_component_attr_name(cls) -> str: + return "component" + + @abstractmethod + def _to_job(self) -> Job: + """This private function is used by the CLI to get a plain job object + so that the CLI can properly serialize the object. + + It is needed as BaseNode._to_dict() dumps objects using pipeline child job schema instead of standalone job + schema, for example Command objects dump have a nested component property, which doesn't apply to stand alone + command jobs. BaseNode._to_dict() needs to be able to dump to both pipeline child job dict as well as stand + alone job dict base on context. + """ + + @classmethod + def _from_rest_object(cls, obj: dict) -> "BaseNode": + if CommonYamlFields.TYPE not in obj: + obj[CommonYamlFields.TYPE] = NodeType.COMMAND + + from azure.ai.ml.entities._job.pipeline._load_component import pipeline_node_factory + + # todo: refine Hard code for now to support different task type for DataTransfer node + _type = obj[CommonYamlFields.TYPE] + if _type == NodeType.DATA_TRANSFER: + _type = "_".join([NodeType.DATA_TRANSFER, obj.get("task", "")]) + instance: BaseNode = pipeline_node_factory.get_create_instance_func(_type)() + init_kwargs = instance._from_rest_object_to_init_params(obj) + # TODO: Bug Item number: 2883415 + instance.__init__(**init_kwargs) # type: ignore + return instance + + @classmethod + def _from_rest_object_to_init_params(cls, obj: dict) -> Dict: + """Convert the rest object to a dict containing items to init the node. + + Will be used in _from_rest_object. Please override this method instead of _from_rest_object to make the logic + reusable. + + :param obj: The REST object + :type obj: dict + :return: The init params + :rtype: Dict + """ + inputs = obj.get("inputs", {}) + outputs = obj.get("outputs", {}) + + obj["inputs"] = BaseNode._from_rest_inputs(inputs) + obj["outputs"] = BaseNode._from_rest_outputs(outputs) + + # Change computeId -> compute + compute_id = obj.pop("computeId", None) + obj["compute"] = get_resource_name_from_arm_id_safe(compute_id) + + # Change componentId -> component. Note that sweep node has no componentId. + if "componentId" in obj: + obj["component"] = obj.pop("componentId") + + # distribution, sweep won't have distribution + if "distribution" in obj and obj["distribution"]: + from azure.ai.ml.entities._job.distribution import DistributionConfiguration + + obj["distribution"] = DistributionConfiguration._from_rest_object(obj["distribution"]) + + return obj + + @classmethod + def _picked_fields_from_dict_to_rest_object(cls) -> List[str]: + """List of fields to be picked from self._to_dict() in self._to_rest_object(). + + By default, returns an empty list. + + Override this method to add custom fields. + + :return: List of fields to pick + :rtype: List[str] + """ + + return [] + + def _to_rest_object(self, **kwargs: Any) -> dict: # pylint: disable=unused-argument + """Convert self to a rest object for remote call. + + :return: The rest object + :rtype: dict + """ + base_dict, rest_obj = self._to_dict(), {} + for key in self._picked_fields_from_dict_to_rest_object(): + if key in base_dict: + rest_obj[key] = base_dict.get(key) + + rest_obj.update( + dict( # pylint: disable=use-dict-literal + name=self.name, + type=self.type, + display_name=self.display_name, + tags=self.tags, + computeId=self.compute, + inputs=self._to_rest_inputs(), + outputs=self._to_rest_outputs(), + properties=self.properties, + _source=self._source, + # add all arbitrary attributes to support setting unknown attributes + **self._get_attrs(), + ) + ) + # only add comment in REST object when it is set + if self.comment is not None: + rest_obj.update({"comment": self.comment}) + + return dict(convert_ordered_dict_to_dict(rest_obj)) + + @property + def inputs(self) -> Dict: + """Get the inputs for the object. + + :return: A dictionary containing the inputs for the object. + :rtype: Dict[str, Union[Input, str, bool, int, float]] + """ + return self._inputs # type: ignore + + @property + def outputs(self) -> Dict: + """Get the outputs of the object. + + :return: A dictionary containing the outputs for the object. + :rtype: Dict[str, Union[str, Output]] + """ + return self._outputs # type: ignore + + def __str__(self) -> str: + try: + return str(self._to_yaml()) + except BaseException: # pylint: disable=W0718 + # add try catch in case component job failed in schema parse + _obj: _AttrDict = _AttrDict() + return _obj.__str__() + + def __hash__(self) -> int: # type: ignore + return hash(self.__str__()) + + def __help__(self) -> Any: + # only show help when component has definition + if isinstance(self._component, Component): + # TODO: Bug Item number: 2883422 + return self._component.__help__() # type: ignore + return None + + def __bool__(self) -> bool: + # _attr_dict will return False if no extra attributes are set + return True + + def _get_origin_job_outputs(self) -> Dict[str, Union[str, Output]]: + """Restore outputs to JobOutput/BindingString and return them. + + :return: The origin job outputs + :rtype: Dict[str, Union[str, Output]] + """ + outputs: Dict = {} + if self.outputs is not None: + for output_name, output_obj in self.outputs.items(): + if isinstance(output_obj, NodeOutput): + outputs[output_name] = output_obj._data + else: + raise TypeError("unsupported built output type: {}: {}".format(output_name, type(output_obj))) + return outputs + + def _get_telemetry_values(self) -> Dict: + telemetry_values = {"type": self.type, "source": self._source} + return telemetry_values + + def _register_in_current_pipeline_component_builder(self) -> None: + """Register this node in current pipeline component builder by adding self to a global stack.""" + from azure.ai.ml.dsl._pipeline_component_builder import _add_component_to_current_definition_builder + + # TODO: would it be better if we make _add_component_to_current_definition_builder a public function of + # _PipelineComponentBuilderStack and make _PipelineComponentBuilderStack a singleton? + _add_component_to_current_definition_builder(self) + + def _is_input_set(self, input_name: str) -> bool: + built_inputs = self._build_inputs() + return input_name in built_inputs and built_inputs[input_name] is not None + + @classmethod + def _refine_optional_inputs_with_no_value(cls, node: "BaseNode", kwargs: Any) -> None: + """Refine optional inputs that have no default value and no value is provided when calling command/parallel + function. + + This is to align with behavior of calling component to generate a pipeline node. + + :param node: The node + :type node: BaseNode + :param kwargs: The kwargs + :type kwargs: dict + """ + for key, value in node.inputs.items(): + meta = value._data + if ( + isinstance(meta, Input) + and meta._is_primitive_type is False + and meta.optional is True + and not meta.path + and key not in kwargs + ): + value._data = None diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/command.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/command.py new file mode 100644 index 00000000..0073307c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/command.py @@ -0,0 +1,1017 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint: disable=protected-access,too-many-lines +import copy +import logging +import os +from enum import Enum +from os import PathLike +from typing import Any, Dict, List, Optional, Tuple, Union, cast, overload + +from marshmallow import INCLUDE, Schema + +from azure.ai.ml._restclient.v2025_01_01_preview.models import CommandJob as RestCommandJob +from azure.ai.ml._restclient.v2025_01_01_preview.models import JobBase +from azure.ai.ml._schema.core.fields import ExperimentalField, NestedField, UnionField +from azure.ai.ml._schema.job.command_job import CommandJobSchema +from azure.ai.ml._schema.job.identity import AMLTokenIdentitySchema, ManagedIdentitySchema, UserIdentitySchema +from azure.ai.ml._schema.job.services import JobServiceSchema +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, LOCAL_COMPUTE_PROPERTY, LOCAL_COMPUTE_TARGET +from azure.ai.ml.constants._component import ComponentSource, NodeType +from azure.ai.ml.entities._assets import Environment +from azure.ai.ml.entities._component.command_component import CommandComponent +from azure.ai.ml.entities._credentials import ( + AmlTokenConfiguration, + ManagedIdentityConfiguration, + UserIdentityConfiguration, + _BaseJobIdentityConfiguration, +) +from azure.ai.ml.entities._inputs_outputs import Input, Output +from azure.ai.ml.entities._job._input_output_helpers import from_rest_data_outputs, from_rest_inputs_to_dataset_literal +from azure.ai.ml.entities._job.command_job import CommandJob +from azure.ai.ml.entities._job.distribution import ( + DistributionConfiguration, + MpiDistribution, + PyTorchDistribution, + RayDistribution, + TensorFlowDistribution, +) +from azure.ai.ml.entities._job.job_limits import CommandJobLimits +from azure.ai.ml.entities._job.job_resource_configuration import JobResourceConfiguration +from azure.ai.ml.entities._job.job_service import ( + JobService, + JobServiceBase, + JupyterLabJobService, + SshJobService, + TensorBoardJobService, + VsCodeJobService, +) +from azure.ai.ml.entities._job.queue_settings import QueueSettings +from azure.ai.ml.entities._job.sweep.early_termination_policy import EarlyTerminationPolicy +from azure.ai.ml.entities._job.sweep.objective import Objective +from azure.ai.ml.entities._job.sweep.search_space import ( + Choice, + LogNormal, + LogUniform, + Normal, + QLogNormal, + QLogUniform, + QNormal, + QUniform, + Randint, + SweepDistribution, + Uniform, +) +from azure.ai.ml.entities._system_data import SystemData +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + +from ..._schema import PathAwareSchema +from ..._schema.job.distribution import ( + MPIDistributionSchema, + PyTorchDistributionSchema, + RayDistributionSchema, + TensorFlowDistributionSchema, +) +from .._job.pipeline._io import NodeWithGroupInputMixin +from .._util import ( + convert_ordered_dict_to_dict, + from_rest_dict_to_dummy_rest_object, + get_rest_dict_for_node_attrs, + load_from_dict, + validate_attribute_type, +) +from .base_node import BaseNode +from .sweep import Sweep + +module_logger = logging.getLogger(__name__) + + +class Command(BaseNode, NodeWithGroupInputMixin): + """Base class for command node, used for command component version consumption. + + You should not instantiate this class directly. Instead, you should create it using the builder function: command(). + + :keyword component: The ID or instance of the command component or job to be run for the step. + :paramtype component: Union[str, ~azure.ai.ml.entities.CommandComponent] + :keyword compute: The compute target the job will run on. + :paramtype compute: Optional[str] + :keyword inputs: A mapping of input names to input data sources used in the job. + :paramtype inputs: Optional[dict[str, Union[ + ~azure.ai.ml.Input, str, bool, int, float, Enum]]] + :keyword outputs: A mapping of output names to output data sources used in the job. + :paramtype outputs: Optional[dict[str, Union[str, ~azure.ai.ml.Output]]] + :keyword limits: The limits for the command component or job. + :paramtype limits: ~azure.ai.ml.entities.CommandJobLimits + :keyword identity: The identity that the command job will use while running on compute. + :paramtype identity: Optional[Union[ + dict[str, str], + ~azure.ai.ml.entities.ManagedIdentityConfiguration, + ~azure.ai.ml.entities.AmlTokenConfiguration, + ~azure.ai.ml.entities.UserIdentityConfiguration]] + :keyword distribution: The configuration for distributed jobs. + :paramtype distribution: Optional[Union[dict, ~azure.ai.ml.PyTorchDistribution, ~azure.ai.ml.MpiDistribution, + ~azure.ai.ml.TensorFlowDistribution, ~azure.ai.ml.RayDistribution]] + :keyword environment: The environment that the job will run in. + :paramtype environment: Optional[Union[str, ~azure.ai.ml.entities.Environment]] + :keyword environment_variables: A dictionary of environment variable names and values. + These environment variables are set on the process where the user script is being executed. + :paramtype environment_variables: Optional[dict[str, str]] + :keyword resources: The compute resource configuration for the command. + :paramtype resources: Optional[~azure.ai.ml.entities.JobResourceConfiguration] + :keyword services: The interactive services for the node. This is an experimental parameter, and may change at any + time. Please see https://aka.ms/azuremlexperimental for more information. + :paramtype services: Optional[dict[str, Union[~azure.ai.ml.entities.JobService, + ~azure.ai.ml.entities.JupyterLabJobService, + ~azure.ai.ml.entities.SshJobService, ~azure.ai.ml.entities.TensorBoardJobService, + ~azure.ai.ml.entities.VsCodeJobService]]] + :keyword queue_settings: Queue settings for the job. + :paramtype queue_settings: Optional[~azure.ai.ml.entities.QueueSettings] + :keyword parent_job_name: parent job id for command job + :paramtype parent_job_name: Optional[str] + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Command cannot be successfully validated. + Details will be provided in the error message. + """ + + # pylint: disable=too-many-instance-attributes + def __init__( + self, + *, + component: Union[str, CommandComponent], + compute: Optional[str] = None, + inputs: Optional[ + Dict[ + str, + Union[ + Input, + str, + bool, + int, + float, + Enum, + ], + ] + ] = None, + outputs: Optional[Dict[str, Union[str, Output]]] = None, + limits: Optional[CommandJobLimits] = None, + identity: Optional[ + Union[Dict, ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration] + ] = None, + distribution: Optional[ + Union[ + Dict, + MpiDistribution, + TensorFlowDistribution, + PyTorchDistribution, + RayDistribution, + DistributionConfiguration, + ] + ] = None, + environment: Optional[Union[Environment, str]] = None, + environment_variables: Optional[Dict] = None, + resources: Optional[JobResourceConfiguration] = None, + services: Optional[ + Dict[str, Union[JobService, JupyterLabJobService, SshJobService, TensorBoardJobService, VsCodeJobService]] + ] = None, + queue_settings: Optional[QueueSettings] = None, + parent_job_name: Optional[str] = None, + **kwargs: Any, + ) -> None: + # validate init params are valid type + validate_attribute_type(attrs_to_check=locals(), attr_type_map=self._attr_type_map()) + + # resolve normal dict to dict[str, JobService] + services = _resolve_job_services(services) + kwargs.pop("type", None) + self._parameters: dict = kwargs.pop("parameters", {}) + BaseNode.__init__( + self, + type=NodeType.COMMAND, + inputs=inputs, + outputs=outputs, + component=component, + compute=compute, + services=services, + **kwargs, + ) + + # init mark for _AttrDict + self._init = True + # initialize command job properties + self.limits = limits + self.identity = identity + self._distribution = distribution + self.environment_variables = {} if environment_variables is None else environment_variables + self.environment: Any = environment + self._resources = resources + self._services = services + self.queue_settings = queue_settings + self.parent_job_name = parent_job_name + + if isinstance(self.component, CommandComponent): + self.resources = self.resources or self.component.resources # type: ignore[assignment] + self.distribution = self.distribution or self.component.distribution + + self._swept: bool = False + self._init = False + + @classmethod + def _get_supported_inputs_types(cls) -> Tuple: + supported_types = super()._get_supported_inputs_types() or () + return ( + SweepDistribution, + *supported_types, + ) + + @classmethod + def _get_supported_outputs_types(cls) -> Tuple: + return str, Output + + @property + def parameters(self) -> Dict[str, str]: + """MLFlow parameters to be logged during the job. + + :return: The MLFlow parameters to be logged during the job. + :rtype: dict[str, str] + """ + return self._parameters + + @property + def distribution( + self, + ) -> Optional[ + Union[ + Dict, + MpiDistribution, + TensorFlowDistribution, + PyTorchDistribution, + RayDistribution, + DistributionConfiguration, + ] + ]: + """The configuration for the distributed command component or job. + + :return: The configuration for distributed jobs. + :rtype: Union[~azure.ai.ml.PyTorchDistribution, ~azure.ai.ml.MpiDistribution, + ~azure.ai.ml.TensorFlowDistribution, ~azure.ai.ml.RayDistribution] + """ + return self._distribution + + @distribution.setter + def distribution( + self, + value: Union[Dict, PyTorchDistribution, TensorFlowDistribution, MpiDistribution, RayDistribution], + ) -> None: + """Sets the configuration for the distributed command component or job. + + :param value: The configuration for distributed jobs. + :type value: Union[dict, ~azure.ai.ml.PyTorchDistribution, ~azure.ai.ml.MpiDistribution, + ~azure.ai.ml.TensorFlowDistribution, ~azure.ai.ml.RayDistribution] + """ + if isinstance(value, dict): + dist_schema = UnionField( + [ + NestedField(PyTorchDistributionSchema, unknown=INCLUDE), + NestedField(TensorFlowDistributionSchema, unknown=INCLUDE), + NestedField(MPIDistributionSchema, unknown=INCLUDE), + ExperimentalField(NestedField(RayDistributionSchema, unknown=INCLUDE)), + ] + ) + value = dist_schema._deserialize(value=value, attr=None, data=None) + self._distribution = value + + @property + def resources(self) -> JobResourceConfiguration: + """The compute resource configuration for the command component or job. + + :rtype: ~azure.ai.ml.entities.JobResourceConfiguration + """ + return cast(JobResourceConfiguration, self._resources) + + @resources.setter + def resources(self, value: Union[Dict, JobResourceConfiguration]) -> None: + """Sets the compute resource configuration for the command component or job. + + :param value: The compute resource configuration for the command component or job. + :type value: Union[dict, ~azure.ai.ml.entities.JobResourceConfiguration] + """ + if isinstance(value, dict): + value = JobResourceConfiguration(**value) + self._resources = value + + @property + def queue_settings(self) -> Optional[QueueSettings]: + """The queue settings for the command component or job. + + :return: The queue settings for the command component or job. + :rtype: ~azure.ai.ml.entities.QueueSettings + """ + return self._queue_settings + + @queue_settings.setter + def queue_settings(self, value: Union[Dict, QueueSettings]) -> None: + """Sets the queue settings for the command component or job. + + :param value: The queue settings for the command component or job. + :type value: Union[dict, ~azure.ai.ml.entities.QueueSettings] + """ + if isinstance(value, dict): + value = QueueSettings(**value) + self._queue_settings = value + + @property + def identity( + self, + ) -> Optional[Union[Dict, ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration]]: + """The identity that the job will use while running on compute. + + :return: The identity that the job will use while running on compute. + :rtype: Optional[Union[~azure.ai.ml.ManagedIdentityConfiguration, ~azure.ai.ml.AmlTokenConfiguration, + ~azure.ai.ml.UserIdentityConfiguration]] + """ + return self._identity + + @identity.setter + def identity( + self, + value: Optional[Union[Dict, ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration]], + ) -> None: + """Sets the identity that the job will use while running on compute. + + :param value: The identity that the job will use while running on compute. + :type value: Union[dict[str, str], ~azure.ai.ml.ManagedIdentityConfiguration, + ~azure.ai.ml.AmlTokenConfiguration, ~azure.ai.ml.UserIdentityConfiguration] + """ + if isinstance(value, dict): + identity_schema = UnionField( + [ + NestedField(ManagedIdentitySchema, unknown=INCLUDE), + NestedField(AMLTokenIdentitySchema, unknown=INCLUDE), + NestedField(UserIdentitySchema, unknown=INCLUDE), + ] + ) + value = identity_schema._deserialize(value=value, attr=None, data=None) + self._identity = value + + @property + def services( + self, + ) -> Optional[ + Dict[str, Union[JobService, JupyterLabJobService, SshJobService, TensorBoardJobService, VsCodeJobService]] + ]: + """The interactive services for the node. + + This is an experimental parameter, and may change at any time. + Please see https://aka.ms/azuremlexperimental for more information. + + :rtype: dict[str, Union[~azure.ai.ml.entities.JobService, ~azure.ai.ml.entities.JupyterLabJobService, + ~azure.ai.ml.entities.SshJobService, ~azure.ai.ml.entities.TensorBoardJobService, + ~azure.ai.ml.entities.VsCodeJobService]] + """ + return self._services + + @services.setter + def services( + self, + value: Dict, + ) -> None: + """Sets the interactive services for the node. + + This is an experimental parameter, and may change at any time. + Please see https://aka.ms/azuremlexperimental for more information. + + :param value: The interactive services for the node. + :type value: dict[str, Union[~azure.ai.ml.entities.JobService, ~azure.ai.ml.entities.JupyterLabJobService, + ~azure.ai.ml.entities.SshJobService, ~azure.ai.ml.entities.TensorBoardJobService, + ~azure.ai.ml.entities.VsCodeJobService]] + """ + self._services = _resolve_job_services(value) # type: ignore[assignment] + + @property + def component(self) -> Union[str, CommandComponent]: + """The ID or instance of the command component or job to be run for the step. + + :return: The ID or instance of the command component or job to be run for the step. + :rtype: Union[str, ~azure.ai.ml.entities.CommandComponent] + """ + return self._component + + @property + def command(self) -> Optional[str]: + """The command to be executed. + + :rtype: Optional[str] + """ + # the same as code + if not isinstance(self.component, CommandComponent): + return None + + if self.component.command is None: + return None + return str(self.component.command) + + @command.setter + def command(self, value: str) -> None: + """Sets the command to be executed. + + :param value: The command to be executed. + :type value: str + """ + if isinstance(self.component, CommandComponent): + self.component.command = value + else: + msg = "Can't set command property for a registered component {}. Tried to set it to {}." + raise ValidationException( + message=msg.format(self.component, value), + no_personal_data_message=msg, + target=ErrorTarget.COMMAND_JOB, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + @property + def code(self) -> Optional[Union[str, PathLike]]: + """The source code to run the job. + + :rtype: Optional[Union[str, os.PathLike]] + """ + # BaseNode is an _AttrDict to allow dynamic attributes, so that lower version of SDK can work with attributes + # added in higher version of SDK. + # self.code will be treated as an Arbitrary attribute if it raises AttributeError in getting + # (when self.component doesn't have attribute code, self.component = 'azureml:xxx:1' e.g. + # you may check _AttrDict._is_arbitrary_attr for detailed logic for Arbitrary judgement), + # then its value will be set to _AttrDict and be deserialized as {"shape": {}} instead of None, + # which is invalid in schema validation. + if not isinstance(self.component, CommandComponent): + return None + + if self.component.code is None: + return None + + return str(self.component.code) + + @code.setter + def code(self, value: str) -> None: + """Sets the source code to run the job. + + :param value: The source code to run the job. Can be a local path or "http:", "https:", or "azureml:" url + pointing to a remote location. + :type value: str + """ + if isinstance(self.component, CommandComponent): + self.component.code = value + else: + msg = "Can't set code property for a registered component {}" + raise ValidationException( + message=msg.format(self.component), + no_personal_data_message=msg, + target=ErrorTarget.COMMAND_JOB, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + def set_resources( + self, + *, + instance_type: Optional[Union[str, List[str]]] = None, + instance_count: Optional[int] = None, + locations: Optional[List[str]] = None, + properties: Optional[Dict] = None, + docker_args: Optional[Union[str, List[str]]] = None, + shm_size: Optional[str] = None, + # pylint: disable=unused-argument + **kwargs: Any, + ) -> None: + """Set resources for Command. + + :keyword instance_type: The type of compute instance to run the job on. If not specified, the job will run on + the default compute target. + :paramtype instance_type: Optional[Union[str, List[str]]] + :keyword instance_count: The number of instances to run the job on. If not specified, the job will run on a + single instance. + :paramtype instance_count: Optional[int] + :keyword locations: The list of locations where the job will run. If not specified, the job will run on the + default compute target. + :paramtype locations: Optional[List[str]] + :keyword properties: The properties of the job. + :paramtype properties: Optional[dict] + :keyword docker_args: The Docker arguments for the job. + :paramtype docker_args: Optional[Union[str,List[str]]] + :keyword shm_size: The size of the docker container's shared memory block. This should be in the + format of (number)(unit) where the number has to be greater than 0 and the unit can be one of + b(bytes), k(kilobytes), m(megabytes), or g(gigabytes). + :paramtype shm_size: Optional[str] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_command_configurations.py + :start-after: [START command_set_resources] + :end-before: [END command_set_resources] + :language: python + :dedent: 8 + :caption: Setting resources on a Command. + """ + if self.resources is None: + self.resources = JobResourceConfiguration() + + if locations is not None: + self.resources.locations = locations + if instance_type is not None: + self.resources.instance_type = instance_type + if instance_count is not None: + self.resources.instance_count = instance_count + if properties is not None: + self.resources.properties = properties + if docker_args is not None: + self.resources.docker_args = docker_args + if shm_size is not None: + self.resources.shm_size = shm_size + + # Save the resources to internal component as well, otherwise calling sweep() will loose the settings + if isinstance(self.component, CommandComponent): + self.component.resources = self.resources + + def set_limits(self, *, timeout: int, **kwargs: Any) -> None: # pylint: disable=unused-argument + """Set limits for Command. + + :keyword timeout: The timeout for the job in seconds. + :paramtype timeout: int + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_command_configurations.py + :start-after: [START command_set_limits] + :end-before: [END command_set_limits] + :language: python + :dedent: 8 + :caption: Setting a timeout limit of 10 seconds on a Command. + """ + if isinstance(self.limits, CommandJobLimits): + self.limits.timeout = timeout + else: + self.limits = CommandJobLimits(timeout=timeout) + + def set_queue_settings(self, *, job_tier: Optional[str] = None, priority: Optional[str] = None) -> None: + """Set QueueSettings for the job. + + :keyword job_tier: The job tier. Accepted values are "Spot", "Basic", "Standard", or "Premium". + :paramtype job_tier: Optional[str] + :keyword priority: The priority of the job on the compute. Defaults to "Medium". + :paramtype priority: Optional[str] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_command_configurations.py + :start-after: [START command_set_queue_settings] + :end-before: [END command_set_queue_settings] + :language: python + :dedent: 8 + :caption: Configuring queue settings on a Command. + """ + if isinstance(self.queue_settings, QueueSettings): + self.queue_settings.job_tier = job_tier + self.queue_settings.priority = priority + else: + self.queue_settings = QueueSettings(job_tier=job_tier, priority=priority) + + def sweep( + self, + *, + primary_metric: str, + goal: str, + sampling_algorithm: str = "random", + compute: Optional[str] = None, + max_concurrent_trials: Optional[int] = None, + max_total_trials: Optional[int] = None, + timeout: Optional[int] = None, + trial_timeout: Optional[int] = None, + early_termination_policy: Optional[Union[EarlyTerminationPolicy, str]] = None, + search_space: Optional[ + Dict[ + str, + Union[ + Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform + ], + ] + ] = None, + identity: Optional[ + Union[ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration] + ] = None, + queue_settings: Optional[QueueSettings] = None, + job_tier: Optional[str] = None, + priority: Optional[str] = None, + ) -> Sweep: + """Turns the command into a sweep node with extra sweep run setting. The command component + in the current command node will be used as its trial component. A command node can sweep + multiple times, and the generated sweep node will share the same trial component. + + :keyword primary_metric: The primary metric of the sweep objective - e.g. AUC (Area Under the Curve). + The metric must be logged while running the trial component. + :paramtype primary_metric: str + :keyword goal: The goal of the Sweep objective. Accepted values are "minimize" or "maximize". + :paramtype goal: str + :keyword sampling_algorithm: The sampling algorithm to use inside the search space. + Acceptable values are "random", "grid", or "bayesian". Defaults to "random". + :paramtype sampling_algorithm: str + :keyword compute: The target compute to run the node on. If not specified, the current node's compute + will be used. + :paramtype compute: Optional[str] + :keyword max_total_trials: The maximum number of total trials to run. This value will overwrite the value in + CommandJob.limits if specified. + :paramtype max_total_trials: Optional[int] + :keyword max_concurrent_trials: The maximum number of concurrent trials for the Sweep job. + :paramtype max_concurrent_trials: Optional[int] + :keyword timeout: The maximum run duration in seconds, after which the job will be cancelled. + :paramtype timeout: Optional[int] + :keyword trial_timeout: The Sweep Job trial timeout value, in seconds. + :paramtype trial_timeout: Optional[int] + :keyword early_termination_policy: The early termination policy of the sweep node. Acceptable + values are "bandit", "median_stopping", or "truncation_selection". Defaults to None. + :paramtype early_termination_policy: Optional[Union[~azure.ai.ml.sweep.BanditPolicy, + ~azure.ai.ml.sweep.TruncationSelectionPolicy, ~azure.ai.ml.sweep.MedianStoppingPolicy, str]] + :keyword identity: The identity that the job will use while running on compute. + :paramtype identity: Optional[Union[ + ~azure.ai.ml.ManagedIdentityConfiguration, + ~azure.ai.ml.AmlTokenConfiguration, + ~azure.ai.ml.UserIdentityConfiguration]] + :keyword search_space: The search space to use for the sweep job. + :paramtype search_space: Optional[Dict[str, Union[ + Choice, + LogNormal, + LogUniform, + Normal, + QLogNormal, + QLogUniform, + QNormal, + QUniform, + Randint, + Uniform + + ]]] + + :keyword queue_settings: The queue settings for the job. + :paramtype queue_settings: Optional[~azure.ai.ml.entities.QueueSettings] + :keyword job_tier: **Experimental** The job tier. Accepted values are "Spot", "Basic", + "Standard", or "Premium". + :paramtype job_tier: Optional[str] + :keyword priority: **Experimental** The compute priority. Accepted values are "low", + "medium", and "high". + :paramtype priority: Optional[str] + :return: A Sweep node with the component from current Command node as its trial component. + :rtype: ~azure.ai.ml.entities.Sweep + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_sweep_configurations.py + :start-after: [START configure_sweep_job_bandit_policy] + :end-before: [END configure_sweep_job_bandit_policy] + :language: python + :dedent: 8 + :caption: Creating a Sweep node from a Command job. + """ + self._swept = True + # inputs & outputs are already built in source Command obj + inputs, inputs_search_space = Sweep._get_origin_inputs_and_search_space(self.inputs) + if search_space: + inputs_search_space.update(search_space) + + if not queue_settings: + queue_settings = self.queue_settings + if queue_settings is not None: + if job_tier is not None: + queue_settings.job_tier = job_tier + if priority is not None: + queue_settings.priority = priority + + sweep_node = Sweep( + trial=copy.deepcopy( + self.component + ), # Make a copy of the underneath Component so that the original node can still be used. + compute=self.compute if compute is None else compute, + objective=Objective(goal=goal, primary_metric=primary_metric), + sampling_algorithm=sampling_algorithm, + inputs=inputs, + outputs=self._get_origin_job_outputs(), + search_space=inputs_search_space, + early_termination=early_termination_policy, + name=self.name, + description=self.description, + display_name=self.display_name, + tags=self.tags, + properties=self.properties, + experiment_name=self.experiment_name, + identity=self.identity if not identity else identity, + _from_component_func=True, + queue_settings=queue_settings, + ) + sweep_node.set_limits( + max_total_trials=max_total_trials, + max_concurrent_trials=max_concurrent_trials, + timeout=timeout, + trial_timeout=trial_timeout, + ) + return sweep_node + + @classmethod + def _attr_type_map(cls) -> dict: + return { + "component": (str, CommandComponent), + "environment": (str, Environment), + "environment_variables": dict, + "resources": (dict, JobResourceConfiguration), + "limits": (dict, CommandJobLimits), + "code": (str, os.PathLike), + } + + def _to_job(self) -> CommandJob: + if isinstance(self.component, CommandComponent): + return CommandJob( + id=self.id, + name=self.name, + display_name=self.display_name, + description=self.description, + tags=self.tags, + properties=self.properties, + command=self.component.command, + experiment_name=self.experiment_name, + code=self.component.code, + compute=self.compute, + status=self.status, + environment=self.environment, + distribution=self.distribution, + identity=self.identity, + environment_variables=self.environment_variables, + resources=self.resources, + limits=self.limits, + inputs=self._job_inputs, + outputs=self._job_outputs, + services=self.services, + creation_context=self.creation_context, + parameters=self.parameters, + queue_settings=self.queue_settings, + parent_job_name=self.parent_job_name, + ) + + return CommandJob( + id=self.id, + name=self.name, + display_name=self.display_name, + description=self.description, + tags=self.tags, + properties=self.properties, + command=None, + experiment_name=self.experiment_name, + code=None, + compute=self.compute, + status=self.status, + environment=self.environment, + distribution=self.distribution, + identity=self.identity, + environment_variables=self.environment_variables, + resources=self.resources, + limits=self.limits, + inputs=self._job_inputs, + outputs=self._job_outputs, + services=self.services, + creation_context=self.creation_context, + parameters=self.parameters, + queue_settings=self.queue_settings, + parent_job_name=self.parent_job_name, + ) + + @classmethod + def _picked_fields_from_dict_to_rest_object(cls) -> List[str]: + return ["resources", "distribution", "limits", "environment_variables", "queue_settings"] + + def _to_rest_object(self, **kwargs: Any) -> dict: + rest_obj = super()._to_rest_object(**kwargs) + for key, value in { + "componentId": self._get_component_id(), + "distribution": get_rest_dict_for_node_attrs(self.distribution, clear_empty_value=True), + "limits": get_rest_dict_for_node_attrs(self.limits, clear_empty_value=True), + "resources": get_rest_dict_for_node_attrs(self.resources, clear_empty_value=True), + "services": get_rest_dict_for_node_attrs(self.services), + "identity": get_rest_dict_for_node_attrs(self.identity), + "queue_settings": get_rest_dict_for_node_attrs(self.queue_settings, clear_empty_value=True), + }.items(): + if value is not None: + rest_obj[key] = value + return cast(dict, convert_ordered_dict_to_dict(rest_obj)) + + @classmethod + def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "Command": + from .command_func import command + + loaded_data = load_from_dict(CommandJobSchema, data, context, additional_message, **kwargs) + + # resources a limits properties are flatten in command() function, exact them and set separately + resources = loaded_data.pop("resources", None) + limits = loaded_data.pop("limits", None) + + command_job: Command = command(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_data) + + command_job.resources = resources + command_job.limits = limits + return command_job + + @classmethod + def _from_rest_object_to_init_params(cls, obj: dict) -> Dict: + obj = BaseNode._from_rest_object_to_init_params(obj) + + if "resources" in obj and obj["resources"]: + obj["resources"] = JobResourceConfiguration._from_rest_object(obj["resources"]) + + # services, sweep won't have services + if "services" in obj and obj["services"]: + # pipeline node rest object are dicts while _from_rest_job_services expect RestJobService + services = {} + for service_name, service in obj["services"].items(): + # in rest object of a pipeline job, service will be transferred to a dict as + # it's attributes of a node, but JobService._from_rest_object expect a + # RestJobService, so we need to convert it back. Here we convert the dict to a + # dummy rest object which may work as a RestJobService instead. + services[service_name] = from_rest_dict_to_dummy_rest_object(service) + obj["services"] = JobServiceBase._from_rest_job_services(services) + + # handle limits + if "limits" in obj and obj["limits"]: + obj["limits"] = CommandJobLimits._from_rest_object(obj["limits"]) + + if "identity" in obj and obj["identity"]: + obj["identity"] = _BaseJobIdentityConfiguration._from_rest_object(obj["identity"]) + + if "queue_settings" in obj and obj["queue_settings"]: + obj["queue_settings"] = QueueSettings._from_rest_object(obj["queue_settings"]) + + return obj + + @classmethod + def _load_from_rest_job(cls, obj: JobBase) -> "Command": + from .command_func import command + + rest_command_job: RestCommandJob = obj.properties + + command_job: Command = command( + name=obj.name, + display_name=rest_command_job.display_name, + description=rest_command_job.description, + tags=rest_command_job.tags, + properties=rest_command_job.properties, + command=rest_command_job.command, + experiment_name=rest_command_job.experiment_name, + services=JobServiceBase._from_rest_job_services(rest_command_job.services), + status=rest_command_job.status, + creation_context=SystemData._from_rest_object(obj.system_data) if obj.system_data else None, + code=rest_command_job.code_id, + compute=rest_command_job.compute_id, + environment=rest_command_job.environment_id, + distribution=DistributionConfiguration._from_rest_object(rest_command_job.distribution), + parameters=rest_command_job.parameters, + identity=( + _BaseJobIdentityConfiguration._from_rest_object(rest_command_job.identity) + if rest_command_job.identity + else None + ), + environment_variables=rest_command_job.environment_variables, + inputs=from_rest_inputs_to_dataset_literal(rest_command_job.inputs), + outputs=from_rest_data_outputs(rest_command_job.outputs), + ) + command_job._id = obj.id + command_job.resources = cast( + JobResourceConfiguration, JobResourceConfiguration._from_rest_object(rest_command_job.resources) + ) + command_job.limits = CommandJobLimits._from_rest_object(rest_command_job.limits) + command_job.queue_settings = QueueSettings._from_rest_object(rest_command_job.queue_settings) + if isinstance(command_job.component, CommandComponent): + command_job.component._source = ( + ComponentSource.REMOTE_WORKSPACE_JOB + ) # This is used by pipeline job telemetries. + + # Handle special case of local job + if ( + command_job.resources is not None + and command_job.resources.properties is not None + and command_job.resources.properties.get(LOCAL_COMPUTE_PROPERTY, None) + ): + command_job.compute = LOCAL_COMPUTE_TARGET + command_job.resources.properties.pop(LOCAL_COMPUTE_PROPERTY) + return command_job + + def _build_inputs(self) -> Dict: + inputs = super(Command, self)._build_inputs() + built_inputs = {} + # Validate and remove non-specified inputs + for key, value in inputs.items(): + if value is not None: + built_inputs[key] = value + + return built_inputs + + @classmethod + def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]: + from azure.ai.ml._schema.pipeline import CommandSchema + + return CommandSchema(context=context) + + # pylint: disable-next=docstring-missing-param + def __call__(self, *args: Any, **kwargs: Any) -> "Command": + """Call Command as a function will return a new instance each time. + + :return: A Command node + :rtype: Command + """ + if isinstance(self._component, CommandComponent): + # call this to validate inputs + node: Command = self._component(*args, **kwargs) + # merge inputs + for name, original_input in self.inputs.items(): + if name not in kwargs: + # use setattr here to make sure owner of input won't change + setattr(node.inputs, name, original_input._data) + node._job_inputs[name] = original_input._data + # get outputs + for name, original_output in self.outputs.items(): + # use setattr here to make sure owner of input won't change + if not isinstance(original_output, str): + setattr(node.outputs, name, original_output._data) + node._job_outputs[name] = original_output._data + self._refine_optional_inputs_with_no_value(node, kwargs) + # set default values: compute, environment_variables, outputs + # won't copy name to be able to distinguish if a node's name is assigned by user + # e.g. node_1 = command_func() + # In above example, node_1.name will be None so we can apply node_1 as it's name + node.compute = self.compute + node.tags = self.tags + # Pass through the display name only if the display name is not system generated. + node.display_name = self.display_name if self.display_name != self.name else None + node.environment = copy.deepcopy(self.environment) + # deep copy for complex object + node.environment_variables = copy.deepcopy(self.environment_variables) + node.limits = copy.deepcopy(self.limits) + node.distribution = copy.deepcopy(self.distribution) + node.resources = copy.deepcopy(self.resources) + node.queue_settings = copy.deepcopy(self.queue_settings) + node.services = copy.deepcopy(self.services) + node.identity = copy.deepcopy(self.identity) + return node + msg = "Command can be called as a function only when referenced component is {}, currently got {}." + raise ValidationException( + message=msg.format(type(CommandComponent), self._component), + no_personal_data_message=msg.format(type(CommandComponent), "self._component"), + target=ErrorTarget.COMMAND_JOB, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + +@overload +def _resolve_job_services(services: Optional[Dict]): ... + + +@overload +def _resolve_job_services( + services: Dict[str, Union[JobServiceBase, Dict]], +) -> Dict[str, Union[JobService, JupyterLabJobService, SshJobService, TensorBoardJobService, VsCodeJobService]]: ... + + +def _resolve_job_services( + services: Optional[Dict[str, Union[JobServiceBase, Dict]]], +) -> Optional[Dict]: + """Resolve normal dict to dict[str, JobService] + + :param services: A dict that maps service names to either a JobServiceBase object, or a Dict used to build one + :type services: Optional[Dict[str, Union[JobServiceBase, Dict]]] + :return: + * None if `services` is None + * A map of job service names to job services + :rtype: Optional[ + Dict[str, Union[JobService, JupyterLabJobService, SshJobService, TensorBoardJobService, VsCodeJobService]] + ] + """ + if services is None: + return None + + if not isinstance(services, dict): + msg = f"Services must be a dict, got {type(services)} instead." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.COMMAND_JOB, + error_category=ErrorCategory.USER_ERROR, + ) + + result = {} + for name, service in services.items(): + if isinstance(service, dict): + service = load_from_dict(JobServiceSchema, service, context={BASE_PATH_CONTEXT_KEY: "."}) + elif not isinstance( + service, (JobService, JupyterLabJobService, SshJobService, TensorBoardJobService, VsCodeJobService) + ): + msg = f"Service value for key {name!r} must be a dict or JobService object, got {type(service)} instead." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.COMMAND_JOB, + error_category=ErrorCategory.USER_ERROR, + ) + result[name] = service + return result diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/command_func.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/command_func.py new file mode 100644 index 00000000..c542f880 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/command_func.py @@ -0,0 +1,314 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import os +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +from azure.ai.ml.constants._common import AssetTypes, LegacyAssetTypes +from azure.ai.ml.constants._component import ComponentSource +from azure.ai.ml.entities._assets.environment import Environment +from azure.ai.ml.entities._component.command_component import CommandComponent +from azure.ai.ml.entities._credentials import ( + AmlTokenConfiguration, + ManagedIdentityConfiguration, + UserIdentityConfiguration, +) +from azure.ai.ml.entities._inputs_outputs import Input, Output +from azure.ai.ml.entities._job.distribution import ( + DistributionConfiguration, + MpiDistribution, + PyTorchDistribution, + RayDistribution, + TensorFlowDistribution, +) +from azure.ai.ml.entities._job.job_service import ( + JobService, + JupyterLabJobService, + SshJobService, + TensorBoardJobService, + VsCodeJobService, +) +from azure.ai.ml.entities._job.pipeline._component_translatable import ComponentTranslatableMixin +from azure.ai.ml.entities._job.sweep.search_space import SweepDistribution +from azure.ai.ml.exceptions import ErrorTarget, ValidationErrorType, ValidationException + +from .command import Command + +SUPPORTED_INPUTS = [ + LegacyAssetTypes.PATH, + AssetTypes.URI_FILE, + AssetTypes.URI_FOLDER, + AssetTypes.CUSTOM_MODEL, + AssetTypes.MLFLOW_MODEL, + AssetTypes.MLTABLE, + AssetTypes.TRITON_MODEL, +] + + +def _parse_input(input_value: Union[Input, Dict, SweepDistribution, str, bool, int, float]) -> Tuple: + component_input = None + job_input: Optional[Union[Input, Dict, SweepDistribution, str, bool, int, float]] = None + + if isinstance(input_value, Input): + component_input = Input(**input_value._to_dict()) + input_type = input_value.type + if input_type in SUPPORTED_INPUTS: + job_input = Input(**input_value._to_dict()) + elif isinstance(input_value, dict): + # if user provided dict, we try to parse it to Input. + # for job input, only parse for path type + input_type = input_value.get("type", None) + if input_type in SUPPORTED_INPUTS: + job_input = Input(**input_value) + component_input = Input(**input_value) + elif isinstance(input_value, (SweepDistribution, str, bool, int, float)): + # Input bindings are not supported + component_input = ComponentTranslatableMixin._to_input_builder_function(input_value) + job_input = input_value + else: + msg = f"Unsupported input type: {type(input_value)}" + msg += ", only Input, dict, str, bool, int and float are supported." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.JOB, + error_type=ValidationErrorType.INVALID_VALUE, + ) + return component_input, job_input + + +def _parse_output(output_value: Optional[Union[Output, Dict, str]]) -> Tuple: + component_output = None + job_output: Optional[Union[Output, Dict, str]] = None + + if isinstance(output_value, Output): + component_output = Output(**output_value._to_dict()) + job_output = Output(**output_value._to_dict()) + elif not output_value: + # output value can be None or empty dictionary + # None output value will be packed into a JobOutput object with mode = ReadWriteMount & type = UriFolder + component_output = ComponentTranslatableMixin._to_output(output_value) + job_output = output_value + elif isinstance(output_value, dict): # When output value is a non-empty dictionary + job_output = Output(**output_value) + component_output = Output(**output_value) + elif isinstance(output_value, str): # When output is passed in from pipeline job yaml + job_output = output_value + else: + msg = f"Unsupported output type: {type(output_value)}, only Output and dict are supported." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.JOB, + error_type=ValidationErrorType.INVALID_VALUE, + ) + return component_output, job_output + + +def _parse_inputs_outputs(io_dict: Dict, parse_func: Callable) -> Tuple[Dict, Dict]: + component_io_dict, job_io_dict = {}, {} + if io_dict: + for key, val in io_dict.items(): + component_io, job_io = parse_func(val) + component_io_dict[key] = component_io + job_io_dict[key] = job_io + return component_io_dict, job_io_dict + + +def command( + *, + name: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[Dict] = None, + properties: Optional[Dict] = None, + display_name: Optional[str] = None, + command: Optional[str] = None, # pylint: disable=redefined-outer-name + experiment_name: Optional[str] = None, + environment: Optional[Union[str, Environment]] = None, + environment_variables: Optional[Dict] = None, + distribution: Optional[ + Union[ + Dict, + MpiDistribution, + TensorFlowDistribution, + PyTorchDistribution, + RayDistribution, + DistributionConfiguration, + ] + ] = None, + compute: Optional[str] = None, + inputs: Optional[Dict] = None, + outputs: Optional[Dict] = None, + instance_count: Optional[int] = None, + instance_type: Optional[str] = None, + locations: Optional[List[str]] = None, + docker_args: Optional[Union[str, List[str]]] = None, + shm_size: Optional[str] = None, + timeout: Optional[int] = None, + code: Optional[Union[str, os.PathLike]] = None, + identity: Optional[Union[ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration]] = None, + is_deterministic: bool = True, + services: Optional[ + Dict[str, Union[JobService, JupyterLabJobService, SshJobService, TensorBoardJobService, VsCodeJobService]] + ] = None, + job_tier: Optional[str] = None, + priority: Optional[str] = None, + parent_job_name: Optional[str] = None, + **kwargs: Any, +) -> Command: + """Creates a Command object which can be used inside a dsl.pipeline function or used as a standalone Command job. + + :keyword name: The name of the Command job or component. + :paramtype name: Optional[str] + :keyword description: The description of the Command. Defaults to None. + :paramtype description: Optional[str] + :keyword tags: Tag dictionary. Tags can be added, removed, and updated. Defaults to None. + :paramtype tags: Optional[dict[str, str]] + :keyword properties: The job property dictionary. Defaults to None. + :paramtype properties: Optional[dict[str, str]] + :keyword display_name: The display name of the job. Defaults to a randomly generated name. + :paramtype display_name: Optional[str] + :keyword command: The command to be executed. Defaults to None. + :paramtype command: Optional[str] + :keyword experiment_name: The name of the experiment that the job will be created under. Defaults to current + directory name. + :paramtype experiment_name: Optional[str] + :keyword environment: The environment that the job will run in. + :paramtype environment: Optional[Union[str, ~azure.ai.ml.entities.Environment]] + :keyword environment_variables: A dictionary of environment variable names and values. + These environment variables are set on the process where user script is being executed. + Defaults to None. + :paramtype environment_variables: Optional[dict[str, str]] + :keyword distribution: The configuration for distributed jobs. Defaults to None. + :paramtype distribution: Optional[Union[dict, ~azure.ai.ml.PyTorchDistribution, ~azure.ai.ml.MpiDistribution, + ~azure.ai.ml.TensorFlowDistribution, ~azure.ai.ml.RayDistribution]] + :keyword compute: The compute target the job will run on. Defaults to default compute. + :paramtype compute: Optional[str] + :keyword inputs: A mapping of input names to input data sources used in the job. Defaults to None. + :paramtype inputs: Optional[dict[str, Union[~azure.ai.ml.Input, str, bool, int, float, Enum]]] + :keyword outputs: A mapping of output names to output data sources used in the job. Defaults to None. + :paramtype outputs: Optional[dict[str, Union[str, ~azure.ai.ml.Output]]] + :keyword instance_count: The number of instances or nodes to be used by the compute target. Defaults to 1. + :paramtype instance_count: Optional[int] + :keyword instance_type: The type of VM to be used by the compute target. + :paramtype instance_type: Optional[str] + :keyword locations: The list of locations where the job will run. + :paramtype locations: Optional[List[str]] + :keyword docker_args: Extra arguments to pass to the Docker run command. This would override any + parameters that have already been set by the system, or in this section. This parameter is only + supported for Azure ML compute types. Defaults to None. + :paramtype docker_args: Optional[Union[str,List[str]]] + :keyword shm_size: The size of the Docker container's shared memory block. This should be in the + format of (number)(unit) where the number has to be greater than 0 and the unit can be one of + b(bytes), k(kilobytes), m(megabytes), or g(gigabytes). + :paramtype shm_size: Optional[str] + :keyword timeout: The number, in seconds, after which the job will be cancelled. + :paramtype timeout: Optional[int] + :keyword code: The source code to run the job. Can be a local path or "http:", "https:", or "azureml:" url + pointing to a remote location. + :paramtype code: Optional[Union[str, os.PathLike]] + :keyword identity: The identity that the command job will use while running on compute. + :paramtype identity: Optional[Union[ + ~azure.ai.ml.entities.ManagedIdentityConfiguration, + ~azure.ai.ml.entities.AmlTokenConfiguration, + ~azure.ai.ml.entities.UserIdentityConfiguration]] + :keyword is_deterministic: Specifies whether the Command will return the same output given the same input. + Defaults to True. When True, if a Command Component is deterministic and has been run before in the + current workspace with the same input and settings, it will reuse results from a previously submitted + job when used as a node or step in a pipeline. In that scenario, no compute resources will be used. + :paramtype is_deterministic: bool + :keyword services: The interactive services for the node. Defaults to None. This is an experimental parameter, + and may change at any time. Please see https://aka.ms/azuremlexperimental for more information. + :paramtype services: Optional[dict[str, Union[~azure.ai.ml.entities.JobService, + ~azure.ai.ml.entities.JupyterLabJobService, ~azure.ai.ml.entities.SshJobService, + ~azure.ai.ml.entities.TensorBoardJobService, ~azure.ai.ml.entities.VsCodeJobService]]] + :keyword job_tier: The job tier. Accepted values are "Spot", "Basic", "Standard", or "Premium". + :paramtype job_tier: Optional[str] + :keyword priority: The priority of the job on the compute. Accepted values are "low", "medium", and "high". + Defaults to "medium". + :paramtype priority: Optional[str] + :keyword parent_job_name: parent job id for command job + :paramtype parent_job_name: Optional[str] + :return: A Command object. + :rtype: ~azure.ai.ml.entities.Command + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_command_configurations.py + :start-after: [START command_function] + :end-before: [END command_function] + :language: python + :dedent: 8 + :caption: Creating a Command Job using the command() builder method. + """ + # pylint: disable=too-many-locals + inputs = inputs or {} + outputs = outputs or {} + component_inputs, job_inputs = _parse_inputs_outputs(inputs, parse_func=_parse_input) + # job inputs can not be None + job_inputs = {k: v for k, v in job_inputs.items() if v is not None} + component_outputs, job_outputs = _parse_inputs_outputs(outputs, parse_func=_parse_output) + + component = kwargs.pop("component", None) + if component is None: + component = CommandComponent( + name=name, + tags=tags, + code=code, + command=command, + environment=environment, + display_name=display_name, + description=description, + inputs=component_inputs, + outputs=component_outputs, + distribution=distribution, + environment_variables=environment_variables, + _source=ComponentSource.BUILDER, + is_deterministic=is_deterministic, + **kwargs, + ) + command_obj = Command( + component=component, + name=name, + description=description, + tags=tags, + properties=properties, + display_name=display_name, + experiment_name=experiment_name, + compute=compute, + inputs=job_inputs, + outputs=job_outputs, + identity=identity, + distribution=distribution, + environment=environment, + environment_variables=environment_variables, + services=services, + parent_job_name=parent_job_name, + **kwargs, + ) + + if ( + locations is not None + or instance_count is not None + or instance_type is not None + or docker_args is not None + or shm_size is not None + ): + command_obj.set_resources( + locations=locations, + instance_count=instance_count, + instance_type=instance_type, + docker_args=docker_args, + shm_size=shm_size, + ) + + if timeout is not None: + command_obj.set_limits(timeout=timeout) + + if job_tier is not None or priority is not None: + command_obj.set_queue_settings(job_tier=job_tier, priority=priority) + + return command_obj diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/condition_node.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/condition_node.py new file mode 100644 index 00000000..5a5ad58b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/condition_node.py @@ -0,0 +1,146 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Any, Dict, List, Optional + +from azure.ai.ml._schema import PathAwareSchema +from azure.ai.ml._utils.utils import is_data_binding_expression +from azure.ai.ml.constants._component import ControlFlowType +from azure.ai.ml.entities._builders import BaseNode +from azure.ai.ml.entities._builders.control_flow_node import ControlFlowNode +from azure.ai.ml.entities._job.automl.automl_job import AutoMLJob +from azure.ai.ml.entities._job.pipeline._io import InputOutputBase +from azure.ai.ml.entities._validation import MutableValidationResult + + +class ConditionNode(ControlFlowNode): + """Conditional node in the pipeline. + + Please do not directly use this class. + + :param condition: The condition for the conditional node. + :type condition: Any + :param true_block: The list of nodes to execute when the condition is true. + :type true_block: List[~azure.ai.ml.entities._builders.BaseNode] + :param false_block: The list of nodes to execute when the condition is false. + :type false_block: List[~azure.ai.ml.entities._builders.BaseNode] + """ + + def __init__( + self, condition: Any, *, true_block: Optional[List] = None, false_block: Optional[List] = None, **kwargs: Any + ) -> None: + kwargs.pop("type", None) + super(ConditionNode, self).__init__(type=ControlFlowType.IF_ELSE, **kwargs) + self.condition = condition + if true_block and not isinstance(true_block, list): + true_block = [true_block] + self._true_block = true_block + if false_block and not isinstance(false_block, list): + false_block = [false_block] + self._false_block = false_block + + @classmethod + def _create_schema_for_validation(cls, context: Any) -> PathAwareSchema: + from azure.ai.ml._schema.pipeline.condition_node import ConditionNodeSchema + + return ConditionNodeSchema(context=context) + + @classmethod + def _from_rest_object(cls, obj: dict) -> "ConditionNode": + return cls(**obj) + + @classmethod + def _create_instance_from_schema_dict(cls, loaded_data: Dict) -> "ConditionNode": + """Create a condition node instance from schema parsed dict. + + :param loaded_data: The loaded data + :type loaded_data: Dict + :return: The ConditionNode node + :rtype: ConditionNode + """ + return cls(**loaded_data) + + @property + def true_block(self) -> Optional[List]: + """Get the list of nodes to execute when the condition is true. + + :return: The list of nodes to execute when the condition is true. + :rtype: List[~azure.ai.ml.entities._builders.BaseNode] + """ + return self._true_block + + @property + def false_block(self) -> Optional[List]: + """Get the list of nodes to execute when the condition is false. + + :return: The list of nodes to execute when the condition is false. + :rtype: List[~azure.ai.ml.entities._builders.BaseNode] + """ + return self._false_block + + def _customized_validate(self) -> MutableValidationResult: + return self._validate_params() + + def _validate_params(self) -> MutableValidationResult: + # pylint disable=protected-access + validation_result = self._create_empty_validation_result() + if not isinstance(self.condition, (str, bool, InputOutputBase)): + validation_result.append_error( + yaml_path="condition", + message=f"'condition' of dsl.condition node must be an instance of " + f"{str}, {bool} or {InputOutputBase}, got {type(self.condition)}.", + ) + + # Check if output is a control output. + # pylint: disable=protected-access + if isinstance(self.condition, InputOutputBase) and self.condition._meta is not None: + # pylint: disable=protected-access + output_definition = self.condition._meta + if output_definition is not None and not output_definition._is_primitive_type: + validation_result.append_error( + yaml_path="condition", + message=f"'condition' of dsl.condition node must be primitive type " + f"with value 'True', got {output_definition._is_primitive_type}", + ) + + # check if condition is valid binding + if isinstance(self.condition, str) and not is_data_binding_expression( + self.condition, ["parent"], is_singular=False + ): + error_tail = "for example, ${{parent.jobs.xxx.outputs.output}}" + validation_result.append_error( + yaml_path="condition", + message=f"'condition' of dsl.condition has invalid binding expression: {self.condition}, {error_tail}", + ) + + error_msg = ( + "{!r} of dsl.condition node must be an instance of " f"{BaseNode}, {AutoMLJob} or {str}," "got {!r}." + ) + blocks = self.true_block if self.true_block else [] + for block in blocks: + if block is not None and not isinstance(block, (BaseNode, AutoMLJob, str)): + validation_result.append_error( + yaml_path="true_block", message=error_msg.format("true_block", type(block)) + ) + blocks = self.false_block if self.false_block else [] + for block in blocks: + if block is not None and not isinstance(block, (BaseNode, AutoMLJob, str)): + validation_result.append_error( + yaml_path="false_block", message=error_msg.format("false_block", type(block)) + ) + + # check if true/false block is valid binding + for name, blocks in {"true_block": self.true_block, "false_block": self.false_block}.items(): # type: ignore + blocks = blocks if blocks else [] + for block in blocks: + if block is None or not isinstance(block, str): + continue + error_tail = "for example, ${{parent.jobs.xxx}}" + if not is_data_binding_expression(block, ["parent", "jobs"], is_singular=False): + validation_result.append_error( + yaml_path=name, + message=f"'{name}' of dsl.condition has invalid binding expression: {block}, {error_tail}", + ) + + return validation_result diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/control_flow_node.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/control_flow_node.py new file mode 100644 index 00000000..c757a1e4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/control_flow_node.py @@ -0,0 +1,170 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import logging +import re +import uuid +from abc import ABC +from typing import Any, Dict, Union, cast # pylint: disable=unused-import + +from marshmallow import ValidationError + +from azure.ai.ml._utils.utils import is_data_binding_expression +from azure.ai.ml.constants._common import CommonYamlFields +from azure.ai.ml.constants._component import ComponentSource, ControlFlowType +from azure.ai.ml.entities._mixins import YamlTranslatableMixin +from azure.ai.ml.entities._validation import MutableValidationResult, PathAwareSchemaValidatableMixin +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + +from .._util import convert_ordered_dict_to_dict +from .base_node import BaseNode + +module_logger = logging.getLogger(__name__) + + +# ControlFlowNode did not inherit from BaseNode since it doesn't have inputs/outputs like other nodes. +class ControlFlowNode(YamlTranslatableMixin, PathAwareSchemaValidatableMixin, ABC): + """Base class for control flow node in the pipeline. + + Please do not directly use this class. + + :param kwargs: Additional keyword arguments. + :type kwargs: Dict[str, Union[Any]] + """ + + def __init__(self, **kwargs: Any) -> None: + # TODO(1979547): refactor this + _source = kwargs.pop("_source", None) + self._source = _source if _source else ComponentSource.DSL + _from_component_func = kwargs.pop("_from_component_func", False) + self._type = kwargs.get("type", None) + self._instance_id = str(uuid.uuid4()) + self.name = None + if _from_component_func: + # add current control flow node in pipeline stack for dsl scenario and remove the body from the pipeline + # stack. + self._register_in_current_pipeline_component_builder() + + @property + def type(self) -> Any: + """Get the type of the control flow node. + + :return: The type of the control flow node. + :rtype: self._type + """ + return self._type + + def _to_dict(self) -> Dict: + return dict(self._dump_for_validation()) + + def _to_rest_object(self, **kwargs: Any) -> dict: # pylint: disable=unused-argument + """Convert self to a rest object for remote call. + + :return: The rest object + :rtype: dict + """ + rest_obj = self._to_dict() + rest_obj["_source"] = self._source + return cast(dict, convert_ordered_dict_to_dict(rest_obj)) + + def _register_in_current_pipeline_component_builder(self) -> None: + """Register this node in current pipeline component builder by adding self to a global stack.""" + from azure.ai.ml.dsl._pipeline_component_builder import _add_component_to_current_definition_builder + + _add_component_to_current_definition_builder(self) # type: ignore[arg-type] + + @classmethod + def _create_validation_error(cls, message: str, no_personal_data_message: str) -> ValidationException: + return ValidationException( + message=message, + no_personal_data_message=no_personal_data_message, + target=ErrorTarget.PIPELINE, + ) + + +class LoopNode(ControlFlowNode, ABC): + """Base class for loop node in the pipeline. + + Please do not directly use this class. + + :param body: The body of the loop node. + :type body: ~azure.ai.ml.entities._builders.BaseNode + :param kwargs: Additional keyword arguments. + :type kwargs: Dict[str, Union[Any]] + """ + + def __init__(self, *, body: BaseNode, **kwargs: Any) -> None: + self._body = body + super(LoopNode, self).__init__(**kwargs) + # always set the referenced control flow node instance id to the body. + self.body._set_referenced_control_flow_node_instance_id(self._instance_id) + + @property + def body(self) -> BaseNode: + """Get the body of the loop node. + + :return: The body of the loop node. + :rtype: ~azure.ai.ml.entities._builders.BaseNode + """ + return self._body + + _extra_body_types = None + + @classmethod + def _attr_type_map(cls) -> dict: + from .command import Command + from .pipeline import Pipeline + + enable_body_type = (Command, Pipeline) + if cls._extra_body_types is not None: + enable_body_type = enable_body_type + cls._extra_body_types + return { + "body": enable_body_type, + } + + @classmethod + def _get_body_from_pipeline_jobs(cls, pipeline_jobs: Dict[str, BaseNode], body_name: str) -> BaseNode: + # Get body object from pipeline job list. + if body_name not in pipeline_jobs: + raise ValidationError( + message=f'Cannot find the do-while loop body "{body_name}" in the pipeline.', + target=ErrorTarget.PIPELINE, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + return pipeline_jobs[body_name] + + def _validate_body(self) -> MutableValidationResult: + # pylint: disable=protected-access + validation_result = self._create_empty_validation_result() + + if self._instance_id != self.body._referenced_control_flow_node_instance_id: + # When the body is used in another loop node record the error message in validation result. + validation_result.append_error("body", "The body of loop node cannot be promoted as another loop again.") + return validation_result + + def _get_body_binding_str(self) -> str: + return "${{parent.jobs.%s}}" % self.body.name + + @staticmethod + def _get_data_binding_expression_value(expression: str, regex: str) -> str: + try: + if is_data_binding_expression(expression): + return str(re.findall(regex, expression)[0]) + + return expression + except Exception: # pylint: disable=W0718 + module_logger.warning("Cannot get the value from data binding expression %s.", expression) + return expression + + @staticmethod + def _is_loop_node_dict(obj: Any) -> bool: + return obj.get(CommonYamlFields.TYPE, None) in [ControlFlowType.DO_WHILE, ControlFlowType.PARALLEL_FOR] + + @classmethod + def _from_rest_object(cls, obj: dict, pipeline_jobs: dict) -> "LoopNode": + from azure.ai.ml.entities._job.pipeline._load_component import pipeline_node_factory + + node_type = obj.get(CommonYamlFields.TYPE, None) + load_from_rest_obj_func = pipeline_node_factory.get_load_from_rest_object_func(_type=node_type) + return load_from_rest_obj_func(obj, pipeline_jobs) # type: ignore diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/data_transfer.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/data_transfer.py new file mode 100644 index 00000000..83e88a48 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/data_transfer.py @@ -0,0 +1,575 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import logging +from typing import Any, Dict, List, Optional, Tuple, Union, cast + +from marshmallow import Schema + +from azure.ai.ml._restclient.v2022_10_01_preview.models import JobBase +from azure.ai.ml._schema.job.data_transfer_job import ( + DataTransferCopyJobSchema, + DataTransferExportJobSchema, + DataTransferImportJobSchema, +) +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, AssetTypes +from azure.ai.ml.constants._component import DataTransferTaskType, ExternalDataType, NodeType +from azure.ai.ml.entities._component.component import Component +from azure.ai.ml.entities._component.datatransfer_component import ( + DataTransferComponent, + DataTransferCopyComponent, + DataTransferExportComponent, + DataTransferImportComponent, +) +from azure.ai.ml.entities._inputs_outputs import Input, Output +from azure.ai.ml.entities._inputs_outputs.external_data import Database, FileSystem +from azure.ai.ml.entities._job.data_transfer.data_transfer_job import ( + DataTransferCopyJob, + DataTransferExportJob, + DataTransferImportJob, +) +from azure.ai.ml.entities._validation.core import MutableValidationResult +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + +from ..._schema import PathAwareSchema +from .._job.pipeline._io import NodeOutput +from .._util import convert_ordered_dict_to_dict, load_from_dict, validate_attribute_type +from .base_node import BaseNode + +module_logger = logging.getLogger(__name__) + + +def _build_source_sink(io_dict: Optional[Union[Dict, Database, FileSystem]]) -> Optional[Union[Database, FileSystem]]: + if io_dict is None: + return io_dict + if isinstance(io_dict, (Database, FileSystem)): + component_io = io_dict + else: + if isinstance(io_dict, dict): + data_type = io_dict.pop("type", None) + if data_type == ExternalDataType.DATABASE: + component_io = Database(**io_dict) + elif data_type == ExternalDataType.FILE_SYSTEM: + component_io = FileSystem(**io_dict) + else: + msg = "Type in source or sink only support {} and {}, currently got {}." + raise ValidationException( + message=msg.format( + ExternalDataType.DATABASE, + ExternalDataType.FILE_SYSTEM, + data_type, + ), + no_personal_data_message=msg.format( + ExternalDataType.DATABASE, + ExternalDataType.FILE_SYSTEM, + "data_type", + ), + target=ErrorTarget.DATA_TRANSFER_JOB, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + else: + msg = "Source or sink only support dict, Database and FileSystem" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.DATA_TRANSFER_JOB, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + return component_io + + +class DataTransfer(BaseNode): + """Base class for data transfer node, used for data transfer component version consumption. + + You should not instantiate this class directly. + """ + + def __init__( + self, + *, + component: Union[str, DataTransferCopyComponent, DataTransferImportComponent], + compute: Optional[str] = None, + inputs: Optional[Dict[str, Union[NodeOutput, Input, str]]] = None, + outputs: Optional[Dict[str, Union[str, Output]]] = None, + **kwargs: Any, + ): + # resolve normal dict to dict[str, JobService] + kwargs.pop("type", None) + super().__init__( + type=NodeType.DATA_TRANSFER, + inputs=inputs, + outputs=outputs, + component=component, + compute=compute, + **kwargs, + ) + + @property + def component(self) -> Union[str, DataTransferComponent]: + res: Union[str, DataTransferComponent] = self._component + return res + + @classmethod + def _load_from_rest_job(cls, obj: JobBase) -> "DataTransfer": + # Todo: need update rest api + raise NotImplementedError("Not support submit standalone job for now") + + @classmethod + def _get_supported_outputs_types(cls) -> Tuple: + return str, Output + + def _build_inputs(self) -> Dict: + inputs = super(DataTransfer, self)._build_inputs() + built_inputs = {} + # Validate and remove non-specified inputs + for key, value in inputs.items(): + if value is not None: + built_inputs[key] = value + + return built_inputs + + +@experimental +class DataTransferCopy(DataTransfer): + """Base class for data transfer copy node. + + You should not instantiate this class directly. Instead, you should + create from builder function: copy_data. + + :param component: Id or instance of the data transfer component/job to be run for the step + :type component: DataTransferCopyComponent + :param inputs: Inputs to the data transfer. + :type inputs: Dict[str, Union[NodeOutput, Input, str]] + :param outputs: Mapping of output data bindings used in the job. + :type outputs: Dict[str, Union[str, Output, dict]] + :param name: Name of the data transfer. + :type name: str + :param description: Description of the data transfer. + :type description: str + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: dict[str, str] + :param display_name: Display name of the job. + :type display_name: str + :param experiment_name: Name of the experiment the job will be created under, + if None is provided, default will be set to current directory name. + :type experiment_name: str + :param compute: The compute target the job runs on. + :type compute: str + :param data_copy_mode: data copy mode in copy task, possible value is "merge_with_overwrite", "fail_if_conflict". + :type data_copy_mode: str + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if DataTransferCopy cannot be successfully validated. + Details will be provided in the error message. + """ + + def __init__( + self, + *, + component: Union[str, DataTransferCopyComponent], + compute: Optional[str] = None, + inputs: Optional[Dict[str, Union[NodeOutput, Input, str]]] = None, + outputs: Optional[Dict[str, Union[str, Output]]] = None, + data_copy_mode: Optional[str] = None, + **kwargs: Any, + ): + # validate init params are valid type + validate_attribute_type(attrs_to_check=locals(), attr_type_map=self._attr_type_map()) + super().__init__( + inputs=inputs, + outputs=outputs, + component=component, + compute=compute, + **kwargs, + ) + # init mark for _AttrDict + self._init = True + self.task = DataTransferTaskType.COPY_DATA + self.data_copy_mode = data_copy_mode + is_component = isinstance(component, DataTransferCopyComponent) + if is_component: + _component: DataTransferCopyComponent = cast(DataTransferCopyComponent, component) + self.task = _component.task or self.task + self.data_copy_mode = _component.data_copy_mode or self.data_copy_mode + self._init = False + + @classmethod + def _attr_type_map(cls) -> dict: + return { + "component": (str, DataTransferCopyComponent), + } + + @classmethod + def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]: + from azure.ai.ml._schema.pipeline import DataTransferCopySchema + + return DataTransferCopySchema(context=context) + + @classmethod + def _picked_fields_from_dict_to_rest_object(cls) -> List[str]: + return ["type", "task", "data_copy_mode"] + + def _to_rest_object(self, **kwargs: Any) -> dict: + rest_obj = super()._to_rest_object(**kwargs) + for key, value in { + "componentId": self._get_component_id(), + "data_copy_mode": self.data_copy_mode, + }.items(): + if value is not None: + rest_obj[key] = value + return cast(dict, convert_ordered_dict_to_dict(rest_obj)) + + @classmethod + def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> Any: + from .data_transfer_func import copy_data + + loaded_data = load_from_dict(DataTransferCopyJobSchema, data, context, additional_message, **kwargs) + data_transfer_job = copy_data(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_data) + + return data_transfer_job + + def _to_job(self) -> DataTransferCopyJob: + return DataTransferCopyJob( + experiment_name=self.experiment_name, + name=self.name, + display_name=self.display_name, + description=self.description, + tags=self.tags, + status=self.status, + inputs=self._job_inputs, + outputs=self._job_outputs, + services=self.services, + compute=self.compute, + data_copy_mode=self.data_copy_mode, + ) + + # pylint: disable-next=docstring-missing-param + def __call__(self, *args: Any, **kwargs: Any) -> "DataTransferCopy": + """Call DataTransferCopy as a function will return a new instance each time. + + :return: A DataTransferCopy node + :rtype: DataTransferCopy + """ + if isinstance(self._component, Component): + # call this to validate inputs + node: DataTransferCopy = self._component(*args, **kwargs) + # merge inputs + for name, original_input in self.inputs.items(): + if name not in kwargs: + # use setattr here to make sure owner of input won't change + setattr(node.inputs, name, original_input._data) + node._job_inputs[name] = original_input._data + # get outputs + for name, original_output in self.outputs.items(): + # use setattr here to make sure owner of input won't change + if not isinstance(original_output, str): + setattr(node.outputs, name, original_output._data) + self._refine_optional_inputs_with_no_value(node, kwargs) + # set default values: compute, environment_variables, outputs + node._name = self.name + node.compute = self.compute + node.tags = self.tags + # Pass through the display name only if the display name is not system generated. + node.display_name = self.display_name if self.display_name != self.name else None + return node + msg = "copy_data can be called as a function only when referenced component is {}, currently got {}." + raise ValidationException( + message=msg.format(type(Component), self._component), + no_personal_data_message=msg.format(type(Component), "self._component"), + target=ErrorTarget.DATA_TRANSFER_JOB, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + +@experimental +class DataTransferImport(DataTransfer): + """Base class for data transfer import node. + + You should not instantiate this class directly. Instead, you should + create from builder function: import_data. + + :param component: Id of the data transfer built in component to be run for the step + :type component: str + :param source: The data source of file system or database + :type source: Union[Dict, Database, FileSystem] + :param outputs: Mapping of output data bindings used in the job. + :type outputs: Dict[str, Union[str, Output, dict]] + :param name: Name of the data transfer. + :type name: str + :param description: Description of the data transfer. + :type description: str + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: dict[str, str] + :param display_name: Display name of the job. + :type display_name: str + :param experiment_name: Name of the experiment the job will be created under, + if None is provided, default will be set to current directory name. + :type experiment_name: str + :param compute: The compute target the job runs on. + :type compute: str + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if DataTransferImport cannot be successfully validated. + Details will be provided in the error message. + """ + + def __init__( + self, + *, + component: Union[str, DataTransferImportComponent], + compute: Optional[str] = None, + source: Optional[Union[Dict, Database, FileSystem]] = None, + outputs: Optional[Dict[str, Union[str, Output]]] = None, + **kwargs: Any, + ): + # validate init params are valid type + validate_attribute_type(attrs_to_check=locals(), attr_type_map=self._attr_type_map()) + super(DataTransferImport, self).__init__( + component=component, + outputs=outputs, + compute=compute, + **kwargs, + ) + # init mark for _AttrDict + self._init = True + self.task = DataTransferTaskType.IMPORT_DATA + is_component = isinstance(component, DataTransferImportComponent) + if is_component: + _component: DataTransferImportComponent = cast(DataTransferImportComponent, component) + self.task = _component.task or self.task + self.source = _build_source_sink(source) + self._init = False + + @classmethod + def _attr_type_map(cls) -> dict: + return { + "component": (str, DataTransferImportComponent), + } + + @classmethod + def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]: + from azure.ai.ml._schema.pipeline import DataTransferImportSchema + + return DataTransferImportSchema(context=context) + + @classmethod + def _picked_fields_from_dict_to_rest_object(cls) -> List[str]: + return ["type", "task", "source"] + + def _customized_validate(self) -> MutableValidationResult: + result = super()._customized_validate() + if self.source is None: + result.append_error( + yaml_path="source", + message="Source is a required field for import data task in DataTransfer job", + ) + if len(self.outputs) != 1 or list(self.outputs.keys())[0] != "sink": + result.append_error( + yaml_path="outputs.sink", + message="Outputs field only support one output called sink in import task", + ) + if ( + "sink" in self.outputs + and not isinstance(self.outputs["sink"], str) + and isinstance(self.outputs["sink"]._data, Output) + ): + sink_output = self.outputs["sink"]._data + if self.source is not None: + + if (self.source.type == ExternalDataType.DATABASE and sink_output.type != AssetTypes.MLTABLE) or ( + self.source.type == ExternalDataType.FILE_SYSTEM and sink_output.type != AssetTypes.URI_FOLDER + ): + result.append_error( + yaml_path="outputs.sink.type", + message="Outputs field only support type {} for {} and {} for {}".format( + AssetTypes.MLTABLE, + ExternalDataType.DATABASE, + AssetTypes.URI_FOLDER, + ExternalDataType.FILE_SYSTEM, + ), + ) + return result + + def _to_rest_object(self, **kwargs: Any) -> dict: + rest_obj = super()._to_rest_object(**kwargs) + for key, value in { + "componentId": self._get_component_id(), + }.items(): + if value is not None: + rest_obj[key] = value + return cast(dict, convert_ordered_dict_to_dict(rest_obj)) + + @classmethod + def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "DataTransferImport": + from .data_transfer_func import import_data + + loaded_data = load_from_dict(DataTransferImportJobSchema, data, context, additional_message, **kwargs) + data_transfer_job: DataTransferImport = import_data(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_data) + + return data_transfer_job + + def _to_job(self) -> DataTransferImportJob: + return DataTransferImportJob( + experiment_name=self.experiment_name, + name=self.name, + display_name=self.display_name, + description=self.description, + tags=self.tags, + status=self.status, + source=self.source, + outputs=self._job_outputs, + services=self.services, + compute=self.compute, + ) + + +@experimental +class DataTransferExport(DataTransfer): + """Base class for data transfer export node. + + You should not instantiate this class directly. Instead, you should + create from builder function: export_data. + + :param component: Id of the data transfer built in component to be run for the step + :type component: str + :param sink: The sink of external data and databases. + :type sink: Union[Dict, Database, FileSystem] + :param inputs: Mapping of input data bindings used in the job. + :type inputs: Dict[str, Union[NodeOutput, Input, str, Input]] + :param name: Name of the data transfer. + :type name: str + :param description: Description of the data transfer. + :type description: str + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: dict[str, str] + :param display_name: Display name of the job. + :type display_name: str + :param experiment_name: Name of the experiment the job will be created under, + if None is provided, default will be set to current directory name. + :type experiment_name: str + :param compute: The compute target the job runs on. + :type compute: str + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if DataTransferExport cannot be successfully validated. + Details will be provided in the error message. + """ + + def __init__( + self, + *, + component: Union[str, DataTransferCopyComponent, DataTransferImportComponent], + compute: Optional[str] = None, + sink: Optional[Union[Dict, Database, FileSystem]] = None, + inputs: Optional[Dict[str, Union[NodeOutput, Input, str]]] = None, + **kwargs: Any, + ): + # validate init params are valid type + validate_attribute_type(attrs_to_check=locals(), attr_type_map=self._attr_type_map()) + super(DataTransferExport, self).__init__( + component=component, + inputs=inputs, + compute=compute, + **kwargs, + ) + # init mark for _AttrDict + self._init = True + self.task = DataTransferTaskType.EXPORT_DATA + is_component = isinstance(component, DataTransferExportComponent) + if is_component: + _component: DataTransferExportComponent = cast(DataTransferExportComponent, component) + self.task = _component.task or self.task + self.sink = sink + self._init = False + + @property + def sink(self) -> Optional[Union[Dict, Database, FileSystem]]: + """The sink of external data and databases. + + :return: The sink of external data and databases. + :rtype: Union[None, Database, FileSystem] + """ + return self._sink + + @sink.setter + def sink(self, value: Union[Dict, Database, FileSystem]) -> None: + self._sink = _build_source_sink(value) + + @classmethod + def _attr_type_map(cls) -> dict: + return { + "component": (str, DataTransferExportComponent), + } + + @classmethod + def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]: + from azure.ai.ml._schema.pipeline import DataTransferExportSchema + + return DataTransferExportSchema(context=context) + + @classmethod + def _picked_fields_from_dict_to_rest_object(cls) -> List[str]: + return ["type", "task", "sink"] + + def _customized_validate(self) -> MutableValidationResult: + result = super()._customized_validate() + if self.sink is None: + result.append_error( + yaml_path="sink", + message="Sink is a required field for export data task in DataTransfer job", + ) + if len(self.inputs) != 1 or list(self.inputs.keys())[0] != "source": + result.append_error( + yaml_path="inputs.source", + message="Inputs field only support one input called source in export task", + ) + if "source" in self.inputs and isinstance(self.inputs["source"]._data, Input): + source_input = self.inputs["source"]._data + if self.sink is not None and not isinstance(self.sink, Dict): + if (self.sink.type == ExternalDataType.DATABASE and source_input.type != AssetTypes.URI_FILE) or ( + self.sink.type == ExternalDataType.FILE_SYSTEM and source_input.type != AssetTypes.URI_FOLDER + ): + result.append_error( + yaml_path="inputs.source.type", + message="Inputs field only support type {} for {} and {} for {}".format( + AssetTypes.URI_FILE, + ExternalDataType.DATABASE, + AssetTypes.URI_FOLDER, + ExternalDataType.FILE_SYSTEM, + ), + ) + + return result + + def _to_rest_object(self, **kwargs: Any) -> dict: + rest_obj = super()._to_rest_object(**kwargs) + for key, value in { + "componentId": self._get_component_id(), + }.items(): + if value is not None: + rest_obj[key] = value + return cast(dict, convert_ordered_dict_to_dict(rest_obj)) + + @classmethod + def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "DataTransferExport": + from .data_transfer_func import export_data + + loaded_data = load_from_dict(DataTransferExportJobSchema, data, context, additional_message, **kwargs) + data_transfer_job: DataTransferExport = export_data(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_data) + + return data_transfer_job + + def _to_job(self) -> DataTransferExportJob: + return DataTransferExportJob( + experiment_name=self.experiment_name, + name=self.name, + display_name=self.display_name, + description=self.description, + tags=self.tags, + status=self.status, + sink=self.sink, + inputs=self._job_inputs, + services=self.services, + compute=self.compute, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/data_transfer_func.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/data_transfer_func.py new file mode 100644 index 00000000..423c125b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/data_transfer_func.py @@ -0,0 +1,335 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint: disable=protected-access + +from typing import Any, Callable, Dict, Optional, Tuple, Union + +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.constants._common import AssetTypes, LegacyAssetTypes +from azure.ai.ml.constants._component import ComponentSource, DataTransferBuiltinComponentUri, ExternalDataType +from azure.ai.ml.entities._builders.base_node import pipeline_node_decorator +from azure.ai.ml.entities._component.datatransfer_component import DataTransferCopyComponent +from azure.ai.ml.entities._inputs_outputs import Input, Output +from azure.ai.ml.entities._inputs_outputs.external_data import Database, FileSystem +from azure.ai.ml.entities._job.pipeline._component_translatable import ComponentTranslatableMixin +from azure.ai.ml.entities._job.pipeline._io import NodeOutput, PipelineInput +from azure.ai.ml.exceptions import ErrorTarget, ValidationErrorType, ValidationException + +from .data_transfer import DataTransferCopy, DataTransferExport, DataTransferImport, _build_source_sink + +SUPPORTED_INPUTS = [ + LegacyAssetTypes.PATH, + AssetTypes.URI_FILE, + AssetTypes.URI_FOLDER, + AssetTypes.CUSTOM_MODEL, + AssetTypes.MLFLOW_MODEL, + AssetTypes.MLTABLE, + AssetTypes.TRITON_MODEL, +] + + +def _parse_input(input_value: Union[Input, dict, str, PipelineInput, NodeOutput]) -> Tuple: + component_input = None + job_input: Union[Input, dict, str, PipelineInput, NodeOutput] = "" + + if isinstance(input_value, Input): + component_input = Input(**input_value._to_dict()) + input_type = input_value.type + if input_type in SUPPORTED_INPUTS: + job_input = Input(**input_value._to_dict()) + elif isinstance(input_value, dict): + # if user provided dict, we try to parse it to Input. + # for job input, only parse for path type + input_type = input_value.get("type", None) + if input_type in SUPPORTED_INPUTS: + job_input = Input(**input_value) + component_input = Input(**input_value) + elif isinstance(input_value, str): + # Input bindings + component_input = ComponentTranslatableMixin._to_input_builder_function(input_value) + job_input = input_value + elif isinstance(input_value, (PipelineInput, NodeOutput)): + data: Any = None + # datatransfer node can accept PipelineInput/NodeOutput for export task. + if input_value._data is None or isinstance(input_value._data, Output): + data = Input(type=input_value.type, mode=input_value.mode) + else: + data = input_value._data + component_input, _ = _parse_input(data) + job_input = input_value + else: + msg = ( + f"Unsupported input type: {type(input_value)}, only Input, dict, str, PipelineInput and NodeOutput are " + f"supported." + ) + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.JOB, + error_type=ValidationErrorType.INVALID_VALUE, + ) + return component_input, job_input + + +def _parse_output(output_value: Union[Output, Dict]) -> Tuple: + component_output = None + job_output: Union[Output, Dict] = {} + + if isinstance(output_value, Output): + component_output = Output(**output_value._to_dict()) + job_output = Output(**output_value._to_dict()) + elif not output_value: + # output value can be None or empty dictionary + # None output value will be packed into a JobOutput object with mode = ReadWriteMount & type = UriFolder + component_output = ComponentTranslatableMixin._to_output(output_value) + job_output = output_value + elif isinstance(output_value, dict): # When output value is a non-empty dictionary + job_output = Output(**output_value) + component_output = Output(**output_value) + elif isinstance(output_value, str): # When output is passed in from pipeline job yaml + job_output = output_value + else: + msg = f"Unsupported output type: {type(output_value)}, only Output and dict are supported." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.JOB, + error_type=ValidationErrorType.INVALID_VALUE, + ) + return component_output, job_output + + +def _parse_inputs_outputs(io_dict: Optional[Dict], parse_func: Callable) -> Tuple[Dict, Dict]: + component_io_dict, job_io_dict = {}, {} + if io_dict: + for key, val in io_dict.items(): + component_io, job_io = parse_func(val) + component_io_dict[key] = component_io + job_io_dict[key] = job_io + return component_io_dict, job_io_dict + + +@experimental +def copy_data( + *, + name: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[Dict] = None, + display_name: Optional[str] = None, + experiment_name: Optional[str] = None, + compute: Optional[str] = None, + inputs: Optional[Dict] = None, + outputs: Optional[Dict] = None, + is_deterministic: bool = True, + data_copy_mode: Optional[str] = None, + **kwargs: Any, +) -> DataTransferCopy: + """Create a DataTransferCopy object which can be used inside dsl.pipeline as a function. + + :keyword name: The name of the job. + :paramtype name: str + :keyword description: Description of the job. + :paramtype description: str + :keyword tags: Tag dictionary. Tags can be added, removed, and updated. + :paramtype tags: dict[str, str] + :keyword display_name: Display name of the job. + :paramtype display_name: str + :keyword experiment_name: Name of the experiment the job will be created under. + :paramtype experiment_name: str + :keyword compute: The compute resource the job runs on. + :paramtype compute: str + :keyword inputs: Mapping of inputs data bindings used in the job. + :paramtype inputs: dict + :keyword outputs: Mapping of outputs data bindings used in the job. + :paramtype outputs: dict + :keyword is_deterministic: Specify whether the command will return same output given same input. + If a command (component) is deterministic, when use it as a node/step in a pipeline, + it will reuse results from a previous submitted job in current workspace which has same inputs and settings. + In this case, this step will not use any compute resource. + Default to be True, specify is_deterministic=False if you would like to avoid such reuse behavior. + :paramtype is_deterministic: bool + :keyword data_copy_mode: data copy mode in copy task, possible value is "merge_with_overwrite", "fail_if_conflict". + :paramtype data_copy_mode: str + :return: A DataTransferCopy object. + :rtype: ~azure.ai.ml.entities._component.datatransfer_component.DataTransferCopyComponent + """ + inputs = inputs or {} + outputs = outputs or {} + component_inputs, job_inputs = _parse_inputs_outputs(inputs, parse_func=_parse_input) + # job inputs can not be None + job_inputs = {k: v for k, v in job_inputs.items() if v is not None} + component_outputs, job_outputs = _parse_inputs_outputs(outputs, parse_func=_parse_output) + component = kwargs.pop("component", None) + if component is None: + component = DataTransferCopyComponent( + name=name, + tags=tags, + display_name=display_name, + description=description, + inputs=component_inputs, + outputs=component_outputs, + data_copy_mode=data_copy_mode, + _source=ComponentSource.BUILDER, + is_deterministic=is_deterministic, + **kwargs, + ) + data_transfer_copy_obj = DataTransferCopy( + component=component, + name=name, + description=description, + tags=tags, + display_name=display_name, + experiment_name=experiment_name, + compute=compute, + inputs=job_inputs, + outputs=job_outputs, + data_copy_mode=data_copy_mode, + **kwargs, + ) + return data_transfer_copy_obj + + +@experimental +@pipeline_node_decorator +def import_data( + *, + name: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[Dict] = None, + display_name: Optional[str] = None, + experiment_name: Optional[str] = None, + compute: Optional[str] = None, + source: Optional[Union[Dict, Database, FileSystem]] = None, + outputs: Optional[Dict] = None, + **kwargs: Any, +) -> DataTransferImport: + """Create a DataTransferImport object which can be used inside dsl.pipeline. + + :keyword name: The name of the job. + :paramtype name: str + :keyword description: Description of the job. + :paramtype description: str + :keyword tags: Tag dictionary. Tags can be added, removed, and updated. + :paramtype tags: dict[str, str] + :keyword display_name: Display name of the job. + :paramtype display_name: str + :keyword experiment_name: Name of the experiment the job will be created under. + :paramtype experiment_name: str + :keyword compute: The compute resource the job runs on. + :paramtype compute: str + :keyword source: The data source of file system or database. + :paramtype source: Union[Dict, ~azure.ai.ml.entities._inputs_outputs.external_data.Database, + ~azure.ai.ml.entities._inputs_outputs.external_data.FileSystem] + :keyword outputs: Mapping of outputs data bindings used in the job. + The default will be an output port with the key "sink" and type "mltable". + :paramtype outputs: dict + :return: A DataTransferImport object. + :rtype: ~azure.ai.ml.entities._job.pipeline._component_translatable.DataTransferImport + """ + source = _build_source_sink(source) + outputs = outputs or {"sink": Output(type=AssetTypes.MLTABLE)} + # # job inputs can not be None + # job_inputs = {k: v for k, v in job_inputs.items() if v is not None} + _, job_outputs = _parse_inputs_outputs(outputs, parse_func=_parse_output) + component = kwargs.pop("component", None) + update_source = False + if component is None: + if source and source.type == ExternalDataType.DATABASE: + component = DataTransferBuiltinComponentUri.IMPORT_DATABASE + else: + component = DataTransferBuiltinComponentUri.IMPORT_FILE_SYSTEM + update_source = True + + data_transfer_import_obj = DataTransferImport( + component=component, + name=name, + description=description, + tags=tags, + display_name=display_name, + experiment_name=experiment_name, + compute=compute, + source=source, + outputs=job_outputs, + **kwargs, + ) + if update_source: + data_transfer_import_obj._source = ComponentSource.BUILTIN + + return data_transfer_import_obj + + +@experimental +@pipeline_node_decorator +def export_data( + *, + name: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[Dict] = None, + display_name: Optional[str] = None, + experiment_name: Optional[str] = None, + compute: Optional[str] = None, + sink: Optional[Union[Dict, Database, FileSystem]] = None, + inputs: Optional[Dict] = None, + **kwargs: Any, +) -> DataTransferExport: + """Create a DataTransferExport object which can be used inside dsl.pipeline. + + :keyword name: The name of the job. + :paramtype name: str + :keyword description: Description of the job. + :paramtype description: str + :keyword tags: Tag dictionary. Tags can be added, removed, and updated. + :paramtype tags: dict[str, str] + :keyword display_name: Display name of the job. + :paramtype display_name: str + :keyword experiment_name: Name of the experiment the job will be created under. + :paramtype experiment_name: str + :keyword compute: The compute resource the job runs on. + :paramtype compute: str + :keyword sink: The sink of external data and databases. + :paramtype sink: Union[ + Dict, + ~azure.ai.ml.entities._inputs_outputs.external_data.Database, + ~azure.ai.ml.entities._inputs_outputs.external_data.FileSystem] + :keyword inputs: Mapping of inputs data bindings used in the job. + :paramtype inputs: dict + :return: A DataTransferExport object. + :rtype: ~azure.ai.ml.entities._job.pipeline._component_translatable.DataTransferExport + :raises ValidationException: If sink is not provided or exporting file system is not supported. + """ + sink = _build_source_sink(sink) + _, job_inputs = _parse_inputs_outputs(inputs, parse_func=_parse_input) + # job inputs can not be None + job_inputs = {k: v for k, v in job_inputs.items() if v is not None} + component = kwargs.pop("component", None) + update_source = False + if component is None: + if sink and sink.type == ExternalDataType.DATABASE: + component = DataTransferBuiltinComponentUri.EXPORT_DATABASE + else: + msg = "Sink is a required field for export data task and we don't support exporting file system for now." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.JOB, + error_type=ValidationErrorType.INVALID_VALUE, + ) + update_source = True + + data_transfer_export_obj = DataTransferExport( + component=component, + name=name, + description=description, + tags=tags, + display_name=display_name, + experiment_name=experiment_name, + compute=compute, + sink=sink, + inputs=job_inputs, + **kwargs, + ) + if update_source: + data_transfer_export_obj._source = ComponentSource.BUILTIN + + return data_transfer_export_obj diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/do_while.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/do_while.py new file mode 100644 index 00000000..ecfd51ca --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/do_while.py @@ -0,0 +1,357 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import logging +from typing import Any, Dict, Optional, Union + +from typing_extensions import Literal + +from azure.ai.ml._schema.pipeline.control_flow_job import DoWhileSchema +from azure.ai.ml.constants._component import DO_WHILE_MAX_ITERATION, ControlFlowType +from azure.ai.ml.entities._job.job_limits import DoWhileJobLimits +from azure.ai.ml.entities._job.pipeline._io import InputOutputBase, NodeInput, NodeOutput +from azure.ai.ml.entities._job.pipeline.pipeline_job import PipelineJob +from azure.ai.ml.entities._validation import MutableValidationResult + +from .._util import load_from_dict, validate_attribute_type +from .base_node import BaseNode +from .control_flow_node import LoopNode +from .pipeline import Pipeline + +module_logger = logging.getLogger(__name__) + + +class DoWhile(LoopNode): + """Do-while loop node in the pipeline job. By specifying the loop body and loop termination condition in this class, + a job-level do while loop can be implemented. It will be initialized when calling dsl.do_while or when loading the + pipeline yml containing do_while node. Please do not manually initialize this class. + + :param body: Pipeline job for the do-while loop body. + :type body: ~azure.ai.ml.entities._builders.pipeline.Pipeline + :param condition: Boolean type control output of body as do-while loop condition. + :type condition: ~azure.ai.ml.entities.Output + :param mapping: Output-Input mapping for each round of the do-while loop. + Key is the last round output of the body. Value is the input port for the current body. + :type mapping: dict[Union[str, ~azure.ai.ml.entities.Output], + Union[str, ~azure.ai.ml.entities.Input, list]] + :param limits: Limits in running the do-while node. + :type limits: Union[dict, ~azure.ai.ml.entities._job.job_limits.DoWhileJobLimits] + :raises ValidationError: If the initialization parameters are not of valid types. + """ + + def __init__( + self, + *, + body: Union[Pipeline, BaseNode], + condition: Optional[Union[str, NodeInput, NodeOutput]], + mapping: Dict, + limits: Optional[Union[dict, DoWhileJobLimits]] = None, + **kwargs: Any, + ) -> None: + # validate init params are valid type + validate_attribute_type(attrs_to_check=locals(), attr_type_map=self._attr_type_map()) + + kwargs.pop("type", None) + super(DoWhile, self).__init__( + type=ControlFlowType.DO_WHILE, + body=body, + **kwargs, + ) + + # init mark for _AttrDict + self._init = True + self._mapping = mapping or {} + self._condition = condition + self._limits = limits + self._init = False + + @property + def mapping(self) -> Dict: + """Get the output-input mapping for each round of the do-while loop. + + :return: Output-Input mapping for each round of the do-while loop. + :rtype: dict[Union[str, ~azure.ai.ml.entities.Output], + Union[str, ~azure.ai.ml.entities.Input, list]] + """ + return self._mapping + + @property + def condition(self) -> Optional[Union[str, NodeInput, NodeOutput]]: + """Get the boolean type control output of the body as the do-while loop condition. + + :return: Control output of the body as the do-while loop condition. + :rtype: ~azure.ai.ml.entities.Output + """ + return self._condition + + @property + def limits(self) -> Union[Dict, DoWhileJobLimits, None]: + """Get the limits in running the do-while node. + + :return: Limits in running the do-while node. + :rtype: Union[dict, ~azure.ai.ml.entities._job.job_limits.DoWhileJobLimits] + """ + return self._limits + + @classmethod + def _attr_type_map(cls) -> dict: + return { + **super(DoWhile, cls)._attr_type_map(), + "mapping": dict, + "limits": (dict, DoWhileJobLimits), + } + + @classmethod + def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "DoWhile": + loaded_data = load_from_dict(DoWhileSchema, data, context, additional_message, **kwargs) + + return cls(**loaded_data) + + @classmethod + def _get_port_obj( + cls, body: BaseNode, port_name: str, is_input: bool = True, validate_port: bool = True + ) -> Union[str, NodeInput, NodeOutput]: + if is_input: + port = body.inputs.get(port_name, None) + else: + port = body.outputs.get(port_name, None) + if port is None: + if validate_port: + raise cls._create_validation_error( + message=f"Cannot find {port_name} in do_while loop body {'inputs' if is_input else 'outputs'}.", + no_personal_data_message=f"Miss port in do_while loop body {'inputs' if is_input else 'outputs'}.", + ) + return port_name + + res: Union[str, NodeInput, NodeOutput] = port + return res + + @classmethod + def _create_instance_from_schema_dict( + cls, pipeline_jobs: Dict[str, BaseNode], loaded_data: Dict, validate_port: bool = True + ) -> "DoWhile": + """Create a do_while instance from schema parsed dict. + + :param pipeline_jobs: The pipeline jobs + :type pipeline_jobs: Dict[str, BaseNode] + :param loaded_data: The loaded data + :type loaded_data: Dict + :param validate_port: Whether to raise if inputs/outputs are not present. Defaults to True + :type validate_port: bool + :return: The DoWhile node + :rtype: DoWhile + """ + + # Get body object from pipeline job list. + body_name = cls._get_data_binding_expression_value(loaded_data.pop("body"), regex=r"\{\{.*\.jobs\.(.*)\}\}") + body = cls._get_body_from_pipeline_jobs(pipeline_jobs, body_name) + + # Convert mapping key-vault to input/output object + mapping = {} + for k, v in loaded_data.pop("mapping", {}).items(): + output_name = cls._get_data_binding_expression_value(k, regex=r"\{\{.*\.%s\.outputs\.(.*)\}\}" % body_name) + input_names = v if isinstance(v, list) else [v] + input_names = [ + cls._get_data_binding_expression_value(item, regex=r"\{\{.*\.%s\.inputs\.(.*)\}\}" % body_name) + for item in input_names + ] + mapping[output_name] = [cls._get_port_obj(body, item, validate_port=validate_port) for item in input_names] + + limits = loaded_data.pop("limits", None) + + if "condition" in loaded_data: + # Convert condition to output object + condition_name = cls._get_data_binding_expression_value( + loaded_data.pop("condition"), regex=r"\{\{.*\.%s\.outputs\.(.*)\}\}" % body_name + ) + condition_value = cls._get_port_obj(body, condition_name, is_input=False, validate_port=validate_port) + else: + condition_value = None + + do_while_instance = DoWhile( + body=body, + mapping=mapping, + condition=condition_value, + **loaded_data, + ) + do_while_instance.set_limits(**limits) + + return do_while_instance + + @classmethod + def _create_schema_for_validation(cls, context: Any) -> DoWhileSchema: + return DoWhileSchema(context=context) + + @classmethod + def _from_rest_object(cls, obj: dict, pipeline_jobs: dict) -> "DoWhile": + # pylint: disable=protected-access + + obj = BaseNode._from_rest_object_to_init_params(obj) + return cls._create_instance_from_schema_dict(pipeline_jobs, obj, validate_port=False) + + def set_limits( + self, + *, + max_iteration_count: int, + # pylint: disable=unused-argument + **kwargs: Any, + ) -> None: + """ + Set the maximum iteration count for the do-while job. + + The range of the iteration count is (0, 1000]. + + :keyword max_iteration_count: The maximum iteration count for the do-while job. + :paramtype max_iteration_count: int + """ + if isinstance(self.limits, DoWhileJobLimits): + self.limits._max_iteration_count = max_iteration_count # pylint: disable=protected-access + else: + self._limits = DoWhileJobLimits(max_iteration_count=max_iteration_count) + + def _customized_validate(self) -> MutableValidationResult: + validation_result = self._validate_loop_condition() + validation_result.merge_with(self._validate_body()) + validation_result.merge_with(self._validate_do_while_limit()) + validation_result.merge_with(self._validate_body_output_mapping()) + return validation_result + + def _validate_port( + self, + port: Union[str, NodeInput, NodeOutput], + node_ports: Dict[str, Union[NodeInput, NodeOutput]], + port_type: Literal["input", "output"], + yaml_path: str, + ) -> MutableValidationResult: + """Validate input/output port is exist in the dowhile body. + + :param port: Either: + * The name of an input or output + * An input object + * An output object + :type port: Union[str, NodeInput, NodeOutput], + :param node_ports: The node input/outputs + :type node_ports: Union[Dict[str, Union[NodeInput, NodeOutput]]] + :param port_type: The port type + :type port_type: Literal["input", "output"], + :param yaml_path: The yaml path + :type yaml_path: str, + :return: The validation result + :rtype: MutableValidationResult + """ + validation_result = self._create_empty_validation_result() + if isinstance(port, str): + port_obj = node_ports.get(port, None) + else: + port_obj = port + if ( + port_obj is not None + and port_obj._owner is not None # pylint: disable=protected-access + and not isinstance(port_obj._owner, PipelineJob) # pylint: disable=protected-access + and port_obj._owner._instance_id != self.body._instance_id # pylint: disable=protected-access + ): + # Check the port owner is dowhile body. + validation_result.append_error( + yaml_path=yaml_path, + message=( + f"{port_obj._port_name} is the {port_type} of {port_obj._owner.name}, " # pylint: disable=protected-access + f"dowhile only accept {port_type} of the body: {self.body.name}." + ), + ) + elif port_obj is None or port_obj._port_name not in node_ports: # pylint: disable=protected-access + # Check port is exist in dowhile body. + validation_result.append_error( + yaml_path=yaml_path, + message=( + f"The {port_type} of mapping {port_obj._port_name if port_obj else port} does not " # pylint: disable=protected-access + f"exist in {self.body.name} {port_type}, existing {port_type}: {node_ports.keys()}" + ), + ) + return validation_result + + def _validate_loop_condition(self) -> MutableValidationResult: + # pylint: disable=protected-access + validation_result = self._create_empty_validation_result() + if self.condition is not None: + # Check condition exists in dowhile body. + validation_result.merge_with( + self._validate_port(self.condition, self.body.outputs, port_type="output", yaml_path="condition") + ) + if validation_result.passed: + # Check condition is a control output. + condition_name = self.condition if isinstance(self.condition, str) else self.condition._port_name + if not self.body._outputs[condition_name]._is_primitive_type: + validation_result.append_error( + yaml_path="condition", + message=( + f"{condition_name} is not a control output and is not primitive type. " + "The condition of dowhile must be the control output or primitive type of the body." + ), + ) + return validation_result + + def _validate_do_while_limit(self) -> MutableValidationResult: + validation_result = self._create_empty_validation_result() + if isinstance(self.limits, DoWhileJobLimits): + if not self.limits or self.limits.max_iteration_count is None: + return validation_result + if isinstance(self.limits.max_iteration_count, InputOutputBase): + validation_result.append_error( + yaml_path="limit.max_iteration_count", + message="The max iteration count cannot be linked with an primitive type input.", + ) + elif self.limits.max_iteration_count > DO_WHILE_MAX_ITERATION or self.limits.max_iteration_count < 0: + validation_result.append_error( + yaml_path="limit.max_iteration_count", + message=f"The max iteration count cannot be less than 0 or larger than {DO_WHILE_MAX_ITERATION}.", + ) + return validation_result + + def _validate_body_output_mapping(self) -> MutableValidationResult: + # pylint disable=protected-access + validation_result = self._create_empty_validation_result() + if not isinstance(self.mapping, dict): + validation_result.append_error( + yaml_path="mapping", message=f"Mapping expects a dict type but passes in a {type(self.mapping)} type." + ) + else: + # Record the mapping relationship between input and output + input_output_mapping: Dict = {} + # Validate mapping input&output should come from while body + for output, inputs in self.mapping.items(): + # pylint: disable=protected-access + output_name = output if isinstance(output, str) else output._port_name + validate_results = self._validate_port( + output, self.body.outputs, port_type="output", yaml_path="mapping" + ) + if validate_results.passed: + is_primitive_output = self.body._outputs[output_name]._is_primitive_type + inputs = inputs if isinstance(inputs, list) else [inputs] + for item in inputs: + input_validate_results = self._validate_port( + item, self.body.inputs, port_type="input", yaml_path="mapping" + ) + validation_result.merge_with(input_validate_results) + # pylint: disable=protected-access + input_name = item if isinstance(item, str) else item._port_name + input_output_mapping[input_name] = input_output_mapping.get(input_name, []) + [output_name] + is_primitive_type = self.body._inputs[input_name]._meta._is_primitive_type + + if input_validate_results.passed and not is_primitive_output and is_primitive_type: + validate_results.append_error( + yaml_path="mapping", + message=( + f"{output_name} is a non-primitive type output and {input_name} " + "is a primitive input. Non-primitive type output cannot be connected " + "to an a primitive type input." + ), + ) + + validation_result.merge_with(validate_results) + # Validate whether input is linked to multiple outputs + for _input, outputs in input_output_mapping.items(): + if len(outputs) > 1: + validation_result.append_error( + yaml_path="mapping", message=f"Input {_input} has been linked to multiple outputs {outputs}." + ) + return validation_result diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/fl_scatter_gather.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/fl_scatter_gather.py new file mode 100644 index 00000000..0ad6b0e2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/fl_scatter_gather.py @@ -0,0 +1,886 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import re +from typing import Any, Dict, List, Optional, Tuple, Union + +from azure.ai.ml import Output +from azure.ai.ml._schema import PathAwareSchema +from azure.ai.ml._schema.pipeline.control_flow_job import FLScatterGatherSchema +from azure.ai.ml.constants import JobType +from azure.ai.ml.constants._common import AssetTypes +from azure.ai.ml.dsl import pipeline +from azure.ai.ml.dsl._do_while import do_while +from azure.ai.ml.entities._assets.federated_learning_silo import FederatedLearningSilo +from azure.ai.ml.entities._builders.control_flow_node import ControlFlowNode +from azure.ai.ml.entities._builders.pipeline import Pipeline +from azure.ai.ml.entities._component.command_component import CommandComponent +from azure.ai.ml.entities._component.component import Component +from azure.ai.ml.entities._inputs_outputs.input import Input +from azure.ai.ml.entities._job.pipeline._io.mixin import NodeIOMixin +from azure.ai.ml.entities._job.pipeline.pipeline_job import PipelineJob +from azure.ai.ml.entities._util import convert_ordered_dict_to_dict +from azure.ai.ml.entities._validation import MutableValidationResult + +from .subcomponents import create_scatter_output_table + +# TODO 2293610: add support for more types of outputs besides uri_folder and mltable +# Likely types that ought to be mergeable: string, int, uri_file +MERGE_COMPONENT_MAPPING = { + "mltable": create_scatter_output_table, + "uri_folder": create_scatter_output_table, +} + + +ANCHORABLE_OUTPUT_TYPES = {AssetTypes.MLTABLE, AssetTypes.URI_FOLDER} + +ANCHORING_PATH_ROOT = "root" + + +# big TODO: For some reason, surfacing this file in __init__.py causes +# a circular import exception on the first attempted import +# In notebooks, the second import succeeds, but then causes a silent failure where the +# MLDesigner component created by the subcomponents.create_scatter_output_table function +# will produce a ComponentExecutor object instead of the actual component. +# TODO 2293541: Add telemetry of some sort +# pylint: disable=too-many-instance-attributes +class FLScatterGather(ControlFlowNode, NodeIOMixin): + """A node which creates a federated learning scatter-gather loop as a pipeline subgraph. + Intended for use inside a pipeline job. This is initialized when calling + `dsl.fl_scatter_gather()` or when loading a serialized version of this node from YAML. + Please do not manually initialize this class. + + :param silo_configs: List of federated learning silo configurations. + :type silo_configs: List[~azure.ai.ml.entities._assets.federated_learning_silo.FederatedLearningSilo] + :param silo_component: Component representing the silo for federated learning. + :type silo_component: ~azure.ai.ml.entities.Component + :param aggregation_component: Component representing the aggregation step. + :type aggregation_component: ~azure.ai.ml.entities.Component + :param aggregation_compute: The compute resource for the aggregation step. + :type aggregation_compute: str + :param aggregation_datastore: The datastore for the aggregation step. + :type aggregation_datastore: str + :param shared_silo_kwargs: Keyword arguments shared across all silos. + :type shared_silo_kwargs: dict + :param aggregation_kwargs: Keyword arguments specific to the aggregation step. + :type aggregation_kwargs: dict + :param silo_to_aggregation_argument_map: Mapping of silo to aggregation arguments. + :type silo_to_aggregation_argument_map: dict + :param aggregation_to_silo_argument_map: Mapping of aggregation to silo arguments. + :type aggregation_to_silo_argument_map: dict + :param max_iterations: The maximum number of iterations for the scatter-gather loop. + :type max_iterations: int + :param create_default_mappings_if_needed: Whether to create default argument mappings if needed. + :type create_default_mappings_if_needed: bool + """ + + # See node class for input descriptions, no point maintaining + # double descriptions between a wrapper its interior. + def __init__( + self, + *, + silo_configs: List[FederatedLearningSilo], + silo_component: Component, + aggregation_component: Component, + aggregation_compute: Optional[str] = None, + aggregation_datastore: Optional[str] = None, + shared_silo_kwargs: Optional[Dict] = None, + aggregation_kwargs: Optional[Dict] = None, + silo_to_aggregation_argument_map: Optional[Dict] = None, + aggregation_to_silo_argument_map: Optional[Dict] = None, + max_iterations: int = 1, + create_default_mappings_if_needed: bool = False, + **kwargs: Any, + ) -> None: + # auto-create X_to_Y_argument_map values if allowed and needed. + if create_default_mappings_if_needed: + ( + silo_to_aggregation_argument_map, + aggregation_to_silo_argument_map, + ) = FLScatterGather._try_create_default_mappings( + silo_component, + aggregation_component, + silo_to_aggregation_argument_map, + aggregation_to_silo_argument_map, + ) + + # input validation. + FLScatterGather.validate_inputs( + silo_configs=silo_configs, + silo_component=silo_component, + aggregation_component=aggregation_component, + shared_silo_kwargs=shared_silo_kwargs, + aggregation_compute=aggregation_compute, + aggregation_datastore=aggregation_datastore, + aggregation_kwargs=aggregation_kwargs, + silo_to_aggregation_argument_map=silo_to_aggregation_argument_map, + aggregation_to_silo_argument_map=aggregation_to_silo_argument_map, + max_iterations=max_iterations, + ) + + # store inputs + self.silo_configs = silo_configs + self.aggregation_compute = aggregation_compute + self.aggregation_datastore = aggregation_datastore + self.silo_component = silo_component + self.aggregation_component = aggregation_component + self.shared_silo_kwargs = shared_silo_kwargs + self.aggregation_kwargs = aggregation_kwargs + self.silo_to_aggregation_argument_map = silo_to_aggregation_argument_map + self.aggregation_to_silo_argument_map = aggregation_to_silo_argument_map + self.max_iterations = max_iterations + self._init = True # Needed by parent class to work properly + + self.scatter_gather_graph = self.scatter_gather() + + # set SG node flag for telemetry + self.scatter_gather_graph.properties["azureml.telemetry.attribution"] = "FederatedLearningSGJobFlag" + self.scatter_gather_graph._to_rest_object() + + # set output to final aggregation step's output + self._outputs = self.scatter_gather_graph.outputs + super(FLScatterGather, self).__init__( + type=JobType.COMPONENT, + component=None, + inputs=None, + outputs=self.scatter_gather_graph.outputs, + name=None, + display_name=None, + description=None, + tags=None, + properties=None, + comment=None, + compute=None, + experiment_name=None, + ) + + def scatter_gather(self) -> PipelineJob: + """Executes the scatter-gather loop by creating and executing a pipeline subgraph. + Returns the outputs of the final aggregation step. + + :return: Outputs of the final aggregation step. + :rtype: list[~azure.ai.ml.Output] + """ + + @pipeline( + func=None, + name="Scatter gather", + description="It includes all steps that need to be executed in silo and aggregation", + ) + # pylint: disable-next=docstring-missing-return,docstring-missing-rtype + def scatter_gather_iteration_body(**silo_inputs: Input) -> PipelineJob: + """ + Performs a scatter-gather iteration by running copies of the silo step on different + computes/datstores according to this node's silo configs. The outputs of these + silo components are then merged by an internal helper component. The merged values + are then inputted into the user-provided aggregation component. Returns the executed aggregation component. + + Kwargs are a dictionary of names and Inputs to be injected into each executed silo step. This dictionary is + merged with silo-specific inputs before each executed. + """ + + silo_outputs = [] + # TODO 2293586 replace this for-loop with a parallel-for node + for silo_config in self.silo_configs: + silo_inputs.update(silo_config.inputs) + executed_silo_component = self.silo_component(**silo_inputs) + for v, k in executed_silo_component.inputs.items(): + if v in silo_config.inputs and k.type == "uri_folder": + k.mode = "ro_mount" + FLScatterGather._anchor_step( + pipeline_step=executed_silo_component, + compute=silo_config.compute, + internal_datastore=silo_config.datastore, + orchestrator_datastore=self.aggregation_datastore, + ) + # add to silo outputs list + silo_outputs.append(executed_silo_component) + + # produce internal argument-merging components and record them in local subgraph + merge_comp_mapping = self._inject_merge_components(silo_outputs) + + # produce aggregate step inputs by merging static kwargs and mapped arguments from + # internal merge components + agg_inputs: Dict = {} + if self.aggregation_kwargs is not None: + agg_inputs.update(self.aggregation_kwargs) + internal_merge_outputs = { + self._get_aggregator_input_name(k): v.outputs.aggregated_output for k, v in merge_comp_mapping.items() + } + agg_inputs.update(internal_merge_outputs) + + # run the user aggregation step + executed_aggregation_component = self.aggregation_component(**agg_inputs) + # Set mode of aggregated mltable inputs as eval mount to allow files referenced within the table + # to be accessible by the component + for name, agg_input in executed_aggregation_component.inputs.items(): + if ( + self.silo_to_aggregation_argument_map is not None + and name in self.silo_to_aggregation_argument_map.values() + and agg_input.type == "mltable" + ): + agg_input.mode = "eval_download" + + # Anchor both the internal merge components and the user-supplied aggregation step + # to the aggregation compute and datastore + if self.aggregation_compute is not None and self.aggregation_datastore is not None: + # internal merge component is also siloed to wherever the aggregation component lives. + for executed_merge_component in merge_comp_mapping.values(): + FLScatterGather._anchor_step( + pipeline_step=executed_merge_component, + compute=self.aggregation_compute, + internal_datastore=self.aggregation_datastore, + orchestrator_datastore=self.aggregation_datastore, + ) + FLScatterGather._anchor_step( + pipeline_step=executed_aggregation_component, + compute=self.aggregation_compute, + internal_datastore=self.aggregation_datastore, + orchestrator_datastore=self.aggregation_datastore, + ) + res: PipelineJob = executed_aggregation_component.outputs + return res + + @pipeline(func=None, name="Scatter gather graph") + # pylint: disable-next=docstring-missing-return,docstring-missing-rtype + def create_scatter_gather_graph() -> PipelineJob: + """ + Creates a scatter-gather graph by executing the scatter_gather_iteration_body + function in a do-while loop. The loop terminates when the user-supplied + termination condition is met. + """ + + silo_inputs: Dict = {} + if self.shared_silo_kwargs is not None: + # Start with static inputs + silo_inputs.update(self.shared_silo_kwargs) + + # merge in inputs passed in from previous iteration's aggregate step) + if self.aggregation_to_silo_argument_map is not None: + silo_inputs.update({v: None for v in self.aggregation_to_silo_argument_map.values()}) + + scatter_gather_body = scatter_gather_iteration_body(**silo_inputs) + + # map aggregation outputs to scatter inputs + if self.aggregation_to_silo_argument_map is not None: + do_while_mapping = { + k: getattr(scatter_gather_body.inputs, v) for k, v in self.aggregation_to_silo_argument_map.items() + } + + do_while( + body=scatter_gather_body, # type: ignore[arg-type] + mapping=do_while_mapping, # pylint: disable=possibly-used-before-assignment + max_iteration_count=self.max_iterations, + ) + res_scatter: PipelineJob = scatter_gather_body.outputs # type: ignore[assignment] + return res_scatter + + res: PipelineJob = create_scatter_gather_graph() + return res + + @classmethod + def _get_fl_datastore_path( + cls, + datastore_name: Optional[str], + output_name: str, + unique_id: str = "${{name}}", + iteration_num: Optional[int] = None, + ) -> str: + """Construct a path string using the inputted values. The important aspect is that this produces a + path with a specified datastore. + + :param datastore_name: The datastore to use in the constructed path. + :type datastore_name: str + :param output_name: The name of the output value that this path is assumed to belong to. + Is injected into the path. + :type output_name: str + :param unique_id: An additional string to inject if needed. Defaults to ${{name}}, which is the + output name again. + :type unique_id: str + :param iteration_num: The iteration number of the current scatter-gather iteration. + If set, inject this into the resulting path string. + :type iteration_num: Optional[int] + :return: A data path string containing the various aforementioned inputs. + :rtype: str + + """ + data_path = f"azureml://datastores/{datastore_name}/paths/federated_learning/{output_name}/{unique_id}/" + if iteration_num: + data_path += f"iteration_{iteration_num}/" + return data_path + + @classmethod + def _check_datastore(cls, path: str, expected_datastore: Optional[str]) -> bool: + """Perform a simple regex check to try determine if the datastore in the inputted path string + matches the expected_datastore. + + + :param path: An output pathstring. + :type path: str + :param expected_datastore: A datastore name. + :type expected_datastore: str + :return: Whether or not the expected_datastore was found in the path at the expected location. + :rtype: bool + """ + match = re.match("(.*datastore/)([^/]*)(/.*)", path) + if match: + groups = match.groups() + if groups[1] == expected_datastore: + return True + return False + + @classmethod + def _check_or_set_datastore( + cls, + name: str, + output: Output, + target_datastore: Optional[str], + iteration_num: Optional[int] = None, + ) -> MutableValidationResult: + """Tries to assign output.path to a value which includes the target_datastore if it's not already + set. If the output's path is already set, return a warning if it doesn't match the target_datastore. + + :param name: The name of the output to modify + :type name: str + :param output: The output object to examine and potentially change the datastore of. + :type output: Output + :param target_datastore: The name of the datastore to try applying to the output + :type target_datastore: str + :param iteration_num: the current iteration in the scatter gather loop. If set, include this in the generated + path. + :type iteration_num: Optional[int] + :return: A validation result containing any problems that arose. Contains a warning if the examined output + already contains a datastore that does not match 'target_datastore'. + :rtype: MutableValidationResult + """ + validation_result = cls._create_empty_validation_result() + if not hasattr(output, "path") or not output.path: + output.path = cls._get_fl_datastore_path(target_datastore, name, iteration_num=iteration_num) + # Double check the path's datastore leads to the target if it's already set. + elif not cls._check_datastore(output.path, target_datastore): + validation_result.append_warning( + yaml_path=name, + message=f"Output '{name}' has an undetermined datastore, or a datstore" + + f" that does not match the expected datastore for this output, which is '{target_datastore}'." + + " Make sure this is intended.", + ) + return validation_result + + # TODO 2293705: Add anchoring for more resource types. + @classmethod + def _anchor_step( + cls, + pipeline_step: Union[Pipeline, CommandComponent], + compute: str, + internal_datastore: str, + orchestrator_datastore: Optional[str], + iteration: Optional[int] = 0, + _path: str = "root", + ) -> MutableValidationResult: + """Take a pipeline step and recursively enforces the right compute/datastore config. + + :param pipeline_step: a step to anchor + :type pipeline_step: Union[Pipeline, CommandComponent] + :param compute: name of the compute target + :type compute: str + :param internal_datastore: The name of the datastore that should be used for internal output anchoring. + :type internal_datastore: str + :param orchestrator_datastore: The name of the orchestrator/aggregation datastore that should be used for + 'real' output anchoring. + :type orchestrator_datastore: str + :param iteration: The current iteration number in the scatter gather loop. Defaults to 0. + :type iteration: Optional[int] + :param _path: for recursive anchoring, codes the "path" inside the pipeline for messaging + :type _path: str + :return: A validation result containing any issues that were uncovered during anchoring. This function adds + warnings when outputs already have assigned paths which don't contain the expected datastore. + :rtype: MutableValidationResult + """ + + validation_result = cls._create_empty_validation_result() + + # Current step is a pipeline, which means we need to inspect its steps (jobs) and + # potentially anchor those as well. + if pipeline_step.type == "pipeline": + if hasattr(pipeline_step, "component"): + # Current step is probably not the root of the graph + # its outputs should be anchored to the internal_datastore. + for name, output in pipeline_step.outputs.items(): + if not isinstance(output, str): + if output.type in ANCHORABLE_OUTPUT_TYPES: + validation_result.merge_with( + cls._check_or_set_datastore( + name=name, + output=output, + target_datastore=orchestrator_datastore, + iteration_num=iteration, + ) + ) + + # then we need to anchor the internal component of this step + # The outputs of this sub-component are a deep copy of the outputs of this step + # This is dangerous, and we need to make sure they both use the same datastore, + # so we keep datastore types identical across this recursive call. + cls._anchor_step( + pipeline_step.component, # type: ignore + compute, + internal_datastore=internal_datastore, + orchestrator_datastore=orchestrator_datastore, + _path=f"{_path}.component", + ) + + else: + # This is a pipeline step with multiple jobs beneath it. + # Anchor its outputs... + for name, output in pipeline_step.outputs.items(): + if not isinstance(output, str): + if output.type in ANCHORABLE_OUTPUT_TYPES: + validation_result.merge_with( + cls._check_or_set_datastore( + name=name, + output=output, + target_datastore=orchestrator_datastore, + iteration_num=iteration, + ) + ) + # ...then recursively anchor each job inside the pipeline + if not isinstance(pipeline_step, CommandComponent): + for job_key in pipeline_step.jobs: + job = pipeline_step.jobs[job_key] + # replace orchestrator with internal datastore, jobs components + # should either use the local datastore + # or have already had their outputs re-assigned. + cls._anchor_step( + job, + compute, + internal_datastore=internal_datastore, + orchestrator_datastore=internal_datastore, + _path=f"{_path}.jobs.{job_key}", + ) + + elif pipeline_step.type == "command": + # if the current step is a command component + # make sure the compute corresponds to the silo + if not isinstance(pipeline_step, CommandComponent) and pipeline_step.compute is None: + pipeline_step.compute = compute + # then anchor each of the job's outputs + for name, output in pipeline_step.outputs.items(): + if not isinstance(output, str): + if output.type in ANCHORABLE_OUTPUT_TYPES: + validation_result.merge_with( + cls._check_or_set_datastore( + name=name, + output=output, + target_datastore=orchestrator_datastore, + iteration_num=iteration, + ) + ) + else: + # TODO revisit this and add support for anchoring more things + raise NotImplementedError(f"under path={_path}: step type={pipeline_step.type} is not supported") + + return validation_result + + # Making this a class method allows for easier, isolated testing, and allows careful + # users to call this as a pre-init step. + # TODO: Might be worth migrating this to a schema validation class, but out of scope for now. + # pylint: disable=too-many-statements,too-many-branches, too-many-locals + @classmethod + def validate_inputs( + cls, + *, + silo_configs: List[FederatedLearningSilo], + silo_component: Component, + aggregation_component: Component, + shared_silo_kwargs: Optional[Dict], + aggregation_compute: Optional[str], + aggregation_datastore: Optional[str], + aggregation_kwargs: Optional[Dict], + silo_to_aggregation_argument_map: Optional[Dict], + aggregation_to_silo_argument_map: Optional[Dict], + max_iterations: int, + raise_error: bool = False, + ) -> MutableValidationResult: + """Validates the inputs for the scatter-gather node. + + :keyword silo_configs: List of federated learning silo configurations. + :paramtype silo_configs: List[~azure.ai.ml.entities._assets.federated_learning_silo.FederatedLearningSilo] + :keyword silo_component: Component representing the silo for federated learning. + :paramtype silo_component: ~azure.ai.ml.entities.Component + :keyword aggregation_component: Component representing the aggregation step. + :paramtype aggregation_component: ~azure.ai.ml.entities.Component + :keyword shared_silo_kwargs: Keyword arguments shared across all silos. + :paramtype shared_silo_kwargs: Dict + :keyword aggregation_compute: The compute resource for the aggregation step. + :paramtype aggregation_compute: str + :keyword aggregation_datastore: The datastore for the aggregation step. + :paramtype aggregation_datastore: str + :keyword aggregation_kwargs: Keyword arguments specific to the aggregation step. + :paramtype aggregation_kwargs: Dict + :keyword silo_to_aggregation_argument_map: Mapping of silo to aggregation arguments. + :paramtype silo_to_aggregation_argument_map: Dict + :keyword aggregation_to_silo_argument_map: Mapping of aggregation to silo arguments. + :paramtype aggregation_to_silo_argument_map: Dict + :keyword max_iterations: The maximum number of iterations for the scatter-gather loop. + :paramtype max_iterations: int + :keyword raise_error: Whether to raise an exception if validation fails. Defaults to False. + :paramtype raise_error: bool + :return: The validation result. + :rtype: ~azure.ai.ml.entities._validation.MutableValidationResult + """ + validation_result = cls._create_empty_validation_result() + + # saved values for validation later on + silo_inputs = None + silo_outputs = None + agg_inputs = None + agg_outputs = None + # validate silo component + if silo_component is None: + validation_result.append_error( + yaml_path="silo_component", + message="silo_component is a required argument for the scatter gather node.", + ) + else: + # ensure that silo component has both inputs and outputs + if not hasattr(silo_component, "inputs"): + validation_result.append_error( + yaml_path="silo_component", + message="silo_component is missing 'inputs' attribute;" + + "it does not appear to be a valid component that can be used in a scatter-gather loop.", + ) + else: + silo_inputs = silo_component.inputs + if not hasattr(silo_component, "outputs"): + validation_result.append_error( + yaml_path="silo_component", + message="silo_component is missing 'outputs' attribute;" + + "it does not appear to be a valid component that can be used in a scatter-gather loop.", + ) + else: + silo_outputs = silo_component.outputs + # validate aggregation component + if aggregation_component is None: + validation_result.append_error( + yaml_path="aggregation_component", + message="aggregation_component is a required argument for the scatter gather node.", + ) + else: + # ensure that aggregation component has both inputs and outputs + if not hasattr(aggregation_component, "inputs"): + validation_result.append_error( + yaml_path="aggregation_component", + message="aggregation_component is missing 'inputs' attribute;" + + "it does not appear to be a valid component that can be used in a scatter-gather loop.", + ) + else: + agg_inputs = aggregation_component.inputs + if not hasattr(aggregation_component, "outputs"): + validation_result.append_error( + yaml_path="aggregation_component", + message="aggregation_component is missing 'outputs' attribute;" + + " it does not appear to be a valid component that can be used in a scatter-gather loop.", + ) + else: + agg_outputs = aggregation_component.outputs + + # validate silos configs + if silo_configs is None: + validation_result.append_error( + yaml_path="silo_configs", + message="silo_configs is a required argument for the scatter gather node.", + ) + elif len(silo_configs) == 0: + validation_result.append_error( + yaml_path="silo_configs", + message="silo_configs cannot be an empty list.", + ) + else: + first_silo = silo_configs[0] + expected_inputs: List = [] + if hasattr(first_silo, "inputs"): + expected_inputs = first_silo.inputs.keys() # type: ignore + num_expected_inputs = len(expected_inputs) + # pylint: disable=consider-using-enumerate + for i in range(len(silo_configs)): + silo = silo_configs[i] + if not hasattr(silo, "compute"): + validation_result.append_error( + yaml_path="silo_configs", + message=f"Silo at index {i} in silo_configs is missing its compute value.", + ) + if not hasattr(silo, "datastore"): + validation_result.append_error( + yaml_path="silo_configs", + message=f"Silo at index {i} in silo_configs is missing its datastore value.", + ) + silo_input_len = 0 + if hasattr(silo, "inputs"): + silo_input_len = len(silo.inputs) + # if inputs exist, make sure the inputs names are consistent across silo configs + for expected_input_name in expected_inputs: + if expected_input_name not in silo.inputs: + validation_result.append_error( + yaml_path="silo_configs", + message=f"Silo at index {i} has is missing inputs named '{expected_input_name}'," + + "which was listed in the first silo config. " + + "Silos must have consistent inputs names.", + ) + if silo_input_len != num_expected_inputs: + validation_result.append_error( + yaml_path="silo_configs", + message=f"Silo at index {i} has {silo_input_len} inputs, but the first silo established that" + + f"each silo would have {num_expected_inputs} silo-specific inputs.", + ) + + # Make sure both aggregation overrides are set, or not + if aggregation_datastore is None and aggregation_compute is not None: + validation_result.append_error( + yaml_path="aggregation_datastore", + message="aggregation_datastore cannot be unset if aggregation_compute is set.", + ) + elif aggregation_datastore is not None and aggregation_compute is None: + validation_result.append_error( + yaml_path="aggregation_compute", + message="aggregation_compute cannot be unset if aggregation_datastore is set.", + ) + + # validate component kwargs, ensuring that the relevant components contain the specified inputs + if shared_silo_kwargs is None: + validation_result.append_error( + yaml_path="shared_silo_kwargs", + message="shared_silo_kwargs should never be None. Input an empty dictionary instead.", + ) + elif silo_inputs is not None: + for k in shared_silo_kwargs.keys(): + if k not in silo_inputs: + validation_result.append_error( + yaml_path="shared_silo_kwargs", + message=f"shared_silo_kwargs keyword {k} not listed in silo_component's inputs", + ) + if aggregation_kwargs is None: + validation_result.append_error( + yaml_path="aggregation_kwargs", + message="aggregation_kwargs should never be None. Input an empty dictionary instead.", + ) + elif silo_inputs is not None: + for k in aggregation_kwargs.keys(): + if agg_inputs is not None and k not in agg_inputs: + validation_result.append_error( + yaml_path="aggregation_kwargs", + message=f"aggregation_kwargs keyword {k} not listed in aggregation_component's inputs", + ) + + # validate that argument mappings leverage inputs and outputs that actually exist + if aggregation_to_silo_argument_map is None: + validation_result.append_error( + yaml_path="aggregation_to_silo_argument_map", + message="aggregation_to_silo_argument_map should never be None. Input an empty dictionary instead.", + ) + elif silo_inputs is not None and agg_outputs is not None: + for k, v in aggregation_to_silo_argument_map.items(): + if k not in agg_outputs: + validation_result.append_error( + yaml_path="aggregation_to_silo_argument_map", + message=f"aggregation_to_silo_argument_map key {k} " + + "is not a known output of the aggregation component.", + ) + if v not in silo_inputs: + validation_result.append_error( + yaml_path="aggregation_to_silo_argument_map", + message=f"aggregation_to_silo_argument_map value {v} " + + "is not a known input of the silo component.", + ) + # and check the other mapping + if silo_to_aggregation_argument_map is None: + validation_result.append_error( + yaml_path="silo_to_aggregation_argument_map", + message="silo_to_aggregation_argument_map should never be None. " + + "Input an empty dictionary instead.", + ) + elif agg_inputs is not None and silo_outputs is not None: + for k, v in silo_to_aggregation_argument_map.items(): + if k not in silo_outputs: + validation_result.append_error( + yaml_path="silo_to_aggregation_argument_map", + message=f"silo_to_aggregation_argument_map key {k }" + + " is not a known output of the silo component.", + ) + if v not in agg_inputs: + validation_result.append_error( + yaml_path="silo_to_aggregation_argument_map", + message=f"silo_to_aggregation_argument_map value {v}" + + " is not a known input of the aggregation component.", + ) + + if max_iterations < 1: + validation_result.append_error( + yaml_path="max_iterations", + message=f"max_iterations must be a positive value, not '{max_iterations}'.", + ) + + return cls._try_raise(validation_result, raise_error=raise_error) + + @classmethod + def _custom_fl_data_path( + cls, + datastore_name: str, + output_name: str, + unique_id: str = "${{name}}", + iteration_num: str = "${{iteration_num}}", + ) -> str: + """Produces a path to store the data during FL training. + + :param datastore_name: name of the Azure ML datastore + :type datastore_name: str + :param output_name: a name unique to this output + :type output_name: str + :param unique_id: a unique id for the run (default: inject run id with ${{name}}) + :type unique_id: str + :param iteration_num: an iteration number if relevant + :type iteration_num: str + :return: direct url to the data path to store the data + :rtype: str + """ + data_path = f"azureml://datastores/{datastore_name}/paths/federated_learning/{output_name}/{unique_id}/" + if iteration_num is not None: + data_path += f"iteration_{iteration_num}/" + + return data_path + + def _get_aggregator_input_name(self, silo_output_name: str) -> Optional[str]: + """Retrieves the aggregator input name + + :param silo_output_name: The silo output name + :type silo_output_name: str + :return: + * Returns aggregator input name that maps to silo_output. + * Returns None if silo_output_name not in silo_to_aggregation_argument_map + :rtype: Optional[str] + """ + if self.silo_to_aggregation_argument_map is None: + return None + + return self.silo_to_aggregation_argument_map.get(silo_output_name) + + @classmethod + def _try_create_default_mappings( + cls, + silo_comp: Optional[Component], + agg_comp: Optional[Component], + silo_agg_map: Optional[Dict], + agg_silo_map: Optional[Dict], + ) -> Tuple[Optional[Dict], Optional[Dict]]: + """ + This function tries to produce dictionaries that link the silo and aggregation + components' outputs to the other's inputs. + The mapping only occurs for inputted mappings that are None, otherwise + the inputted mapping is returned unchanged. + These auto-generated mappings are naive, and simply maps all outputs of a component that have a + identically-named input in the other component. + + This function does nothing if either inputted component is None. This function will also do nothing + for a given mapping if either of the relevant inputs or outputs are None (but not empty). + + Example inputs: + silo_comp.inputs = {"silo_input" : value } + silo_comp.outputs = {"c" : ..., "silo_output2" : ... } + agg_comp.inputs = {"silo_output1" : ... } + agg_comp.outputs = {"agg_output" : ... } + silo_agg_map = None + agg_silo_map = {} + + Example returns: + {"silo_output1" : "silo_output1"}, {} + + :param silo_comp: The silo component + :type silo_comp: Optional[Component] + :param agg_comp: The aggregation component + :type agg_comp: Optional[Component] + :param silo_agg_map: Mapping of silo to aggregation arguments. + :type silo_agg_map: Optional[Dict] + :param agg_silo_map: Mapping of aggregation to silo arguments. + :type agg_silo_map: Optional[Dict] + :return: Returns a tuple of the potentially modified silo to aggregation mapping, followed by the aggregation + to silo mapping. + :rtype: Tuple[Optional[Dict], Optional[Dict]] + """ + if silo_comp is None or agg_comp is None: + return silo_agg_map, agg_silo_map + if silo_agg_map is None and silo_comp.outputs is not None and agg_comp.inputs is not None: + silo_agg_map = {output: output for output in silo_comp.outputs.keys() if output in agg_comp.inputs} + if agg_silo_map is None: + agg_silo_map = {output: output for output in agg_comp.outputs.keys() if output in silo_comp.inputs} + return silo_agg_map, agg_silo_map + + @staticmethod + # pylint: disable-next=docstring-missing-rtype + def _get_merge_component(output_type: str) -> Any: + """Gets the merge component to be used based on type of output + + :param output_type: The output type + :type output_type: str + :return: The merge component + """ + return MERGE_COMPONENT_MAPPING[output_type] + + def _inject_merge_components(self, executed_silo_components: Any) -> Dict: + """Add a merge component for each silo output in the silo_to_aggregation_argument_map. + These merge components act as a mediator between the user silo and aggregation steps, reducing + the variable number of silo outputs into a single input for the aggergation step. + + :param executed_silo_components: A list of executed silo steps to extract outputs from. + :type executed_silo_components: + :return: A mapping from silo output names to the corresponding newly created and executed merge component + :rtype: dict + """ + executed_component = executed_silo_components[0] + + merge_comp_mapping = {} + if self.silo_to_aggregation_argument_map is not None: + for ( + silo_output_argument_name, + _, + ) in self.silo_to_aggregation_argument_map.items(): + merge_comp = self._get_merge_component(executed_component.outputs[silo_output_argument_name].type) + merge_component_inputs = { + silo_output_argument_name + + "_silo_" + + str(i): executed_silo_components[i].outputs[silo_output_argument_name] + for i in range(0, len(executed_silo_components)) + } + executed_merge_component = merge_comp(**merge_component_inputs) + for input_obj in executed_merge_component.inputs.values(): + input_obj.mode = "direct" + for output_obj in executed_merge_component.outputs.values(): + output_obj.type = "mltable" + merge_comp_mapping.update({silo_output_argument_name: executed_merge_component}) + + return merge_comp_mapping + + # boilerplate functions - largely copied from other node builders + + @property + def outputs(self) -> Dict[str, Union[str, Output]]: + """Get the outputs of the scatter-gather node. + + :return: The outputs of the scatter-gather node. + :rtype: Dict[str, Union[str, ~azure.ai.ml.Output]] + """ + return self._outputs + + @classmethod + def _create_schema_for_validation(cls, context: Any) -> PathAwareSchema: + return FLScatterGatherSchema(context=context) + + def _to_rest_object(self, **kwargs: Any) -> dict: + """Convert self to a rest object for remote call. + + :return: The rest object + :rtype: dict + """ + rest_node = super(FLScatterGather, self)._to_rest_object(**kwargs) + rest_node.update({"outputs": self._to_rest_outputs()}) + # TODO: Bug Item number: 2897665 + res: dict = convert_ordered_dict_to_dict(rest_node) # type: ignore + return res diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/import_func.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/import_func.py new file mode 100644 index 00000000..c9ecabd8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/import_func.py @@ -0,0 +1,93 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from typing import Any, Dict, Optional + +from azure.ai.ml.constants._component import ComponentSource +from azure.ai.ml.entities._component.import_component import ImportComponent +from azure.ai.ml.entities._inputs_outputs import Output +from azure.ai.ml.entities._job.import_job import ImportSource + +from .command_func import _parse_input, _parse_inputs_outputs, _parse_output +from .import_node import Import + + +def import_job( + *, + name: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[Dict] = None, + display_name: Optional[str] = None, + experiment_name: Optional[str] = None, + source: Optional[ImportSource] = None, + output: Optional[Output] = None, + is_deterministic: bool = True, + **kwargs: Any, +) -> Import: + """Create an Import object which can be used inside dsl.pipeline as a function + and can also be created as a standalone import job. + + :keyword name: Name of the import job or component created. + :paramtype name: str + :keyword description: A friendly description of the import. + :paramtype description: str + :keyword tags: Tags to be attached to this import. + :paramtype tags: Dict + :keyword display_name: A friendly name. + :paramtype display_name: str + :keyword experiment_name: Name of the experiment the job will be created under. + If None is provided, the default will be set to the current directory name. + Will be ignored as a pipeline step. + :paramtype experiment_name: str + :keyword source: Input source parameters used by this import. + :paramtype source: ~azure.ai.ml.entities._job.import_job.ImportSource + :keyword output: The output of this import. + :paramtype output: ~azure.ai.ml.entities.Output + :keyword is_deterministic: Specify whether the command will return the same output given the same input. + If a command (component) is deterministic, when used as a node/step in a pipeline, + it will reuse results from a previously submitted job in the current workspace + which has the same inputs and settings. + In this case, this step will not use any compute resource. + Defaults to True. + :paramtype is_deterministic: bool + :returns: The Import object. + :rtype: ~azure.ai.ml.entities._builders.import_node.Import + """ + inputs = source._to_job_inputs() if source else kwargs.pop("inputs") + outputs = {"output": output} if output else kwargs.pop("outputs") + component_inputs, job_inputs = _parse_inputs_outputs(inputs, parse_func=_parse_input) + # job inputs can not be None + job_inputs = {k: v for k, v in job_inputs.items() if v is not None} + component_outputs, job_outputs = _parse_inputs_outputs(outputs, parse_func=_parse_output) + + component = kwargs.pop("component", None) + + if component is None: + component = ImportComponent( + name=name, + tags=tags, + display_name=display_name, + description=description, + source=component_inputs, + output=component_outputs["output"], + _source=ComponentSource.BUILDER, + is_deterministic=is_deterministic, + **kwargs, + ) + + import_obj = Import( + component=component, + name=name, + description=description, + tags=tags, + display_name=display_name, + experiment_name=experiment_name, + inputs=job_inputs, + outputs=job_outputs, + **kwargs, + ) + + return import_obj diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/import_node.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/import_node.py new file mode 100644 index 00000000..144753d5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/import_node.py @@ -0,0 +1,205 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint: disable=protected-access + +import logging +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +from marshmallow import Schema + +from azure.ai.ml._restclient.v2022_02_01_preview.models import CommandJob as RestCommandJob +from azure.ai.ml._restclient.v2022_02_01_preview.models import JobBaseData +from azure.ai.ml._schema.job.import_job import ImportJobSchema +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY +from azure.ai.ml.constants._component import ComponentSource, NodeType +from azure.ai.ml.constants._compute import ComputeType +from azure.ai.ml.entities._component.component import Component +from azure.ai.ml.entities._component.import_component import ImportComponent +from azure.ai.ml.entities._inputs_outputs import Output +from azure.ai.ml.entities._job._input_output_helpers import from_rest_data_outputs, from_rest_inputs_to_dataset_literal +from azure.ai.ml.entities._job.import_job import ImportJob, ImportSource +from azure.ai.ml.exceptions import ErrorTarget, ValidationErrorType, ValidationException + +from ..._schema import PathAwareSchema +from .._inputs_outputs import Output +from .._util import convert_ordered_dict_to_dict, load_from_dict, validate_attribute_type +from .base_node import BaseNode + +module_logger = logging.getLogger(__name__) + + +class Import(BaseNode): + """Base class for import node, used for import component version consumption. + + You should not instantiate this class directly. Instead, you should + create from a builder function. + + :param component: Id or instance of the import component/job to be run for the step. + :type component: ~azure.ai.ml.entities._component.import_component.ImportComponent + :param inputs: Input parameters to the import. + :type inputs: Dict[str, str] + :param outputs: Mapping of output data bindings used in the job. + :type outputs: Dict[str, Union[str, ~azure.ai.ml.entities.Output]] + :param name: Name of the import. + :type name: str + :param description: Description of the import. + :type description: str + :param display_name: Display name of the job. + :type display_name: str + :param experiment_name: Name of the experiment the job will be created under, + if None is provided, the default will be set to the current directory name. + :type experiment_name: str + """ + + def __init__( + self, + *, + component: Union[str, ImportComponent], + inputs: Optional[Dict] = None, + outputs: Optional[Dict] = None, + **kwargs: Any, + ) -> None: + # validate init params are valid type + validate_attribute_type(attrs_to_check=locals(), attr_type_map=self._attr_type_map()) + + kwargs.pop("type", None) + kwargs.pop("compute", None) + + self._parameters = kwargs.pop("parameters", {}) + BaseNode.__init__( + self, + type=NodeType.IMPORT, + inputs=inputs, + outputs=outputs, + component=component, + compute=ComputeType.ADF, + **kwargs, + ) + + @classmethod + def _get_supported_inputs_types(cls) -> Type[str]: + # import source parameters type, connection, query, path are always str + return str + + @classmethod + def _get_supported_outputs_types(cls) -> Tuple: + return str, Output + + @property + def component(self) -> Union[str, ImportComponent]: + res: Union[str, ImportComponent] = self._component + return res + + @classmethod + def _attr_type_map(cls) -> dict: + return { + "component": (str, ImportComponent), + } + + def _to_job(self) -> ImportJob: + return ImportJob( + id=self.id, + name=self.name, + display_name=self.display_name, + description=self.description, + experiment_name=self.experiment_name, + status=self.status, + source=ImportSource._from_job_inputs(self._job_inputs), + output=self._job_outputs.get("output"), + creation_context=self.creation_context, + ) + + @classmethod + def _picked_fields_from_dict_to_rest_object(cls) -> List[str]: + return [] + + def _to_rest_object(self, **kwargs: Any) -> dict: + rest_obj: dict = super()._to_rest_object(**kwargs) + rest_obj.update( + convert_ordered_dict_to_dict( + { + "componentId": self._get_component_id(), + } + ) + ) + return rest_obj + + @classmethod + def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "Import": + from .import_func import import_job + + loaded_data = load_from_dict(ImportJobSchema, data, context, additional_message, **kwargs) + + _import_job: Import = import_job(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_data) + + return _import_job + + @classmethod + def _load_from_rest_job(cls, obj: JobBaseData) -> "Import": + from .import_func import import_job + + rest_command_job: RestCommandJob = obj.properties + inputs = from_rest_inputs_to_dataset_literal(rest_command_job.inputs) + outputs = from_rest_data_outputs(rest_command_job.outputs) + + _import_job: Import = import_job( + name=obj.name, + display_name=rest_command_job.display_name, + description=rest_command_job.description, + experiment_name=rest_command_job.experiment_name, + status=rest_command_job.status, + creation_context=obj.system_data, + inputs=inputs, + output=outputs["output"] if "output" in outputs else None, + ) + _import_job._id = obj.id + if isinstance(_import_job.component, ImportComponent): + _import_job.component._source = ( + ComponentSource.REMOTE_WORKSPACE_JOB + ) # This is used by pipeline job telemetries. + + return _import_job + + @classmethod + def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]: + from azure.ai.ml._schema.pipeline import ImportSchema + + return ImportSchema(context=context) + + # pylint: disable-next=docstring-missing-param + def __call__(self, *args: Any, **kwargs: Any) -> "Import": + """Call Import as a function will return a new instance each time. + + :return: An Import node. + :rtype: Import + """ + if isinstance(self._component, Component): + # call this to validate inputs + node: Import = self._component(*args, **kwargs) + # merge inputs + for name, original_input in self.inputs.items(): + if name not in kwargs: + # use setattr here to make sure owner of input won't change + setattr(node.inputs, name, original_input._data) + node._job_inputs[name] = original_input._data + # get outputs + for name, original_output in self.outputs.items(): + # use setattr here to make sure owner of input won't change + if not isinstance(original_output, str): + setattr(node.outputs, name, original_output._data) + self._refine_optional_inputs_with_no_value(node, kwargs) + # set default values: compute, environment_variables, outputs + node._name = self.name + node.compute = self.compute + node.tags = self.tags + # Pass through the display name only if the display name is not system generated. + node.display_name = self.display_name if self.display_name != self.name else None + return node + msg = "Import can be called as a function only when referenced component is {}, currently got {}." + raise ValidationException( + message=msg.format(type(Component), self._component), + no_personal_data_message=msg.format(type(Component), "self._component"), + target=ErrorTarget.COMMAND_JOB, + error_type=ValidationErrorType.INVALID_VALUE, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/parallel.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/parallel.py new file mode 100644 index 00000000..db1de797 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/parallel.py @@ -0,0 +1,551 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import copy +import json +import logging +import os +import re +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast + +from marshmallow import INCLUDE, Schema + +from azure.ai.ml._schema.core.fields import NestedField, UnionField +from azure.ai.ml._schema.job.identity import AMLTokenIdentitySchema, ManagedIdentitySchema, UserIdentitySchema +from azure.ai.ml.entities._credentials import ( + AmlTokenConfiguration, + ManagedIdentityConfiguration, + UserIdentityConfiguration, + _BaseJobIdentityConfiguration, +) +from azure.ai.ml.entities._job.job import Job +from azure.ai.ml.entities._job.parallel.run_function import RunFunction +from azure.ai.ml.entities._job.pipeline._io import NodeOutput +from azure.ai.ml.exceptions import MlException + +from ..._schema import PathAwareSchema +from ..._utils.utils import is_data_binding_expression +from ...constants._common import ARM_ID_PREFIX +from ...constants._component import NodeType +from .._component.component import Component +from .._component.flow import FlowComponent +from .._component.parallel_component import ParallelComponent +from .._inputs_outputs import Input, Output +from .._job.job_resource_configuration import JobResourceConfiguration +from .._job.parallel.parallel_job import ParallelJob +from .._job.parallel.parallel_task import ParallelTask +from .._job.parallel.retry_settings import RetrySettings +from .._job.pipeline._io import NodeWithGroupInputMixin +from .._util import convert_ordered_dict_to_dict, get_rest_dict_for_node_attrs, validate_attribute_type +from .base_node import BaseNode + +module_logger = logging.getLogger(__name__) + + +class Parallel(BaseNode, NodeWithGroupInputMixin): # pylint: disable=too-many-instance-attributes + """Base class for parallel node, used for parallel component version consumption. + + You should not instantiate this class directly. Instead, you should + create from builder function: parallel. + + :param component: Id or instance of the parallel component/job to be run for the step + :type component: ~azure.ai.ml.entities._component.parallel_component.parallelComponent + :param name: Name of the parallel + :type name: str + :param description: Description of the commad + :type description: str + :param tags: Tag dictionary. Tags can be added, removed, and updated + :type tags: dict[str, str] + :param properties: The job property dictionary + :type properties: dict[str, str] + :param display_name: Display name of the job + :type display_name: str + :param retry_settings: Parallel job run failed retry + :type retry_settings: BatchRetrySettings + :param logging_level: A string of the logging level name + :type logging_level: str + :param max_concurrency_per_instance: The max parallellism that each compute instance has + :type max_concurrency_per_instance: int + :param error_threshold: The number of item processing failures should be ignored + :type error_threshold: int + :param mini_batch_error_threshold: The number of mini batch processing failures should be ignored + :type mini_batch_error_threshold: int + :param task: The parallel task + :type task: ParallelTask + :param mini_batch_size: For FileDataset input, this field is the number of files + a user script can process in one run() call. + For TabularDataset input, this field is the approximate size of data + the user script can process in one run() call. + Example values are 1024, 1024KB, 10MB, and 1GB. (optional, default value is 10 files + for FileDataset and 1MB for TabularDataset.) + This value could be set through PipelineParameter + :type mini_batch_size: str + :param partition_keys: The keys used to partition dataset into mini-batches. If specified, + the data with the same key will be partitioned into the same mini-batch. + If both partition_keys and mini_batch_size are specified, + the partition keys will take effect. + The input(s) must be partitioned dataset(s), + and the partition_keys must be a subset of the keys of every input dataset for this to work. + :keyword identity: The identity that the command job will use while running on compute. + :paramtype identity: Optional[Union[ + dict[str, str], + ~azure.ai.ml.entities.ManagedIdentityConfiguration, + ~azure.ai.ml.entities.AmlTokenConfiguration, + ~azure.ai.ml.entities.UserIdentityConfiguration]] + :type partition_keys: List + :param input_data: The input data + :type input_data: str + :param inputs: Inputs of the component/job + :type inputs: dict + :param outputs: Outputs of the component/job + :type outputs: dict + """ + + # pylint: disable=too-many-statements + def __init__( + self, + *, + component: Union[ParallelComponent, str], + compute: Optional[str] = None, + inputs: Optional[Dict[str, Union[NodeOutput, Input, str, bool, int, float, Enum]]] = None, + outputs: Optional[Dict[str, Union[str, Output, "Output"]]] = None, + retry_settings: Optional[Union[RetrySettings, Dict[str, str]]] = None, + logging_level: Optional[str] = None, + max_concurrency_per_instance: Optional[int] = None, + error_threshold: Optional[int] = None, + mini_batch_error_threshold: Optional[int] = None, + input_data: Optional[str] = None, + task: Optional[Union[ParallelTask, RunFunction, Dict]] = None, + partition_keys: Optional[List] = None, + mini_batch_size: Optional[Union[str, int]] = None, + resources: Optional[JobResourceConfiguration] = None, + environment_variables: Optional[Dict] = None, + identity: Optional[ + Union[ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration, Dict] + ] = None, + **kwargs: Any, + ) -> None: + # validate init params are valid type + validate_attribute_type(attrs_to_check=locals(), attr_type_map=self._attr_type_map()) + kwargs.pop("type", None) + + if isinstance(component, FlowComponent): + # make input definition fit actual inputs for flow component + with component._inputs._fit_inputs(inputs): # type: ignore[attr-defined] + BaseNode.__init__( + self, + type=NodeType.PARALLEL, + component=component, + inputs=inputs, + outputs=outputs, + compute=compute, + **kwargs, + ) + else: + BaseNode.__init__( + self, + type=NodeType.PARALLEL, + component=component, + inputs=inputs, + outputs=outputs, + compute=compute, + **kwargs, + ) + # init mark for _AttrDict + self._init = True + + self._task = task + + if ( + mini_batch_size is not None + and not isinstance(mini_batch_size, int) + and not is_data_binding_expression(mini_batch_size) + ): + """Convert str to int.""" # pylint: disable=pointless-string-statement + pattern = re.compile(r"^\d+([kKmMgG][bB])*$") + if not pattern.match(mini_batch_size): + raise ValueError(r"Parameter mini_batch_size must follow regex rule ^\d+([kKmMgG][bB])*$") + + try: + mini_batch_size = int(mini_batch_size) + except ValueError as e: + if not isinstance(mini_batch_size, int): + unit = mini_batch_size[-2:].lower() + if unit == "kb": + mini_batch_size = int(mini_batch_size[0:-2]) * 1024 + elif unit == "mb": + mini_batch_size = int(mini_batch_size[0:-2]) * 1024 * 1024 + elif unit == "gb": + mini_batch_size = int(mini_batch_size[0:-2]) * 1024 * 1024 * 1024 + else: + raise ValueError("mini_batch_size unit must be kb, mb or gb") from e + + self.mini_batch_size = mini_batch_size + self.partition_keys = partition_keys + self.input_data = input_data + self._retry_settings = retry_settings + self.logging_level = logging_level + self.max_concurrency_per_instance = max_concurrency_per_instance + self.error_threshold = error_threshold + self.mini_batch_error_threshold = mini_batch_error_threshold + self._resources = resources + self.environment_variables = {} if environment_variables is None else environment_variables + self._identity = identity + if isinstance(self.component, ParallelComponent): + self.resources = cast(JobResourceConfiguration, self.resources) or cast( + JobResourceConfiguration, copy.deepcopy(self.component.resources) + ) + # TODO: Bug Item number: 2897665 + self.retry_settings = self.retry_settings or copy.deepcopy(self.component.retry_settings) # type: ignore + self.input_data = self.input_data or self.component.input_data + self.max_concurrency_per_instance = ( + self.max_concurrency_per_instance or self.component.max_concurrency_per_instance + ) + self.mini_batch_error_threshold = ( + self.mini_batch_error_threshold or self.component.mini_batch_error_threshold + ) + self.mini_batch_size = self.mini_batch_size or self.component.mini_batch_size + self.partition_keys = self.partition_keys or copy.deepcopy(self.component.partition_keys) + + if not self.task: + self.task = self.component.task + # task.code is based on self.component.base_path + self._base_path = self.component.base_path + + self._init = False + + @classmethod + def _get_supported_outputs_types(cls) -> Tuple: + return str, Output + + @property + def retry_settings(self) -> RetrySettings: + """Get the retry settings for the parallel job. + + :return: The retry settings for the parallel job. + :rtype: ~azure.ai.ml.entities._job.parallel.retry_settings.RetrySettings + """ + return self._retry_settings # type: ignore + + @retry_settings.setter + def retry_settings(self, value: Union[RetrySettings, Dict]) -> None: + """Set the retry settings for the parallel job. + + :param value: The retry settings for the parallel job. + :type value: ~azure.ai.ml.entities._job.parallel.retry_settings.RetrySettings or dict + """ + if isinstance(value, dict): + value = RetrySettings(**value) + self._retry_settings = value + + @property + def resources(self) -> Optional[JobResourceConfiguration]: + """Get the resource configuration for the parallel job. + + :return: The resource configuration for the parallel job. + :rtype: ~azure.ai.ml.entities._job.job_resource_configuration.JobResourceConfiguration + """ + return self._resources + + @resources.setter + def resources(self, value: Union[JobResourceConfiguration, Dict]) -> None: + """Set the resource configuration for the parallel job. + + :param value: The resource configuration for the parallel job. + :type value: ~azure.ai.ml.entities._job.job_resource_configuration.JobResourceConfiguration or dict + """ + if isinstance(value, dict): + value = JobResourceConfiguration(**value) + self._resources = value + + @property + def identity( + self, + ) -> Optional[Union[ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration, Dict]]: + """The identity that the job will use while running on compute. + + :return: The identity that the job will use while running on compute. + :rtype: Optional[Union[~azure.ai.ml.ManagedIdentityConfiguration, ~azure.ai.ml.AmlTokenConfiguration, + ~azure.ai.ml.UserIdentityConfiguration]] + """ + return self._identity + + @identity.setter + def identity( + self, + value: Union[Dict, ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration, None], + ) -> None: + """Sets the identity that the job will use while running on compute. + + :param value: The identity that the job will use while running on compute. + :type value: Union[dict[str, str], ~azure.ai.ml.ManagedIdentityConfiguration, + ~azure.ai.ml.AmlTokenConfiguration, ~azure.ai.ml.UserIdentityConfiguration] + """ + if isinstance(value, dict): + identity_schema = UnionField( + [ + NestedField(ManagedIdentitySchema, unknown=INCLUDE), + NestedField(AMLTokenIdentitySchema, unknown=INCLUDE), + NestedField(UserIdentitySchema, unknown=INCLUDE), + ] + ) + value = identity_schema._deserialize(value=value, attr=None, data=None) + self._identity = value + + @property + def component(self) -> Union[str, ParallelComponent]: + """Get the component of the parallel job. + + :return: The component of the parallel job. + :rtype: str or ~azure.ai.ml.entities._component.parallel_component.ParallelComponent + """ + res: Union[str, ParallelComponent] = self._component + return res + + @property + def task(self) -> Optional[ParallelTask]: + """Get the parallel task. + + :return: The parallel task. + :rtype: ~azure.ai.ml.entities._job.parallel.parallel_task.ParallelTask + """ + return self._task # type: ignore + + @task.setter + def task(self, value: Union[ParallelTask, Dict]) -> None: + """Set the parallel task. + + :param value: The parallel task. + :type value: ~azure.ai.ml.entities._job.parallel.parallel_task.ParallelTask or dict + """ + # base path should be reset if task is set via sdk + self._base_path: Optional[Union[str, os.PathLike]] = None + if isinstance(value, dict): + value = ParallelTask(**value) + self._task = value + + def _set_base_path(self, base_path: Optional[Union[str, os.PathLike]]) -> None: + if self._base_path: + return + super(Parallel, self)._set_base_path(base_path) + + def set_resources( + self, + *, + instance_type: Optional[Union[str, List[str]]] = None, + instance_count: Optional[int] = None, + properties: Optional[Dict] = None, + docker_args: Optional[str] = None, + shm_size: Optional[str] = None, + # pylint: disable=unused-argument + **kwargs: Any, + ) -> None: + """Set the resources for the parallel job. + + :keyword instance_type: The instance type or a list of instance types used as supported by the compute target. + :paramtype instance_type: Union[str, List[str]] + :keyword instance_count: The number of instances or nodes used by the compute target. + :paramtype instance_count: int + :keyword properties: The property dictionary for the resources. + :paramtype properties: dict + :keyword docker_args: Extra arguments to pass to the Docker run command. + :paramtype docker_args: str + :keyword shm_size: Size of the Docker container's shared memory block. + :paramtype shm_size: str + """ + if self.resources is None: + self.resources = JobResourceConfiguration() + + if instance_type is not None: + self.resources.instance_type = instance_type + if instance_count is not None: + self.resources.instance_count = instance_count + if properties is not None: + self.resources.properties = properties + if docker_args is not None: + self.resources.docker_args = docker_args + if shm_size is not None: + self.resources.shm_size = shm_size + + # Save the resources to internal component as well, otherwise calling sweep() will loose the settings + if isinstance(self.component, Component): + self.component.resources = self.resources + + @classmethod + def _attr_type_map(cls) -> dict: + return { + "component": (str, ParallelComponent, FlowComponent), + "retry_settings": (dict, RetrySettings), + "resources": (dict, JobResourceConfiguration), + "task": (dict, ParallelTask), + "logging_level": str, + "max_concurrency_per_instance": (str, int), + "error_threshold": (str, int), + "mini_batch_error_threshold": (str, int), + "environment_variables": dict, + } + + def _to_job(self) -> ParallelJob: + return ParallelJob( + name=self.name, + display_name=self.display_name, + description=self.description, + tags=self.tags, + properties=self.properties, + compute=self.compute, + resources=self.resources, + partition_keys=self.partition_keys, + mini_batch_size=self.mini_batch_size, + task=self.task, + retry_settings=self.retry_settings, + input_data=self.input_data, + logging_level=self.logging_level, + identity=self.identity, + max_concurrency_per_instance=self.max_concurrency_per_instance, + error_threshold=self.error_threshold, + mini_batch_error_threshold=self.mini_batch_error_threshold, + environment_variables=self.environment_variables, + inputs=self._job_inputs, + outputs=self._job_outputs, + ) + + def _parallel_attr_to_dict(self, attr: str, base_type: Type) -> dict: + # Convert parallel attribute to dict + rest_attr = {} + parallel_attr = getattr(self, attr) + if parallel_attr is not None: + if isinstance(parallel_attr, base_type): + rest_attr = parallel_attr._to_dict() + elif isinstance(parallel_attr, dict): + rest_attr = parallel_attr + else: + msg = f"Expecting {base_type} for {attr}, got {type(parallel_attr)} instead." + raise MlException(message=msg, no_personal_data_message=msg) + # TODO: Bug Item number: 2897665 + res: dict = convert_ordered_dict_to_dict(rest_attr) # type: ignore + return res + + @classmethod + def _picked_fields_from_dict_to_rest_object(cls) -> List[str]: + return [ + "type", + "resources", + "error_threshold", + "mini_batch_error_threshold", + "environment_variables", + "max_concurrency_per_instance", + "task", + "input_data", + ] + + def _to_rest_object(self, **kwargs: Any) -> dict: + rest_obj: Dict = super(Parallel, self)._to_rest_object(**kwargs) + rest_obj.update( + convert_ordered_dict_to_dict( + { + "componentId": self._get_component_id(), + "retry_settings": get_rest_dict_for_node_attrs(self.retry_settings), + "logging_level": self.logging_level, + "mini_batch_size": self.mini_batch_size, + "partition_keys": ( + json.dumps(self.partition_keys) if self.partition_keys is not None else self.partition_keys + ), + "identity": get_rest_dict_for_node_attrs(self.identity), + "resources": get_rest_dict_for_node_attrs(self.resources), + } + ) + ) + return rest_obj + + @classmethod + def _from_rest_object_to_init_params(cls, obj: dict) -> Dict: + obj = super()._from_rest_object_to_init_params(obj) + # retry_settings + if "retry_settings" in obj and obj["retry_settings"]: + obj["retry_settings"] = RetrySettings._from_dict(obj["retry_settings"]) + + if "task" in obj and obj["task"]: + obj["task"] = ParallelTask._from_dict(obj["task"]) + task_code = obj["task"].code + task_env = obj["task"].environment + # remove azureml: prefix in code and environment which is added in _to_rest_object + if task_code and isinstance(task_code, str) and task_code.startswith(ARM_ID_PREFIX): + obj["task"].code = task_code[len(ARM_ID_PREFIX) :] + if task_env and isinstance(task_env, str) and task_env.startswith(ARM_ID_PREFIX): + obj["task"].environment = task_env[len(ARM_ID_PREFIX) :] + + if "resources" in obj and obj["resources"]: + obj["resources"] = JobResourceConfiguration._from_rest_object(obj["resources"]) + + if "partition_keys" in obj and obj["partition_keys"]: + obj["partition_keys"] = json.dumps(obj["partition_keys"]) + if "identity" in obj and obj["identity"]: + obj["identity"] = _BaseJobIdentityConfiguration._from_rest_object(obj["identity"]) + return obj + + def _build_inputs(self) -> Dict: + inputs = super(Parallel, self)._build_inputs() + built_inputs = {} + # Validate and remove non-specified inputs + for key, value in inputs.items(): + if value is not None: + built_inputs[key] = value + return built_inputs + + @classmethod + def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]: + from azure.ai.ml._schema.pipeline import ParallelSchema + + return ParallelSchema(context=context) + + # pylint: disable-next=docstring-missing-param + def __call__(self, *args: Any, **kwargs: Any) -> "Parallel": + """Call Parallel as a function will return a new instance each time. + + :return: A Parallel node + :rtype: Parallel + """ + if isinstance(self._component, Component): + # call this to validate inputs + node: Parallel = self._component(*args, **kwargs) + # merge inputs + for name, original_input in self.inputs.items(): + if name not in kwargs: + # use setattr here to make sure owner of input won't change + setattr(node.inputs, name, original_input._data) + # get outputs + for name, original_output in self.outputs.items(): + # use setattr here to make sure owner of input won't change + if not isinstance(original_output, str): + setattr(node.outputs, name, original_output._data) + self._refine_optional_inputs_with_no_value(node, kwargs) + # set default values: compute, environment_variables, outputs + node._name = self.name + node.compute = self.compute + node.tags = self.tags + node.display_name = self.display_name + node.mini_batch_size = self.mini_batch_size + node.partition_keys = self.partition_keys + node.logging_level = self.logging_level + node.max_concurrency_per_instance = self.max_concurrency_per_instance + node.error_threshold = self.error_threshold + # deep copy for complex object + node.retry_settings = copy.deepcopy(self.retry_settings) + node.input_data = self.input_data + node.task = copy.deepcopy(self.task) + node._base_path = self.base_path + node.resources = copy.deepcopy(self.resources) + node.environment_variables = copy.deepcopy(self.environment_variables) + node.identity = copy.deepcopy(self.identity) + return node + msg = f"Parallel can be called as a function only when referenced component is {type(Component)}, \ + currently got {self._component}." + raise MlException(message=msg, no_personal_data_message=msg) + + @classmethod + def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "Job": + raise NotImplementedError() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/parallel_for.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/parallel_for.py new file mode 100644 index 00000000..1e888f50 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/parallel_for.py @@ -0,0 +1,362 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import json +import os +from typing import Any, Dict, Optional, Union + +from azure.ai.ml import Input, Output +from azure.ai.ml._schema import PathAwareSchema +from azure.ai.ml._schema.pipeline.control_flow_job import ParallelForSchema +from azure.ai.ml._utils.utils import is_data_binding_expression +from azure.ai.ml.constants import AssetTypes +from azure.ai.ml.constants._component import ComponentParameterTypes, ControlFlowType +from azure.ai.ml.entities import Component, Pipeline +from azure.ai.ml.entities._builders import BaseNode +from azure.ai.ml.entities._builders.control_flow_node import LoopNode +from azure.ai.ml.entities._job.pipeline._io import NodeOutput, PipelineInput +from azure.ai.ml.entities._job.pipeline._io.mixin import NodeIOMixin +from azure.ai.ml.entities._util import convert_ordered_dict_to_dict, validate_attribute_type +from azure.ai.ml.entities._validation import MutableValidationResult +from azure.ai.ml.exceptions import UserErrorException + + +class ParallelFor(LoopNode, NodeIOMixin): + """Parallel for loop node in the pipeline job. By specifying the loop body and aggregated items, a job-level + parallel for loop can be implemented. It will be initialized when calling dsl.parallel_for or when loading the + pipeline yml containing parallel_for node. Please do not manually initialize this class. + + :param body: Pipeline job for the parallel for loop body. + :type body: ~azure.ai.ml.entities.Pipeline + :param items: The loop body's input which will bind to the loop node. + :type items: typing.Union[list, dict, str, ~azure.ai.ml.entities._job.pipeline._io.NodeOutput, + ~azure.ai.ml.entities._job.pipeline._io.PipelineInput] + :param max_concurrency: Maximum number of concurrent iterations to run. All loop body nodes will be executed + in parallel if not specified. + :type max_concurrency: int + """ + + OUT_TYPE_MAPPING = { + AssetTypes.URI_FILE: AssetTypes.MLTABLE, + AssetTypes.URI_FOLDER: AssetTypes.MLTABLE, + AssetTypes.MLTABLE: AssetTypes.MLTABLE, + AssetTypes.MLFLOW_MODEL: AssetTypes.MLTABLE, + AssetTypes.TRITON_MODEL: AssetTypes.MLTABLE, + AssetTypes.CUSTOM_MODEL: AssetTypes.MLTABLE, + # legacy path support + "path": AssetTypes.MLTABLE, + ComponentParameterTypes.NUMBER: ComponentParameterTypes.STRING, + ComponentParameterTypes.STRING: ComponentParameterTypes.STRING, + ComponentParameterTypes.BOOLEAN: ComponentParameterTypes.STRING, + ComponentParameterTypes.INTEGER: ComponentParameterTypes.STRING, + } + + def __init__( + self, + *, + body: "Pipeline", + items: Union[list, dict, str, PipelineInput, NodeOutput], + max_concurrency: Optional[int] = None, + **kwargs: Any, + ) -> None: + # validate init params are valid type + validate_attribute_type(attrs_to_check=locals(), attr_type_map=self._attr_type_map()) + + kwargs.pop("type", None) + super(ParallelFor, self).__init__( + type=ControlFlowType.PARALLEL_FOR, + body=body, + **kwargs, + ) + # loop body is incomplete in submission time, so won't validate required inputs + self.body._validate_required_input_not_provided = False + self._outputs: dict = {} + + actual_outputs = kwargs.get("outputs", {}) + # parallel for node shares output meta with body + try: + outputs = self.body._component.outputs + # transform body outputs to aggregate types when available + self._outputs = self._build_outputs_dict( + outputs=actual_outputs, output_definition_dict=self._convert_output_meta(outputs) + ) + except AttributeError: + # when body output not available, create default output builder without meta + self._outputs = self._build_outputs_dict(outputs=actual_outputs) + + self._items = items + + self.max_concurrency = max_concurrency + + @property + def outputs(self) -> Dict[str, Union[str, Output]]: + """Get the outputs of the parallel for loop. + + :return: The dictionary containing the outputs of the parallel for loop. + :rtype: dict[str, Union[str, ~azure.ai.ml.Output]] + """ + return self._outputs + + @property + def items(self) -> Union[list, dict, str, PipelineInput, NodeOutput]: + """Get the loop body's input which will bind to the loop node. + + :return: The input for the loop body. + :rtype: typing.Union[list, dict, str, ~azure.ai.ml.entities._job.pipeline._io.NodeOutput, + ~azure.ai.ml.entities._job.pipeline._io.PipelineInput] + """ + return self._items + + @classmethod + def _create_schema_for_validation(cls, context: Any) -> PathAwareSchema: + return ParallelForSchema(context=context) + + @classmethod + def _attr_type_map(cls) -> dict: + return { + **super(ParallelFor, cls)._attr_type_map(), + "items": (dict, list, str, PipelineInput, NodeOutput), + } + + @classmethod + # pylint: disable-next=docstring-missing-param + def _to_rest_item(cls, item: dict) -> dict: + """Convert item to rest object. + + :return: The rest object + :rtype: dict + """ + primitive_inputs, asset_inputs = {}, {} + # validate item + for key, val in item.items(): + if isinstance(val, Input): + asset_inputs[key] = val + elif isinstance(val, (PipelineInput, NodeOutput)): + # convert binding object to string + primitive_inputs[key] = str(val) + else: + primitive_inputs[key] = val + return { + # asset type inputs will be converted to JobInput dict: + # {"asset_param": {"uri": "xxx", "job_input_type": "uri_file"}} + **cls._input_entity_to_rest_inputs(input_entity=asset_inputs), + # primitive inputs has primitive type value like this + # {"int_param": 1} + **primitive_inputs, + } + + @classmethod + # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype + def _to_rest_items(cls, items: Union[list, dict, str, NodeOutput, PipelineInput]) -> str: + """Convert items to rest object.""" + # validate items. + cls._validate_items(items=items, raise_error=True, body_component=None) + result: str = "" + # convert items to rest object + if isinstance(items, list): + rest_items_list = [cls._to_rest_item(item=i) for i in items] + result = json.dumps(rest_items_list) + elif isinstance(items, dict): + rest_items_dict = {k: cls._to_rest_item(item=v) for k, v in items.items()} + result = json.dumps(rest_items_dict) + elif isinstance(items, (NodeOutput, PipelineInput)): + result = str(items) + elif isinstance(items, str): + result = items + else: + raise UserErrorException("Unsupported items type: {}".format(type(items))) + return result + + def _to_rest_object(self, **kwargs: Any) -> dict: + """Convert self to a rest object for remote call. + + :return: The rest object + :rtype: dict + """ + rest_node = super(ParallelFor, self)._to_rest_object(**kwargs) + # convert items to rest object + rest_items = self._to_rest_items(items=self.items) + rest_node.update({"items": rest_items, "outputs": self._to_rest_outputs()}) + # TODO: Bug Item number: 2897665 + res: dict = convert_ordered_dict_to_dict(rest_node) # type: ignore + return res + + @classmethod + # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype + def _from_rest_item(cls, rest_item: Any) -> Dict: + """Convert rest item to item.""" + primitive_inputs, asset_inputs = {}, {} + for key, val in rest_item.items(): + if isinstance(val, dict) and val.get("job_input_type"): + asset_inputs[key] = val + else: + primitive_inputs[key] = val + return {**cls._from_rest_inputs(inputs=asset_inputs), **primitive_inputs} + + @classmethod + # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype + def _from_rest_items(cls, rest_items: str) -> Union[dict, list, str]: + """Convert items from rest object.""" + try: + items = json.loads(rest_items) + except json.JSONDecodeError: + # return original items when failed to load + return rest_items + if isinstance(items, list): + return [cls._from_rest_item(rest_item=i) for i in items] + if isinstance(items, dict): + return {k: cls._from_rest_item(rest_item=v) for k, v in items.items()} + return rest_items + + @classmethod + def _from_rest_object(cls, obj: dict, pipeline_jobs: dict) -> "ParallelFor": + # pylint: disable=protected-access + obj = BaseNode._from_rest_object_to_init_params(obj) + obj["items"] = cls._from_rest_items(rest_items=obj.get("items", "")) + return cls._create_instance_from_schema_dict(pipeline_jobs=pipeline_jobs, loaded_data=obj) + + @classmethod + def _create_instance_from_schema_dict(cls, pipeline_jobs: Dict, loaded_data: Dict, **kwargs: Any) -> "ParallelFor": + body_name = cls._get_data_binding_expression_value(loaded_data.pop("body"), regex=r"\{\{.*\.jobs\.(.*)\}\}") + + loaded_data["body"] = cls._get_body_from_pipeline_jobs(pipeline_jobs=pipeline_jobs, body_name=body_name) + return cls(**loaded_data, **kwargs) + + def _convert_output_meta(self, outputs: Dict[str, Union[NodeOutput, Output]]) -> Dict[str, Output]: + """Convert output meta to aggregate types. + + :param outputs: Output meta + :type outputs: Dict[str, Union[NodeOutput, Output]] + :return: Dictionary of aggregate types + :rtype: Dict[str, Output] + """ + # pylint: disable=protected-access + aggregate_outputs = {} + for name, output in outputs.items(): + if output.type in self.OUT_TYPE_MAPPING: + new_type = self.OUT_TYPE_MAPPING[output.type] + else: + # when loop body introduces some new output type, this will be raised as a reminder to support is in + # parallel for + raise UserErrorException( + "Referencing output with type {} is not supported in parallel_for node.".format(output.type) + ) + if isinstance(output, NodeOutput): + output = output._to_job_output() # type: ignore + if isinstance(output, Output): + out_dict = output._to_dict() + out_dict["type"] = new_type + resolved_output = Output(**out_dict) + else: + resolved_output = Output(type=new_type) + aggregate_outputs[name] = resolved_output + return aggregate_outputs + + def _customized_validate(self) -> MutableValidationResult: + """Customized validation for parallel for node. + + :return: The validation result + :rtype: MutableValidationResult + """ + # pylint: disable=protected-access + validation_result = self._validate_body() + validation_result.merge_with( + self._validate_items(items=self.items, raise_error=False, body_component=self.body._component) + ) + return validation_result + + @classmethod + def _validate_items( + cls, + items: Union[list, dict, str, NodeOutput, PipelineInput], + raise_error: bool = True, + body_component: Optional[Union[str, Component]] = None, + ) -> MutableValidationResult: + validation_result = cls._create_empty_validation_result() + if items is not None: + if isinstance(items, str): + # TODO: remove the validation + # try to deserialize str if it's a json string + try: + items = json.loads(items) + except json.JSONDecodeError as e: + if not is_data_binding_expression(items, ["parent"]): + validation_result.append_error( + yaml_path="items", + message=f"Items is neither a valid JSON string due to {e} or a binding string.", + ) + if isinstance(items, dict): + # Validate dict keys + items = list(items.values()) + if isinstance(items, list): + if len(items) > 0: + cls._validate_items_list(items, validation_result, body_component=body_component) + else: + validation_result.append_error(yaml_path="items", message="Items is an empty list/dict.") + else: + validation_result.append_error( + yaml_path="items", + message="Items is required for parallel_for node", + ) + return cls._try_raise(validation_result, raise_error=raise_error) + + @classmethod + def _validate_items_list( + cls, + items: list, + validation_result: MutableValidationResult, + body_component: Optional[Union[str, Component]] = None, + ) -> None: + meta: dict = {} + # all items have to be dict and have matched meta + for item in items: + # item has to be dict + # Note: item can be empty dict when loop_body don't have foreach inputs. + if not isinstance(item, dict): + validation_result.append_error( + yaml_path="items", + message=f"Items has to be list/dict of dict as value, " f"but got {type(item)} for {item}.", + ) + else: + # item has to have matched meta + if meta.keys() != item.keys(): + if not meta.keys(): + meta = item + else: + msg = f"Items should have same keys with body inputs, but got {item.keys()} and {meta.keys()}." + validation_result.append_error(yaml_path="items", message=msg) + # items' keys should appear in body's inputs + if isinstance(body_component, Component) and (not item.keys() <= body_component.inputs.keys()): + msg = f"Item {item} got unmatched inputs with loop body component inputs {body_component.inputs}." + validation_result.append_error(yaml_path="items", message=msg) + # validate item value type + cls._validate_item_value_type(item=item, validation_result=validation_result) + + @classmethod + def _validate_item_value_type(cls, item: dict, validation_result: MutableValidationResult) -> None: + supported_types = (Input, str, bool, int, float, PipelineInput) + for _, val in item.items(): + if not isinstance(val, supported_types): + validation_result.append_error( + yaml_path="items", + message="Unsupported type {} in parallel_for items. Supported types are: {}".format( + type(val), supported_types + ), + ) + if isinstance(val, Input): + cls._validate_input_item_value(entry=val, validation_result=validation_result) + + @classmethod + def _validate_input_item_value(cls, entry: Input, validation_result: MutableValidationResult) -> None: + if not isinstance(entry, Input): + return + if not entry.path: + validation_result.append_error( + yaml_path="items", + message=f"Input path not provided for {entry}.", + ) + if isinstance(entry.path, str) and os.path.exists(entry.path): + validation_result.append_error( + yaml_path="items", + message=f"Local file input {entry} is not supported, please create it as a dataset.", + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/parallel_func.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/parallel_func.py new file mode 100644 index 00000000..a8f08d1e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/parallel_func.py @@ -0,0 +1,285 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import os +from typing import Any, Dict, List, Optional, Union + +from azure.ai.ml.constants._component import ComponentSource +from azure.ai.ml.entities._component.parallel_component import ParallelComponent +from azure.ai.ml.entities._credentials import ( + AmlTokenConfiguration, + ManagedIdentityConfiguration, + UserIdentityConfiguration, +) +from azure.ai.ml.entities._deployment.deployment_settings import BatchRetrySettings +from azure.ai.ml.entities._job.parallel.run_function import RunFunction + +from .command_func import _parse_input, _parse_inputs_outputs, _parse_output +from .parallel import Parallel + + +def parallel_run_function( + *, + name: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[Dict] = None, + properties: Optional[Dict] = None, + display_name: Optional[str] = None, + experiment_name: Optional[str] = None, + compute: Optional[str] = None, + retry_settings: Optional[BatchRetrySettings] = None, + environment_variables: Optional[Dict] = None, + logging_level: Optional[str] = None, + max_concurrency_per_instance: Optional[int] = None, + error_threshold: Optional[int] = None, + mini_batch_error_threshold: Optional[int] = None, + task: Optional[RunFunction] = None, + mini_batch_size: Optional[str] = None, + partition_keys: Optional[List] = None, + input_data: Optional[str] = None, + inputs: Optional[Dict] = None, + outputs: Optional[Dict] = None, + instance_count: Optional[int] = None, + instance_type: Optional[str] = None, + docker_args: Optional[str] = None, + shm_size: Optional[str] = None, + identity: Optional[Union[ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration]] = None, + is_deterministic: bool = True, + **kwargs: Any, +) -> Parallel: + """Create a Parallel object which can be used inside dsl.pipeline as a function and can also be created as a + standalone parallel job. + + For an example of using ParallelRunStep, see the notebook + https://aka.ms/parallel-example-notebook + + .. note:: + + To use parallel_run_function: + + * Create a :class:`azure.ai.ml.entities._builders.Parallel` object to specify how parallel run is performed, + with parameters to control batch size,number of nodes per compute target, and a + reference to your custom Python script. + + * Build pipeline with the parallel object as a function. defines inputs and + outputs for the step. + + * Sumbit the pipeline to run. + + .. code:: python + + from azure.ai.ml import Input, Output, parallel + + parallel_run = parallel_run_function( + name="batch_score_with_tabular_input", + display_name="Batch Score with Tabular Dataset", + description="parallel component for batch score", + inputs=dict( + job_data_path=Input( + type=AssetTypes.MLTABLE, + description="The data to be split and scored in parallel", + ), + score_model=Input( + type=AssetTypes.URI_FOLDER, description="The model for batch score." + ), + ), + outputs=dict(job_output_path=Output(type=AssetTypes.MLTABLE)), + input_data="${{inputs.job_data_path}}", + max_concurrency_per_instance=2, # Optional, default is 1 + mini_batch_size="100", # optional + mini_batch_error_threshold=5, # Optional, allowed failed count on mini batch items, default is -1 + logging_level="DEBUG", # Optional, default is INFO + error_threshold=5, # Optional, allowed failed count totally, default is -1 + retry_settings=dict(max_retries=2, timeout=60), # Optional + task=RunFunction( + code="./src", + entry_script="tabular_batch_inference.py", + environment=Environment( + image="mcr.microsoft.com/azureml/openmpi3.1.2-ubuntu18.04", + conda_file="./src/environment_parallel.yml", + ), + program_arguments="--model ${{inputs.score_model}}", + append_row_to="${{outputs.job_output_path}}", # Optional, if not set, summary_only + ), + ) + + :keyword name: Name of the parallel job or component created. + :paramtype name: str + :keyword description: A friendly description of the parallel. + :paramtype description: str + :keyword tags: Tags to be attached to this parallel. + :paramtype tags: Dict + :keyword properties: The asset property dictionary. + :paramtype properties: Dict + :keyword display_name: A friendly name. + :paramtype display_name: str + :keyword experiment_name: Name of the experiment the job will be created under, + if None is provided, default will be set to current directory name. + Will be ignored as a pipeline step. + :paramtype experiment_name: str + :keyword compute: The name of the compute where the parallel job is executed (will not be used + if the parallel is used as a component/function). + :paramtype compute: str + :keyword retry_settings: Parallel component run failed retry + :paramtype retry_settings: ~azure.ai.ml.entities._deployment.deployment_settings.BatchRetrySettings + :keyword environment_variables: A dictionary of environment variables names and values. + These environment variables are set on the process + where user script is being executed. + :paramtype environment_variables: Dict[str, str] + :keyword logging_level: A string of the logging level name, which is defined in 'logging'. + Possible values are 'WARNING', 'INFO', and 'DEBUG'. (optional, default value is 'INFO'.) + This value could be set through PipelineParameter. + :paramtype logging_level: str + :keyword max_concurrency_per_instance: The max parallellism that each compute instance has. + :paramtype max_concurrency_per_instance: int + :keyword error_threshold: The number of record failures for Tabular Dataset and file failures for File Dataset + that should be ignored during processing. + If the error count goes above this value, then the job will be aborted. + Error threshold is for the entire input rather + than the individual mini-batch sent to run() method. + The range is [-1, int.max]. -1 indicates ignore all failures during processing + :paramtype error_threshold: int + :keyword mini_batch_error_threshold: The number of mini batch processing failures should be ignored + :paramtype mini_batch_error_threshold: int + :keyword task: The parallel task + :paramtype task: ~azure.ai.ml.entities._job.parallel.run_function.RunFunction + :keyword mini_batch_size: For FileDataset input, + this field is the number of files a user script can process in one run() call. + For TabularDataset input, this field is the approximate size of data + the user script can process in one run() call. + Example values are 1024, 1024KB, 10MB, and 1GB. + (optional, default value is 10 files for FileDataset and 1MB for TabularDataset.) + This value could be set through PipelineParameter. + :paramtype mini_batch_size: str + :keyword partition_keys: The keys used to partition dataset into mini-batches. If specified, + the data with the same key will be partitioned into the same mini-batch. + If both partition_keys and mini_batch_size are specified, + the partition keys will take effect. + The input(s) must be partitioned dataset(s), + and the partition_keys must be a subset of the keys of every input dataset for this to work + :paramtype partition_keys: List + :keyword input_data: The input data. + :paramtype input_data: str + :keyword inputs: A dict of inputs used by this parallel. + :paramtype inputs: Dict + :keyword outputs: The outputs of this parallel + :paramtype outputs: Dict + :keyword instance_count: Optional number of instances or nodes used by the compute target. + Defaults to 1 + :paramtype instance_count: int + :keyword instance_type: Optional type of VM used as supported by the compute target.. + :paramtype instance_type: str + :keyword docker_args: Extra arguments to pass to the Docker run command. + This would override any parameters that have already been set by the system, + or in this section. + This parameter is only supported for Azure ML compute types. + :paramtype docker_args: str + :keyword shm_size: Size of the docker container's shared memory block. + This should be in the format of (number)(unit) where number as to be greater than 0 + and the unit can be one of b(bytes), k(kilobytes), m(megabytes), or g(gigabytes). + :paramtype shm_size: str + :keyword identity: Identity that PRS job will use while running on compute. + :paramtype identity: Optional[Union[ + ~azure.ai.ml.entities.ManagedIdentityConfiguration, + ~azure.ai.ml.entities.AmlTokenConfiguration, + ~azure.ai.ml.entities.UserIdentityConfiguration]] + :keyword is_deterministic: Specify whether the parallel will return same output given same input. + If a parallel (component) is deterministic, when use it as a node/step in a pipeline, + it will reuse results from a previous submitted job in current workspace + which has same inputs and settings. + In this case, this step will not use any compute resource. Defaults to True, + specify is_deterministic=False if you would like to avoid such reuse behavior, + defaults to True. + :paramtype is_deterministic: bool + :return: The parallel node + :rtype: ~azure.ai.ml._builders.parallel.Parallel + """ + # pylint: disable=too-many-locals + inputs = inputs or {} + outputs = outputs or {} + component_inputs, job_inputs = _parse_inputs_outputs(inputs, parse_func=_parse_input) + # job inputs can not be None + job_inputs = {k: v for k, v in job_inputs.items() if v is not None} + component_outputs, job_outputs = _parse_inputs_outputs(outputs, parse_func=_parse_output) + + component = kwargs.pop("component", None) + + if component is None: + if task is None: + component = ParallelComponent( + base_path=os.getcwd(), # base path should be current folder + name=name, + tags=tags, + code=None, + display_name=display_name, + description=description, + inputs=component_inputs, + outputs=component_outputs, + retry_settings=retry_settings, # type: ignore[arg-type] + logging_level=logging_level, + max_concurrency_per_instance=max_concurrency_per_instance, + error_threshold=error_threshold, + mini_batch_error_threshold=mini_batch_error_threshold, + task=task, + mini_batch_size=mini_batch_size, + partition_keys=partition_keys, + input_data=input_data, + _source=ComponentSource.BUILDER, + is_deterministic=is_deterministic, + **kwargs, + ) + else: + component = ParallelComponent( + base_path=os.getcwd(), # base path should be current folder + name=name, + tags=tags, + code=task.code, + display_name=display_name, + description=description, + inputs=component_inputs, + outputs=component_outputs, + retry_settings=retry_settings, # type: ignore[arg-type] + logging_level=logging_level, + max_concurrency_per_instance=max_concurrency_per_instance, + error_threshold=error_threshold, + mini_batch_error_threshold=mini_batch_error_threshold, + task=task, + mini_batch_size=mini_batch_size, + partition_keys=partition_keys, + input_data=input_data, + _source=ComponentSource.BUILDER, + is_deterministic=is_deterministic, + **kwargs, + ) + + parallel_obj = Parallel( + component=component, + name=name, + description=description, + tags=tags, + properties=properties, + display_name=display_name, + experiment_name=experiment_name, + compute=compute, + inputs=job_inputs, + outputs=job_outputs, + identity=identity, + environment_variables=environment_variables, + retry_settings=retry_settings, # type: ignore[arg-type] + logging_level=logging_level, + max_concurrency_per_instance=max_concurrency_per_instance, + error_threshold=error_threshold, + mini_batch_error_threshold=mini_batch_error_threshold, + task=task, + mini_batch_size=mini_batch_size, + partition_keys=partition_keys, + input_data=input_data, + **kwargs, + ) + + if instance_count is not None or instance_type is not None or docker_args is not None or shm_size is not None: + parallel_obj.set_resources( + instance_count=instance_count, instance_type=instance_type, docker_args=docker_args, shm_size=shm_size + ) + + return parallel_obj diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/pipeline.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/pipeline.py new file mode 100644 index 00000000..188d9044 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/pipeline.py @@ -0,0 +1,225 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import logging +from enum import Enum +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast + +from marshmallow import Schema + +from azure.ai.ml.entities._component.component import Component, NodeType +from azure.ai.ml.entities._inputs_outputs import Input, Output +from azure.ai.ml.entities._job.job import Job +from azure.ai.ml.entities._validation import MutableValidationResult + +from ..._schema import PathAwareSchema +from .._job.pipeline.pipeline_job_settings import PipelineJobSettings +from .._util import convert_ordered_dict_to_dict, copy_output_setting, validate_attribute_type +from .base_node import BaseNode + +if TYPE_CHECKING: + from azure.ai.ml.entities._job.pipeline.pipeline_job import PipelineJob + +module_logger = logging.getLogger(__name__) + + +class Pipeline(BaseNode): + """Base class for pipeline node, used for pipeline component version consumption. You should not instantiate this + class directly. Instead, you should use @pipeline decorator to create a pipeline node. + + :param component: Id or instance of the pipeline component/job to be run for the step. + :type component: Union[~azure.ai.ml.entities.Component, str] + :param inputs: Inputs of the pipeline node. + :type inputs: Optional[Dict[str, Union[ + ~azure.ai.ml.entities.Input, + str, bool, int, float, Enum, "Input"]]]. + :param outputs: Outputs of the pipeline node. + :type outputs: Optional[Dict[str, Union[str, ~azure.ai.ml.entities.Output, "Output"]]] + :param settings: Setting of pipeline node, only taking effect for root pipeline job. + :type settings: Optional[~azure.ai.ml.entities._job.pipeline.pipeline_job_settings.PipelineJobSettings] + """ + + def __init__( + self, + *, + component: Union[Component, str], + inputs: Optional[ + Dict[ + str, + Union[ + Input, + str, + bool, + int, + float, + Enum, + "Input", + ], + ] + ] = None, + outputs: Optional[Dict[str, Union[str, Output, "Output"]]] = None, + settings: Optional[PipelineJobSettings] = None, + **kwargs: Any, + ) -> None: + # validate init params are valid type + validate_attribute_type(attrs_to_check=locals(), attr_type_map=self._attr_type_map()) + kwargs.pop("type", None) + + BaseNode.__init__( + self, + type=NodeType.PIPELINE, + component=component, + inputs=inputs, + outputs=outputs, + **kwargs, + ) + # copy pipeline component output's setting to node level + self._copy_pipeline_component_out_setting_to_node() + self._settings: Optional[PipelineJobSettings] = None + self.settings = settings + + @property + def component(self) -> Optional[Union[str, Component]]: + """Id or instance of the pipeline component/job to be run for the step. + + :return: Id or instance of the pipeline component/job. + :rtype: Union[str, ~azure.ai.ml.entities.Component] + """ + res: Union[str, Component] = self._component + return res + + @property + def settings(self) -> Optional[PipelineJobSettings]: + """Settings of the pipeline. + + Note: settings is available only when create node as a job. + i.e. ml_client.jobs.create_or_update(node). + + :return: Settings of the pipeline. + :rtype: ~azure.ai.ml.entities.PipelineJobSettings + """ + if self._settings is None: + self._settings = PipelineJobSettings() + return self._settings + + @settings.setter + def settings(self, value: Union[PipelineJobSettings, Dict]) -> None: + """Set the settings of the pipeline. + + :param value: The settings of the pipeline. + :type value: Union[~azure.ai.ml.entities.PipelineJobSettings, dict] + :raises TypeError: If the value is not an instance of PipelineJobSettings or a dict. + """ + if value is not None: + if isinstance(value, PipelineJobSettings): + # since PipelineJobSettings inherit _AttrDict, we need add this branch to distinguish with dict + pass + elif isinstance(value, dict): + value = PipelineJobSettings(**value) + else: + raise TypeError("settings must be PipelineJobSettings or dict but got {}".format(type(value))) + self._settings = value + + @classmethod + def _get_supported_inputs_types(cls) -> None: + # Return None here to skip validation, + # as input could be custom class object(parameter group). + return None + + @property + def _skip_required_compute_missing_validation(self) -> bool: + return True + + @classmethod + def _get_skip_fields_in_schema_validation(cls) -> List[str]: + # pipeline component must be a file reference when loading from yaml, + # so the created object can't pass schema validation. + return ["component"] + + @classmethod + def _attr_type_map(cls) -> dict: + # Use local import to avoid recursive reference as BaseNode is imported in PipelineComponent. + from azure.ai.ml.entities import PipelineComponent + + return { + "component": (str, PipelineComponent), + } + + def _to_job(self) -> "PipelineJob": + from azure.ai.ml.entities._job.pipeline.pipeline_job import PipelineJob + + return PipelineJob( + name=self.name, + display_name=self.display_name, + description=self.description, + tags=self.tags, + properties=self.properties, + # Filter None out to avoid case below failed with conflict keys check: + # group: None (user not specified) + # group.xx: 1 (user specified + inputs={k: v for k, v in self._job_inputs.items() if v}, + outputs=self._job_outputs, + component=self.component, + settings=self.settings, + ) + + def _customized_validate(self) -> MutableValidationResult: + """Check unsupported settings when use as a node. + + :return: The validation result + :rtype: MutableValidationResult + """ + # Note: settings is not supported on node, + # jobs.create_or_update(node) will call node._to_job() at first, + # thus won't reach here. + # pylint: disable=protected-access + from azure.ai.ml.entities import PipelineComponent + + validation_result = super(Pipeline, self)._customized_validate() + ignored_keys = PipelineComponent._check_ignored_keys(self) + if ignored_keys: + validation_result.append_warning(message=f"{ignored_keys} ignored on node {self.name!r}.") + if isinstance(self.component, PipelineComponent): + validation_result.merge_with(self.component._customized_validate()) + return validation_result + + def _to_rest_object(self, **kwargs: Any) -> dict: + rest_obj: Dict = super()._to_rest_object(**kwargs) + rest_obj.update( + convert_ordered_dict_to_dict( + { + "componentId": self._get_component_id(), + } + ) + ) + return rest_obj + + def _build_inputs(self) -> Dict: + inputs = super(Pipeline, self)._build_inputs() + built_inputs = {} + # Validate and remove non-specified inputs + for key, value in inputs.items(): + if value is not None: + built_inputs[key] = value + return built_inputs + + @classmethod + def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]: + from azure.ai.ml._schema.pipeline.pipeline_component import PipelineSchema + + return PipelineSchema(context=context) + + def _copy_pipeline_component_out_setting_to_node(self) -> None: + """Copy pipeline component output's setting to node level.""" + from azure.ai.ml.entities import PipelineComponent + from azure.ai.ml.entities._job.pipeline._io import NodeOutput + + if not isinstance(self.component, PipelineComponent): + return + for key, val in self.component.outputs.items(): + node_output = cast(NodeOutput, self.outputs.get(key)) + copy_output_setting(source=val, target=node_output) + + @classmethod + def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "Job": + raise NotImplementedError() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/spark.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/spark.py new file mode 100644 index 00000000..e72f1334 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/spark.py @@ -0,0 +1,663 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint: disable=protected-access, too-many-instance-attributes + +import copy +import logging +import re +from enum import Enum +from os import PathLike, path +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union, cast + +from marshmallow import INCLUDE, Schema + +from ..._restclient.v2023_04_01_preview.models import JobBase as JobBaseData +from ..._restclient.v2023_04_01_preview.models import SparkJob as RestSparkJob +from ..._schema import NestedField, PathAwareSchema, UnionField +from ..._schema.job.identity import AMLTokenIdentitySchema, ManagedIdentitySchema, UserIdentitySchema +from ..._schema.job.parameterized_spark import CONF_KEY_MAP +from ..._schema.job.spark_job import SparkJobSchema +from ..._utils.utils import is_url +from ...constants._common import ( + ARM_ID_PREFIX, + BASE_PATH_CONTEXT_KEY, + REGISTRY_URI_FORMAT, + SPARK_ENVIRONMENT_WARNING_MESSAGE, +) +from ...constants._component import NodeType +from ...constants._job.job import SparkConfKey +from ...entities._assets import Environment +from ...entities._component.component import Component +from ...entities._component.spark_component import SparkComponent +from ...entities._credentials import ( + AmlTokenConfiguration, + ManagedIdentityConfiguration, + UserIdentityConfiguration, + _BaseJobIdentityConfiguration, +) +from ...entities._inputs_outputs import Input, Output +from ...entities._job._input_output_helpers import ( + from_rest_data_outputs, + from_rest_inputs_to_dataset_literal, + validate_inputs_for_args, +) +from ...entities._job.spark_job import SparkJob +from ...entities._job.spark_job_entry import SparkJobEntryType +from ...entities._job.spark_resource_configuration import SparkResourceConfiguration +from ...entities._validation import MutableValidationResult +from ...exceptions import ErrorCategory, ErrorTarget, ValidationException +from .._job.pipeline._io import NodeOutput +from .._job.spark_helpers import ( + _validate_compute_or_resources, + _validate_input_output_mode, + _validate_spark_configurations, +) +from .._job.spark_job_entry_mixin import SparkJobEntry, SparkJobEntryMixin +from .._util import convert_ordered_dict_to_dict, get_rest_dict_for_node_attrs, load_from_dict, validate_attribute_type +from .base_node import BaseNode + +module_logger = logging.getLogger(__name__) + + +class Spark(BaseNode, SparkJobEntryMixin): + """Base class for spark node, used for spark component version consumption. + + You should not instantiate this class directly. Instead, you should + create it from the builder function: spark. + + :param component: The ID or instance of the Spark component or job to be run during the step. + :type component: Union[str, ~azure.ai.ml.entities.SparkComponent] + :param identity: The identity that the Spark job will use while running on compute. + :type identity: Union[Dict[str, str], + ~azure.ai.ml.entities.ManagedIdentityConfiguration, + ~azure.ai.ml.entities.AmlTokenConfiguration, + ~azure.ai.ml.entities.UserIdentityConfiguration + + ] + + :param driver_cores: The number of cores to use for the driver process, only in cluster mode. + :type driver_cores: int + :param driver_memory: The amount of memory to use for the driver process, formatted as strings with a size unit + suffix ("k", "m", "g" or "t") (e.g. "512m", "2g"). + :type driver_memory: str + :param executor_cores: The number of cores to use on each executor. + :type executor_cores: int + :param executor_memory: The amount of memory to use per executor process, formatted as strings with a size unit + suffix ("k", "m", "g" or "t") (e.g. "512m", "2g"). + :type executor_memory: str + :param executor_instances: The initial number of executors. + :type executor_instances: int + :param dynamic_allocation_enabled: Whether to use dynamic resource allocation, which scales the number of + executors registered with this application up and down based on the workload. + :type dynamic_allocation_enabled: bool + :param dynamic_allocation_min_executors: The lower bound for the number of executors if dynamic allocation + is enabled. + :type dynamic_allocation_min_executors: int + :param dynamic_allocation_max_executors: The upper bound for the number of executors if dynamic allocation + is enabled. + :type dynamic_allocation_max_executors: int + :param conf: A dictionary with pre-defined Spark configurations key and values. + :type conf: Dict[str, str] + :param inputs: A mapping of input names to input data sources used in the job. + :type inputs: Dict[str, Union[ + str, + bool, + int, + float, + Enum, + ~azure.ai.ml.entities._job.pipeline._io.NodeOutput, + ~azure.ai.ml.Input + + ]] + + :param outputs: A mapping of output names to output data sources used in the job. + :type outputs: Dict[str, Union[str, ~azure.ai.ml.Output]] + :param args: The arguments for the job. + :type args: str + :param compute: The compute resource the job runs on. + :type compute: str + :param resources: The compute resource configuration for the job. + :type resources: Union[Dict, ~azure.ai.ml.entities.SparkResourceConfiguration] + :param entry: The file or class entry point. + :type entry: Dict[str, str] + :param py_files: The list of .zip, .egg or .py files to place on the PYTHONPATH for Python apps. + :type py_files: List[str] + :param jars: The list of .JAR files to include on the driver and executor classpaths. + :type jars: List[str] + :param files: The list of files to be placed in the working directory of each executor. + :type files: List[str] + :param archives: The list of archives to be extracted into the working directory of each executor. + :type archives: List[str] + """ + + def __init__( + self, + *, + component: Union[str, SparkComponent], + identity: Optional[ + Union[Dict, ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration] + ] = None, + driver_cores: Optional[Union[int, str]] = None, + driver_memory: Optional[str] = None, + executor_cores: Optional[Union[int, str]] = None, + executor_memory: Optional[str] = None, + executor_instances: Optional[Union[int, str]] = None, + dynamic_allocation_enabled: Optional[Union[bool, str]] = None, + dynamic_allocation_min_executors: Optional[Union[int, str]] = None, + dynamic_allocation_max_executors: Optional[Union[int, str]] = None, + conf: Optional[Dict[str, str]] = None, + inputs: Optional[ + Dict[ + str, + Union[ + NodeOutput, + Input, + str, + bool, + int, + float, + Enum, + "Input", + ], + ] + ] = None, + outputs: Optional[Dict[str, Union[str, Output, "Output"]]] = None, + compute: Optional[str] = None, + resources: Optional[Union[Dict, SparkResourceConfiguration]] = None, + entry: Union[Dict[str, str], SparkJobEntry, None] = None, + py_files: Optional[List[str]] = None, + jars: Optional[List[str]] = None, + files: Optional[List[str]] = None, + archives: Optional[List[str]] = None, + args: Optional[str] = None, + **kwargs: Any, + ) -> None: + # validate init params are valid type + validate_attribute_type(attrs_to_check=locals(), attr_type_map=self._attr_type_map()) + kwargs.pop("type", None) + + BaseNode.__init__( + self, type=NodeType.SPARK, inputs=inputs, outputs=outputs, component=component, compute=compute, **kwargs + ) + + # init mark for _AttrDict + self._init = True + SparkJobEntryMixin.__init__(self, entry=entry) + self.conf = conf + self.driver_cores = driver_cores + self.driver_memory = driver_memory + self.executor_cores = executor_cores + self.executor_memory = executor_memory + self.executor_instances = executor_instances + self.dynamic_allocation_enabled = dynamic_allocation_enabled + self.dynamic_allocation_min_executors = dynamic_allocation_min_executors + self.dynamic_allocation_max_executors = dynamic_allocation_max_executors + + is_spark_component = isinstance(component, SparkComponent) + if is_spark_component: + # conf is dict and we need copy component conf here, otherwise node conf setting will affect component + # setting + _component = cast(SparkComponent, component) + self.conf = self.conf or copy.copy(_component.conf) + self.driver_cores = self.driver_cores or _component.driver_cores + self.driver_memory = self.driver_memory or _component.driver_memory + self.executor_cores = self.executor_cores or _component.executor_cores + self.executor_memory = self.executor_memory or _component.executor_memory + self.executor_instances = self.executor_instances or _component.executor_instances + self.dynamic_allocation_enabled = self.dynamic_allocation_enabled or _component.dynamic_allocation_enabled + self.dynamic_allocation_min_executors = ( + self.dynamic_allocation_min_executors or _component.dynamic_allocation_min_executors + ) + self.dynamic_allocation_max_executors = ( + self.dynamic_allocation_max_executors or _component.dynamic_allocation_max_executors + ) + if self.executor_instances is None and str(self.dynamic_allocation_enabled).lower() == "true": + self.executor_instances = self.dynamic_allocation_min_executors + # When create standalone job or pipeline job, following fields will always get value from component or get + # default None, because we will not pass those fields to Spark. But in following cases, we expect to get + # correct value from spark._from_rest_object() and then following fields will get from their respective + # keyword arguments. + # 1. when we call regenerated_spark_node=Spark._from_rest_object(spark_node._to_rest_object()) in local test, + # we expect regenerated_spark_node and spark_node are identical. + # 2.when get created remote job through Job._from_rest_object(result) in job operation where component is an + # arm_id, we expect get remote returned values. + # 3.when we load a remote job, component now is an arm_id, we need get entry from node level returned from + # service + self.entry = _component.entry if is_spark_component else entry + self.py_files = _component.py_files if is_spark_component else py_files + self.jars = _component.jars if is_spark_component else jars + self.files = _component.files if is_spark_component else files + self.archives = _component.archives if is_spark_component else archives + self.args = _component.args if is_spark_component else args + self.environment: Any = _component.environment if is_spark_component else None + + self.resources = resources + self.identity = identity + self._swept = False + self._init = False + + @classmethod + def _get_supported_outputs_types(cls) -> Tuple: + return str, Output + + @property + def component(self) -> Union[str, SparkComponent]: + """The ID or instance of the Spark component or job to be run during the step. + + :rtype: ~azure.ai.ml.entities.SparkComponent + """ + res: Union[str, SparkComponent] = self._component + return res + + @property + def resources(self) -> Optional[Union[Dict, SparkResourceConfiguration]]: + """The compute resource configuration for the job. + + :rtype: ~azure.ai.ml.entities.SparkResourceConfiguration + """ + return self._resources # type: ignore + + @resources.setter + def resources(self, value: Optional[Union[Dict, SparkResourceConfiguration]]) -> None: + """Sets the compute resource configuration for the job. + + :param value: The compute resource configuration for the job. + :type value: Union[Dict[str, str], ~azure.ai.ml.entities.SparkResourceConfiguration] + """ + if isinstance(value, dict): + value = SparkResourceConfiguration(**value) + self._resources = value + + @property + def identity( + self, + ) -> Optional[Union[Dict, ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration]]: + """The identity that the Spark job will use while running on compute. + + :rtype: Union[~azure.ai.ml.entities.ManagedIdentityConfiguration, ~azure.ai.ml.entities.AmlTokenConfiguration, + ~azure.ai.ml.entities.UserIdentityConfiguration] + """ + # If there is no identity from CLI/SDK input: for jobs running on synapse compute (MLCompute Clusters), the + # managed identity is the default; for jobs running on clusterless, the user identity should be the default, + # otherwise use user input identity. + if self._identity is None: + if self.compute is not None: + return ManagedIdentityConfiguration() + if self.resources is not None: + return UserIdentityConfiguration() + return self._identity + + @identity.setter + def identity( + self, + value: Union[Dict[str, str], ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration], + ) -> None: + """Sets the identity that the Spark job will use while running on compute. + + :param value: The identity that the Spark job will use while running on compute. + :type value: Union[Dict[str, str], ~azure.ai.ml.entities.ManagedIdentityConfiguration, + ~azure.ai.ml.entities.AmlTokenConfiguration, ~azure.ai.ml.entities.UserIdentityConfiguration] + """ + if isinstance(value, dict): + identify_schema = UnionField( + [ + NestedField(ManagedIdentitySchema, unknown=INCLUDE), + NestedField(AMLTokenIdentitySchema, unknown=INCLUDE), + NestedField(UserIdentitySchema, unknown=INCLUDE), + ] + ) + value = identify_schema._deserialize(value=value, attr=None, data=None) + self._identity = value + + @property + def code(self) -> Optional[Union[str, PathLike]]: + """The local or remote path pointing at source code. + + :rtype: Union[str, PathLike] + """ + if isinstance(self.component, Component): + _code: Optional[Union[str, PathLike]] = self.component.code + return _code + return None + + @code.setter + def code(self, value: str) -> None: + """Sets the source code to be used for the job. + + :param value: The local or remote path pointing at source code. + :type value: Union[str, PathLike] + """ + if isinstance(self.component, Component): + self.component.code = value + else: + msg = "Can't set code property for a registered component {}" + raise ValidationException( + message=msg.format(self.component), + no_personal_data_message=msg.format(self.component), + target=ErrorTarget.SPARK_JOB, + error_category=ErrorCategory.USER_ERROR, + ) + + @classmethod + def _from_rest_object_to_init_params(cls, obj: dict) -> Dict: + obj = super()._from_rest_object_to_init_params(obj) + + if "resources" in obj and obj["resources"]: + obj["resources"] = SparkResourceConfiguration._from_rest_object(obj["resources"]) + + if "identity" in obj and obj["identity"]: + obj["identity"] = _BaseJobIdentityConfiguration._from_rest_object(obj["identity"]) + + if "entry" in obj and obj["entry"]: + obj["entry"] = SparkJobEntry._from_rest_object(obj["entry"]) + if "conf" in obj and obj["conf"]: + # get conf setting value from conf + for field_name, _ in CONF_KEY_MAP.items(): + value = obj["conf"].get(field_name, None) + if value is not None: + obj[field_name] = value + + return obj + + @classmethod + def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "Spark": + from .spark_func import spark + + loaded_data = load_from_dict(SparkJobSchema, data, context, additional_message, **kwargs) + spark_job: Spark = spark(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_data) + + return spark_job + + @classmethod + def _load_from_rest_job(cls, obj: JobBaseData) -> "Spark": + from .spark_func import spark + + rest_spark_job: RestSparkJob = obj.properties + rest_spark_conf = copy.copy(rest_spark_job.conf) or {} + + spark_job: Spark = spark( + name=obj.name, + id=obj.id, + entry=SparkJobEntry._from_rest_object(rest_spark_job.entry), + display_name=rest_spark_job.display_name, + description=rest_spark_job.description, + tags=rest_spark_job.tags, + properties=rest_spark_job.properties, + experiment_name=rest_spark_job.experiment_name, + services=rest_spark_job.services, + status=rest_spark_job.status, + creation_context=obj.system_data, + code=rest_spark_job.code_id, + compute=rest_spark_job.compute_id, + environment=rest_spark_job.environment_id, + identity=( + _BaseJobIdentityConfiguration._from_rest_object(rest_spark_job.identity) + if rest_spark_job.identity + else None + ), + args=rest_spark_job.args, + conf=rest_spark_conf, + driver_cores=rest_spark_conf.get( + SparkConfKey.DRIVER_CORES, None + ), # copy fields from conf into the promote attribute in spark + driver_memory=rest_spark_conf.get(SparkConfKey.DRIVER_MEMORY, None), + executor_cores=rest_spark_conf.get(SparkConfKey.EXECUTOR_CORES, None), + executor_memory=rest_spark_conf.get(SparkConfKey.EXECUTOR_MEMORY, None), + executor_instances=rest_spark_conf.get(SparkConfKey.EXECUTOR_INSTANCES, None), + dynamic_allocation_enabled=rest_spark_conf.get(SparkConfKey.DYNAMIC_ALLOCATION_ENABLED, None), + dynamic_allocation_min_executors=rest_spark_conf.get(SparkConfKey.DYNAMIC_ALLOCATION_MIN_EXECUTORS, None), + dynamic_allocation_max_executors=rest_spark_conf.get(SparkConfKey.DYNAMIC_ALLOCATION_MAX_EXECUTORS, None), + resources=SparkResourceConfiguration._from_rest_object(rest_spark_job.resources), + inputs=from_rest_inputs_to_dataset_literal(rest_spark_job.inputs), + outputs=from_rest_data_outputs(rest_spark_job.outputs), + ) + return spark_job + + @classmethod + def _attr_type_map(cls) -> dict: + return { + # hack: allow use InternalSparkComponent as component + # "component": (str, SparkComponent), + "environment": (str, Environment), + "resources": (dict, SparkResourceConfiguration), + "code": (str, PathLike), + } + + @property + def _skip_required_compute_missing_validation(self) -> bool: + return self.resources is not None + + def _to_job(self) -> SparkJob: + if isinstance(self.component, SparkComponent): + return SparkJob( + experiment_name=self.experiment_name, + name=self.name, + display_name=self.display_name, + description=self.description, + tags=self.tags, + code=self.component.code, + entry=self.entry, + py_files=self.py_files, + jars=self.jars, + files=self.files, + archives=self.archives, + identity=self.identity, + driver_cores=self.driver_cores, + driver_memory=self.driver_memory, + executor_cores=self.executor_cores, + executor_memory=self.executor_memory, + executor_instances=self.executor_instances, + dynamic_allocation_enabled=self.dynamic_allocation_enabled, + dynamic_allocation_min_executors=self.dynamic_allocation_min_executors, + dynamic_allocation_max_executors=self.dynamic_allocation_max_executors, + conf=self.conf, + environment=self.environment, + status=self.status, + inputs=self._job_inputs, + outputs=self._job_outputs, + services=self.services, + args=self.args, + compute=self.compute, + resources=self.resources, + ) + + return SparkJob( + experiment_name=self.experiment_name, + name=self.name, + display_name=self.display_name, + description=self.description, + tags=self.tags, + code=self.component, + entry=self.entry, + py_files=self.py_files, + jars=self.jars, + files=self.files, + archives=self.archives, + identity=self.identity, + driver_cores=self.driver_cores, + driver_memory=self.driver_memory, + executor_cores=self.executor_cores, + executor_memory=self.executor_memory, + executor_instances=self.executor_instances, + dynamic_allocation_enabled=self.dynamic_allocation_enabled, + dynamic_allocation_min_executors=self.dynamic_allocation_min_executors, + dynamic_allocation_max_executors=self.dynamic_allocation_max_executors, + conf=self.conf, + environment=self.environment, + status=self.status, + inputs=self._job_inputs, + outputs=self._job_outputs, + services=self.services, + args=self.args, + compute=self.compute, + resources=self.resources, + ) + + @classmethod + def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]: + from azure.ai.ml._schema.pipeline import SparkSchema + + return SparkSchema(context=context) + + @classmethod + def _picked_fields_from_dict_to_rest_object(cls) -> List[str]: + return [ + "type", + "resources", + "py_files", + "jars", + "files", + "archives", + "identity", + "conf", + "args", + ] + + def _to_rest_object(self, **kwargs: Any) -> dict: + rest_obj: dict = super()._to_rest_object(**kwargs) + rest_obj.update( + convert_ordered_dict_to_dict( + { + "componentId": self._get_component_id(), + "identity": get_rest_dict_for_node_attrs(self.identity), + "resources": get_rest_dict_for_node_attrs(self.resources), + "entry": get_rest_dict_for_node_attrs(self.entry), + } + ) + ) + return rest_obj + + def _build_inputs(self) -> dict: + inputs = super(Spark, self)._build_inputs() + built_inputs = {} + # Validate and remove non-specified inputs + for key, value in inputs.items(): + if value is not None: + built_inputs[key] = value + return built_inputs + + def _customized_validate(self) -> MutableValidationResult: + result = super()._customized_validate() + if ( + isinstance(self.component, SparkComponent) + and isinstance(self.component._environment, Environment) + and self.component._environment.image is not None + ): + result.append_warning( + yaml_path="environment.image", + message=SPARK_ENVIRONMENT_WARNING_MESSAGE, + ) + result.merge_with(self._validate_entry_exist()) + result.merge_with(self._validate_fields()) + return result + + def _validate_entry_exist(self) -> MutableValidationResult: + is_remote_code = isinstance(self.code, str) and ( + self.code.startswith("git+") + or self.code.startswith(REGISTRY_URI_FORMAT) + or self.code.startswith(ARM_ID_PREFIX) + or is_url(self.code) + or bool(self.CODE_ID_RE_PATTERN.match(self.code)) + ) + validation_result = self._create_empty_validation_result() + # validate whether component entry exists to ensure code path is correct, especially when code is default value + if self.code is None or is_remote_code or not isinstance(self.entry, SparkJobEntry): + # skip validate when code is not a local path or code is None, or self.entry is not SparkJobEntry object + pass + else: + if not path.isabs(self.code): + _component: SparkComponent = self.component # type: ignore + code_path = Path(_component.base_path) / self.code + if code_path.exists(): + code_path = code_path.resolve().absolute() + else: + validation_result.append_error( + message=f"Code path {code_path} doesn't exist.", yaml_path="component.code" + ) + entry_path = code_path / self.entry.entry + else: + entry_path = Path(self.code) / self.entry.entry + + if ( + isinstance(self.entry, SparkJobEntry) + and self.entry.entry_type == SparkJobEntryType.SPARK_JOB_FILE_ENTRY + ): + if not entry_path.exists(): + validation_result.append_error( + message=f"Entry {entry_path} doesn't exist.", yaml_path="component.entry" + ) + return validation_result + + def _validate_fields(self) -> MutableValidationResult: + validation_result = self._create_empty_validation_result() + try: + _validate_compute_or_resources(self.compute, self.resources) + except ValidationException as e: + validation_result.append_error(message=str(e), yaml_path="resources") + validation_result.append_error(message=str(e), yaml_path="compute") + + try: + _validate_input_output_mode(self.inputs, self.outputs) + except ValidationException as e: + msg = str(e) + m = re.match(r"(Input|Output) '(\w+)'", msg) + if m: + io_type, io_name = m.groups() + if io_type == "Input": + validation_result.append_error(message=msg, yaml_path=f"inputs.{io_name}") + else: + validation_result.append_error(message=msg, yaml_path=f"outputs.{io_name}") + + try: + _validate_spark_configurations(self) + except ValidationException as e: + validation_result.append_error(message=str(e), yaml_path="conf") + + try: + self._validate_entry() + except ValidationException as e: + validation_result.append_error(message=str(e), yaml_path="entry") + + if self.args: + try: + validate_inputs_for_args(self.args, self.inputs) + except ValidationException as e: + validation_result.append_error(message=str(e), yaml_path="args") + return validation_result + + # pylint: disable-next=docstring-missing-param + def __call__(self, *args: Any, **kwargs: Any) -> "Spark": + """Call Spark as a function will return a new instance each time. + + :return: A Spark object + :rtype: Spark + """ + if isinstance(self._component, Component): + # call this to validate inputs + node: Spark = self._component(*args, **kwargs) + # merge inputs + for name, original_input in self.inputs.items(): + if name not in kwargs: + # use setattr here to make sure owner of input won't change + setattr(node.inputs, name, original_input._data) + node._job_inputs[name] = original_input._data + # get outputs + for name, original_output in self.outputs.items(): + # use setattr here to make sure owner of output won't change + if not isinstance(original_output, str): + setattr(node.outputs, name, original_output._data) + self._refine_optional_inputs_with_no_value(node, kwargs) + node._name = self.name + node.compute = self.compute + node.environment = copy.deepcopy(self.environment) + node.resources = copy.deepcopy(self.resources) + return node + + msg = "Spark can be called as a function only when referenced component is {}, currently got {}." + raise ValidationException( + message=msg.format(type(Component), self._component), + no_personal_data_message=msg.format(type(Component), "self._component"), + target=ErrorTarget.SPARK_JOB, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/spark_func.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/spark_func.py new file mode 100644 index 00000000..342f8c44 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/spark_func.py @@ -0,0 +1,306 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint: disable=protected-access, too-many-locals + +import os +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import AmlToken, ManagedIdentity, UserIdentity +from azure.ai.ml.constants._common import AssetTypes +from azure.ai.ml.constants._component import ComponentSource +from azure.ai.ml.entities import Environment +from azure.ai.ml.entities._component.spark_component import SparkComponent +from azure.ai.ml.entities._inputs_outputs import Input, Output +from azure.ai.ml.entities._job.pipeline._component_translatable import ComponentTranslatableMixin +from azure.ai.ml.entities._job.spark_job_entry import SparkJobEntry +from azure.ai.ml.entities._job.spark_resource_configuration import SparkResourceConfiguration +from azure.ai.ml.exceptions import ErrorTarget, ValidationException + +from .spark import Spark + +SUPPORTED_INPUTS = [AssetTypes.URI_FILE, AssetTypes.URI_FOLDER, AssetTypes.MLTABLE] + + +def _parse_input(input_value: Union[Input, dict, str, bool, int, float]) -> Tuple: + component_input = None + job_input: Union[Input, dict, str, bool, int, float] = "" + + if isinstance(input_value, Input): + component_input = Input(**input_value._to_dict()) + input_type = input_value.type + if input_type in SUPPORTED_INPUTS: + job_input = Input(**input_value._to_dict()) + elif isinstance(input_value, dict): + # if user provided dict, we try to parse it to Input. + # for job input, only parse for path type + input_type = input_value.get("type", None) + if input_type in SUPPORTED_INPUTS: + job_input = Input(**input_value) + component_input = Input(**input_value) + elif isinstance(input_value, (str, bool, int, float)): + # Input bindings are not supported + component_input = ComponentTranslatableMixin._to_input_builder_function(input_value) + job_input = input_value + else: + msg = f"Unsupported input type: {type(input_value)}, only Input, dict, str, bool, int and float are supported." + raise ValidationException(message=msg, no_personal_data_message=msg, target=ErrorTarget.JOB) + return component_input, job_input + + +def _parse_output(output_value: Union[Output, dict]) -> Tuple: + component_output = None + job_output: Union[Output, dict] = {} + + if isinstance(output_value, Output): + component_output = Output(**output_value._to_dict()) + job_output = Output(**output_value._to_dict()) + elif not output_value: + # output value can be None or empty dictionary + # None output value will be packed into a JobOutput object with mode = ReadWriteMount & type = UriFolder + component_output = ComponentTranslatableMixin._to_output(output_value) + job_output = output_value + elif isinstance(output_value, dict): # When output value is a non-empty dictionary + job_output = Output(**output_value) + component_output = Output(**output_value) + elif isinstance(output_value, str): # When output is passed in from pipeline job yaml + job_output = output_value + else: + msg = f"Unsupported output type: {type(output_value)}, only Output and dict are supported." + raise ValidationException(message=msg, no_personal_data_message=msg, target=ErrorTarget.JOB) + return component_output, job_output + + +def _parse_inputs_outputs(io_dict: Dict, parse_func: Callable) -> Tuple[Dict, Dict]: + component_io_dict, job_io_dict = {}, {} + if io_dict: + for key, val in io_dict.items(): + component_io, job_io = parse_func(val) + component_io_dict[key] = component_io + job_io_dict[key] = job_io + return component_io_dict, job_io_dict + + +def spark( + *, + experiment_name: Optional[str] = None, + name: Optional[str] = None, + display_name: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[Dict] = None, + code: Optional[Union[str, os.PathLike]] = None, + entry: Union[Dict[str, str], SparkJobEntry, None] = None, + py_files: Optional[List[str]] = None, + jars: Optional[List[str]] = None, + files: Optional[List[str]] = None, + archives: Optional[List[str]] = None, + identity: Optional[Union[Dict[str, str], ManagedIdentity, AmlToken, UserIdentity]] = None, + driver_cores: Optional[int] = None, + driver_memory: Optional[str] = None, + executor_cores: Optional[int] = None, + executor_memory: Optional[str] = None, + executor_instances: Optional[int] = None, + dynamic_allocation_enabled: Optional[bool] = None, + dynamic_allocation_min_executors: Optional[int] = None, + dynamic_allocation_max_executors: Optional[int] = None, + conf: Optional[Dict[str, str]] = None, + environment: Optional[Union[str, Environment]] = None, + inputs: Optional[Dict] = None, + outputs: Optional[Dict] = None, + args: Optional[str] = None, + compute: Optional[str] = None, + resources: Optional[Union[Dict, SparkResourceConfiguration]] = None, + **kwargs: Any, +) -> Spark: + """Creates a Spark object which can be used inside a dsl.pipeline function or used as a standalone Spark job. + + :keyword experiment_name: The name of the experiment the job will be created under. + :paramtype experiment_name: Optional[str] + :keyword name: The name of the job. + :paramtype name: Optional[str] + :keyword display_name: The job display name. + :paramtype display_name: Optional[str] + :keyword description: The description of the job. Defaults to None. + :paramtype description: Optional[str] + :keyword tags: The dictionary of tags for the job. Tags can be added, removed, and updated. Defaults to None. + :paramtype tags: Optional[dict[str, str]] + :keyword code: The source code to run the job. Can be a local path or "http:", "https:", or "azureml:" url + pointing to a remote location. + :type code: Optional[Union[str, os.PathLike]] + :keyword entry: The file or class entry point. + :paramtype entry: Optional[Union[dict[str, str], ~azure.ai.ml.entities.SparkJobEntry]] + :keyword py_files: The list of .zip, .egg or .py files to place on the PYTHONPATH for Python apps. + Defaults to None. + :paramtype py_files: Optional[List[str]] + :keyword jars: The list of .JAR files to include on the driver and executor classpaths. Defaults to None. + :paramtype jars: Optional[List[str]] + :keyword files: The list of files to be placed in the working directory of each executor. Defaults to None. + :paramtype files: Optional[List[str]] + :keyword archives: The list of archives to be extracted into the working directory of each executor. + Defaults to None. + :paramtype archives: Optional[List[str]] + :keyword identity: The identity that the Spark job will use while running on compute. + :paramtype identity: Optional[Union[ + dict[str, str], + ~azure.ai.ml.entities.ManagedIdentityConfiguration, + ~azure.ai.ml.entities.AmlTokenConfiguration, + ~azure.ai.ml.entities.UserIdentityConfiguration]] + :keyword driver_cores: The number of cores to use for the driver process, only in cluster mode. + :paramtype driver_cores: Optional[int] + :keyword driver_memory: The amount of memory to use for the driver process, formatted as strings with a size unit + suffix ("k", "m", "g" or "t") (e.g. "512m", "2g"). + :paramtype driver_memory: Optional[str] + :keyword executor_cores: The number of cores to use on each executor. + :paramtype executor_cores: Optional[int] + :keyword executor_memory: The amount of memory to use per executor process, formatted as strings with a size unit + suffix ("k", "m", "g" or "t") (e.g. "512m", "2g"). + :paramtype executor_memory: Optional[str] + :keyword executor_instances: The initial number of executors. + :paramtype executor_instances: Optional[int] + :keyword dynamic_allocation_enabled: Whether to use dynamic resource allocation, which scales the number of + executors registered with this application up and down based on the workload. + :paramtype dynamic_allocation_enabled: Optional[bool] + :keyword dynamic_allocation_min_executors: The lower bound for the number of executors if dynamic allocation is + enabled. + :paramtype dynamic_allocation_min_executors: Optional[int] + :keyword dynamic_allocation_max_executors: The upper bound for the number of executors if dynamic allocation is + enabled. + :paramtype dynamic_allocation_max_executors: Optional[int] + :keyword conf: A dictionary with pre-defined Spark configurations key and values. Defaults to None. + :paramtype conf: Optional[dict[str, str]] + :keyword environment: The Azure ML environment to run the job in. + :paramtype environment: Optional[Union[str, ~azure.ai.ml.entities.Environment]] + :keyword inputs: A mapping of input names to input data used in the job. Defaults to None. + :paramtype inputs: Optional[dict[str, ~azure.ai.ml.Input]] + :keyword outputs: A mapping of output names to output data used in the job. Defaults to None. + :paramtype outputs: Optional[dict[str, ~azure.ai.ml.Output]] + :keyword args: The arguments for the job. + :paramtype args: Optional[str] + :keyword compute: The compute resource the job runs on. + :paramtype compute: Optional[str] + :keyword resources: The compute resource configuration for the job. + :paramtype resources: Optional[Union[dict, ~azure.ai.ml.entities.SparkResourceConfiguration]] + :return: A Spark object. + :rtype: ~azure.ai.ml.entities.Spark + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_spark_configurations.py + :start-after: [START spark_function_configuration_1] + :end-before: [END spark_function_configuration_1] + :language: python + :dedent: 8 + :caption: Configuring a SparkJob. + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_spark_configurations.py + :start-after: [START spark_function_configuration_2] + :end-before: [END spark_function_configuration_2] + :language: python + :dedent: 8 + :caption: Configuring a SparkJob. + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_spark_configurations.py + :start-after: [START spark_dsl_pipeline] + :end-before: [END spark_dsl_pipeline] + :language: python + :dedent: 8 + :caption: Building a Spark pipeline using the DSL pipeline decorator + """ + + inputs = inputs or {} + outputs = outputs or {} + component_inputs, job_inputs = _parse_inputs_outputs(inputs, parse_func=_parse_input) + # job inputs can not be None + job_inputs = {k: v for k, v in job_inputs.items() if v is not None} + component_outputs, job_outputs = _parse_inputs_outputs(outputs, parse_func=_parse_output) + component = kwargs.pop("component", None) + + if component is None: + component = SparkComponent( + name=name, + display_name=display_name, + tags=tags, + description=description, + code=code, + entry=entry, + py_files=py_files, + jars=jars, + files=files, + archives=archives, + driver_cores=driver_cores, + driver_memory=driver_memory, + executor_cores=executor_cores, + executor_memory=executor_memory, + executor_instances=executor_instances, + dynamic_allocation_enabled=dynamic_allocation_enabled, + dynamic_allocation_min_executors=dynamic_allocation_min_executors, + dynamic_allocation_max_executors=dynamic_allocation_max_executors, + conf=conf, + environment=environment, + inputs=component_inputs, + outputs=component_outputs, + args=args, + _source=ComponentSource.BUILDER, + **kwargs, + ) + if isinstance(component, SparkComponent): + spark_obj = Spark( + experiment_name=experiment_name, + name=name, + display_name=display_name, + tags=tags, + description=description, + component=component, + identity=identity, + driver_cores=driver_cores, + driver_memory=driver_memory, + executor_cores=executor_cores, + executor_memory=executor_memory, + executor_instances=executor_instances, + dynamic_allocation_enabled=dynamic_allocation_enabled, + dynamic_allocation_min_executors=dynamic_allocation_min_executors, + dynamic_allocation_max_executors=dynamic_allocation_max_executors, + conf=conf, + inputs=job_inputs, + outputs=job_outputs, + compute=compute, + resources=resources, + **kwargs, + ) + else: + # when we load a remote job, component now is an arm_id, we need get entry from node level returned from + # service + spark_obj = Spark( + experiment_name=experiment_name, + name=name, + display_name=display_name, + tags=tags, + description=description, + component=component, + identity=identity, + driver_cores=driver_cores, + driver_memory=driver_memory, + executor_cores=executor_cores, + executor_memory=executor_memory, + executor_instances=executor_instances, + dynamic_allocation_enabled=dynamic_allocation_enabled, + dynamic_allocation_min_executors=dynamic_allocation_min_executors, + dynamic_allocation_max_executors=dynamic_allocation_max_executors, + conf=conf, + inputs=job_inputs, + outputs=job_outputs, + compute=compute, + resources=resources, + entry=entry, + py_files=py_files, + jars=jars, + files=files, + archives=archives, + args=args, + **kwargs, + ) + return spark_obj diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/subcomponents.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/subcomponents.py new file mode 100644 index 00000000..9b9ed5d2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/subcomponents.py @@ -0,0 +1,59 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# This file contains mldesigner decorator-produced components +# that are used within node constructors. Keep imports and +# general complexity in this file to a minimum. + +from typing import List + +from mldesigner import Output, command_component + +from azure.ai.ml.constants._common import DefaultOpenEncoding + + +def save_mltable_yaml(path: str, mltable_paths: List[str]) -> None: + """Save MLTable YAML. + + :param path: The path to save the MLTable YAML file. + :type path: str + :param mltable_paths: List of paths to be included in the MLTable. + :type mltable_paths: List[str] + :raises ValueError: If the given path points to a file. + """ + import os + + path = os.path.abspath(path) + + if os.path.isfile(path): + raise ValueError(f"The given path {path} points to a file.") + + if not os.path.exists(path): + os.makedirs(path, exist_ok=True) + + save_path = os.path.join(path, "MLTable") + # Do not touch - this is MLTable syntax that is needed to mount these paths + # To the MLTable's inputs + mltable_file_content = "\n".join(["paths:"] + [f"- folder : {path}" for path in mltable_paths]) + + with open(save_path, "w", encoding=DefaultOpenEncoding.WRITE) as f: + f.write(mltable_file_content) + + +# TODO 2293610: add support for more types of outputs besides uri_folder and mltable +@command_component() +def create_scatter_output_table(aggregated_output: Output(type="mltable"), **kwargs: str) -> Output: # type: ignore + """Create scatter output table. + + This function is used by the FL scatter gather node to reduce a dynamic number of silo outputs + into a single input for the user-supplied aggregation step. + + :param aggregated_output: The aggregated output MLTable. + :type aggregated_output: ~mldesigner.Output(type="mltable") + + Keyword arguments represent input names and URI folder paths. + """ + # kwargs keys are inputs names (ex: silo_output_silo_1) + # values are uri_folder paths + save_mltable_yaml(aggregated_output, list(kwargs.values())) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/sweep.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/sweep.py new file mode 100644 index 00000000..603babbe --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/sweep.py @@ -0,0 +1,454 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint: disable=protected-access + +import logging +from typing import Any, Dict, List, Optional, Tuple, Union + +import pydash +from marshmallow import EXCLUDE, Schema + +from azure.ai.ml._schema._sweep.sweep_fields_provider import EarlyTerminationField +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY +from azure.ai.ml.constants._component import NodeType +from azure.ai.ml.constants._job.sweep import SearchSpace +from azure.ai.ml.entities._component.command_component import CommandComponent +from azure.ai.ml.entities._credentials import ( + AmlTokenConfiguration, + ManagedIdentityConfiguration, + UserIdentityConfiguration, +) +from azure.ai.ml.entities._inputs_outputs import Input, Output +from azure.ai.ml.entities._job.job_limits import SweepJobLimits +from azure.ai.ml.entities._job.job_resource_configuration import JobResourceConfiguration +from azure.ai.ml.entities._job.pipeline._io import NodeInput +from azure.ai.ml.entities._job.queue_settings import QueueSettings +from azure.ai.ml.entities._job.sweep.early_termination_policy import ( + BanditPolicy, + EarlyTerminationPolicy, + MedianStoppingPolicy, + TruncationSelectionPolicy, +) +from azure.ai.ml.entities._job.sweep.objective import Objective +from azure.ai.ml.entities._job.sweep.parameterized_sweep import ParameterizedSweep +from azure.ai.ml.entities._job.sweep.sampling_algorithm import SamplingAlgorithm +from azure.ai.ml.entities._job.sweep.search_space import ( + Choice, + LogNormal, + LogUniform, + Normal, + QLogNormal, + QLogUniform, + QNormal, + QUniform, + Randint, + SweepDistribution, + Uniform, +) +from azure.ai.ml.exceptions import ErrorTarget, UserErrorException, ValidationErrorType, ValidationException +from azure.ai.ml.sweep import SweepJob + +from ..._restclient.v2022_10_01.models import ComponentVersion +from ..._schema import PathAwareSchema +from ..._schema._utils.data_binding_expression import support_data_binding_expression_for_fields +from ..._utils.utils import camel_to_snake +from .base_node import BaseNode + +module_logger = logging.getLogger(__name__) + + +class Sweep(ParameterizedSweep, BaseNode): + """Base class for sweep node. + + This class should not be instantiated directly. Instead, it should be created via the builder function: sweep. + + :param trial: The ID or instance of the command component or job to be run for the step. + :type trial: Union[~azure.ai.ml.entities.CommandComponent, str] + :param compute: The compute definition containing the compute information for the step. + :type compute: str + :param limits: The limits for the sweep node. + :type limits: ~azure.ai.ml.sweep.SweepJobLimits + :param sampling_algorithm: The sampling algorithm to use to sample inside the search space. + Accepted values are: "random", "grid", or "bayesian". + :type sampling_algorithm: str + :param objective: The objective used to determine the target run with the local optimal + hyperparameter in search space. + :type objective: ~azure.ai.ml.sweep.Objective + :param early_termination_policy: The early termination policy of the sweep node. + :type early_termination_policy: Union[ + + ~azure.mgmt.machinelearningservices.models.BanditPolicy, + ~azure.mgmt.machinelearningservices.models.MedianStoppingPolicy, + ~azure.mgmt.machinelearningservices.models.TruncationSelectionPolicy + + ] + + :param search_space: The hyperparameter search space to run trials in. + :type search_space: Dict[str, Union[ + + ~azure.ai.ml.entities.Choice, + ~azure.ai.ml.entities.LogNormal, + ~azure.ai.ml.entities.LogUniform, + ~azure.ai.ml.entities.Normal, + ~azure.ai.ml.entities.QLogNormal, + ~azure.ai.ml.entities.QLogUniform, + ~azure.ai.ml.entities.QNormal, + ~azure.ai.ml.entities.QUniform, + ~azure.ai.ml.entities.Randint, + ~azure.ai.ml.entities.Uniform + + ]] + + :param inputs: Mapping of input data bindings used in the job. + :type inputs: Dict[str, Union[ + + ~azure.ai.ml.Input, + + str, + bool, + int, + float + + ]] + + :param outputs: Mapping of output data bindings used in the job. + :type outputs: Dict[str, Union[str, ~azure.ai.ml.Output]] + :param identity: The identity that the training job will use while running on compute. + :type identity: Union[ + + ~azure.ai.ml.ManagedIdentityConfiguration, + ~azure.ai.ml.AmlTokenConfiguration, + ~azure.ai.ml.UserIdentityConfiguration + + ] + + :param queue_settings: The queue settings for the job. + :type queue_settings: ~azure.ai.ml.entities.QueueSettings + :param resources: Compute Resource configuration for the job. + :type resources: Optional[Union[dict, ~azure.ai.ml.entities.ResourceConfiguration]] + """ + + def __init__( + self, + *, + trial: Optional[Union[CommandComponent, str]] = None, + compute: Optional[str] = None, + limits: Optional[SweepJobLimits] = None, + sampling_algorithm: Optional[Union[str, SamplingAlgorithm]] = None, + objective: Optional[Objective] = None, + early_termination: Optional[ + Union[BanditPolicy, MedianStoppingPolicy, TruncationSelectionPolicy, EarlyTerminationPolicy, str] + ] = None, + search_space: Optional[ + Dict[ + str, + Union[ + Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform + ], + ] + ] = None, + inputs: Optional[Dict[str, Union[Input, str, bool, int, float]]] = None, + outputs: Optional[Dict[str, Union[str, Output]]] = None, + identity: Optional[ + Union[Dict, ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration] + ] = None, + queue_settings: Optional[QueueSettings] = None, + resources: Optional[Union[dict, JobResourceConfiguration]] = None, + **kwargs: Any, + ) -> None: + # TODO: get rid of self._job_inputs, self._job_outputs once we have general Input + self._job_inputs, self._job_outputs = inputs, outputs + + kwargs.pop("type", None) + BaseNode.__init__( + self, + type=NodeType.SWEEP, + component=trial, + inputs=inputs, + outputs=outputs, + compute=compute, + **kwargs, + ) + # init mark for _AttrDict + self._init = True + ParameterizedSweep.__init__( + self, + sampling_algorithm=sampling_algorithm, + objective=objective, + limits=limits, + early_termination=early_termination, + search_space=search_space, + queue_settings=queue_settings, + resources=resources, + ) + + self.identity: Any = identity + self._init = False + + @property + def trial(self) -> CommandComponent: + """The ID or instance of the command component or job to be run for the step. + + :rtype: ~azure.ai.ml.entities.CommandComponent + """ + res: CommandComponent = self._component + return res + + @property + def search_space( + self, + ) -> Optional[ + Dict[ + str, + Union[Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform], + ] + ]: + """Dictionary of the hyperparameter search space. + + Each key is the name of a hyperparameter and its value is the parameter expression. + + :rtype: Dict[str, Union[~azure.ai.ml.entities.Choice, ~azure.ai.ml.entities.LogNormal, + ~azure.ai.ml.entities.LogUniform, ~azure.ai.ml.entities.Normal, ~azure.ai.ml.entities.QLogNormal, + ~azure.ai.ml.entities.QLogUniform, ~azure.ai.ml.entities.QNormal, ~azure.ai.ml.entities.QUniform, + ~azure.ai.ml.entities.Randint, ~azure.ai.ml.entities.Uniform]] + """ + return self._search_space + + @search_space.setter + def search_space(self, values: Dict[str, Dict[str, Union[str, int, float, dict]]]) -> None: + """Sets the search space for the sweep job. + + :param values: The search space to set. + :type values: Dict[str, Dict[str, Union[str, int, float, dict]]] + """ + search_space: Dict = {} + for name, value in values.items(): + # If value is a SearchSpace object, directly pass it to job.search_space[name] + search_space[name] = self._value_type_to_class(value) if isinstance(value, dict) else value + self._search_space = search_space + + @classmethod + def _value_type_to_class(cls, value: Any) -> Dict: + value_type = value["type"] + search_space_dict = { + SearchSpace.CHOICE: Choice, + SearchSpace.RANDINT: Randint, + SearchSpace.LOGNORMAL: LogNormal, + SearchSpace.NORMAL: Normal, + SearchSpace.LOGUNIFORM: LogUniform, + SearchSpace.UNIFORM: Uniform, + SearchSpace.QLOGNORMAL: QLogNormal, + SearchSpace.QNORMAL: QNormal, + SearchSpace.QLOGUNIFORM: QLogUniform, + SearchSpace.QUNIFORM: QUniform, + } + + res: dict = search_space_dict[value_type](**value) + return res + + @classmethod + def _get_supported_inputs_types(cls) -> Tuple: + supported_types = super()._get_supported_inputs_types() or () + return ( + SweepDistribution, + *supported_types, + ) + + @classmethod + def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "Sweep": + raise NotImplementedError("Sweep._load_from_dict is not supported") + + @classmethod + def _picked_fields_from_dict_to_rest_object(cls) -> List[str]: + return [ + "limits", + "sampling_algorithm", + "objective", + "early_termination", + "search_space", + "queue_settings", + "resources", + ] + + def _to_rest_object(self, **kwargs: Any) -> dict: + rest_obj: dict = super(Sweep, self)._to_rest_object(**kwargs) + # hack: ParameterizedSweep.early_termination is not allowed to be None + for key in ["early_termination"]: + if key in rest_obj and rest_obj[key] is None: + del rest_obj[key] + + # hack: only early termination policy does not follow yaml schema now, should be removed after server-side made + # the change + if "early_termination" in rest_obj: + _early_termination: EarlyTerminationPolicy = self.early_termination # type: ignore + rest_obj["early_termination"] = _early_termination._to_rest_object().as_dict() + + rest_obj.update( + { + "type": self.type, + "trial": self._get_trial_component_rest_obj(), + } + ) + return rest_obj + + @classmethod + def _from_rest_object_to_init_params(cls, obj: dict) -> Dict: + obj = super()._from_rest_object_to_init_params(obj) + + # hack: only early termination policy does not follow yaml schema now, should be removed after server-side made + # the change + if "early_termination" in obj and "policy_type" in obj["early_termination"]: + # can't use _from_rest_object here, because obj is a dict instead of an EarlyTerminationPolicy rest object + obj["early_termination"]["type"] = camel_to_snake(obj["early_termination"].pop("policy_type")) + + # TODO: use cls._get_schema() to load from rest object + from azure.ai.ml._schema._sweep.parameterized_sweep import ParameterizedSweepSchema + + schema = ParameterizedSweepSchema(context={BASE_PATH_CONTEXT_KEY: "./"}) + support_data_binding_expression_for_fields(schema, ["type", "component", "trial"]) + + base_sweep = schema.load(obj, unknown=EXCLUDE, partial=True) + for key, value in base_sweep.items(): + obj[key] = value + + # trial + trial_component_id = pydash.get(obj, "trial.componentId", None) + obj["trial"] = trial_component_id # check this + + return obj + + def _get_trial_component_rest_obj(self) -> Union[Dict, ComponentVersion, None]: + # trial component to rest object is different from usual component + trial_component_id = self._get_component_id() + if trial_component_id is None: + return None + if isinstance(trial_component_id, str): + return {"componentId": trial_component_id} + if isinstance(trial_component_id, CommandComponent): + return trial_component_id._to_rest_object() + raise UserErrorException(f"invalid trial in sweep node {self.name}: {str(self.trial)}") + + def _to_job(self) -> SweepJob: + command = self.trial.command + if self.search_space is not None: + for key, _ in self.search_space.items(): + if command is not None: + # Double curly brackets to escape + command = command.replace(f"${{{{inputs.{key}}}}}", f"${{{{search_space.{key}}}}}") + + # TODO: raise exception when the trial is a pre-registered component + if command != self.trial.command and isinstance(self.trial, CommandComponent): + self.trial.command = command + + return SweepJob( + name=self.name, + display_name=self.display_name, + description=self.description, + properties=self.properties, + tags=self.tags, + experiment_name=self.experiment_name, + trial=self.trial, + compute=self.compute, + sampling_algorithm=self.sampling_algorithm, + search_space=self.search_space, + limits=self.limits, + early_termination=self.early_termination, # type: ignore[arg-type] + objective=self.objective, + inputs=self._job_inputs, + outputs=self._job_outputs, + identity=self.identity, + queue_settings=self.queue_settings, + resources=self.resources, + ) + + @classmethod + def _get_component_attr_name(cls) -> str: + return "trial" + + def _build_inputs(self) -> Dict: + inputs = super(Sweep, self)._build_inputs() + built_inputs = {} + # Validate and remove non-specified inputs + for key, value in inputs.items(): + if value is not None: + built_inputs[key] = value + + return built_inputs + + @classmethod + def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]: + from azure.ai.ml._schema.pipeline.component_job import SweepSchema + + return SweepSchema(context=context) + + @classmethod + def _get_origin_inputs_and_search_space(cls, built_inputs: Optional[Dict[str, NodeInput]]) -> Tuple: + """Separate mixed true inputs & search space definition from inputs of + this node and return them. + + Input will be restored to Input/LiteralInput before returned. + + :param built_inputs: The built inputs + :type built_inputs: Optional[Dict[str, NodeInput]] + :return: A tuple of the inputs and search space + :rtype: Tuple[ + Dict[str, Union[Input, str, bool, int, float]], + Dict[str, SweepDistribution], + ] + """ + search_space: Dict = {} + inputs: Dict = {} + if built_inputs is not None: + for input_name, input_obj in built_inputs.items(): + if isinstance(input_obj, NodeInput): + if isinstance(input_obj._data, SweepDistribution): + search_space[input_name] = input_obj._data + else: + inputs[input_name] = input_obj._data + else: + msg = "unsupported built input type: {}: {}" + raise ValidationException( + message=msg.format(input_name, type(input_obj)), + no_personal_data_message=msg.format("[input_name]", type(input_obj)), + target=ErrorTarget.SWEEP_JOB, + error_type=ValidationErrorType.INVALID_VALUE, + ) + return inputs, search_space + + def _is_input_set(self, input_name: str) -> bool: + if super(Sweep, self)._is_input_set(input_name): + return True + return self.search_space is not None and input_name in self.search_space + + def __setattr__(self, key: Any, value: Any) -> None: + super(Sweep, self).__setattr__(key, value) + if key == "early_termination" and isinstance(self.early_termination, BanditPolicy): + # only one of slack_amount and slack_factor can be specified but default value is 0.0. + # Need to keep track of which one is null. + if self.early_termination.slack_amount == 0.0: + self.early_termination.slack_amount = None # type: ignore[assignment] + if self.early_termination.slack_factor == 0.0: + self.early_termination.slack_factor = None # type: ignore[assignment] + + @property + def early_termination(self) -> Optional[Union[str, EarlyTerminationPolicy]]: + """The early termination policy for the sweep job. + + :rtype: Union[str, ~azure.ai.ml.sweep.BanditPolicy, ~azure.ai.ml.sweep.MedianStoppingPolicy, + ~azure.ai.ml.sweep.TruncationSelectionPolicy] + """ + return self._early_termination + + @early_termination.setter + def early_termination(self, value: Optional[Union[str, EarlyTerminationPolicy]]) -> None: + """Sets the early termination policy for the sweep job. + + :param value: The early termination policy for the sweep job. + :type value: Union[~azure.ai.ml.sweep.BanditPolicy, ~azure.ai.ml.sweep.MedianStoppingPolicy, + ~azure.ai.ml.sweep.TruncationSelectionPolicy, dict[str, Union[str, float, int, bool]]] + """ + if isinstance(value, dict): + early_termination_schema = EarlyTerminationField() + value = early_termination_schema._deserialize(value=value, attr=None, data=None) + self._early_termination = value # type: ignore[assignment] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/_additional_includes.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/_additional_includes.py new file mode 100644 index 00000000..85f609ca --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/_additional_includes.py @@ -0,0 +1,541 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import json +import os +import shutil +import tempfile +import zipfile +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager +from multiprocessing import cpu_count +from pathlib import Path +from typing import Any, Dict, Generator, List, Optional, Tuple, Union + +from azure.ai.ml.constants._common import AzureDevopsArtifactsType +from azure.ai.ml.entities._validation import MutableValidationResult, ValidationResultBuilder + +from ..._utils._artifact_utils import ArtifactCache +from ..._utils._asset_utils import IgnoreFile, get_upload_files_from_folder +from ..._utils.utils import is_concurrent_component_registration_enabled, is_private_preview_enabled +from ...entities._util import _general_copy +from .._assets import Code +from .code import ComponentCodeMixin, ComponentIgnoreFile + +PLACEHOLDER_FILE_NAME = "_placeholder_spec.yaml" + + +class AdditionalIncludes: + """Initialize the AdditionalIncludes object. + + :param origin_code_value: The origin code value. + :type origin_code_value: Optional[str] + :param base_path: The base path for origin code path and additional include configs. + :type base_path: Path + :param configs: The additional include configs. + :type configs: List[Union[str, dict]] + """ + + def __init__( + self, + *, + origin_code_value: Optional[str], + base_path: Path, + configs: Optional[List[Union[str, dict]]] = None, + ) -> None: + self._base_path = base_path + self._origin_code_value = origin_code_value + self._origin_configs = configs + + @property + def origin_configs(self) -> List: + """The origin additional include configs. + Artifact additional include configs haven't been resolved in this property. + + :return: The origin additional include configs. + :rtype: List[Union[str, dict]] + """ + return self._origin_configs or [] + + @property + def resolved_code_path(self) -> Union[None, Path]: + """The resolved origin code path based on base path, if code path is not specified, return None. + We shouldn't change this property name given it's referenced in mldesigner. + + :return: The resolved origin code path. + :rtype: Union[None, Path] + """ + if self._origin_code_value is None: + return None + if os.path.isabs(self._origin_code_value): + return Path(self._origin_code_value) + return (self.base_path / self._origin_code_value).resolve() + + @property + def base_path(self) -> Path: + """Base path for origin code path and additional include configs. + + :return: The base path. + :rtype: Path + """ + return self._base_path + + @property + def with_includes(self) -> bool: + """Whether the additional include configs have been provided. + + :return: True if additional include configs have been provided, False otherwise. + :rtype: bool + """ + return len(self.origin_configs) != 0 + + @classmethod + def _get_artifacts_by_config(cls, artifact_config: Dict[str, str]) -> Union[str, os.PathLike]: + # config key existence has been validated in _validate_additional_include_config + res: Union[str, os.PathLike] = ArtifactCache().get( + organization=artifact_config.get("organization", None), + project=artifact_config.get("project", None), + feed=artifact_config["feed"], + name=artifact_config["name"], + version=artifact_config["version"], + scope=artifact_config.get("scope", "organization"), + resolve=True, + ) + return res + + def _validate_additional_include_config( + self, additional_include_config: Union[Dict, str] + ) -> MutableValidationResult: + validation_result = ValidationResultBuilder.success() + if ( + isinstance(additional_include_config, dict) + and additional_include_config.get("type") == AzureDevopsArtifactsType.ARTIFACT + ): + # for artifact additional include, we validate the required fields in config but won't validate the + # artifact content to avoid downloading it in validation stage + # note that runtime error will be thrown when loading the artifact + for item in ["feed", "name", "version"]: + if item not in additional_include_config: + # TODO: add yaml path after we support list index in yaml path + validation_result.append_error( + "{} are required for artifacts config but got {}.".format( + item, json.dumps(additional_include_config) + ) + ) + elif isinstance(additional_include_config, str): + validation_result.merge_with(self._validate_local_additional_include_config(additional_include_config)) + else: + validation_result.append_error( + message=f"Unexpected format in additional_includes, {additional_include_config}" + ) + return validation_result + + @classmethod + def _resolve_artifact_additional_include_config( + cls, artifact_additional_include_config: Dict[str, str] + ) -> List[Tuple[str, str]]: + """Resolve an artifact additional include config into a list of (local_path, config_info) tuples. + + Configured artifact will be downloaded to local path first; the config_info will be in below format: + %name%:%version% in %feed% + + :param artifact_additional_include_config: Additional include config for an artifact + :type artifact_additional_include_config: Dict[str, str] + :return: A list of 2-tuples of local_path and config_info + :rtype: List[Tuple[str, str]] + """ + result = [] + # Note that we don't validate the artifact config here, since it has already been validated in + # _validate_additional_include_config + artifact_path = cls._get_artifacts_by_config(artifact_additional_include_config) + for item in os.listdir(artifact_path): + config_info = ( + f"{artifact_additional_include_config['name']}:{artifact_additional_include_config['version']} in " + f"{artifact_additional_include_config['feed']}" + ) + result.append((os.path.join(artifact_path, item), config_info)) + return result + + def _resolve_artifact_additional_include_configs( + self, artifact_additional_includes_configs: List[Dict[str, str]] + ) -> List: + additional_include_info_tuples = [] + # Unlike component registration, artifact downloading is a pure download progress; so we can use + # more threads to speed up the downloading process. + # We use 5 threads per CPU core plus 5 extra threads, and the max number of threads is 64. + num_threads = min(64, (int(cpu_count()) * 5) + 5) + if ( + len(artifact_additional_includes_configs) > 1 + and is_concurrent_component_registration_enabled() + and is_private_preview_enabled() + ): + with ThreadPoolExecutor(max_workers=num_threads) as executor: + all_artifact_pairs_itr = executor.map( + self._resolve_artifact_additional_include_config, artifact_additional_includes_configs + ) + + for artifact_pairs in all_artifact_pairs_itr: + additional_include_info_tuples.extend(artifact_pairs) + else: + all_artifact_pairs_list = list( + map(self._resolve_artifact_additional_include_config, artifact_additional_includes_configs) + ) + + for artifact_pairs in all_artifact_pairs_list: + additional_include_info_tuples.extend(artifact_pairs) + + return additional_include_info_tuples + + @staticmethod + def _copy(src: Path, dst: Path, *, ignore_file: Optional[Any] = None) -> None: + if ignore_file and ignore_file.is_file_excluded(src): + return + if not src.exists(): + raise ValueError(f"Path {src} does not exist.") + if src.is_file(): + _general_copy(src, dst) + if src.is_dir(): + # TODO: should we cover empty folder? + # use os.walk to replace shutil.copytree, which may raise FileExistsError + # for same folder, the expected behavior is merging + # ignore will be also applied during this process + for name in src.glob("*"): + if ignore_file is not None: + AdditionalIncludes._copy(name, dst / name.name, ignore_file=ignore_file.merge(name)) + + @staticmethod + def _is_folder_to_compress(path: Path) -> bool: + """Check if the additional include needs to compress corresponding folder as a zip. + + For example, given additional include /mnt/c/hello.zip + 1) if a file named /mnt/c/hello.zip already exists, return False (simply copy) + 2) if a folder named /mnt/c/hello exists, return True (compress as a zip and copy) + + :param path: Given path in additional include. + :type path: Path + :return: If the path need to be compressed as a zip file. + :rtype: bool + """ + if path.suffix != ".zip": + return False + # if zip file exists, simply copy as other additional includes + if path.exists(): + return False + # remove .zip suffix and check whether the folder exists + stem_path = path.parent / path.stem + return stem_path.is_dir() + + def _resolve_folder_to_compress(self, include: str, dst_path: Path, ignore_file: IgnoreFile) -> None: + """resolve the zip additional include, need to compress corresponding folder. + + :param include: The path, relative to :attr:`AdditionalIncludes.base_path`, to zip + :type include: str + :param dst_path: The path to write the zipfile to + :type dst_path: Path + :param ignore_file: The ignore file to use to filter files + :type ignore_file: IgnoreFile + """ + zip_additional_include = (self.base_path / include).resolve() + folder_to_zip = zip_additional_include.parent / zip_additional_include.stem + zip_file = dst_path / zip_additional_include.name + with zipfile.ZipFile(zip_file, "w") as zf: + zf.write(folder_to_zip, os.path.relpath(folder_to_zip, folder_to_zip.parent)) # write root in zip + paths = [path for path, _ in get_upload_files_from_folder(folder_to_zip, ignore_file=ignore_file)] + # sort the paths to make sure the zip file (namelist) is deterministic + for path in sorted(paths): + zf.write(path, os.path.relpath(path, folder_to_zip.parent)) + + def _get_resolved_additional_include_configs(self) -> List[str]: + """ + Resolve additional include configs to a list of local_paths and return it. + + Addition includes is a list of include files, including local paths and Azure Devops Artifacts. + Yaml format of additional_includes looks like below: + additional_includes: + - your/local/path + - type: artifact + organization: devops_organization + project: devops_project + feed: artifacts_feed_name + name: universal_package_name + version: package_version + scope: scope_type + The artifacts package will be downloaded from devops to the local in this function and transferred to + the local paths of downloaded artifacts; + The local paths will be returned directly. + If there are conflicts among artifacts, runtime error will be raised. Note that we won't check the + conflicts between artifacts and local paths and conflicts among local paths. Reasons are: + 1. There can be ignore_file in local paths, which makes it hard to check the conflict and may lead to breaking + changes; + 2. Conflicts among artifacts are more likely to happen, since user may refer to 2 artifacts of the same name + but with different version & feed. + 3. According to current design, folders in local paths will be merged; while artifact conflicts can be + identified by folder name conflicts and are not allowed. + + :return additional_includes: Path list of additional_includes + :rtype additional_includes: List[str] + """ + additional_include_configs_in_local_path = [] + + artifact_additional_include_configs = [] + for additional_include_config in self.origin_configs: + if isinstance(additional_include_config, str): + # add local additional include configs directly + additional_include_configs_in_local_path.append(additional_include_config) + else: + # artifact additional include config will be downloaded and resolved to a local path later + # note that there is no more validation for artifact additional include config here, since it has + # already been validated in _validate_additional_include_config + artifact_additional_include_configs.append(additional_include_config) + + artifact_additional_include_info_tuples = self._resolve_artifact_additional_include_configs( + artifact_additional_include_configs + ) + additional_include_configs_in_local_path.extend( + local_path for local_path, _ in artifact_additional_include_info_tuples + ) + + # check file conflicts among artifact package + # given this is not in validate stage, we will raise error if there are conflict files + conflict_files: dict = defaultdict(set) + for local_path, config_info in artifact_additional_include_info_tuples: + file_name = Path(local_path).name + conflict_files[file_name].add(config_info) + + conflict_files = {k: v for k, v in conflict_files.items() if len(v) > 1} + if conflict_files: + raise RuntimeError(f"There are conflict files in additional include: {conflict_files}") + + return additional_include_configs_in_local_path + + def _validate_local_additional_include_config( + self, local_path: str, config_info: Optional[str] = None + ) -> MutableValidationResult: + """Validate local additional include config. + + Note that we will check the file conflicts between each local additional includes and origin code, but + won't check the file conflicts among local additional includes fo now. + + :param local_path: The local path + :type local_path: str + :param config_info: The config info + :type config_info: Optional[str] + :return: The validation result. + :rtype: ~azure.ai.ml.entities._validation.MutableValidationResult + """ + validation_result = ValidationResultBuilder.success() + include_path = self.base_path / local_path + # if additional include has not supported characters, resolve will fail and raise OSError + try: + src_path = include_path.resolve() + except OSError: + # no need to include potential yaml file name in error message as it will be covered by + # validation message construction. + error_msg = ( + f"Failed to resolve additional include " f"{config_info or local_path} " f"based on {self.base_path}." + ) + validation_result.append_error(message=error_msg) + return validation_result + + if not src_path.exists() and not self._is_folder_to_compress(src_path): + error_msg = f"Unable to find additional include {config_info or local_path}" + validation_result.append_error(message=error_msg) + return validation_result + + if len(src_path.parents) == 0: + error_msg = "Root directory is not supported for additional includes." + validation_result.append_error(message=error_msg) + return validation_result + + dst_path = Path(self.resolved_code_path) / src_path.name if self.resolved_code_path else None + if dst_path: + if dst_path.is_symlink(): + # if destination path is symbolic link, check if it points to the same file/folder as source path + if dst_path.resolve() != src_path.resolve(): + error_msg = f"A symbolic link already exists for additional include {config_info or local_path}." + validation_result.append_error(message=error_msg) + return validation_result + elif dst_path.exists(): + error_msg = f"A file already exists for additional include {config_info or local_path}." + validation_result.append_error(message=error_msg) + return validation_result + + def validate(self) -> MutableValidationResult: + """Validate the AdditionalIncludes object. + + :return: The validation result. + :rtype: ~azure.ai.ml.entities._validation.MutableValidationResult + """ + validation_result = ValidationResultBuilder.success() + for additional_include_config in self.origin_configs: + validation_result.merge_with(self._validate_additional_include_config(additional_include_config)) + return validation_result + + def _copy_origin_code(self, target_path: Path) -> ComponentIgnoreFile: + """Copy origin code to target path. + + :param target_path: The destination to copy to + :type target_path: Path + :return: The component ignore file for the origin path + :rtype: ComponentIgnoreFile + """ + # code can be either file or folder, as additional includes exists, need to copy to temporary folder + if self.resolved_code_path is None: + # if additional include configs exist but no origin code path, return a dummy ignore file + return ComponentIgnoreFile( + self.base_path, + ) + + if Path(self.resolved_code_path).is_file(): + # use a dummy ignore file to save base path + root_ignore_file = ComponentIgnoreFile( + Path(self.resolved_code_path).parent, + skip_ignore_file=True, + ) + self._copy( + Path(self.resolved_code_path), + target_path / Path(self.resolved_code_path).name, + ignore_file=root_ignore_file, + ) + else: + # current implementation of ignore file is based on absolute path, so it cannot be shared + root_ignore_file = ComponentIgnoreFile(self.resolved_code_path) + self._copy(self.resolved_code_path, target_path, ignore_file=root_ignore_file) + return root_ignore_file + + @contextmanager + def merge_local_code_and_additional_includes(self) -> Generator: + """Merge code and potential additional includes into a temporary folder and return the absolute path of it. + + If no additional includes are specified, just return the absolute path of the original code path. + If no original code path is specified, return None. + + :return: The absolute path of the merged code and additional includes. + :rtype: Path + """ + if not self.with_includes: + if self.resolved_code_path is None: + yield None + else: + yield self.resolved_code_path.absolute() + return + + # for now, upload path of a code asset will include the folder name of the code path (name of folder or + # parent name of file). For example, if code path is /mnt/c/code-a, upload path will be xxx/code-a + # which means that the upload path will change every time as we will merge additional includes into a temp + # folder. To avoid this, we will copy the code path to a child folder with a fixed name under the temp folder, + # then the child folder will be used in upload path. + # This issue shouldn't impact users as there is a separate asset existence check before uploading. + # We still make this change as: + # 1. We will always need to record for twice as upload path will be changed for first time uploading + # 2. This will improve the stability of the code asset existence check - AssetNotChanged check in + # BlobStorageClient will be a backup check + tmp_folder_path = Path(tempfile.mkdtemp(), "code_with_additional_includes") + tmp_folder_path.mkdir(parents=True, exist_ok=True) + + root_ignore_file = self._copy_origin_code(tmp_folder_path) + + # resolve additional includes + base_path = self.base_path + # additional includes from artifact will be downloaded to a temp local path on calling + # self.includes, so no need to add specific logic for artifact + + # TODO: skip ignored files defined in code when copying additional includes + # copy additional includes disregarding ignore files as current ignore file implementation + # is based on absolute path, which is not suitable for additional includes + for additional_include_local_path in self._get_resolved_additional_include_configs(): + src_path = Path(additional_include_local_path) + if not src_path.is_absolute(): + src_path = (base_path / additional_include_local_path).resolve() + dst_path = (tmp_folder_path / src_path.name).resolve() + + root_ignore_file.rebase(src_path.parent) + if self._is_folder_to_compress(src_path): + self._resolve_folder_to_compress( + additional_include_local_path, + Path(tmp_folder_path), + # actual src path is without .zip suffix + ignore_file=root_ignore_file.merge(src_path.parent / src_path.stem), + ) + # early continue as the folder is compressed as a zip file + continue + + # no need to check if src_path exists as it is already validated + if src_path.is_file(): + self._copy(src_path, dst_path, ignore_file=root_ignore_file) + elif src_path.is_dir(): + self._copy( + src_path, + dst_path, + # root ignore file on parent + ignore file on src_path + ignore_file=root_ignore_file.merge(src_path), + ) + else: + raise ValueError(f"Unable to find additional include {additional_include_local_path}.") + try: + yield tmp_folder_path.absolute() + + finally: + # clean up tmp folder as it can be very disk space consuming + shutil.rmtree(tmp_folder_path, ignore_errors=True) + + +class AdditionalIncludesMixin(ComponentCodeMixin): + @classmethod + def _get_additional_includes_field_name(cls) -> str: + """Get the field name for additional includes. + + :return: The field name + :rtype: str + """ + return "additional_includes" + + def _get_all_additional_includes_configs(self) -> List: + return getattr(self, self._get_additional_includes_field_name(), []) + + def _append_diagnostics_and_check_if_origin_code_reliable_for_local_path_validation( + self, base_validation_result: Optional[MutableValidationResult] = None + ) -> bool: + is_reliable: bool = super()._append_diagnostics_and_check_if_origin_code_reliable_for_local_path_validation( + base_validation_result + ) + additional_includes_obj = self._generate_additional_includes_obj() + + if base_validation_result is not None: + base_validation_result.merge_with( + additional_includes_obj.validate(), field_name=self._get_additional_includes_field_name() + ) + # if additional includes is specified, origin code will be merged with additional includes into a temp folder + # before registered as a code asset, so origin code value is not reliable for local path validation + if additional_includes_obj.with_includes: + return False + return is_reliable + + def _generate_additional_includes_obj(self) -> AdditionalIncludes: + return AdditionalIncludes( + base_path=self._get_base_path_for_code(), + configs=self._get_all_additional_includes_configs(), + origin_code_value=self._get_origin_code_in_str(), + ) + + @contextmanager + def _try_build_local_code(self) -> Generator: + """Build final code when origin code is a local code. + + Will merge code path with additional includes into a temp folder if additional includes is specified. + + :return: The built Code object + :rtype: Iterable[Optional[Code]] + """ + # will try to merge code and additional includes even if code is None + tmp_code_dir: Any + with self._generate_additional_includes_obj().merge_local_code_and_additional_includes() as tmp_code_dir: + if tmp_code_dir is None: + yield None + else: + yield Code( + base_path=self._get_base_path_for_code(), + path=tmp_code_dir, + ignore_file=ComponentIgnoreFile(tmp_code_dir), + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/automl_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/automl_component.py new file mode 100644 index 00000000..3e7be727 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/automl_component.py @@ -0,0 +1,42 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Any, Optional + +from azure.ai.ml._schema import PathAwareSchema +from azure.ai.ml._schema.component.automl_component import AutoMLComponentSchema +from azure.ai.ml.constants._common import COMPONENT_TYPE +from azure.ai.ml.constants._component import NodeType +from azure.ai.ml.entities._component.component import Component + + +class AutoMLComponent(Component): + """AutoML component entity, used to define an automl component. + + AutoML Component will only be used "internally" for the mentioned scenarios that need it. AutoML Component schema is + not intended to be used by the end users and therefore it won't be provided to the end users and it won't have + public documentation for the users. + + :param task: Task of the component. + :type task: str + """ + + def __init__( + self, + *, + task: Optional[str] = None, + **kwargs: Any, + ) -> None: + kwargs[COMPONENT_TYPE] = NodeType.AUTOML + super(AutoMLComponent, self).__init__(**kwargs) + self._task = task + + @property + def task(self) -> Optional[str]: + """Returns task of the component.""" + return self._task + + @classmethod + def _create_schema_for_validation(cls, context: Any) -> PathAwareSchema: + return AutoMLComponentSchema(context=context) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/code.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/code.py new file mode 100644 index 00000000..1f838bec --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/code.py @@ -0,0 +1,297 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import os +from contextlib import contextmanager +from enum import Enum +from pathlib import Path +from typing import Any, Generator, List, Optional, Union + +from azure.ai.ml._utils._arm_id_utils import is_ARM_id_for_resource, is_registry_id_for_resource +from azure.ai.ml._utils._asset_utils import IgnoreFile, get_ignore_file +from azure.ai.ml._utils.utils import is_private_preview_enabled +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, AzureMLResourceType +from azure.ai.ml.entities._assets import Code +from azure.ai.ml.entities._validation import MutableValidationResult + + +class ComponentIgnoreFile(IgnoreFile): + _COMPONENT_CODE_IGNORES = ["__pycache__"] + """Component-specific ignore file used for ignoring files in a component directory. + + :param directory_path: The directory path for the ignore file. + :type directory_path: Union[str, Path] + :param additional_includes_file_name: Name of the additional includes file in the root directory to be ignored. + :type additional_includes_file_name: str + :param skip_ignore_file: Whether to skip the ignore file, defaults to False. + :type skip_ignore_file: bool + :param extra_ignore_list: List of additional ignore files to be considered during file exclusion. + :type extra_ignore_list: List[~azure.ai.ml._utils._asset_utils.IgnoreFile] + :raises ValueError: If additional include file is not found. + :return: The ComponentIgnoreFile object. + :rtype: ComponentIgnoreFile + """ + + def __init__( + self, + directory_path: Union[str, Path], + *, + additional_includes_file_name: Optional[str] = None, + skip_ignore_file: bool = False, + extra_ignore_list: Optional[List[IgnoreFile]] = None, + ): + self._base_path: Union[str, Path] = Path(directory_path) + self._extra_ignore_list: List[IgnoreFile] = extra_ignore_list or [] + # only the additional include file in root directory is ignored + # additional include files in subdirectories are not processed so keep them + self._additional_includes_file_name = additional_includes_file_name + # note: the parameter changes to directory path in this class, rather than file path + file_path = None if skip_ignore_file else get_ignore_file(directory_path).path + super(ComponentIgnoreFile, self).__init__(file_path=file_path) + + def exists(self) -> bool: + """Check if the ignore file exists. + + :return: True + :rtype: bool + """ + return True + + @property + def base_path(self) -> Union[str, Path]: + """Get the base path of the ignore file. + + :return: The base path. + :rtype: Path + """ + # for component ignore file, the base path can be different from file.parent + return self._base_path + + def rebase(self, directory_path: Union[str, Path]) -> "ComponentIgnoreFile": + """Rebase the ignore file to a new directory. + + :param directory_path: The new directory path. + :type directory_path: Union[str, Path] + :return: The rebased ComponentIgnoreFile object. + :rtype: ComponentIgnoreFile + """ + self._base_path = directory_path + return self + + def is_file_excluded(self, file_path: Union[str, Path]) -> bool: + """Check if a file should be excluded based on the ignore file rules. + + :param file_path: The file path. + :type file_path: Union[str, Path] + :return: True if the file should be excluded, False otherwise. + :rtype: bool + """ + if self._additional_includes_file_name and self._get_rel_path(file_path) == self._additional_includes_file_name: + return True + for ignore_file in self._extra_ignore_list: + if ignore_file.is_file_excluded(file_path): + return True + res: bool = super(ComponentIgnoreFile, self).is_file_excluded(file_path) + return res + + def merge(self, other_path: Path) -> "ComponentIgnoreFile": + """Merge the ignore list from another ComponentIgnoreFile object. + + :param other_path: The path of the other ignore file. + :type other_path: Path + :return: The merged ComponentIgnoreFile object. + :rtype: ComponentIgnoreFile + """ + if other_path.is_file(): + return self + return ComponentIgnoreFile(other_path, extra_ignore_list=self._extra_ignore_list + [self]) + + def _get_ignore_list(self) -> List[str]: + """Retrieves the list of ignores from ignore file + + Override to add custom ignores. + + :return: The ignore rules + :rtype: List[str] + """ + if not super(ComponentIgnoreFile, self).exists(): + return self._COMPONENT_CODE_IGNORES + res: list = super(ComponentIgnoreFile, self)._get_ignore_list() + self._COMPONENT_CODE_IGNORES + return res + + +class CodeType(Enum): + """Code type.""" + + LOCAL = "local" + NONE = "none" + GIT = "git" + ARM_ID = "arm_id" + UNKNOWN = "unknown" + + +def _get_code_type(origin_code_value: Optional[str]) -> CodeType: + if origin_code_value is None: + return CodeType.NONE + if not isinstance(origin_code_value, str): + # note that: + # 1. Code & CodeOperation are not public for now + # 2. AnonymousCodeSchema is not within CodeField + # 3. Code will be returned as an arm id as an attribute of a component when getting a component from remote + # So origin_code_value should never be a Code object, or an exception will be raised + # in validation stage. + return CodeType.UNKNOWN + if is_ARM_id_for_resource(origin_code_value, AzureMLResourceType.CODE) or is_registry_id_for_resource( + origin_code_value + ): + return CodeType.ARM_ID + if origin_code_value.startswith("git+"): + return CodeType.GIT + return CodeType.LOCAL + + +class ComponentCodeMixin: + """Mixin class for components with local files as part of the component. Those local files will be uploaded to + blob storage and further referenced as a code asset in arm id. In below docstring, we will refer to those local + files as "code". + + The major interface of this mixin is self._customized_code_validate and self._build_code. + self._customized_code_validate will return a validation result indicating whether the code is valid. + self._build_code will return a temp Code object for server-side code asset creation. + """ + + def _get_base_path_for_code(self) -> Path: + """Get base path for additional includes. + + :return: The base path + :rtype: Path + """ + if hasattr(self, BASE_PATH_CONTEXT_KEY): + return Path(getattr(self, BASE_PATH_CONTEXT_KEY)) + raise NotImplementedError( + "Component must have a base_path attribute to use ComponentCodeMixin. " + "Please set base_path in __init__ or override _get_base_path_for_code." + ) + + @classmethod + def _get_code_field_name(cls) -> str: + """Get the field name for code. + + Will be used to get origin code value by default and will be used as field name of validation diagnostics. + + :return: Code field name + :rtype: str + """ + return "code" + + def _get_origin_code_value(self) -> Union[str, os.PathLike, None]: + """Get origin code value. + Origin code value is either an absolute path or a relative path to base path if it's a local path. + Additional includes are only supported for component types with code attribute. Origin code path will be copied + to a temp folder along with additional includes to form a new code content. + """ + return getattr(self, self._get_code_field_name(), None) + + def _fill_back_code_value(self, value: str) -> None: + """Fill resolved code value back to the component. + + :param value: resolved code value + :type value: str + :return: no return + :rtype: None + """ + return setattr(self, self._get_code_field_name(), value) + + def _get_origin_code_in_str(self) -> Optional[str]: + """Get origin code value in str to simplify following logic.""" + origin_code_value = self._get_origin_code_value() + if origin_code_value is None: + return None + if isinstance(origin_code_value, Path): + return origin_code_value.as_posix() + return str(origin_code_value) + + def _append_diagnostics_and_check_if_origin_code_reliable_for_local_path_validation( + self, base_validation_result: Optional[MutableValidationResult] = None + ) -> bool: + """Append diagnostics from customized validation logic to the base validation result and check if origin code + value is valid for path validation. + + For customized validation logic, this method shouldn't cover the validation logic duplicated with schema + validation, like local code existence check. + For the check, as "code" includes file dependencies of a component, other fields may depend on those files. + However, the origin code value may not be reliable for validation of those fields. For example: + 1. origin code value can be a remote git path or an arm id of a code asset. + 2. some file operations may be done during build_code, which makes final code content different from what we can + get from origin code value. + So, we use this function to check if origin code value is reliable for further local path validation. + + :param base_validation_result: base validation result to append diagnostics to. + :type base_validation_result: MutableValidationResult + :return: whether origin code value is reliable for further local path validation. + :rtype: bool + """ + # If private features are enable and component has code value of type str we need to check + # that it is a valid git path case. Otherwise, we should throw a ValidationError + # saying that the code value is not valid + code_type = _get_code_type(self._get_origin_code_in_str()) + if code_type == CodeType.GIT and not is_private_preview_enabled(): + if base_validation_result is not None: + base_validation_result.append_error( + message="Not a valid code value: git paths are not supported.", + yaml_path=self._get_code_field_name(), + ) + return code_type == CodeType.LOCAL + + @contextmanager + def _build_code(self) -> Generator: + """Create a Code object if necessary based on origin code value and yield it. + + :return: If built code is the same as its origin value, do nothing and yield None. + Otherwise, yield a Code object pointing to the code. + :rtype: Iterable[Optional[Code]] + """ + origin_code_value = self._get_origin_code_in_str() + code_type = _get_code_type(origin_code_value) + + if code_type == CodeType.GIT: + # git also need to be resolved into arm id + yield Code(path=origin_code_value) + elif code_type in [CodeType.LOCAL, CodeType.NONE]: + code: Any + # false-positive by pylint, hence disable it + # (https://github.com/pylint-dev/pylint/blob/main/doc/data/messages + # /c/contextmanager-generator-missing-cleanup/details.rst) + with self._try_build_local_code() as code: # pylint:disable=contextmanager-generator-missing-cleanup + yield code + else: + # arm id, None and unknown need no extra resolution + yield None + + @contextmanager + def _try_build_local_code(self) -> Generator: + """Extract the logic of _build_code for local code for further override. + + :return: The Code object if could be constructed, None otherwise + :rtype: Iterable[Optional[Code]] + """ + origin_code_value = self._get_origin_code_in_str() + if origin_code_value is None: + yield None + else: + base_path = self._get_base_path_for_code() + absolute_path: Union[str, Path] = ( + origin_code_value if os.path.isabs(origin_code_value) else base_path / origin_code_value + ) + + yield Code( + base_path=base_path, + path=origin_code_value, + ignore_file=ComponentIgnoreFile(absolute_path), + ) + + def _with_local_code(self) -> bool: + # TODO: remove this method after we have a better way to do this judge in cache_utils + origin_code_value = self._get_origin_code_in_str() + code_type = _get_code_type(origin_code_value) + return code_type == CodeType.LOCAL diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/command_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/command_component.py new file mode 100644 index 00000000..9bdcd3d1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/command_component.py @@ -0,0 +1,300 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import os +from typing import Any, Dict, List, Optional, Union, cast + +from marshmallow import Schema + +from azure.ai.ml._schema.component.command_component import CommandComponentSchema +from azure.ai.ml.constants._common import COMPONENT_TYPE +from azure.ai.ml.constants._component import NodeType +from azure.ai.ml.entities._assets import Environment +from azure.ai.ml.entities._job.distribution import ( + DistributionConfiguration, + MpiDistribution, + PyTorchDistribution, + RayDistribution, + TensorFlowDistribution, +) +from azure.ai.ml.entities._job.job_resource_configuration import JobResourceConfiguration +from azure.ai.ml.entities._job.parameterized_command import ParameterizedCommand +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException + +from ..._restclient.v2022_10_01.models import ComponentVersion +from ..._schema import PathAwareSchema +from ..._utils.utils import get_all_data_binding_expressions, parse_args_description_from_docstring +from .._util import convert_ordered_dict_to_dict, validate_attribute_type +from .._validation import MutableValidationResult +from ._additional_includes import AdditionalIncludesMixin +from .component import Component + +# pylint: disable=protected-access + + +class CommandComponent(Component, ParameterizedCommand, AdditionalIncludesMixin): + """Command component version, used to define a Command Component or Job. + + :keyword name: The name of the Command job or component. + :paramtype name: Optional[str] + :keyword version: The version of the Command job or component. + :paramtype version: Optional[str] + :keyword description: The description of the component. Defaults to None. + :paramtype description: Optional[str] + :keyword tags: Tag dictionary. Tags can be added, removed, and updated. Defaults to None. + :paramtype tags: Optional[dict] + :keyword display_name: The display name of the component. + :paramtype display_name: Optional[str] + :keyword command: The command to be executed. + :paramtype command: Optional[str] + :keyword code: The source code to run the job. Can be a local path or "http:", "https:", or "azureml:" url pointing + to a remote location. + :type code: Optional[str] + :keyword environment: The environment that the job will run in. + :paramtype environment: Optional[Union[str, ~azure.ai.ml.entities.Environment]] + :keyword distribution: The configuration for distributed jobs. Defaults to None. + :paramtype distribution: Optional[Union[~azure.ai.ml.PyTorchDistribution, ~azure.ai.ml.MpiDistribution, + ~azure.ai.ml.TensorFlowDistribution, ~azure.ai.ml.RayDistribution]] + :keyword resources: The compute resource configuration for the command. + :paramtype resources: Optional[~azure.ai.ml.entities.JobResourceConfiguration] + :keyword inputs: A mapping of input names to input data sources used in the job. Defaults to None. + :paramtype inputs: Optional[dict[str, Union[ + ~azure.ai.ml.Input, + str, + bool, + int, + float, + Enum, + ]]] + :keyword outputs: A mapping of output names to output data sources used in the job. Defaults to None. + :paramtype outputs: Optional[dict[str, Union[str, ~azure.ai.ml.Output]]] + :keyword instance_count: The number of instances or nodes to be used by the compute target. Defaults to 1. + :paramtype instance_count: Optional[int] + :keyword is_deterministic: Specifies whether the Command will return the same output given the same input. + Defaults to True. When True, if a Command (component) is deterministic and has been run before in the + current workspace with the same input and settings, it will reuse results from a previous submitted job + when used as a node or step in a pipeline. In that scenario, no compute resources will be used. + :paramtype is_deterministic: Optional[bool] + :keyword additional_includes: A list of shared additional files to be included in the component. Defaults to None. + :paramtype additional_includes: Optional[List[str]] + :keyword properties: The job property dictionary. Defaults to None. + :paramtype properties: Optional[dict[str, str]] + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if CommandComponent cannot be successfully validated. + Details will be provided in the error message. + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_command_configurations.py + :start-after: [START command_component_definition] + :end-before: [END command_component_definition] + :language: python + :dedent: 8 + :caption: Creating a CommandComponent. + """ + + def __init__( + self, + *, + name: Optional[str] = None, + version: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[Dict] = None, + display_name: Optional[str] = None, + command: Optional[str] = None, + code: Optional[Union[str, os.PathLike]] = None, + environment: Optional[Union[str, Environment]] = None, + distribution: Optional[ + Union[ + Dict, + MpiDistribution, + TensorFlowDistribution, + PyTorchDistribution, + RayDistribution, + DistributionConfiguration, + ] + ] = None, + resources: Optional[JobResourceConfiguration] = None, + inputs: Optional[Dict] = None, + outputs: Optional[Dict] = None, + instance_count: Optional[int] = None, # promoted property from resources.instance_count + is_deterministic: bool = True, + additional_includes: Optional[List] = None, + properties: Optional[Dict] = None, + **kwargs: Any, + ) -> None: + # validate init params are valid type + validate_attribute_type(attrs_to_check=locals(), attr_type_map=self._attr_type_map()) + + kwargs[COMPONENT_TYPE] = NodeType.COMMAND + + # Component backend doesn't support environment_variables yet, + # this is to support the case of CommandComponent being the trial of + # a SweepJob, where environment_variables is stored as part of trial + environment_variables = kwargs.pop("environment_variables", None) + super().__init__( + name=name, + version=version, + description=description, + tags=tags, + display_name=display_name, + inputs=inputs, + outputs=outputs, + is_deterministic=is_deterministic, + properties=properties, + **kwargs, + ) + + # No validation on value passed here because in pipeline job, required code&environment maybe absent + # and fill in later with job defaults. + self.command = command + self.code = code + self.environment_variables = environment_variables + self.environment = environment + self.resources = resources # type: ignore[assignment] + self.distribution = distribution + + # check mutual exclusivity of promoted properties + if self.resources is not None and instance_count is not None: + msg = "instance_count and resources are mutually exclusive" + raise ValidationException( + message=msg, + target=ErrorTarget.COMPONENT, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + ) + self.instance_count = instance_count + self.additional_includes = additional_includes or [] + + def _to_ordered_dict_for_yaml_dump(self) -> Dict: + """Dump the component content into a sorted yaml string. + + :return: The ordered dict + :rtype: Dict + """ + + obj: dict = super()._to_ordered_dict_for_yaml_dump() + # dict dumped base on schema will transfer code to an absolute path, while we want to keep its original value + if self.code and isinstance(self.code, str): + obj["code"] = self.code + return obj + + @property + def instance_count(self) -> Optional[int]: + """The number of instances or nodes to be used by the compute target. + + :return: The number of instances or nodes. + :rtype: int + """ + return self.resources.instance_count if self.resources and not isinstance(self.resources, dict) else None + + @instance_count.setter + def instance_count(self, value: int) -> None: + """Sets the number of instances or nodes to be used by the compute target. + + :param value: The number of instances of nodes to be used by the compute target. Defaults to 1. + :type value: int + """ + if not value: + return + if not self.resources: + self.resources = JobResourceConfiguration(instance_count=value) + else: + if not isinstance(self.resources, dict): + self.resources.instance_count = value + + @classmethod + def _attr_type_map(cls) -> dict: + return { + "environment": (str, Environment), + "environment_variables": dict, + "resources": (dict, JobResourceConfiguration), + "code": (str, os.PathLike), + } + + def _to_dict(self) -> Dict: + return cast( + dict, convert_ordered_dict_to_dict({**self._other_parameter, **super(CommandComponent, self)._to_dict()}) + ) + + @classmethod + def _from_rest_object_to_init_params(cls, obj: ComponentVersion) -> Dict: + # put it here as distribution is shared by some components, e.g. command + distribution = obj.properties.component_spec.pop("distribution", None) + init_kwargs: dict = super()._from_rest_object_to_init_params(obj) + if distribution: + init_kwargs["distribution"] = DistributionConfiguration._from_rest_object(distribution) + return init_kwargs + + def _get_environment_id(self) -> Union[str, None]: + # Return environment id of environment + # handle case when environment is defined inline + if isinstance(self.environment, Environment): + _id: Optional[str] = self.environment.id + return _id + return self.environment + + # region SchemaValidatableMixin + @classmethod + def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]: + return CommandComponentSchema(context=context) + + def _customized_validate(self) -> MutableValidationResult: + validation_result = super(CommandComponent, self)._customized_validate() + self._append_diagnostics_and_check_if_origin_code_reliable_for_local_path_validation(validation_result) + validation_result.merge_with(self._validate_command()) + validation_result.merge_with(self._validate_early_available_output()) + return validation_result + + def _validate_command(self) -> MutableValidationResult: + validation_result = self._create_empty_validation_result() + # command + if self.command: + invalid_expressions = [] + for data_binding_expression in get_all_data_binding_expressions(self.command, is_singular=False): + if not self._is_valid_data_binding_expression(data_binding_expression): + invalid_expressions.append(data_binding_expression) + + if invalid_expressions: + validation_result.append_error( + yaml_path="command", + message="Invalid data binding expression: {}".format(", ".join(invalid_expressions)), + ) + return validation_result + + def _validate_early_available_output(self) -> MutableValidationResult: + validation_result = self._create_empty_validation_result() + for name, output in self.outputs.items(): + if output.early_available is True and output._is_primitive_type is not True: + msg = ( + f"Early available output {name!r} requires output is primitive type, " + f"got {output._is_primitive_type!r}." + ) + validation_result.append_error(message=msg, yaml_path=f"outputs.{name}") + return validation_result + + def _is_valid_data_binding_expression(self, data_binding_expression: str) -> bool: + current_obj: Any = self + for item in data_binding_expression.split("."): + if hasattr(current_obj, item): + current_obj = getattr(current_obj, item) + else: + try: + current_obj = current_obj[item] + except Exception: # pylint: disable=W0718 + return False + return True + + # endregion + + @classmethod + def _parse_args_description_from_docstring(cls, docstring: str) -> Dict: + res: dict = parse_args_description_from_docstring(docstring) + return res + + def __str__(self) -> str: + try: + toYaml: str = self._to_yaml() + return toYaml + except BaseException: # pylint: disable=W0718 + toStr: str = super(CommandComponent, self).__str__() + return toStr diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/component.py new file mode 100644 index 00000000..c02a3a33 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/component.py @@ -0,0 +1,641 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import re +import uuid +from os import PathLike +from pathlib import Path +from typing import IO, TYPE_CHECKING, Any, AnyStr, Callable, Dict, Iterable, Optional, Tuple, Union + +from marshmallow import INCLUDE + +from ..._restclient.v2024_01_01_preview.models import ( + ComponentContainer, + ComponentContainerProperties, + ComponentVersion, + ComponentVersionProperties, +) +from ..._schema import PathAwareSchema +from ..._schema.component import ComponentSchema +from ..._utils.utils import dump_yaml_to_file, hash_dict +from ...constants._common import ( + ANONYMOUS_COMPONENT_NAME, + BASE_PATH_CONTEXT_KEY, + PARAMS_OVERRIDE_KEY, + REGISTRY_URI_FORMAT, + SOURCE_PATH_CONTEXT_KEY, + CommonYamlFields, + SchemaUrl, +) +from ...constants._component import ComponentSource, IOConstants, NodeType +from ...entities._assets.asset import Asset +from ...entities._inputs_outputs import Input, Output +from ...entities._mixins import LocalizableMixin, TelemetryMixin, YamlTranslatableMixin +from ...entities._system_data import SystemData +from ...entities._util import find_type_in_override +from ...entities._validation import MutableValidationResult, PathAwareSchemaValidatableMixin, RemoteValidatableMixin +from ...exceptions import ErrorCategory, ErrorTarget, ValidationException +from .._inputs_outputs import GroupInput + +if TYPE_CHECKING: + from ...entities.builders import BaseNode +# pylint: disable=protected-access, redefined-builtin +# disable redefined-builtin to use id/type as argument name + + +COMPONENT_PLACEHOLDER = "COMPONENT_PLACEHOLDER" + + +class Component( + Asset, + RemoteValidatableMixin, + TelemetryMixin, + YamlTranslatableMixin, + PathAwareSchemaValidatableMixin, + LocalizableMixin, +): + """Base class for component version, used to define a component. Can't be instantiated directly. + + :param name: Name of the resource. + :type name: str + :param version: Version of the resource. + :type version: str + :param id: Global ID of the resource, Azure Resource Manager ID. + :type id: str + :param type: Type of the command, supported is 'command'. + :type type: str + :param description: Description of the resource. + :type description: str + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: dict + :param properties: Internal use only. + :type properties: dict + :param display_name: Display name of the component. + :type display_name: str + :param is_deterministic: Whether the component is deterministic. Defaults to True. + :type is_deterministic: bool + :param inputs: Inputs of the component. + :type inputs: dict + :param outputs: Outputs of the component. + :type outputs: dict + :param yaml_str: The YAML string of the component. + :type yaml_str: str + :param _schema: Schema of the component. + :type _schema: str + :param creation_context: Creation metadata of the component. + :type creation_context: ~azure.ai.ml.entities.SystemData + :param kwargs: Additional parameters for the component. + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Component cannot be successfully validated. + Details will be provided in the error message. + """ + + # pylint: disable=too-many-instance-attributes + def __init__( + self, + *, + name: Optional[str] = None, + version: Optional[str] = None, + id: Optional[str] = None, + type: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[Dict] = None, + properties: Optional[Dict] = None, + display_name: Optional[str] = None, + is_deterministic: bool = True, + inputs: Optional[Dict] = None, + outputs: Optional[Dict] = None, + yaml_str: Optional[str] = None, + _schema: Optional[str] = None, + creation_context: Optional[SystemData] = None, + **kwargs: Any, + ) -> None: + self.latest_version = None + self._intellectual_property = kwargs.pop("intellectual_property", None) + # Setting this before super init because when asset init version, _auto_increment_version's value may change + self._auto_increment_version = kwargs.pop("auto_increment", False) + # Get source from id first, then kwargs. + self._source = ( + self._resolve_component_source_from_id(id) if id else kwargs.pop("_source", ComponentSource.CLASS) + ) + # use ANONYMOUS_COMPONENT_NAME instead of guid + is_anonymous = kwargs.pop("is_anonymous", False) + if not name and version is None: + name = ANONYMOUS_COMPONENT_NAME + version = "1" + is_anonymous = True + + super().__init__( + name=name, + version=version, + id=id, + description=description, + tags=tags, + properties=properties, + creation_context=creation_context, + is_anonymous=is_anonymous, + base_path=kwargs.pop(BASE_PATH_CONTEXT_KEY, None), + source_path=kwargs.pop(SOURCE_PATH_CONTEXT_KEY, None), + ) + # store kwargs to self._other_parameter instead of pop to super class to allow component have extra + # fields not defined in current schema. + + inputs = inputs if inputs else {} + outputs = outputs if outputs else {} + + self.name = name + self._schema = _schema + self._type = type + self._display_name = display_name + self._is_deterministic = is_deterministic + self._inputs = self._build_io(inputs, is_input=True) + self._outputs = self._build_io(outputs, is_input=False) + # Store original yaml + self._yaml_str = yaml_str + self._other_parameter = kwargs + + @property + def _func(self) -> Callable[..., "BaseNode"]: + from azure.ai.ml.entities._job.pipeline._load_component import _generate_component_function + + # validate input/output names before creating component function + validation_result = self._validate_io_names(self.inputs) + validation_result.merge_with(self._validate_io_names(self.outputs)) + self._try_raise(validation_result) + + res: Callable = _generate_component_function(self) + return res + + @property + def type(self) -> Optional[str]: + """Type of the component, default is 'command'. + + :return: Type of the component. + :rtype: str + """ + return self._type + + @property + def display_name(self) -> Optional[str]: + """Display name of the component. + + :return: Display name of the component. + :rtype: str + """ + return self._display_name + + @display_name.setter + def display_name(self, custom_display_name: str) -> None: + """Set display_name of the component. + + :param custom_display_name: The new display name + :type custom_display_name: str + """ + self._display_name = custom_display_name + + @property + def is_deterministic(self) -> Optional[bool]: + """Whether the component is deterministic. + + :return: Whether the component is deterministic + :rtype: bool + """ + return self._is_deterministic + + @property + def inputs(self) -> Dict: + """Inputs of the component. + + :return: Inputs of the component. + :rtype: dict + """ + res: dict = self._inputs + return res + + @property + def outputs(self) -> Dict: + """Outputs of the component. + + :return: Outputs of the component. + :rtype: dict + """ + return self._outputs + + @property + def version(self) -> Optional[str]: + """Version of the component. + + :return: Version of the component. + :rtype: str + """ + return self._version + + @version.setter + def version(self, value: str) -> None: + """Set the version of the component. + + :param value: The version of the component. + :type value: str + """ + if value: + if not isinstance(value, str): + msg = f"Component version must be a string, not type {type(value)}." + raise ValidationException( + message=msg, + target=ErrorTarget.COMPONENT, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + ) + self._version = value + self._auto_increment_version = self.name and not self._version + + def dump(self, dest: Union[str, PathLike, IO[AnyStr]], **kwargs: Any) -> None: + """Dump the component content into a file in yaml format. + + :param dest: The destination to receive this component's content. + Must be either a path to a local file, or an already-open file stream. + If dest is a file path, a new file will be created, + and an exception is raised if the file exists. + If dest is an open file, the file will be written to directly, + and an exception will be raised if the file is not writable. + :type dest: Union[PathLike, str, IO[AnyStr]] + """ + path = kwargs.pop("path", None) + yaml_serialized = self._to_dict() + dump_yaml_to_file(dest, yaml_serialized, default_flow_style=False, path=path, **kwargs) + + @staticmethod + def _resolve_component_source_from_id( # pylint: disable=docstring-type-do-not-use-class + id: Optional[Union["Component", str]], + ) -> Any: + """Resolve the component source from id. + + :param id: The component ID + :type id: Optional[str] + :return: The component source + :rtype: Literal[ + ComponentSource.CLASS, + ComponentSource.REMOTE_REGISTRY, + ComponentSource.REMOTE_WORKSPACE_COMPONENT + + ] + """ + if id is None: + return ComponentSource.CLASS + # Consider default is workspace source, as + # azureml: prefix will be removed for arm versioned id. + return ( + ComponentSource.REMOTE_REGISTRY + if not isinstance(id, Component) and id.startswith(REGISTRY_URI_FORMAT) + else ComponentSource.REMOTE_WORKSPACE_COMPONENT + ) + + @classmethod + def _validate_io_names(cls, io_names: Iterable[str], raise_error: bool = False) -> MutableValidationResult: + """Validate input/output names, raise exception if invalid. + + :param io_names: The names to validate + :type io_names: Iterable[str] + :param raise_error: Whether to raise if validation fails. Defaults to False + :type raise_error: bool + :return: The validation result + :rtype: MutableValidationResult + """ + validation_result = cls._create_empty_validation_result() + lower2original_kwargs: dict = {} + + for name in io_names: + if re.match(IOConstants.VALID_KEY_PATTERN, name) is None: + msg = "{!r} is not a valid parameter name, must be composed letters, numbers, and underscores." + validation_result.append_error(message=msg.format(name), yaml_path=f"inputs.{name}") + # validate name conflict + lower_key = name.lower() + if lower_key in lower2original_kwargs: + msg = "Invalid component input names {!r} and {!r}, which are equal ignore case." + validation_result.append_error( + message=msg.format(name, lower2original_kwargs[lower_key]), yaml_path=f"inputs.{name}" + ) + else: + lower2original_kwargs[lower_key] = name + return cls._try_raise(validation_result, raise_error=raise_error) + + @classmethod + def _build_io(cls, io_dict: Union[Dict, Input, Output], is_input: bool) -> Dict: + component_io: dict = {} + for name, port in io_dict.items(): + if is_input: + component_io[name] = port if isinstance(port, Input) else Input(**port) + else: + component_io[name] = port if isinstance(port, Output) else Output(**port) + + if is_input: + # Restore flattened parameters to group + res: dict = GroupInput.restore_flattened_inputs(component_io) + return res + return component_io + + @classmethod + def _create_schema_for_validation(cls, context: Any) -> PathAwareSchema: + return ComponentSchema(context=context) + + @classmethod + def _create_validation_error(cls, message: str, no_personal_data_message: str) -> ValidationException: + return ValidationException( + message=message, + no_personal_data_message=no_personal_data_message, + target=ErrorTarget.COMPONENT, + ) + + @classmethod + def _is_flow(cls, data: Any) -> bool: + _schema = data.get(CommonYamlFields.SCHEMA, None) + + if _schema and _schema in [SchemaUrl.PROMPTFLOW_FLOW, SchemaUrl.PROMPTFLOW_RUN]: + return True + return False + + @classmethod + def _load( + cls, + data: Optional[Dict] = None, + yaml_path: Optional[Union[PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "Component": + data = data or {} + params_override = params_override or [] + base_path = Path(yaml_path).parent if yaml_path else Path("./") + + type_in_override = find_type_in_override(params_override) + + # type_in_override > type_in_yaml > default (command) + if type_in_override is None: + type_in_override = data.get(CommonYamlFields.TYPE, None) + if type_in_override is None and cls._is_flow(data): + type_in_override = NodeType.FLOW_PARALLEL + if type_in_override is None: + type_in_override = NodeType.COMMAND + data[CommonYamlFields.TYPE] = type_in_override + + from azure.ai.ml.entities._component.component_factory import component_factory + + create_instance_func, _ = component_factory.get_create_funcs( + data, + for_load=True, + ) + new_instance: Component = create_instance_func() + # specific keys must be popped before loading with schema using kwargs + init_kwargs = { + "yaml_str": kwargs.pop("yaml_str", None), + "_source": kwargs.pop("_source", ComponentSource.YAML_COMPONENT), + } + init_kwargs.update( + new_instance._load_with_schema( # pylint: disable=protected-access + data, + context={ + BASE_PATH_CONTEXT_KEY: base_path, + SOURCE_PATH_CONTEXT_KEY: yaml_path, + PARAMS_OVERRIDE_KEY: params_override, + }, + unknown=INCLUDE, + raise_original_exception=True, + **kwargs, + ) + ) + # Set base path separately to avoid doing this in post load, as return types of post load are not unified, + # could be object or dict. + # base_path in context can be changed in loading, so we use original base_path here. + init_kwargs[BASE_PATH_CONTEXT_KEY] = base_path.absolute() + if yaml_path: + init_kwargs[SOURCE_PATH_CONTEXT_KEY] = Path(yaml_path).absolute().as_posix() + # TODO: Bug Item number: 2883415 + new_instance.__init__( # type: ignore + **init_kwargs, + ) + return new_instance + + @classmethod + def _from_container_rest_object(cls, component_container_rest_object: ComponentContainer) -> "Component": + component_container_details: ComponentContainerProperties = component_container_rest_object.properties + component = Component( + id=component_container_rest_object.id, + name=component_container_rest_object.name, + description=component_container_details.description, + creation_context=SystemData._from_rest_object(component_container_rest_object.system_data), + tags=component_container_details.tags, + properties=component_container_details.properties, + type=NodeType._CONTAINER, + # Set this field to None as it hold a default True in init. + is_deterministic=None, # type: ignore[arg-type] + ) + component.latest_version = component_container_details.latest_version + return component + + @classmethod + def _from_rest_object(cls, obj: ComponentVersion) -> "Component": + # TODO: Remove in PuP with native import job/component type support in MFE/Designer + # Convert command component back to import component private preview + component_spec = obj.properties.component_spec + if component_spec[CommonYamlFields.TYPE] == NodeType.COMMAND and component_spec["command"] == NodeType.IMPORT: + component_spec[CommonYamlFields.TYPE] = NodeType.IMPORT + component_spec["source"] = component_spec.pop("inputs") + component_spec["output"] = component_spec.pop("outputs")["output"] + + # shouldn't block serialization when name is not valid + # maybe override serialization method for name field? + from azure.ai.ml.entities._component.component_factory import component_factory + + create_instance_func, _ = component_factory.get_create_funcs(obj.properties.component_spec, for_load=True) + + instance: Component = create_instance_func() + # TODO: Bug Item number: 2883415 + instance.__init__(**instance._from_rest_object_to_init_params(obj)) # type: ignore + return instance + + @classmethod + def _from_rest_object_to_init_params(cls, obj: ComponentVersion) -> Dict: + # Object got from rest data contain _source, we delete it. + if "_source" in obj.properties.component_spec: + del obj.properties.component_spec["_source"] + + rest_component_version = obj.properties + _type = rest_component_version.component_spec[CommonYamlFields.TYPE] + + # inputs/outputs will be parsed by instance._build_io in instance's __init__ + inputs = rest_component_version.component_spec.pop("inputs", {}) + # parse String -> string, Integer -> integer, etc + for _input in inputs.values(): + _input["type"] = Input._map_from_rest_type(_input["type"]) + outputs = rest_component_version.component_spec.pop("outputs", {}) + + origin_name = rest_component_version.component_spec[CommonYamlFields.NAME] + rest_component_version.component_spec[CommonYamlFields.NAME] = ANONYMOUS_COMPONENT_NAME + init_kwargs = cls._load_with_schema( + rest_component_version.component_spec, context={BASE_PATH_CONTEXT_KEY: Path.cwd()}, unknown=INCLUDE + ) + init_kwargs.update( + { + "id": obj.id, + "is_anonymous": rest_component_version.is_anonymous, + "creation_context": obj.system_data, + "inputs": inputs, + "outputs": outputs, + "name": origin_name, + } + ) + + # remove empty values, because some property only works for specific component, eg: distribution for command + # note that there is an issue that environment == {} will always be true, so use isinstance here + return {k: v for k, v in init_kwargs.items() if v is not None and not (isinstance(v, dict) and not v)} + + def _get_anonymous_hash(self) -> str: + """Return the hash of anonymous component. + + Anonymous Components (same code and interface) will have same hash. + + :return: The component hash + :rtype: str + """ + # omit version since anonymous component's version is random guid + # omit name since name doesn't impact component's uniqueness + return self._get_component_hash(keys_to_omit=["name", "id", "version"]) + + def _get_component_hash(self, keys_to_omit: Optional[Iterable[str]] = None) -> str: + """Return the hash of component. + + :param keys_to_omit: An iterable of keys to omit when computing the component hash + :type keys_to_omit: Optional[Iterable[str]] + :return: The component hash + :rtype: str + """ + component_interface_dict = self._to_dict() + res: str = hash_dict(component_interface_dict, keys_to_omit=keys_to_omit) + return res + + @classmethod + def _get_resource_type(cls) -> str: + return "Microsoft.MachineLearningServices/workspaces/components/versions" + + def _get_resource_name_version(self) -> Tuple: + version: Optional[str] = None + if not self.version and not self._auto_increment_version: + version = str(uuid.uuid4()) + else: + version = self.version + return self.name or ANONYMOUS_COMPONENT_NAME, version + + def _validate(self, raise_error: Optional[bool] = False) -> MutableValidationResult: + origin_name = self.name + # skip name validation for anonymous component as ANONYMOUS_COMPONENT_NAME will be used in component creation + if self._is_anonymous: + self.name = ANONYMOUS_COMPONENT_NAME + try: + return super()._validate(raise_error) + finally: + self.name = origin_name + + def _customized_validate(self) -> MutableValidationResult: + validation_result = super(Component, self)._customized_validate() + + # validate inputs names + validation_result.merge_with(self._validate_io_names(self.inputs, raise_error=False)) + validation_result.merge_with(self._validate_io_names(self.outputs, raise_error=False)) + + return validation_result + + def _get_anonymous_component_name_version(self) -> Tuple: + return ANONYMOUS_COMPONENT_NAME, self._get_anonymous_hash() + + def _get_rest_name_version(self) -> Tuple: + if self._is_anonymous: + return self._get_anonymous_component_name_version() + return self.name, self.version + + def _to_rest_object(self) -> ComponentVersion: + component = self._to_dict() + + # TODO: Remove in PuP with native import job/component type support in MFE/Designer + # Convert import component to command component private preview + if component.get(CommonYamlFields.TYPE, None) == NodeType.IMPORT: + component[CommonYamlFields.TYPE] = NodeType.COMMAND + component["inputs"] = component.pop("source") + component["outputs"] = dict({"output": component.pop("output")}) + # method _to_dict() will remove empty keys + if "tags" not in component: + component["tags"] = {} + component["tags"]["component_type_overwrite"] = NodeType.IMPORT + component["command"] = NodeType.IMPORT + + # add source type to component rest object + component["_source"] = self._source + if self._intellectual_property: + # hack while full pass through supported is worked on for IPP fields + component.pop("intellectual_property") + component["intellectualProperty"] = self._intellectual_property._to_rest_object().serialize() + properties = ComponentVersionProperties( + component_spec=component, + description=self.description, + is_anonymous=self._is_anonymous, + properties=dict(self.properties) if self.properties else {}, + tags=self.tags, + ) + result = ComponentVersion(properties=properties) + if self._is_anonymous: + result.name = ANONYMOUS_COMPONENT_NAME + else: + result.name = self.name + result.properties.properties["client_component_hash"] = self._get_component_hash(keys_to_omit=["version"]) + return result + + def _to_dict(self) -> Dict: + # Replace the name of $schema to schema. + component_schema_dict: dict = self._dump_for_validation() + component_schema_dict.pop(BASE_PATH_CONTEXT_KEY, None) + + # TODO: handle other_parameters and remove override from subclass + return component_schema_dict + + def _localize(self, base_path: str) -> None: + """Called on an asset got from service to clean up remote attributes like id, creation_context, etc. and update + base_path. + + :param base_path: The base_path + :type base_path: str + """ + if not getattr(self, "id", None): + raise ValueError("Only remote asset can be localize but got a {} without id.".format(type(self))) + self._id = None + self._creation_context = None + self._base_path = base_path + + def _get_telemetry_values(self, *args: Any, **kwargs: Any) -> Dict: + # Note: the is_anonymous is not reliable here, create_or_update will log is_anonymous from parameter. + is_anonymous = self.name is None or ANONYMOUS_COMPONENT_NAME in self.name + return {"type": self.type, "source": self._source, "is_anonymous": is_anonymous} + + # pylint: disable-next=docstring-missing-param + def __call__(self, *args: Any, **kwargs: Any) -> "BaseNode": + """Call ComponentVersion as a function and get a Component object. + + :return: The component object + :rtype: BaseNode + """ + if args: + # raise clear error message for unsupported positional args + if self._func._has_parameters: # type: ignore + _error = f"got {args} for {self.name}" + msg = ( + f"Component function doesn't support positional arguments, {_error}. " # type: ignore + f"Please use keyword arguments like: {self._func._func_calling_example}." + ) + else: + msg = ( + "Component function doesn't has any parameters, " + f"please make sure component {self.name} has inputs. " + ) + raise ValidationException( + message=msg, + target=ErrorTarget.COMPONENT, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + ) + return self._func(*args, **kwargs) # pylint: disable=not-callable diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/component_factory.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/component_factory.py new file mode 100644 index 00000000..012dd260 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/component_factory.py @@ -0,0 +1,171 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from typing import Any, Callable, Dict, Optional, Tuple + +from marshmallow import Schema + +from ..._restclient.v2022_10_01.models import ComponentVersion +from ..._utils.utils import is_internal_component_data +from ...constants._common import SOURCE_PATH_CONTEXT_KEY +from ...constants._component import DataTransferTaskType, NodeType +from ...entities._component.automl_component import AutoMLComponent +from ...entities._component.command_component import CommandComponent +from ...entities._component.component import Component +from ...entities._component.datatransfer_component import ( + DataTransferCopyComponent, + DataTransferExportComponent, + DataTransferImportComponent, +) +from ...entities._component.import_component import ImportComponent +from ...entities._component.parallel_component import ParallelComponent +from ...entities._component.pipeline_component import PipelineComponent +from ...entities._component.spark_component import SparkComponent +from ...entities._util import get_type_from_spec +from .flow import FlowComponent + + +class _ComponentFactory: + """A class to create component instances from yaml dict or rest objects without hard-coded type check.""" + + def __init__(self) -> None: + self._create_instance_funcs: Dict = {} + self._create_schema_funcs: Dict = {} + + self.register_type( + _type=NodeType.PARALLEL, + create_instance_func=lambda: ParallelComponent.__new__(ParallelComponent), + create_schema_func=ParallelComponent._create_schema_for_validation, + ) + self.register_type( + _type=NodeType.COMMAND, + create_instance_func=lambda: CommandComponent.__new__(CommandComponent), + create_schema_func=CommandComponent._create_schema_for_validation, + ) + self.register_type( + _type=NodeType.IMPORT, + create_instance_func=lambda: ImportComponent.__new__(ImportComponent), + create_schema_func=ImportComponent._create_schema_for_validation, + ) + self.register_type( + _type=NodeType.PIPELINE, + create_instance_func=lambda: PipelineComponent.__new__(PipelineComponent), + create_schema_func=PipelineComponent._create_schema_for_validation, + ) + self.register_type( + _type=NodeType.AUTOML, + create_instance_func=lambda: AutoMLComponent.__new__(AutoMLComponent), + create_schema_func=AutoMLComponent._create_schema_for_validation, + ) + self.register_type( + _type=NodeType.SPARK, + create_instance_func=lambda: SparkComponent.__new__(SparkComponent), + create_schema_func=SparkComponent._create_schema_for_validation, + ) + self.register_type( + _type="_".join([NodeType.DATA_TRANSFER, DataTransferTaskType.COPY_DATA]), + create_instance_func=lambda: DataTransferCopyComponent.__new__(DataTransferCopyComponent), + create_schema_func=DataTransferCopyComponent._create_schema_for_validation, + ) + + self.register_type( + _type="_".join([NodeType.DATA_TRANSFER, DataTransferTaskType.IMPORT_DATA]), + create_instance_func=lambda: DataTransferImportComponent.__new__(DataTransferImportComponent), + create_schema_func=DataTransferImportComponent._create_schema_for_validation, + ) + + self.register_type( + _type="_".join([NodeType.DATA_TRANSFER, DataTransferTaskType.EXPORT_DATA]), + create_instance_func=lambda: DataTransferExportComponent.__new__(DataTransferExportComponent), + create_schema_func=DataTransferExportComponent._create_schema_for_validation, + ) + + self.register_type( + _type=NodeType.FLOW_PARALLEL, + create_instance_func=lambda: FlowComponent.__new__(FlowComponent), + create_schema_func=FlowComponent._create_schema_for_validation, + ) + + def get_create_funcs( + self, yaml_spec: dict, for_load: bool = False + ) -> Tuple[Callable[..., Component], Callable[[Any], Schema]]: + """Get registered functions to create an instance and its corresponding schema for the given type. + + :param yaml_spec: The YAML specification. + :type yaml_spec: dict + :param for_load: Whether the function is called for loading a component. Defaults to False. + :type for_load: bool + :return: A tuple containing the create_instance_func and create_schema_func. + :rtype: tuple + """ + + _type = get_type_from_spec(yaml_spec, valid_keys=self._create_instance_funcs) + # SparkComponent and InternalSparkComponent share the same type name, but they are different types. + if for_load and is_internal_component_data(yaml_spec, raise_if_not_enabled=True) and _type == NodeType.SPARK: + from azure.ai.ml._internal._schema.node import NodeType as InternalNodeType + + _type = InternalNodeType.SPARK + + create_instance_func = self._create_instance_funcs[_type] + create_schema_func = self._create_schema_funcs[_type] + return create_instance_func, create_schema_func + + def register_type( + self, + _type: str, + create_instance_func: Callable[..., Component], + create_schema_func: Callable[[Any], Schema], + ) -> None: + """Register a new component type. + + :param _type: The type name of the component. + :type _type: str + :param create_instance_func: A function to create an instance of the component. + :type create_instance_func: Callable[..., ~azure.ai.ml.entities.Component] + :param create_schema_func: A function to create a schema for the component. + :type create_schema_func: Callable[[Any], Schema] + """ + self._create_instance_funcs[_type] = create_instance_func + self._create_schema_funcs[_type] = create_schema_func + + @classmethod + def load_from_dict(cls, *, data: Dict, context: Dict, _type: Optional[str] = None, **kwargs: Any) -> Component: + """Load a component from a YAML dict. + + :keyword data: The YAML dict. + :paramtype data: dict + :keyword context: The context of the YAML dict. + :paramtype context: dict + :keyword _type: The type name of the component. When None, it will be inferred from the YAML dict. + :paramtype _type: str + :return: The loaded component. + :rtype: ~azure.ai.ml.entities.Component + """ + + return Component._load( + data=data, + yaml_path=context.get(SOURCE_PATH_CONTEXT_KEY, None), + params_override=[{"type": _type}] if _type is not None else [], + **kwargs, + ) + + @classmethod + def load_from_rest(cls, *, obj: ComponentVersion, _type: Optional[str] = None) -> Component: + """Load a component from a REST object. + + :keyword obj: The REST object. + :paramtype obj: ComponentVersion + :keyword _type: The type name of the component. When None, it will be inferred from the REST object. + :paramtype _type: str + :return: The loaded component. + :rtype: ~azure.ai.ml.entities.Component + """ + if _type is not None: + obj.properties.component_spec["type"] = _type + return Component._from_rest_object(obj) + + +component_factory = _ComponentFactory() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/datatransfer_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/datatransfer_component.py new file mode 100644 index 00000000..e71712ab --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/datatransfer_component.py @@ -0,0 +1,325 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from pathlib import Path +from typing import Any, Dict, NoReturn, Optional, Union, cast + +from marshmallow import Schema + +from azure.ai.ml._schema.component.data_transfer_component import ( + DataTransferCopyComponentSchema, + DataTransferExportComponentSchema, + DataTransferImportComponentSchema, +) +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, COMPONENT_TYPE, AssetTypes +from azure.ai.ml.constants._component import DataTransferTaskType, ExternalDataType, NodeType +from azure.ai.ml.entities._inputs_outputs.external_data import Database, FileSystem +from azure.ai.ml.entities._inputs_outputs.output import Output +from azure.ai.ml.entities._validation.core import MutableValidationResult +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + +from ..._schema import PathAwareSchema +from .._util import convert_ordered_dict_to_dict, validate_attribute_type +from .component import Component + + +class DataTransferComponent(Component): + """DataTransfer component version, used to define a data transfer component. + + :param task: Task type in the data transfer component. Possible values are "copy_data", + "import_data", and "export_data". + :type task: str + :param inputs: Mapping of input data bindings used in the job. + :type inputs: dict + :param outputs: Mapping of output data bindings used in the job. + :type outputs: dict + :param kwargs: Additional parameters for the data transfer component. + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if the component cannot be successfully validated. + Details will be provided in the error message. + """ + + def __init__( + self, + *, + task: Optional[str] = None, + inputs: Optional[Dict] = None, + outputs: Optional[Dict] = None, + **kwargs: Any, + ) -> None: + # validate init params are valid type + validate_attribute_type(attrs_to_check=locals(), attr_type_map=self._attr_type_map()) + + kwargs[COMPONENT_TYPE] = NodeType.DATA_TRANSFER + # Set default base path + if BASE_PATH_CONTEXT_KEY not in kwargs: + kwargs[BASE_PATH_CONTEXT_KEY] = Path(".") + + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + self._task = task + + @classmethod + def _attr_type_map(cls) -> dict: + return {} + + @property + def task(self) -> Optional[str]: + """Task type of the component. + + :return: Task type of the component. + :rtype: str + """ + return self._task + + def _to_dict(self) -> Dict: + return cast( + dict, + convert_ordered_dict_to_dict({**self._other_parameter, **super(DataTransferComponent, self)._to_dict()}), + ) + + def __str__(self) -> str: + try: + _toYaml: str = self._to_yaml() + return _toYaml + except BaseException: # pylint: disable=W0718 + _toStr: str = super(DataTransferComponent, self).__str__() + return _toStr + + @classmethod + def _build_source_sink(cls, io_dict: Union[Dict, Database, FileSystem]) -> Union[Database, FileSystem]: + component_io: Union[Database, FileSystem] = Database() + + if isinstance(io_dict, Database): + component_io = Database() + elif isinstance(io_dict, FileSystem): + component_io = FileSystem() + else: + if isinstance(io_dict, dict): + data_type = io_dict.pop("type", None) + if data_type == ExternalDataType.DATABASE: + component_io = Database() + elif data_type == ExternalDataType.FILE_SYSTEM: + component_io = FileSystem() + else: + msg = "Type in source or sink only support {} and {}, currently got {}." + raise ValidationException( + message=msg.format( + ExternalDataType.DATABASE, + ExternalDataType.FILE_SYSTEM, + data_type, + ), + no_personal_data_message=msg.format( + ExternalDataType.DATABASE, + ExternalDataType.FILE_SYSTEM, + "data_type", + ), + target=ErrorTarget.COMPONENT, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + else: + msg = "Source or sink only support dict, Database and FileSystem" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.COMPONENT, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + return component_io + + +@experimental +class DataTransferCopyComponent(DataTransferComponent): + """DataTransfer copy component version, used to define a data transfer copy component. + + :param data_copy_mode: Data copy mode in the copy task. + Possible values are "merge_with_overwrite" and "fail_if_conflict". + :type data_copy_mode: str + :param inputs: Mapping of input data bindings used in the job. + :type inputs: dict + :param outputs: Mapping of output data bindings used in the job. + :type outputs: dict + :param kwargs: Additional parameters for the data transfer copy component. + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if the component cannot be successfully validated. + Details will be provided in the error message. + """ + + def __init__( + self, + *, + data_copy_mode: Optional[str] = None, + inputs: Optional[Dict] = None, + outputs: Optional[Dict] = None, + **kwargs: Any, + ) -> None: + kwargs["task"] = DataTransferTaskType.COPY_DATA + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + self._data_copy_mode = data_copy_mode + + @classmethod + def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]: + return DataTransferCopyComponentSchema(context=context) + + @property + def data_copy_mode(self) -> Optional[str]: + """Data copy mode of the component. + + :return: Data copy mode of the component. + :rtype: str + """ + return self._data_copy_mode + + def _customized_validate(self) -> MutableValidationResult: + validation_result = super(DataTransferCopyComponent, self)._customized_validate() + validation_result.merge_with(self._validate_input_output_mapping()) + return validation_result + + def _validate_input_output_mapping(self) -> MutableValidationResult: + validation_result = self._create_empty_validation_result() + inputs_count = len(self.inputs) + outputs_count = len(self.outputs) + if outputs_count != 1: + msg = "Only support single output in {}, but there're {} outputs." + validation_result.append_error( + message=msg.format(DataTransferTaskType.COPY_DATA, outputs_count), + yaml_path="outputs", + ) + else: + input_type = None + output_type = None + if inputs_count == 1: + for _, input_data in self.inputs.items(): + input_type = input_data.type + for _, output_data in self.outputs.items(): + output_type = output_data.type + if input_type is None or output_type is None or input_type != output_type: + msg = "Input type {} doesn't exactly match with output type {} in task {}" + validation_result.append_error( + message=msg.format(input_type, output_type, DataTransferTaskType.COPY_DATA), + yaml_path="outputs", + ) + elif inputs_count > 1: + for _, output_data in self.outputs.items(): + output_type = output_data.type + if output_type is None or output_type != AssetTypes.URI_FOLDER: + msg = "output type {} need to be {} in task {}" + validation_result.append_error( + message=msg.format( + output_type, + AssetTypes.URI_FOLDER, + DataTransferTaskType.COPY_DATA, + ), + yaml_path="outputs", + ) + else: + msg = "Inputs must be set in task {}." + validation_result.append_error( + message=msg.format(DataTransferTaskType.COPY_DATA), + yaml_path="inputs", + ) + return validation_result + + +@experimental +class DataTransferImportComponent(DataTransferComponent): + """DataTransfer import component version, used to define a data transfer import component. + + :param source: The data source of the file system or database. + :type source: dict + :param outputs: Mapping of output data bindings used in the job. + Default value is an output port with the key "sink" and the type "mltable". + :type outputs: dict + :param kwargs: Additional parameters for the data transfer import component. + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if the component cannot be successfully validated. + Details will be provided in the error message. + """ + + def __init__( + self, + *, + source: Optional[Dict] = None, + outputs: Optional[Dict] = None, + **kwargs: Any, + ) -> None: + outputs = outputs or {"sink": Output(type=AssetTypes.MLTABLE)} + kwargs["task"] = DataTransferTaskType.IMPORT_DATA + super().__init__( + outputs=outputs, + **kwargs, + ) + + source = source if source else {} + self.source = self._build_source_sink(source) + + @classmethod + def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]: + return DataTransferImportComponentSchema(context=context) + + # pylint: disable-next=docstring-missing-param + def __call__(self, *args: Any, **kwargs: Any) -> NoReturn: + """Call ComponentVersion as a function and get a Component object.""" + + msg = "DataTransfer component is not callable for import task." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.COMPONENT, + error_category=ErrorCategory.USER_ERROR, + ) + + +@experimental +class DataTransferExportComponent(DataTransferComponent): + """DataTransfer export component version, used to define a data transfer export component. + + :param sink: The sink of external data and databases. + :type sink: Union[Dict, Database, FileSystem] + :param inputs: Mapping of input data bindings used in the job. + :type inputs: dict + :param kwargs: Additional parameters for the data transfer export component. + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if the component cannot be successfully validated. + Details will be provided in the error message. + """ + + def __init__( + self, + *, + inputs: Optional[Dict] = None, + sink: Optional[Dict] = None, + **kwargs: Any, + ) -> None: + kwargs["task"] = DataTransferTaskType.EXPORT_DATA + super().__init__( + inputs=inputs, + **kwargs, + ) + + sink = sink if sink else {} + self.sink = self._build_source_sink(sink) + + @classmethod + def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]: + return DataTransferExportComponentSchema(context=context) + + # pylint: disable-next=docstring-missing-param + def __call__(self, *args: Any, **kwargs: Any) -> NoReturn: + """Call ComponentVersion as a function and get a Component object.""" + + msg = "DataTransfer component is not callable for export task." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.COMPONENT, + error_category=ErrorCategory.USER_ERROR, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/flow.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/flow.py new file mode 100644 index 00000000..e4ff06cc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/flow.py @@ -0,0 +1,553 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import contextlib +import json +import os +from collections import defaultdict +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union + +import yaml # type: ignore[import] +from marshmallow import EXCLUDE, Schema, ValidationError + +from azure.ai.ml.constants._common import ( + BASE_PATH_CONTEXT_KEY, + COMPONENT_TYPE, + PROMPTFLOW_AZUREML_OVERRIDE_KEY, + SOURCE_PATH_CONTEXT_KEY, + AssetTypes, + SchemaUrl, +) +from azure.ai.ml.constants._component import ComponentParameterTypes, NodeType + +from ..._restclient.v2022_10_01.models import ComponentVersion +from ..._schema import PathAwareSchema +from ..._schema.component.flow import FlowComponentSchema, FlowSchema, RunSchema +from ...exceptions import ErrorCategory, ErrorTarget, ValidationException +from .. import Environment +from .._inputs_outputs import GroupInput, Input, Output +from ._additional_includes import AdditionalIncludesMixin +from .component import Component + +# avoid circular import error +if TYPE_CHECKING: + from azure.ai.ml.entities._builders.parallel import Parallel + +# pylint: disable=protected-access + + +class _FlowPortNames: + """Common yaml fields. + + Common yaml fields are used to define the common fields in yaml files. It can be one of the following values: type, + name, $schema. + """ + + DATA = "data" + RUN_OUTPUTS = "run_outputs" + CONNECTIONS = "connections" + + FLOW_OUTPUTS = "flow_outputs" + DEBUG_INFO = "debug_info" + + +class _FlowComponentPortDict(dict): + def __init__(self, ports: Dict): + self._allow_update_item = True + super().__init__() + for input_port_name, input_port in ports.items(): + self[input_port_name] = input_port + self._allow_update_item = False + + def __setitem__(self, key: Any, value: Any) -> None: + if not self._allow_update_item: + raise RuntimeError("Ports of flow component are not editable.") + super().__setitem__(key, value) + + def __delitem__(self, key: Any) -> None: + if not self._allow_update_item: + raise RuntimeError("Ports of flow component are not editable.") + super().__delitem__(key) + + +class FlowComponentInputDict(_FlowComponentPortDict): + """Input port dictionary for FlowComponent, with fixed input ports.""" + + def __init__(self) -> None: + super().__init__( + { + _FlowPortNames.CONNECTIONS: GroupInput(values={}, _group_class=None), + _FlowPortNames.DATA: Input(type=AssetTypes.URI_FOLDER, optional=False), + _FlowPortNames.FLOW_OUTPUTS: Input(type=AssetTypes.URI_FOLDER, optional=True), + } + ) + + @contextlib.contextmanager + def _fit_inputs(self, inputs: Optional[Dict]) -> Generator: + """Add dynamic input ports to the input port dictionary. + Input ports of a flow component include: + 1. data: required major uri_folder input + 2. run_output: optional uri_folder input + 3. connections.xxx.xxx: group of string parameters, first layer key can be any node name, + but we won't resolve the exact keys in SDK + 4. xxx: input_mapping parameters, key can be any node name, but we won't resolve the exact keys in SDK + + #3 will be grouped into connections, we make it a fixed group input port. + #4 are dynamic input ports, we will add them temporarily in this context manager and remove them + after the context manager is finished. + + :param inputs: The dynamic input to fit. + :type inputs: Dict[str, Any] + :return: None + :rtype: None + """ + dynamic_columns_mapping_keys = [] + dynamic_connections_inputs = defaultdict(list) + from azure.ai.ml.entities._job.pipeline._io import _GroupAttrDict + from azure.ai.ml.entities._job.pipeline._io.mixin import flatten_dict + + flattened_inputs = flatten_dict(inputs, _GroupAttrDict, allow_dict_fields=[_FlowPortNames.CONNECTIONS]) + + for flattened_input_key in flattened_inputs: + if flattened_input_key.startswith(f"{_FlowPortNames.CONNECTIONS}."): + if flattened_input_key.count(".") != 2: + raise ValidationException( + message="flattened connection input prot name must be " + "in the format of connections.<node_name>.<port_name>, " + "but got %s" % flattened_input_key, + no_personal_data_message="flattened connection input prot name must be in the format of " + "connections.<node_name>.<port_name>", + target=ErrorTarget.COMPONENT, + error_category=ErrorCategory.USER_ERROR, + ) + _, node_name, param_name = flattened_input_key.split(".") + dynamic_connections_inputs[node_name].append(param_name) + continue + if flattened_input_key not in self: + dynamic_columns_mapping_keys.append(flattened_input_key) + + self._allow_update_item = True + for flattened_input_key in dynamic_columns_mapping_keys: + self[flattened_input_key] = Input(type=ComponentParameterTypes.STRING, optional=True) + if dynamic_connections_inputs: + self[_FlowPortNames.CONNECTIONS] = GroupInput( + values={ + node_name: GroupInput( + values={ + parameter_name: Input( + type=ComponentParameterTypes.STRING, + ) + for parameter_name in param_names + }, + _group_class=None, + ) + for node_name, param_names in dynamic_connections_inputs.items() + }, + _group_class=None, + ) + self._allow_update_item = False + + yield + + self._allow_update_item = True + for flattened_input_key in dynamic_columns_mapping_keys: + del self[flattened_input_key] + self[_FlowPortNames.CONNECTIONS] = GroupInput(values={}, _group_class=None) + self._allow_update_item = False + + +class FlowComponentOutputDict(_FlowComponentPortDict): + """Output port dictionary for FlowComponent, with fixed output ports.""" + + def __init__(self) -> None: + super().__init__( + { + _FlowPortNames.FLOW_OUTPUTS: Output(type=AssetTypes.URI_FOLDER), + _FlowPortNames.DEBUG_INFO: Output(type=AssetTypes.URI_FOLDER), + } + ) + + +class FlowComponent(Component, AdditionalIncludesMixin): + """Flow component version, used to define a Flow Component or Job. + + :keyword name: The name of the Flow job or component. + :type name: Optional[str] + :keyword version: The version of the Flow job or component. + :type version: Optional[str] + :keyword description: The description of the component. Defaults to None. + :type description: Optional[str] + :keyword tags: Tag dictionary. Tags can be added, removed, and updated. Defaults to None. + :type tags: Optional[dict] + :keyword display_name: The display name of the component. + :type display_name: Optional[str] + :keyword flow: The path to the flow directory or flow definition file. Defaults to None and base path of this + component will be used as flow directory. + :type flow: Optional[Union[str, Path]] + :keyword column_mappings: The column mapping for the flow. Defaults to None. + :type column_mapping: Optional[dict[str, str]] + :keyword variant: The variant of the flow. Defaults to None. + :type variant: Optional[str] + :keyword connections: The connections for the flow. Defaults to None. + :type connections: Optional[dict[str, dict[str, str]]] + :keyword environment_variables: The environment variables for the flow. Defaults to None. + :type environment_variables: Optional[dict[str, str]] + :keyword environment: The environment for the flow component. Defaults to None. + :type environment: Optional[Union[str, Environment]) + :keyword is_deterministic: Specifies whether the Flow will return the same output given the same input. + Defaults to True. When True, if a Flow (component) is deterministic and has been run before in the + current workspace with the same input and settings, it will reuse results from a previous submitted job + when used as a node or step in a pipeline. In that scenario, no compute resources will be used. + :type is_deterministic: Optional[bool] + :keyword additional_includes: A list of shared additional files to be included in the component. Defaults to None. + :type additional_includes: Optional[list[str]] + :keyword properties: The job property dictionary. Defaults to None. + :type properties: Optional[dict[str, str]] + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if FlowComponent cannot be successfully validated. + Details will be provided in the error message. + """ + + def __init__( + self, + *, + name: Optional[str] = None, + version: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[Dict] = None, + display_name: Optional[str] = None, + flow: Optional[Union[str, Path]] = None, + column_mapping: Optional[Dict[str, str]] = None, + variant: Optional[str] = None, + connections: Optional[Dict[str, Dict[str, str]]] = None, + environment_variables: Optional[Dict[str, str]] = None, + environment: Optional[Union[str, Environment]] = None, + is_deterministic: bool = True, + additional_includes: Optional[List] = None, + properties: Optional[Dict] = None, + **kwargs: Any, + ) -> None: + # validate init params are valid type + kwargs[COMPONENT_TYPE] = NodeType.FLOW_PARALLEL + + # always use flow directory as base path + # Note: we suppose that there is no relative path in run.yaml other than flow. + # If there are any, we will need to rebase them so that they have the same base path as attributes in + # flow.dag.yaml + flow_dir, self._flow = self._get_flow_definition( + flow=flow, + base_path=kwargs.pop(BASE_PATH_CONTEXT_KEY, Path.cwd()), + source_path=kwargs.get(SOURCE_PATH_CONTEXT_KEY, None), + ) + kwargs[BASE_PATH_CONTEXT_KEY] = flow_dir + + super().__init__( + name=name or self._normalize_component_name(flow_dir.name), + version=version or "1", + description=description, + tags=tags, + display_name=display_name, + inputs={}, + outputs={}, + is_deterministic=is_deterministic, + properties=properties, + **kwargs, + ) + self._environment = environment + self._column_mapping = column_mapping or {} + self._variant = variant + self._connections = connections or {} + + self._inputs = FlowComponentInputDict() + self._outputs = FlowComponentOutputDict() + + if flow: + # file existence has been checked in _get_flow_definition + # we don't need to rebase additional_includes as we have updated base_path + with open(Path(self.base_path, self._flow), "r", encoding="utf-8") as f: + flow_content = yaml.safe_load(f.read()) + additional_includes = flow_content.get("additional_includes", None) + # environment variables in run.yaml have higher priority than those in flow.dag.yaml + self._environment_variables = flow_content.get("environment_variables", {}) + self._environment_variables.update(environment_variables or {}) + else: + self._environment_variables = environment_variables or {} + + self._additional_includes = additional_includes or [] + + # unlike other Component, code is a private property in FlowComponent and + # will be used to store the arm id of the created code before constructing rest object + # we haven't used self.flow directly as self.flow can be a path to the flow dag yaml file instead of a directory + self._code_arm_id: Optional[str] = None + + # region valid properties + @property + def flow(self) -> str: + """The path to the flow definition file relative to the flow directory. + + :rtype: str + """ + return self._flow + + @property + def environment(self) -> Optional[Union[str, Environment]]: + """The environment for the flow component. Defaults to None. + + :rtype: Union[str, Environment]) + """ + return self._environment + + @environment.setter + def environment(self, value: Union[str, Environment]) -> None: + """The environment for the flow component. Defaults to None. + + :param value: The column mapping for the flow. + :type value: Union[str, Environment]) + """ + self._environment = value + + @property + def column_mapping(self) -> Dict[str, str]: + """The column mapping for the flow. Defaults to None. + + :rtype: Dict[str, str] + """ + return self._column_mapping + + @column_mapping.setter + def column_mapping(self, value: Optional[Dict[str, str]]) -> None: + """ + The column mapping for the flow. Defaults to None. + + :param value: The column mapping for the flow. + :type value: Optional[Dict[str, str]] + """ + self._column_mapping = value or {} + + @property + def variant(self) -> Optional[str]: + """The variant of the flow. Defaults to None. + + :rtype: Optional[str] + """ + return self._variant + + @variant.setter + def variant(self, value: Optional[str]) -> None: + """The variant of the flow. Defaults to None. + + :param value: The variant of the flow. + :type value: Optional[str] + """ + self._variant = value + + @property + def connections(self) -> Dict[str, Dict[str, str]]: + """The connections for the flow. Defaults to None. + + :rtype: Dict[str, Dict[str, str]] + """ + return self._connections + + @connections.setter + def connections(self, value: Optional[Dict[str, Dict[str, str]]]) -> None: + """ + The connections for the flow. Defaults to None. + + :param value: The connections for the flow. + :type value: Optional[Dict[str, Dict[str, str]]] + """ + self._connections = value or {} + + @property + def environment_variables(self) -> Dict[str, str]: + """The environment variables for the flow. Defaults to None. + + :rtype: Dict[str, str] + """ + return self._environment_variables + + @environment_variables.setter + def environment_variables(self, value: Optional[Dict[str, str]]) -> None: + """The environment variables for the flow. Defaults to None. + + :param value: The environment variables for the flow. + :type value: Optional[Dict[str, str]] + """ + self._environment_variables = value or {} + + @property + def additional_includes(self) -> List: + """A list of shared additional files to be included in the component. Defaults to None. + + :rtype: List + """ + return self._additional_includes + + @additional_includes.setter + def additional_includes(self, value: Optional[List]) -> None: + """A list of shared additional files to be included in the component. Defaults to None. + All local additional includes should be relative to the flow directory. + + :param value: A list of shared additional files to be included in the component. + :type value: Optional[List] + """ + self._additional_includes = value or [] + + # endregion + + @classmethod + def _normalize_component_name(cls, value: str) -> str: + return value.replace("-", "_") + + # region Component + @classmethod + def _from_rest_object_to_init_params(cls, obj: ComponentVersion) -> Dict: + raise RuntimeError("FlowComponent does not support loading from REST object.") + + def _to_rest_object(self) -> ComponentVersion: + rest_obj = super()._to_rest_object() + rest_obj.properties.component_spec["code"] = self._code_arm_id + rest_obj.properties.component_spec["flow_file_name"] = self._flow + return rest_obj + + def _func(self, **kwargs: Any) -> "Parallel": # pylint: disable=invalid-overridden-method + from azure.ai.ml.entities._builders.parallel import Parallel + + with self._inputs._fit_inputs(kwargs): # type: ignore[attr-defined] + # pylint: disable=not-callable + return super()._func(**kwargs) # type: ignore + + @classmethod + def _get_flow_definition( + cls, + base_path: Path, + *, + flow: Optional[Union[str, os.PathLike]] = None, + source_path: Optional[Union[str, os.PathLike]] = None, + ) -> Tuple[Path, str]: + """ + Get the path to the flow directory and the file name of the flow dag yaml file. + If flow is not specified, we will assume that the source_path is the path to the flow dag yaml file. + If flow is specified, it can be either a path to the flow dag yaml file or a path to the flow directory. + If flow is a path to the flow directory, we will assume that the flow dag yaml file is named flow.dag.yaml. + + :param base_path: The base path of the flow component. + :type base_path: Path + :keyword flow: The path to the flow directory or flow definition file. Defaults to None and base path of this + component will be used as flow directory. + :type flow: Optional[Union[str, Path]] + :keyword source_path: The source path of the flow component, should be path to the flow dag yaml file + if specified. + :type source_path: Optional[Union[str, os.PathLike]] + :return: The path to the flow directory and the file name of the flow dag yaml file. + :rtype: Tuple[Path, str] + """ + flow_file_name = "flow.dag.yaml" + + if flow is None and source_path is None: + raise cls._create_validation_error( + message="Either flow or source_path must be specified.", + no_personal_data_message="Either flow or source_path must be specified.", + ) + + if flow is None: + # Flow component must be created with a local yaml file, so no need to check if source_path exists + if isinstance(source_path, (os.PathLike, str)): + flow_file_name = os.path.basename(source_path) + return Path(base_path), flow_file_name + + flow_path = Path(flow) + if not flow_path.is_absolute(): + # if flow_path points to a symlink, we still use the parent of the symlink as origin code + flow_path = Path(base_path, flow) + + if flow_path.is_dir() and (flow_path / flow_file_name).is_file(): + return flow_path, flow_file_name + + if flow_path.is_file(): + return flow_path.parent, flow_path.name + + raise cls._create_validation_error( + message="Flow path must be a directory containing flow.dag.yaml or a file, but got %s" % flow_path, + no_personal_data_message="Flow path must be a directory or a file", + ) + + # endregion + + # region SchemaValidatableMixin + @classmethod + def _load_with_schema( + cls, data: Any, *, context: Optional[Any] = None, raise_original_exception: bool = False, **kwargs: Any + ) -> Any: + # FlowComponent should be loaded with FlowSchema or FlowRunSchema instead of FlowComponentSchema + context = context or {BASE_PATH_CONTEXT_KEY: Path.cwd()} + _schema = data.get("$schema", None) + if _schema == SchemaUrl.PROMPTFLOW_RUN: + schema = RunSchema(context=context) + elif _schema == SchemaUrl.PROMPTFLOW_FLOW: + schema = FlowSchema(context=context) + else: + raise cls._create_validation_error( + message="$schema must be specified correctly for loading component from flow, but got %s" % _schema, + no_personal_data_message="$schema must be specified for loading component from flow", + ) + + # unlike other component, we should ignore unknown fields in flow to keep init_params clean and avoid + # too much understanding of flow.dag.yaml & run.yaml + kwargs["unknown"] = EXCLUDE + try: + loaded_dict = schema.load(data, **kwargs) + except ValidationError as e: + if raise_original_exception: + raise e + msg = "Trying to load data with schema failed. Data:\n%s\nError: %s" % ( + json.dumps(data, indent=4) if isinstance(data, dict) else data, + json.dumps(e.messages, indent=4), + ) + raise cls._create_validation_error( + message=msg, + no_personal_data_message=str(e), + ) from e + loaded_dict.update(loaded_dict.pop(PROMPTFLOW_AZUREML_OVERRIDE_KEY, {})) + return loaded_dict + + @classmethod + def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]: + return FlowComponentSchema(context=context) + + # endregion + + # region AdditionalIncludesMixin + def _get_origin_code_value(self) -> Union[str, os.PathLike, None]: + if self._code_arm_id: + return self._code_arm_id + res: Union[str, os.PathLike, None] = self.base_path + return res + + def _fill_back_code_value(self, value: str) -> None: + self._code_arm_id = value + + @contextlib.contextmanager + def _try_build_local_code(self) -> Generator: + # false-positive by pylint, hence disable it + # (https://github.com/pylint-dev/pylint/blob/main/doc/data/messages + # /c/contextmanager-generator-missing-cleanup/details.rst) + with super()._try_build_local_code() as code: # pylint:disable=contextmanager-generator-missing-cleanup + if not code or not code.path: + yield code + return + + if not (Path(code.path) / ".promptflow" / "flow.tools.json").is_file(): + raise self._create_validation_error( + message="Flow component must be created with a ./promptflow/flow.tools.json, " + "please run `pf flow validate` to generate it or skip it in your ignore file.", + no_personal_data_message="Flow component must be created with a ./promptflow/flow.tools.json, " + "please run `pf flow validate` to generate it or skip it in your ignore file.", + ) + # TODO: should we remove additional includes from flow.dag.yaml? for now we suppose it will be removed + # by mldesigner compile if needed + + yield code + + # endregion diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/import_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/import_component.py new file mode 100644 index 00000000..13464a06 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/import_component.py @@ -0,0 +1,96 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from pathlib import Path +from typing import Any, Dict, Optional, Union + +from marshmallow import Schema + +from azure.ai.ml._schema.component.import_component import ImportComponentSchema +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, COMPONENT_TYPE +from azure.ai.ml.constants._component import NodeType + +from ..._schema import PathAwareSchema +from ..._utils.utils import parse_args_description_from_docstring +from .._util import convert_ordered_dict_to_dict +from .component import Component + + +class ImportComponent(Component): + """Import component version, used to define an import component. + + :param name: Name of the component. + :type name: str + :param version: Version of the component. + :type version: str + :param description: Description of the component. + :type description: str + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: dict + :param display_name: Display name of the component. + :type display_name: str + :param source: Input source parameters of the component. + :type source: dict + :param output: Output of the component. + :type output: dict + :param is_deterministic: Whether the command component is deterministic. Defaults to True. + :type is_deterministic: bool + :param kwargs: Additional parameters for the import component. + """ + + def __init__( + self, + *, + name: Optional[str] = None, + version: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[Dict] = None, + display_name: Optional[str] = None, + source: Optional[Dict] = None, + output: Optional[Dict] = None, + is_deterministic: bool = True, + **kwargs: Any, + ) -> None: + kwargs[COMPONENT_TYPE] = NodeType.IMPORT + # Set default base path + if BASE_PATH_CONTEXT_KEY not in kwargs: + kwargs[BASE_PATH_CONTEXT_KEY] = Path(".") + + super().__init__( + name=name, + version=version, + description=description, + tags=tags, + display_name=display_name, + inputs=source, + outputs={"output": output} if output else None, + is_deterministic=is_deterministic, + **kwargs, + ) + + self.source = source + self.output = output + + def _to_dict(self) -> Dict: + # TODO: Bug Item number: 2897665 + res: Dict = convert_ordered_dict_to_dict( # type: ignore + {**self._other_parameter, **super(ImportComponent, self)._to_dict()} + ) + return res + + @classmethod + def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]: + return ImportComponentSchema(context=context) + + @classmethod + def _parse_args_description_from_docstring(cls, docstring: str) -> Dict: + res: dict = parse_args_description_from_docstring(docstring) + return res + + def __str__(self) -> str: + try: + toYaml: str = self._to_yaml() + return toYaml + except BaseException: # pylint: disable=W0718 + toStr: str = super(ImportComponent, self).__str__() + return toStr diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/parallel_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/parallel_component.py new file mode 100644 index 00000000..3f29b1e1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/parallel_component.py @@ -0,0 +1,305 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import json +import os +import re +from typing import Any, Dict, List, Optional, Union, cast + +from marshmallow import Schema + +from azure.ai.ml._restclient.v2022_10_01.models import ComponentVersion +from azure.ai.ml._schema.component.parallel_component import ParallelComponentSchema +from azure.ai.ml.constants._common import COMPONENT_TYPE +from azure.ai.ml.constants._component import NodeType +from azure.ai.ml.entities._job.job_resource_configuration import JobResourceConfiguration +from azure.ai.ml.entities._job.parallel.parallel_task import ParallelTask +from azure.ai.ml.entities._job.parallel.parameterized_parallel import ParameterizedParallel +from azure.ai.ml.entities._job.parallel.retry_settings import RetrySettings +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException + +from ..._schema import PathAwareSchema +from .._util import validate_attribute_type +from .._validation import MutableValidationResult +from .code import ComponentCodeMixin +from .component import Component + + +class ParallelComponent( + Component, ParameterizedParallel, ComponentCodeMixin +): # pylint: disable=too-many-instance-attributes + """Parallel component version, used to define a parallel component. + + :param name: Name of the component. Defaults to None + :type name: str + :param version: Version of the component. Defaults to None + :type version: str + :param description: Description of the component. Defaults to None + :type description: str + :param tags: Tag dictionary. Tags can be added, removed, and updated. Defaults to None + :type tags: dict + :param display_name: Display name of the component. Defaults to None + :type display_name: str + :param retry_settings: parallel component run failed retry. Defaults to None + :type retry_settings: BatchRetrySettings + :param logging_level: A string of the logging level name. Defaults to None + :type logging_level: str + :param max_concurrency_per_instance: The max parallellism that each compute instance has. Defaults to None + :type max_concurrency_per_instance: int + :param error_threshold: The number of item processing failures should be ignored. Defaults to None + :type error_threshold: int + :param mini_batch_error_threshold: The number of mini batch processing failures should be ignored. Defaults to None + :type mini_batch_error_threshold: int + :param task: The parallel task. Defaults to None + :type task: ParallelTask + :param mini_batch_size: For FileDataset input, this field is the number of files a user script can process + in one run() call. For TabularDataset input, this field is the approximate size of data the user script + can process in one run() call. Example values are 1024, 1024KB, 10MB, and 1GB. + (optional, default value is 10 files for FileDataset and 1MB for TabularDataset.) This value could be set + through PipelineParameter. + :type mini_batch_size: str + :param partition_keys: The keys used to partition dataset into mini-batches. Defaults to None + If specified, the data with the same key will be partitioned into the same mini-batch. + If both partition_keys and mini_batch_size are specified, partition_keys will take effect. + The input(s) must be partitioned dataset(s), + and the partition_keys must be a subset of the keys of every input dataset for this to work. + :type partition_keys: list + :param input_data: The input data. Defaults to None + :type input_data: str + :param resources: Compute Resource configuration for the component. Defaults to None + :type resources: Union[dict, ~azure.ai.ml.entities.JobResourceConfiguration] + :param inputs: Inputs of the component. Defaults to None + :type inputs: dict + :param outputs: Outputs of the component. Defaults to None + :type outputs: dict + :param code: promoted property from task.code + :type code: str + :param instance_count: promoted property from resources.instance_count. Defaults to None + :type instance_count: int + :param is_deterministic: Whether the parallel component is deterministic. Defaults to True + :type is_deterministic: bool + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if ParallelComponent cannot be successfully validated. + Details will be provided in the error message. + """ + + def __init__( # pylint: disable=too-many-locals + self, + *, + name: Optional[str] = None, + version: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[Dict[str, Any]] = None, + display_name: Optional[str] = None, + retry_settings: Optional[RetrySettings] = None, + logging_level: Optional[str] = None, + max_concurrency_per_instance: Optional[int] = None, + error_threshold: Optional[int] = None, + mini_batch_error_threshold: Optional[int] = None, + task: Optional[ParallelTask] = None, + mini_batch_size: Optional[str] = None, + partition_keys: Optional[List] = None, + input_data: Optional[str] = None, + resources: Optional[JobResourceConfiguration] = None, + inputs: Optional[Dict] = None, + outputs: Optional[Dict] = None, + code: Optional[str] = None, # promoted property from task.code + instance_count: Optional[int] = None, # promoted property from resources.instance_count + is_deterministic: bool = True, + **kwargs: Any, + ): + # validate init params are valid type + validate_attribute_type(attrs_to_check=locals(), attr_type_map=self._attr_type_map()) + + kwargs[COMPONENT_TYPE] = NodeType.PARALLEL + + super().__init__( + name=name, + version=version, + description=description, + tags=tags, + display_name=display_name, + inputs=inputs, + outputs=outputs, + is_deterministic=is_deterministic, + **kwargs, + ) + + # No validation on value passed here because in pipeline job, required code&environment maybe absent + # and fill in later with job defaults. + self.task = task + self.mini_batch_size: int = 0 + self.partition_keys = partition_keys + self.input_data = input_data + self.retry_settings = retry_settings + self.logging_level = logging_level + self.max_concurrency_per_instance = max_concurrency_per_instance + self.error_threshold = error_threshold + self.mini_batch_error_threshold = mini_batch_error_threshold + self.resources = resources + + # check mutual exclusivity of promoted properties + if self.resources is not None and instance_count is not None: + msg = "instance_count and resources are mutually exclusive" + raise ValidationException( + message=msg, + target=ErrorTarget.COMPONENT, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + ) + self.instance_count = instance_count + self.code = code + + if mini_batch_size is not None: + # Convert str to int. + pattern = re.compile(r"^\d+([kKmMgG][bB])*$") + if not pattern.match(mini_batch_size): + raise ValueError(r"Parameter mini_batch_size must follow regex rule ^\d+([kKmMgG][bB])*$") + + try: + self.mini_batch_size = int(mini_batch_size) + except ValueError as e: + unit = mini_batch_size[-2:].lower() + if unit == "kb": + self.mini_batch_size = int(mini_batch_size[0:-2]) * 1024 + elif unit == "mb": + self.mini_batch_size = int(mini_batch_size[0:-2]) * 1024 * 1024 + elif unit == "gb": + self.mini_batch_size = int(mini_batch_size[0:-2]) * 1024 * 1024 * 1024 + else: + raise ValueError("mini_batch_size unit must be kb, mb or gb") from e + + @property + def instance_count(self) -> Optional[int]: + """Return value of promoted property resources.instance_count. + + :return: Value of resources.instance_count. + :rtype: Optional[int] + """ + return self.resources.instance_count if self.resources and not isinstance(self.resources, dict) else None + + @instance_count.setter + def instance_count(self, value: int) -> None: + """Set the value of the promoted property resources.instance_count. + + :param value: The value to set for resources.instance_count. + :type value: int + """ + if not value: + return + if not self.resources: + self.resources = JobResourceConfiguration(instance_count=value) + else: + if not isinstance(self.resources, dict): + self.resources.instance_count = value + + @property + def code(self) -> Optional[str]: + """Return value of promoted property task.code, which is a local or + remote path pointing at source code. + + :return: Value of task.code. + :rtype: Optional[str] + """ + return self.task.code if self.task else None + + @code.setter + def code(self, value: str) -> None: + """Set the value of the promoted property task.code. + + :param value: The value to set for task.code. + :type value: str + """ + if not value: + return + if not self.task: + self.task = ParallelTask(code=value) + else: + self.task.code = value + + def _to_ordered_dict_for_yaml_dump(self) -> Dict: + """Dump the component content into a sorted yaml string. + + :return: The ordered dict + :rtype: Dict + """ + + obj: dict = super()._to_ordered_dict_for_yaml_dump() + # dict dumped base on schema will transfer code to an absolute path, while we want to keep its original value + if self.code and isinstance(self.code, str): + obj["task"]["code"] = self.code + return obj + + @property + def environment(self) -> Optional[str]: + """Return value of promoted property task.environment, indicate the + environment that training job will run in. + + :return: Value of task.environment. + :rtype: Optional[Environment, str] + """ + if self.task: + return cast(Optional[str], self.task.environment) + return None + + @environment.setter + def environment(self, value: str) -> None: + """Set the value of the promoted property task.environment. + + :param value: The value to set for task.environment. + :type value: str + """ + if not value: + return + if not self.task: + self.task = ParallelTask(environment=value) + else: + self.task.environment = value + + def _customized_validate(self) -> MutableValidationResult: + validation_result = super()._customized_validate() + self._append_diagnostics_and_check_if_origin_code_reliable_for_local_path_validation(validation_result) + return validation_result + + @classmethod + def _attr_type_map(cls) -> dict: + return { + "retry_settings": (dict, RetrySettings), + "task": (dict, ParallelTask), + "logging_level": str, + "max_concurrency_per_instance": int, + "input_data": str, + "error_threshold": int, + "mini_batch_error_threshold": int, + "code": (str, os.PathLike), + "resources": (dict, JobResourceConfiguration), + } + + def _to_rest_object(self) -> ComponentVersion: + rest_object = super()._to_rest_object() + # schema required list while backend accept json string + if self.partition_keys: + rest_object.properties.component_spec["partition_keys"] = json.dumps(self.partition_keys) + return rest_object + + @classmethod + def _from_rest_object_to_init_params(cls, obj: ComponentVersion) -> Dict: + # schema required list while backend accept json string + # update rest obj as it will be + partition_keys = obj.properties.component_spec.get("partition_keys", None) + if partition_keys: + obj.properties.component_spec["partition_keys"] = json.loads(partition_keys) + res: dict = super()._from_rest_object_to_init_params(obj) + return res + + @classmethod + def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]: + return ParallelComponentSchema(context=context) + + def __str__(self) -> str: + try: + toYaml: str = self._to_yaml() + return toYaml + except BaseException: # pylint: disable=W0718 + toStr: str = super(ParallelComponent, self).__str__() + return toStr diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/pipeline_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/pipeline_component.py new file mode 100644 index 00000000..229b714d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/pipeline_component.py @@ -0,0 +1,529 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import json +import logging +import os +import re +import time +import typing +from collections import Counter +from typing import Any, Dict, List, Optional, Tuple, Union + +from marshmallow import Schema + +from azure.ai.ml._restclient.v2022_10_01.models import ComponentVersion, ComponentVersionProperties +from azure.ai.ml._schema import PathAwareSchema +from azure.ai.ml._schema.pipeline.pipeline_component import PipelineComponentSchema +from azure.ai.ml._utils._asset_utils import get_object_hash +from azure.ai.ml._utils.utils import hash_dict, is_data_binding_expression +from azure.ai.ml.constants._common import ARM_ID_PREFIX, ASSET_ARM_ID_REGEX_FORMAT, COMPONENT_TYPE +from azure.ai.ml.constants._component import ComponentSource, NodeType +from azure.ai.ml.constants._job.pipeline import ValidationErrorCode +from azure.ai.ml.entities._builders import BaseNode, Command +from azure.ai.ml.entities._builders.control_flow_node import ControlFlowNode, LoopNode +from azure.ai.ml.entities._component.component import Component +from azure.ai.ml.entities._inputs_outputs import GroupInput, Input +from azure.ai.ml.entities._job.automl.automl_job import AutoMLJob +from azure.ai.ml.entities._job.pipeline._attr_dict import has_attr_safe, try_get_non_arbitrary_attr +from azure.ai.ml.entities._job.pipeline._pipeline_expression import PipelineExpression +from azure.ai.ml.entities._validation import MutableValidationResult +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException + +module_logger = logging.getLogger(__name__) + + +class PipelineComponent(Component): + """Pipeline component, currently used to store components in an azure.ai.ml.dsl.pipeline. + + :param name: Name of the component. + :type name: str + :param version: Version of the component. + :type version: str + :param description: Description of the component. + :type description: str + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: dict + :param display_name: Display name of the component. + :type display_name: str + :param inputs: Component inputs. + :type inputs: dict + :param outputs: Component outputs. + :type outputs: dict + :param jobs: Id to components dict inside the pipeline definition. + :type jobs: Dict[str, ~azure.ai.ml.entities._builders.BaseNode] + :param is_deterministic: Whether the pipeline component is deterministic. + :type is_deterministic: bool + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if PipelineComponent cannot be successfully validated. + Details will be provided in the error message. + """ + + def __init__( + self, + *, + name: Optional[str] = None, + version: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[Dict] = None, + display_name: Optional[str] = None, + inputs: Optional[Dict] = None, + outputs: Optional[Dict] = None, + jobs: Optional[Dict[str, BaseNode]] = None, + is_deterministic: Optional[bool] = None, + **kwargs: Any, + ) -> None: + kwargs[COMPONENT_TYPE] = NodeType.PIPELINE + super().__init__( + name=name, + version=version, + description=description, + tags=tags, + display_name=display_name, + inputs=inputs, + outputs=outputs, + is_deterministic=is_deterministic, # type: ignore[arg-type] + **kwargs, + ) + self._jobs = self._process_jobs(jobs) if jobs else {} + # for telemetry + self._job_types, self._job_sources = self._get_job_type_and_source() + # Private support: create pipeline component from pipeline job + self._source_job_id = kwargs.pop("source_job_id", None) + # TODO: set anonymous hash for reuse + + def _process_jobs(self, jobs: Dict[str, BaseNode]) -> Dict[str, BaseNode]: + """Process and validate jobs. + + :param jobs: A map of node name to node + :type jobs: Dict[str, BaseNode] + :return: The processed jobs + :rtype: Dict[str, BaseNode] + """ + # Remove swept Command + node_names_to_skip = [] + for node_name, job_instance in jobs.items(): + if isinstance(job_instance, Command) and job_instance._swept is True: + node_names_to_skip.append(node_name) + + for key in node_names_to_skip: + del jobs[key] + + # Set path and validate node type. + for _, job_instance in jobs.items(): + if isinstance(job_instance, BaseNode): + job_instance._set_base_path(self.base_path) + + if not isinstance(job_instance, (BaseNode, AutoMLJob, ControlFlowNode)): + msg = f"Not supported pipeline job type: {type(job_instance)}" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.PIPELINE, + error_category=ErrorCategory.USER_ERROR, + ) + return jobs + + def _customized_validate(self) -> MutableValidationResult: + """Validate pipeline component structure. + + :return: The validation result + :rtype: MutableValidationResult + """ + validation_result = super(PipelineComponent, self)._customized_validate() + + # Validate inputs + for input_name, input_value in self.inputs.items(): + if input_value.type is None: + validation_result.append_error( + yaml_path="inputs.{}".format(input_name), + message="Parameter type unknown, please add type annotation or specify input default value.", + error_code=ValidationErrorCode.PARAMETER_TYPE_UNKNOWN, + ) + + # Validate all nodes + for node_name, node in self.jobs.items(): + if isinstance(node, BaseNode): + # Node inputs will be validated. + validation_result.merge_with(node._validate(), "jobs.{}".format(node_name)) + if isinstance(node.component, Component): + # Validate binding if not remote resource. + validation_result.merge_with(self._validate_binding_inputs(node)) + elif isinstance(node, AutoMLJob): + pass + elif isinstance(node, ControlFlowNode): + # Validate control flow node. + validation_result.merge_with(node._validate(), "jobs.{}".format(node_name)) + else: + validation_result.append_error( + yaml_path="jobs.{}".format(node_name), + message=f"Not supported pipeline job type: {type(node)}", + ) + + return validation_result + + def _validate_compute_is_set(self, *, parent_node_name: Optional[str] = None) -> MutableValidationResult: + """Validate compute in pipeline component. + + This function will only be called from pipeline_job._validate_compute_is_set + when both of the pipeline_job.compute and pipeline_job.settings.default_compute is None. + Rules: + - For pipeline node: will call node._component._validate_compute_is_set to validate node compute in sub graph. + - For general node: + - If _skip_required_compute_missing_validation is True, validation will be skipped. + - All the rest of cases without compute will add compute not set error to validation result. + + :keyword parent_node_name: The name of the parent node. + :type parent_node_name: Optional[str] + :return: The validation result + :rtype: MutableValidationResult + """ + + # Note: do not put this into customized validate, as we would like call + # this from pipeline_job._validate_compute_is_set + validation_result = self._create_empty_validation_result() + no_compute_nodes = [] + parent_node_name = parent_node_name if parent_node_name else "" + for node_name, node in self.jobs.items(): + full_node_name = f"{parent_node_name}{node_name}.jobs." + if node.type == NodeType.PIPELINE and isinstance(node._component, PipelineComponent): + validation_result.merge_with(node._component._validate_compute_is_set(parent_node_name=full_node_name)) + continue + if isinstance(node, BaseNode) and node._skip_required_compute_missing_validation: + continue + if has_attr_safe(node, "compute") and node.compute is None: + no_compute_nodes.append(node_name) + + for node_name in no_compute_nodes: + validation_result.append_error( + yaml_path=f"jobs.{parent_node_name}{node_name}.compute", + message="Compute not set", + ) + return validation_result + + def _get_input_binding_dict(self, node: BaseNode) -> Tuple[dict, dict]: + """Return the input binding dict for each node. + + :param node: The node + :type node: BaseNode + :return: A 2-tuple of (binding_dict, optional_binding_in_expression_dict) + :rtype: Tuple[dict, dict] + """ + # pylint: disable=too-many-nested-blocks + binding_inputs = node._build_inputs() + # Collect binding relation dict {'pipeline_input': ['node_input']} + binding_dict: dict = {} + optional_binding_in_expression_dict: dict = {} + for component_input_name, component_binding_input in binding_inputs.items(): + if isinstance(component_binding_input, PipelineExpression): + for pipeline_input_name in component_binding_input._inputs.keys(): + if pipeline_input_name not in self.inputs: + continue + if pipeline_input_name not in binding_dict: + binding_dict[pipeline_input_name] = [] + binding_dict[pipeline_input_name].append(component_input_name) + if pipeline_input_name not in optional_binding_in_expression_dict: + optional_binding_in_expression_dict[pipeline_input_name] = [] + optional_binding_in_expression_dict[pipeline_input_name].append(pipeline_input_name) + else: + if isinstance(component_binding_input, Input): + component_binding_input = component_binding_input.path + if is_data_binding_expression(component_binding_input, ["parent"]): + # data binding may have more than one PipelineInput now + for pipeline_input_name in PipelineExpression.parse_pipeline_inputs_from_data_binding( + component_binding_input + ): + if pipeline_input_name not in self.inputs: + continue + if pipeline_input_name not in binding_dict: + binding_dict[pipeline_input_name] = [] + binding_dict[pipeline_input_name].append(component_input_name) + # for data binding expression "${{parent.inputs.pipeline_input}}", it should not be optional + if len(component_binding_input.replace("${{parent.inputs." + pipeline_input_name + "}}", "")): + if pipeline_input_name not in optional_binding_in_expression_dict: + optional_binding_in_expression_dict[pipeline_input_name] = [] + optional_binding_in_expression_dict[pipeline_input_name].append(pipeline_input_name) + return binding_dict, optional_binding_in_expression_dict + + def _validate_binding_inputs(self, node: BaseNode) -> MutableValidationResult: + """Validate pipeline binding inputs and return all used pipeline input names. + + Mark input as optional if all binding is optional and optional not set. Raise error if pipeline input is + optional but link to required inputs. + + :param node: The node to validate + :type node: BaseNode + :return: The validation result + :rtype: MutableValidationResult + """ + component_definition_inputs = {} + # Add flattened group input into definition inputs. + # e.g. Add {'group_name.item': PipelineInput} for {'group_name': GroupInput} + for name, val in node.component.inputs.items(): + if isinstance(val, GroupInput): + component_definition_inputs.update(val.flatten(group_parameter_name=name)) + component_definition_inputs[name] = val + # Collect binding relation dict {'pipeline_input': ['node_input']} + validation_result = self._create_empty_validation_result() + binding_dict, optional_binding_in_expression_dict = self._get_input_binding_dict(node) + + # Validate links required and optional + for pipeline_input_name, binding_inputs in binding_dict.items(): + pipeline_input = self.inputs[pipeline_input_name] + required_bindings = [] + for name in binding_inputs: + # not check optional/required for pipeline input used in pipeline expression + if name in optional_binding_in_expression_dict.get(pipeline_input_name, []): + continue + if name in component_definition_inputs and component_definition_inputs[name].optional is not True: + required_bindings.append(f"{node.name}.inputs.{name}") + if pipeline_input.optional is None and not required_bindings: + # Set input as optional if all binding is optional and optional not set. + pipeline_input.optional = True + pipeline_input._is_inferred_optional = True + elif pipeline_input.optional is True and required_bindings: + if pipeline_input._is_inferred_optional: + # Change optional=True to None if is inferred by us + pipeline_input.optional = None + else: + # Raise exception if pipeline input is optional set by user but link to required inputs. + validation_result.append_error( + yaml_path="inputs.{}".format(pipeline_input._port_name), + message=f"Pipeline optional Input binding to required inputs: {required_bindings}", + ) + return validation_result + + def _get_job_type_and_source(self) -> Tuple[Dict[str, int], Dict[str, int]]: + """Get job types and sources for telemetry. + + :return: A 2-tuple of + * A map of job type to the number of occurrences + * A map of job source to the number of occurrences + :rtype: Tuple[Dict[str, int], Dict[str, int]] + """ + job_types: list = [] + job_sources = [] + for job in self.jobs.values(): + job_types.append(job.type) + if isinstance(job, BaseNode): + job_sources.append(job._source) + elif isinstance(job, AutoMLJob): + # Consider all automl_job has builder type for now, + # as it's not easy to distinguish their source(yaml/builder). + job_sources.append(ComponentSource.BUILDER) + else: + # Fall back to CLASS + job_sources.append(ComponentSource.CLASS) + return dict(Counter(job_types)), dict(Counter(job_sources)) + + @property + def jobs(self) -> Dict[str, BaseNode]: + """Return a dictionary from component variable name to component object. + + :return: Dictionary mapping component variable names to component objects. + :rtype: Dict[str, ~azure.ai.ml.entities._builders.BaseNode] + """ + return self._jobs + + def _get_anonymous_hash(self) -> str: + """Get anonymous hash for pipeline component. + + :return: The anonymous hash of the pipeline component + :rtype: str + """ + # ideally we should always use rest object to generate hash as it's the same as + # what we send to server-side, but changing the hash function will break reuse of + # existing components except for command component (hash result is the same for + # command component), so we just use rest object to generate hash for pipeline component, + # which doesn't have reuse issue. + component_interface_dict = self._to_rest_object().properties.component_spec + # Hash local inputs in pipeline component jobs + for job_name, job in self.jobs.items(): + if getattr(job, "inputs", None): + for input_name, input_value in job.inputs.items(): + try: + if ( + getattr(input_value, "_data", None) + and isinstance(input_value._data, Input) + and input_value.path + and os.path.exists(input_value.path) + ): + start_time = time.time() + component_interface_dict["jobs"][job_name]["inputs"][input_name]["content_hash"] = ( + get_object_hash(input_value.path) + ) + module_logger.debug( + "Takes %s seconds to calculate the content hash of local input %s", + time.time() - start_time, + input_value.path, + ) + except ValidationException: + pass + hash_value: str = hash_dict( + component_interface_dict, + keys_to_omit=[ + # omit name since anonymous component will have same name + "name", + # omit _source since it doesn't impact component's uniqueness + "_source", + # omit id since it will be set after component is registered + "id", + # omit version since it will be set to this hash later + "version", + ], + ) + return hash_value + + @classmethod + def _load_from_rest_pipeline_job(cls, data: Dict) -> "PipelineComponent": + # TODO: refine this? + # Set type as None here to avoid schema validation failed + definition_inputs = {p: {"type": None} for p in data.get("inputs", {}).keys()} + definition_outputs = {p: {"type": None} for p in data.get("outputs", {}).keys()} + return PipelineComponent( + display_name=data.get("display_name"), + description=data.get("description"), + inputs=definition_inputs, + outputs=definition_outputs, + jobs=data.get("jobs"), + _source=ComponentSource.REMOTE_WORKSPACE_JOB, + ) + + @classmethod + def _resolve_sub_nodes(cls, rest_jobs: Dict) -> Dict: + from azure.ai.ml.entities._job.pipeline._load_component import pipeline_node_factory + + sub_nodes = {} + if rest_jobs is None: + return sub_nodes + for node_name, node in rest_jobs.items(): + # TODO: Remove this ad-hoc fix after unified arm id format in object + component_id = node.get("componentId", "") + if isinstance(component_id, str) and re.match(ASSET_ARM_ID_REGEX_FORMAT, component_id): + node["componentId"] = component_id[len(ARM_ID_PREFIX) :] + if not LoopNode._is_loop_node_dict(node): + # skip resolve LoopNode first since it may reference other nodes + # use node factory instead of BaseNode._from_rest_object here as AutoMLJob is not a BaseNode + sub_nodes[node_name] = pipeline_node_factory.load_from_rest_object(obj=node) + for node_name, node in rest_jobs.items(): + if LoopNode._is_loop_node_dict(node): + # resolve LoopNode after all other nodes are resolved + sub_nodes[node_name] = pipeline_node_factory.load_from_rest_object(obj=node, pipeline_jobs=sub_nodes) + return sub_nodes + + @classmethod + def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]: + return PipelineComponentSchema(context=context) + + @classmethod + def _get_skip_fields_in_schema_validation(cls) -> typing.List[str]: + # jobs validations are done in _customized_validate() + return ["jobs"] + + @classmethod + def _check_ignored_keys(cls, obj: object) -> List[str]: + """Return ignored keys in obj as a pipeline component when its value be set. + + :param obj: The object to examine + :type obj: object + :return: List of keys to ignore + :rtype: List[str] + """ + examine_mapping = { + "compute": lambda val: val is not None, + "settings": lambda val: val is not None and any(v is not None for v in val._to_dict().values()), + } + # Avoid new attr added by use `try_get_non...` instead of `hasattr` or `getattr` directly. + return [k for k, has_set in examine_mapping.items() if has_set(try_get_non_arbitrary_attr(obj, k))] + + def _get_telemetry_values(self, *args: Any, **kwargs: Any) -> Dict: + telemetry_values: dict = super()._get_telemetry_values() + telemetry_values.update( + { + "source": self._source, + "node_count": len(self.jobs), + "node_type": json.dumps(self._job_types), + "node_source": json.dumps(self._job_sources), + } + ) + return telemetry_values + + @classmethod + def _from_rest_object_to_init_params(cls, obj: ComponentVersion) -> Dict: + # Pop jobs to avoid it goes with schema load + jobs = obj.properties.component_spec.pop("jobs", None) + init_params_dict: dict = super()._from_rest_object_to_init_params(obj) + if jobs: + try: + init_params_dict["jobs"] = PipelineComponent._resolve_sub_nodes(jobs) + except Exception as e: # pylint: disable=W0718 + # Skip parse jobs if error exists. + # TODO: https://msdata.visualstudio.com/Vienna/_workitems/edit/2052262 + module_logger.debug("Parse pipeline component jobs failed with: %s", e) + return init_params_dict + + def _to_dict(self) -> Dict: + return {**self._other_parameter, **super()._to_dict()} + + def _build_rest_component_jobs(self) -> Dict[str, dict]: + """Build pipeline component jobs to rest. + + :return: A map of job name to rest objects + :rtype: Dict[str, dict] + """ + # Build the jobs to dict + rest_component_jobs = {} + for job_name, job in self.jobs.items(): + if isinstance(job, (BaseNode, ControlFlowNode)): + rest_node_dict = job._to_rest_object() + elif isinstance(job, AutoMLJob): + rest_node_dict = json.loads(json.dumps(job._to_dict(inside_pipeline=True))) + else: + msg = f"Non supported job type in Pipeline jobs: {type(job)}" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.PIPELINE, + error_category=ErrorCategory.USER_ERROR, + ) + rest_component_jobs[job_name] = rest_node_dict + return rest_component_jobs + + def _to_rest_object(self) -> ComponentVersion: + """Check ignored keys and return rest object. + + :return: The component version + :rtype: ComponentVersion + """ + ignored_keys = self._check_ignored_keys(self) + if ignored_keys: + module_logger.warning("%s ignored on pipeline component %r.", ignored_keys, self.name) + component = self._to_dict() + # add source type to component rest object + component["_source"] = self._source + component["jobs"] = self._build_rest_component_jobs() + component["sourceJobId"] = self._source_job_id + if self._intellectual_property: + # hack while full pass through supported is worked on for IPP fields + component.pop("intellectual_property") + component["intellectualProperty"] = self._intellectual_property._to_rest_object().serialize() + properties = ComponentVersionProperties( + component_spec=component, + description=self.description, + is_anonymous=self._is_anonymous, + properties=self.properties, + tags=self.tags, + ) + result = ComponentVersion(properties=properties) + result.name = self.name + return result + + def __str__(self) -> str: + try: + toYaml: str = self._to_yaml() + return toYaml + except BaseException: # pylint: disable=W0718 + toStr: str = super(PipelineComponent, self).__str__() + return toStr diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/spark_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/spark_component.py new file mode 100644 index 00000000..7da65fb6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/spark_component.py @@ -0,0 +1,211 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import os +from typing import Any, Dict, List, Optional, Union + +from marshmallow import Schema + +from azure.ai.ml._schema.component.spark_component import SparkComponentSchema +from azure.ai.ml.constants._common import COMPONENT_TYPE +from azure.ai.ml.constants._component import NodeType +from azure.ai.ml.constants._job.job import RestSparkConfKey +from azure.ai.ml.entities._assets import Environment +from azure.ai.ml.entities._job.parameterized_spark import ParameterizedSpark + +from ..._schema import PathAwareSchema +from .._job.spark_job_entry_mixin import SparkJobEntry, SparkJobEntryMixin +from .._util import convert_ordered_dict_to_dict, validate_attribute_type +from .._validation import MutableValidationResult +from ._additional_includes import AdditionalIncludesMixin +from .component import Component + + +class SparkComponent( + Component, ParameterizedSpark, SparkJobEntryMixin, AdditionalIncludesMixin +): # pylint: disable=too-many-instance-attributes + """Spark component version, used to define a Spark Component or Job. + + :keyword code: The source code to run the job. Can be a local path or "http:", "https:", or "azureml:" url pointing + to a remote location. Defaults to ".", indicating the current directory. + :type code: Union[str, os.PathLike] + :keyword entry: The file or class entry point. + :paramtype entry: Optional[Union[dict[str, str], ~azure.ai.ml.entities.SparkJobEntry]] + :keyword py_files: The list of .zip, .egg or .py files to place on the PYTHONPATH for Python apps. Defaults to None. + :paramtype py_files: Optional[List[str]] + :keyword jars: The list of .JAR files to include on the driver and executor classpaths. Defaults to None. + :paramtype jars: Optional[List[str]] + :keyword files: The list of files to be placed in the working directory of each executor. Defaults to None. + :paramtype files: Optional[List[str]] + :keyword archives: The list of archives to be extracted into the working directory of each executor. + Defaults to None. + :paramtype archives: Optional[List[str]] + :keyword driver_cores: The number of cores to use for the driver process, only in cluster mode. + :paramtype driver_cores: Optional[int] + :keyword driver_memory: The amount of memory to use for the driver process, formatted as strings with a size unit + suffix ("k", "m", "g" or "t") (e.g. "512m", "2g"). + :paramtype driver_memory: Optional[str] + :keyword executor_cores: The number of cores to use on each executor. + :paramtype executor_cores: Optional[int] + :keyword executor_memory: The amount of memory to use per executor process, formatted as strings with a size unit + suffix ("k", "m", "g" or "t") (e.g. "512m", "2g"). + :paramtype executor_memory: Optional[str] + :keyword executor_instances: The initial number of executors. + :paramtype executor_instances: Optional[int] + :keyword dynamic_allocation_enabled: Whether to use dynamic resource allocation, which scales the number of + executors registered with this application up and down based on the workload. Defaults to False. + :paramtype dynamic_allocation_enabled: Optional[bool] + :keyword dynamic_allocation_min_executors: The lower bound for the number of executors if dynamic allocation is + enabled. + :paramtype dynamic_allocation_min_executors: Optional[int] + :keyword dynamic_allocation_max_executors: The upper bound for the number of executors if dynamic allocation is + enabled. + :paramtype dynamic_allocation_max_executors: Optional[int] + :keyword conf: A dictionary with pre-defined Spark configurations key and values. Defaults to None. + :paramtype conf: Optional[dict[str, str]] + :keyword environment: The Azure ML environment to run the job in. + :paramtype environment: Optional[Union[str, ~azure.ai.ml.entities.Environment]] + :keyword inputs: A mapping of input names to input data sources used in the job. Defaults to None. + :paramtype inputs: Optional[dict[str, Union[ + ~azure.ai.ml.entities._job.pipeline._io.NodeOutput, + ~azure.ai.ml.Input, + str, + bool, + int, + float, + Enum, + ]]] + :keyword outputs: A mapping of output names to output data sources used in the job. Defaults to None. + :paramtype outputs: Optional[dict[str, Union[str, ~azure.ai.ml.Output]]] + :keyword args: The arguments for the job. Defaults to None. + :paramtype args: Optional[str] + :keyword additional_includes: A list of shared additional files to be included in the component. Defaults to None. + :paramtype additional_includes: Optional[List[str]] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_spark_configurations.py + :start-after: [START spark_component_definition] + :end-before: [END spark_component_definition] + :language: python + :dedent: 8 + :caption: Creating SparkComponent. + """ + + def __init__( + self, + *, + code: Optional[Union[str, os.PathLike]] = ".", + entry: Optional[Union[Dict[str, str], SparkJobEntry]] = None, + py_files: Optional[List[str]] = None, + jars: Optional[List[str]] = None, + files: Optional[List[str]] = None, + archives: Optional[List[str]] = None, + driver_cores: Optional[Union[int, str]] = None, + driver_memory: Optional[str] = None, + executor_cores: Optional[Union[int, str]] = None, + executor_memory: Optional[str] = None, + executor_instances: Optional[Union[int, str]] = None, + dynamic_allocation_enabled: Optional[Union[bool, str]] = None, + dynamic_allocation_min_executors: Optional[Union[int, str]] = None, + dynamic_allocation_max_executors: Optional[Union[int, str]] = None, + conf: Optional[Dict[str, str]] = None, + environment: Optional[Union[str, Environment]] = None, + inputs: Optional[Dict] = None, + outputs: Optional[Dict] = None, + args: Optional[str] = None, + additional_includes: Optional[List] = None, + **kwargs: Any, + ) -> None: + # validate init params are valid type + validate_attribute_type(attrs_to_check=locals(), attr_type_map=self._attr_type_map()) + + kwargs[COMPONENT_TYPE] = NodeType.SPARK + + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + self.code: Optional[Union[str, os.PathLike]] = code + self.entry = entry + self.py_files = py_files + self.jars = jars + self.files = files + self.archives = archives + self.conf = conf + self.environment = environment + self.args = args + self.additional_includes = additional_includes or [] + # For pipeline spark job, we also allow user to set driver_cores, driver_memory and so on by setting conf. + # If root level fields are not set by user, we promote conf setting to root level to facilitate subsequent + # verification. This usually happens when we use to_component(SparkJob) or builder function spark() as a node + # in pipeline sdk + conf = conf or {} + self.driver_cores = driver_cores or conf.get(RestSparkConfKey.DRIVER_CORES, None) + self.driver_memory = driver_memory or conf.get(RestSparkConfKey.DRIVER_MEMORY, None) + self.executor_cores = executor_cores or conf.get(RestSparkConfKey.EXECUTOR_CORES, None) + self.executor_memory = executor_memory or conf.get(RestSparkConfKey.EXECUTOR_MEMORY, None) + self.executor_instances = executor_instances or conf.get(RestSparkConfKey.EXECUTOR_INSTANCES, None) + self.dynamic_allocation_enabled = dynamic_allocation_enabled or conf.get( + RestSparkConfKey.DYNAMIC_ALLOCATION_ENABLED, None + ) + self.dynamic_allocation_min_executors = dynamic_allocation_min_executors or conf.get( + RestSparkConfKey.DYNAMIC_ALLOCATION_MIN_EXECUTORS, None + ) + self.dynamic_allocation_max_executors = dynamic_allocation_max_executors or conf.get( + RestSparkConfKey.DYNAMIC_ALLOCATION_MAX_EXECUTORS, None + ) + + @classmethod + def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]: + return SparkComponentSchema(context=context) + + @classmethod + def _attr_type_map(cls) -> dict: + return { + "environment": (str, Environment), + "code": (str, os.PathLike), + } + + def _customized_validate(self) -> MutableValidationResult: + validation_result = super()._customized_validate() + self._append_diagnostics_and_check_if_origin_code_reliable_for_local_path_validation(validation_result) + return validation_result + + def _to_dict(self) -> Dict: + # TODO: Bug Item number: 2897665 + res: Dict = convert_ordered_dict_to_dict( # type: ignore + {**self._other_parameter, **super(SparkComponent, self)._to_dict()} + ) + return res + + def _to_ordered_dict_for_yaml_dump(self) -> Dict: + """Dump the component content into a sorted yaml string. + + :return: The ordered dict + :rtype: Dict + """ + + obj: dict = super()._to_ordered_dict_for_yaml_dump() + # dict dumped base on schema will transfer code to an absolute path, while we want to keep its original value + if self.code and isinstance(self.code, str): + obj["code"] = self.code + return obj + + def _get_environment_id(self) -> Union[str, None]: + # Return environment id of environment + # handle case when environment is defined inline + if isinstance(self.environment, Environment): + res: Optional[str] = self.environment.id + return res + return self.environment + + def __str__(self) -> str: + try: + toYaml: str = self._to_yaml() + return toYaml + except BaseException: # pylint: disable=W0718 + toStr: str = super(SparkComponent, self).__str__() + return toStr diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_aml_compute_node_info.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_aml_compute_node_info.py new file mode 100644 index 00000000..823a89ca --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_aml_compute_node_info.py @@ -0,0 +1,50 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Dict, Optional + +from azure.ai.ml._restclient.v2022_10_01_preview.models import AmlComputeNodeInformation +from azure.ai.ml._schema.compute.aml_compute_node_info import AmlComputeNodeInfoSchema +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY + + +class AmlComputeNodeInfo: + """Compute node information related to AmlCompute.""" + + def __init__(self) -> None: + self.node_id = None + self.private_ip_address = None + self.public_ip_address = None + self.port = None + self.node_state = None + self.run_id: Optional[str] = None + + @property + def current_job_name(self) -> Optional[str]: + """The run ID of the current job. + + :return: The run ID of the current job. + :rtype: str + """ + return self.run_id + + @current_job_name.setter + def current_job_name(self, value: str) -> None: + """Set the current job run ID. + + :param value: The job run ID. + :type value: str + """ + self.run_id = value + + @classmethod + def _from_rest_object(cls, rest_obj: AmlComputeNodeInformation) -> "AmlComputeNodeInfo": + result = cls() + result.__dict__.update(rest_obj.as_dict()) + return result + + def _to_dict(self) -> Dict: + # pylint: disable=no-member + res: dict = AmlComputeNodeInfoSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_custom_applications.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_custom_applications.py new file mode 100644 index 00000000..2ee65e7f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_custom_applications.py @@ -0,0 +1,221 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint: disable=protected-access,redefined-builtin + +from typing import Any, Dict, List, Optional + +from azure.ai.ml._restclient.v2022_10_01_preview.models import CustomService, Docker +from azure.ai.ml._restclient.v2022_10_01_preview.models import Endpoint as RestEndpoint +from azure.ai.ml._restclient.v2022_10_01_preview.models import EnvironmentVariable as RestEnvironmentVariable +from azure.ai.ml._restclient.v2022_10_01_preview.models import EnvironmentVariableType as RestEnvironmentVariableType +from azure.ai.ml._restclient.v2022_10_01_preview.models import Image as RestImage +from azure.ai.ml._restclient.v2022_10_01_preview.models import ImageType as RestImageType +from azure.ai.ml._restclient.v2022_10_01_preview.models import Protocol +from azure.ai.ml._restclient.v2022_10_01_preview.models import VolumeDefinition as RestVolumeDefinition +from azure.ai.ml._restclient.v2022_10_01_preview.models import VolumeDefinitionType as RestVolumeDefinitionType +from azure.ai.ml.constants._compute import DUPLICATE_APPLICATION_ERROR, INVALID_VALUE_ERROR, CustomApplicationDefaults +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException + + +class ImageSettings: + """Specifies an image configuration for a Custom Application. + + :param reference: Image reference URL. + :type reference: str + """ + + def __init__(self, *, reference: str): + self.reference = reference + + def _to_rest_object(self) -> RestImage: + return RestImage(type=RestImageType.DOCKER, reference=self.reference) + + @classmethod + def _from_rest_object(cls, obj: RestImage) -> "ImageSettings": + return ImageSettings(reference=obj.reference) + + +class EndpointsSettings: + """Specifies an endpoint configuration for a Custom Application. + + :param target: Application port inside the container. + :type target: int + :param published: Port over which the application is exposed from container. + :type published: int + """ + + def __init__(self, *, target: int, published: int): + EndpointsSettings._validate_endpoint_settings(target=target, published=published) + self.target = target + self.published = published + + def _to_rest_object(self) -> RestEndpoint: + return RestEndpoint( + name=CustomApplicationDefaults.ENDPOINT_NAME, + target=self.target, + published=self.published, + protocol=Protocol.HTTP, + ) + + @classmethod + def _from_rest_object(cls, obj: RestEndpoint) -> "EndpointsSettings": + return EndpointsSettings(target=obj.target, published=obj.published) + + @classmethod + def _validate_endpoint_settings(cls, target: int, published: int) -> None: + ports = { + CustomApplicationDefaults.TARGET_PORT: target, + CustomApplicationDefaults.PUBLISHED_PORT: published, + } + min_value = CustomApplicationDefaults.PORT_MIN_VALUE + max_value = CustomApplicationDefaults.PORT_MAX_VALUE + + for port_name, port in ports.items(): + message = INVALID_VALUE_ERROR.format(port_name, min_value, max_value) + if not min_value < port < max_value: + raise ValidationException( + message=message, + target=ErrorTarget.COMPUTE, + no_personal_data_message=message, + error_category=ErrorCategory.USER_ERROR, + ) + + +class VolumeSettings: + """Specifies the Bind Mount settings for a Custom Application. + + :param source: The host path of the mount. + :type source: str + :param target: The path in the container for the mount. + :type target: str + """ + + def __init__(self, *, source: str, target: str): + self.source = source + self.target = target + + def _to_rest_object(self) -> RestVolumeDefinition: + return RestVolumeDefinition( + type=RestVolumeDefinitionType.BIND, + read_only=False, + source=self.source, + target=self.target, + ) + + @classmethod + def _from_rest_object(cls, obj: RestVolumeDefinition) -> "VolumeSettings": + return VolumeSettings(source=obj.source, target=obj.target) + + +class CustomApplications: + """Specifies the custom service application configuration. + + :param name: Name of the Custom Application. + :type name: str + :param image: Describes the Image Specifications. + :type image: ImageSettings + :param type: Type of the Custom Application. + :type type: Optional[str] + :param endpoints: Configuring the endpoints for the container. + :type endpoints: List[EndpointsSettings] + :param environment_variables: Environment Variables for the container. + :type environment_variables: Optional[Dict[str, str]] + :param bind_mounts: Configuration of the bind mounts for the container. + :type bind_mounts: Optional[List[VolumeSettings]] + """ + + def __init__( + self, + *, + name: str, + image: ImageSettings, + type: str = CustomApplicationDefaults.DOCKER, + endpoints: List[EndpointsSettings], + environment_variables: Optional[Dict] = None, + bind_mounts: Optional[List[VolumeSettings]] = None, + **kwargs: Any + ): + self.name = name + self.type = type + self.image = image + self.endpoints = endpoints + self.environment_variables = environment_variables + self.bind_mounts = bind_mounts + self.additional_properties = kwargs + + def _to_rest_object(self) -> CustomService: + endpoints = None + if self.endpoints: + endpoints = [endpoint._to_rest_object() for endpoint in self.endpoints] + + environment_variables = None + if self.environment_variables: + environment_variables = { + name: RestEnvironmentVariable(type=RestEnvironmentVariableType.LOCAL, value=value) + for name, value in self.environment_variables.items() + } + + volumes = None + if self.bind_mounts: + volumes = [volume._to_rest_object() for volume in self.bind_mounts] + + return CustomService( + name=self.name, + image=self.image._to_rest_object(), + endpoints=endpoints, + environment_variables=environment_variables, + volumes=volumes, + docker=Docker(privileged=True), + additional_properties={**{"type": self.type}, **self.additional_properties}, + ) + + @classmethod + def _from_rest_object(cls, obj: CustomService) -> "CustomApplications": + endpoints = [] + for endpoint in obj.endpoints: + endpoints.append(EndpointsSettings._from_rest_object(endpoint)) + + environment_variables = ( + {name: value.value for name, value in obj.environment_variables.items()} + if obj.environment_variables + else None + ) + + bind_mounts = [] + if obj.volumes: + for volume in obj.volumes: + bind_mounts.append(VolumeSettings._from_rest_object(volume)) + + return CustomApplications( + name=obj.name, + image=ImageSettings._from_rest_object(obj.image), + endpoints=endpoints, + environment_variables=environment_variables, + bind_mounts=bind_mounts, + type=obj.additional_properties.pop("type", CustomApplicationDefaults.DOCKER), + **obj.additional_properties, + ) + + +def validate_custom_applications(custom_apps: List[CustomApplications]) -> None: + message = DUPLICATE_APPLICATION_ERROR + + names = [app.name for app in custom_apps] + if len(set(names)) != len(names): + raise ValidationException( + message=message.format("application_name"), + target=ErrorTarget.COMPUTE, + no_personal_data_message=message.format("application_name"), + error_category=ErrorCategory.USER_ERROR, + ) + + published_ports = [endpoint.published for app in custom_apps for endpoint in app.endpoints] + + if len(set(published_ports)) != len(published_ports): + raise ValidationException( + message=message.format("published_port"), + target=ErrorTarget.COMPUTE, + no_personal_data_message=message.format("published_port"), + error_category=ErrorCategory.USER_ERROR, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_image_metadata.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_image_metadata.py new file mode 100644 index 00000000..342e4a97 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_image_metadata.py @@ -0,0 +1,63 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from typing import Optional + + +class ImageMetadata: + """Metadata about the operating system image for the compute instance. + + :param is_latest_os_image_version: Specifies if the compute instance is running on the latest OS image version. + :type is_latest_os_image_version: bool + :param current_image_version: Version of the current image. + :type current_image_version: str + :param latest_image_version: The latest image version. + :type latest_image_version: str + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_compute.py + :start-after: [START image_metadata] + :end-before: [END image_metadata] + :language: python + :dedent: 8 + :caption: Creating a ImageMetadata object. + """ + + def __init__( + self, + *, + is_latest_os_image_version: Optional[bool], + current_image_version: Optional[str], + latest_image_version: Optional[str] + ) -> None: + self._is_latest_os_image_version = is_latest_os_image_version + self._current_image_version = current_image_version + self._latest_image_version = latest_image_version + + @property + def is_latest_os_image_version(self) -> Optional[bool]: + """Whether or not a compute instance is running on the latest OS image version. + + :return: Boolean indicating if the compute instance is running the latest OS image version. + :rtype: bool + """ + return self._is_latest_os_image_version + + @property + def current_image_version(self) -> Optional[str]: + """The current OS image version number. + + :return: The current OS image version number. + :rtype: str + """ + return self._current_image_version + + @property + def latest_image_version(self) -> Optional[str]: + """The latest OS image version number. + + :return: The latest OS image version number. + :rtype: str + """ + return self._latest_image_version diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_schedule.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_schedule.py new file mode 100644 index 00000000..3616a5cc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_schedule.py @@ -0,0 +1,153 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint: disable=protected-access +from typing import Any, List, Optional, Union + +from azure.ai.ml._restclient.v2022_10_01_preview.models import ComputePowerAction +from azure.ai.ml._restclient.v2022_10_01_preview.models import ComputeSchedules as RestComputeSchedules +from azure.ai.ml._restclient.v2022_10_01_preview.models import ComputeStartStopSchedule as RestComputeStartStopSchedule +from azure.ai.ml._restclient.v2022_10_01_preview.models import ScheduleStatus as ScheduleState +from azure.ai.ml._restclient.v2022_10_01_preview.models import TriggerType +from azure.ai.ml.entities._mixins import RestTranslatableMixin + +from .._schedule.trigger import CronTrigger, RecurrencePattern, RecurrenceTrigger + + +class ComputeStartStopSchedule(RestTranslatableMixin): + """Schedules for compute start or stop scenario. + + :param trigger: The trigger of the schedule. + :type trigger: Union[~azure.ai.ml.entities.CronTrigger, ~azure.ai.ml.entities.RecurrenceTrigger] + :param action: The compute power action. + :type action: ~azure.ai.ml.entities.ComputePowerAction + :param state: The state of the schedule. + :type state: ~azure.ai.ml.entities.ScheduleState + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_compute.py + :start-after: [START compute_start_stop_schedule] + :end-before: [END compute_start_stop_schedule] + :language: python + :dedent: 8 + :caption: Creating a ComputeStartStopSchedule object. + """ + + def __init__( + self, + *, + trigger: Optional[Union[CronTrigger, RecurrenceTrigger]] = None, + action: Optional[ComputePowerAction] = None, + state: ScheduleState = ScheduleState.ENABLED, + **kwargs: Any + ) -> None: + self.trigger = trigger + self.action = action + self.state = state + self._schedule_id: Optional[str] = kwargs.pop("schedule_id", None) + self._provisioning_state: Optional[str] = kwargs.pop("provisioning_state", None) + + @property + def schedule_id(self) -> Optional[str]: + """The schedule ID. + + :return: The schedule ID. + :rtype: Optional[str] + """ + return self._schedule_id + + @property + def provisioning_state(self) -> Optional[str]: + """The schedule provisioning state. + + :return: The schedule provisioning state. + :rtype: Optional[str] + """ + return self._provisioning_state + + def _to_rest_object(self) -> RestComputeStartStopSchedule: + rest_object = RestComputeStartStopSchedule( + action=self.action, + status=self.state, + ) + + if isinstance(self.trigger, CronTrigger): + rest_object.trigger_type = TriggerType.CRON + rest_object.cron = self.trigger._to_rest_compute_cron_object() + elif isinstance(self.trigger, RecurrenceTrigger): + rest_object.trigger_type = TriggerType.RECURRENCE + rest_object.recurrence = self.trigger._to_rest_compute_recurrence_object() + + return rest_object + + @classmethod + def _from_rest_object(cls, obj: RestComputeStartStopSchedule) -> "ComputeStartStopSchedule": + schedule = ComputeStartStopSchedule( + action=obj.action, + state=obj.status, + schedule_id=obj.id, + provisioning_state=obj.provisioning_status, + ) + + if obj.trigger_type == TriggerType.CRON: + schedule.trigger = CronTrigger( + start_time=obj.cron.start_time, + time_zone=obj.cron.time_zone, + expression=obj.cron.expression, + ) + elif obj.trigger_type == TriggerType.RECURRENCE: + schedule.trigger = RecurrenceTrigger( + start_time=obj.recurrence.start_time, + time_zone=obj.recurrence.time_zone, + frequency=obj.recurrence.frequency, + interval=obj.recurrence.interval, + schedule=RecurrencePattern._from_rest_object(obj.recurrence.schedule), + ) + + return schedule + + +class ComputeSchedules(RestTranslatableMixin): + """Compute schedules. + + :param compute_start_stop: Compute start or stop schedules. + :type compute_start_stop: List[~azure.ai.ml.entities.ComputeStartStopSchedule] + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_compute.py + :start-after: [START compute_start_stop_schedule] + :end-before: [END compute_start_stop_schedule] + :language: python + :dedent: 8 + :caption: Creating a ComputeSchedules object. + """ + + def __init__(self, *, compute_start_stop: Optional[List[ComputeStartStopSchedule]] = None) -> None: + self.compute_start_stop = compute_start_stop + + def _to_rest_object(self) -> RestComputeSchedules: + rest_schedules: List[RestComputeStartStopSchedule] = [] + if self.compute_start_stop: + for schedule in self.compute_start_stop: + rest_schedules.append(schedule._to_rest_object()) + + return RestComputeSchedules( + compute_start_stop=rest_schedules, + ) + + @classmethod + def _from_rest_object(cls, obj: RestComputeSchedules) -> "ComputeSchedules": + schedules: List[ComputeStartStopSchedule] = [] + if obj.compute_start_stop: + for schedule in obj.compute_start_stop: + schedules.append(ComputeStartStopSchedule._from_rest_object(schedule)) + + return ComputeSchedules( + compute_start_stop=schedules, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_setup_scripts.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_setup_scripts.py new file mode 100644 index 00000000..d2e12fd4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_setup_scripts.py @@ -0,0 +1,90 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint: disable=protected-access + +import re +from typing import Optional, cast + +from azure.ai.ml._restclient.v2022_10_01_preview.models import ScriptReference as RestScriptReference +from azure.ai.ml._restclient.v2022_10_01_preview.models import ScriptsToExecute as RestScriptsToExecute +from azure.ai.ml._restclient.v2022_10_01_preview.models import SetupScripts as RestSetupScripts +from azure.ai.ml.entities._mixins import RestTranslatableMixin + + +class ScriptReference(RestTranslatableMixin): + """Script reference. + + :keyword path: The location of scripts in workspace storage. + :paramtype path: Optional[str] + :keyword command: Command line arguments passed to the script to run. + :paramtype command: Optional[str] + :keyword timeout_minutes: Timeout, in minutes, for the script to run. + :paramtype timeout_minutes: Optional[int] + """ + + def __init__( + self, *, path: Optional[str] = None, command: Optional[str] = None, timeout_minutes: Optional[int] = None + ) -> None: + self.path = path + self.command = command + self.timeout_minutes = timeout_minutes + + def _to_rest_object(self) -> RestScriptReference: + return RestScriptReference( + script_source="workspaceStorage", + script_data=self.path, + script_arguments=self.command, + timeout=f"{self.timeout_minutes}m", + ) + + @classmethod + def _from_rest_object(cls, obj: RestScriptReference) -> Optional["ScriptReference"]: + if obj is None: + return obj + timeout_match = re.match(r"(\d+)m", obj.timeout) if obj.timeout else None + timeout_minutes = timeout_match.group(1) if timeout_match else None + script_reference = ScriptReference( + path=obj.script_data if obj.script_data else None, + command=obj.script_arguments if obj.script_arguments else None, + timeout_minutes=cast(Optional[int], timeout_minutes), + ) + return script_reference + + +class SetupScripts(RestTranslatableMixin): + """Customized setup scripts. + + :keyword startup_script: The script to be run every time the compute is started. + :paramtype startup_script: Optional[~azure.ai.ml.entities.ScriptReference] + :keyword creation_script: The script to be run only when the compute is created. + :paramtype creation_script: Optional[~azure.ai.ml.entities.ScriptReference] + """ + + def __init__( + self, *, startup_script: Optional[ScriptReference] = None, creation_script: Optional[ScriptReference] = None + ) -> None: + self.startup_script = startup_script + self.creation_script = creation_script + + def _to_rest_object(self) -> RestScriptsToExecute: + scripts_to_execute = RestScriptsToExecute( + startup_script=self.startup_script._to_rest_object() if self.startup_script else None, + creation_script=self.creation_script._to_rest_object() if self.creation_script else None, + ) + return RestSetupScripts(scripts=scripts_to_execute) + + @classmethod + def _from_rest_object(cls, obj: RestSetupScripts) -> Optional["SetupScripts"]: + if obj is None or obj.scripts is None: + return None + scripts = obj.scripts + setup_scripts = SetupScripts( + startup_script=ScriptReference._from_rest_object( + scripts.startup_script if scripts.startup_script else None + ), + creation_script=ScriptReference._from_rest_object( + scripts.creation_script if scripts.creation_script else None + ), + ) + return setup_scripts diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_usage.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_usage.py new file mode 100644 index 00000000..6702382e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_usage.py @@ -0,0 +1,100 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from abc import abstractmethod +from os import PathLike +from typing import IO, Any, AnyStr, Dict, Optional, Union + +from azure.ai.ml._restclient.v2022_10_01_preview.models import Usage as RestUsage +from azure.ai.ml._restclient.v2022_10_01_preview.models import UsageUnit +from azure.ai.ml._schema.compute.usage import UsageSchema +from azure.ai.ml._utils.utils import dump_yaml_to_file +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY +from azure.ai.ml.entities._mixins import RestTranslatableMixin + + +class UsageName: + def __init__(self, *, value: Optional[str] = None, localized_value: Optional[str] = None) -> None: + """The usage name. + + :param value: The name of the resource. + :type value: Optional[str] + :param localized_value: The localized name of the resource. + :type localized_value: Optional[str] + """ + self.value = value + self.localized_value = localized_value + + +class Usage(RestTranslatableMixin): + """AzureML resource usage. + + :param id: The resource ID. + :type id: Optional[str] + :param aml_workspace_location: The region of the AzureML workspace specified by the ID. + :type aml_workspace_location: Optional[str] + :param type: The resource type. + :type type: Optional[str] + :param unit: The unit of measurement for usage. Accepted value is "Count". + :type unit: Optional[Union[str, ~azure.ai.ml.entities.UsageUnit]] + :param current_value: The current usage of the resource. + :type current_value: Optional[int] + :param limit: The maximum permitted usage for the resource. + :type limit: Optional[int] + :param name: The name of the usage type. + :type name: Optional[~azure.ai.ml.entities.UsageName] + """ + + def __init__( + self, + id: Optional[str] = None, # pylint: disable=redefined-builtin + aml_workspace_location: Optional[str] = None, + type: Optional[str] = None, # pylint: disable=redefined-builtin + unit: Optional[Union[str, UsageUnit]] = None, # enum + current_value: Optional[int] = None, + limit: Optional[int] = None, + name: Optional[UsageName] = None, + ) -> None: + self.id = id + self.aml_workspace_location = aml_workspace_location + self.type = type + self.unit = unit + self.current_value = current_value + self.limit = limit + self.name = name + + @classmethod + def _from_rest_object(cls, obj: RestUsage) -> "Usage": + result = cls() + result.__dict__.update(obj.as_dict()) + return result + + def dump(self, dest: Union[str, PathLike, IO[AnyStr]], **kwargs: Any) -> None: + """Dumps the job content into a file in YAML format. + + :param dest: The local path or file stream to write the YAML content to. + If dest is a file path, a new file will be created. + If dest is an open file, the file will be written to directly. + :type dest: Union[PathLike, str, IO[AnyStr]] + :raises: FileExistsError if dest is a file path and the file already exists. + :raises: IOError if dest is an open file and the file is not writable. + """ + path = kwargs.pop("path", None) + yaml_serialized = self._to_dict() + dump_yaml_to_file(dest, yaml_serialized, default_flow_style=False, path=path, **kwargs) + + def _to_dict(self) -> Dict: + # pylint: disable=no-member + res: dict = UsageSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + @classmethod + @abstractmethod + def _load( + cls, + path: Union[PathLike, str], + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "Usage": + pass diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_vm_size.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_vm_size.py new file mode 100644 index 00000000..2f0049f0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_vm_size.py @@ -0,0 +1,104 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from abc import abstractmethod +from os import PathLike +from typing import IO, Any, AnyStr, Dict, List, Optional, Union + +from azure.ai.ml._restclient.v2022_10_01_preview.models import VirtualMachineSize +from azure.ai.ml._schema.compute.vm_size import VmSizeSchema +from azure.ai.ml._utils.utils import dump_yaml_to_file +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY +from azure.ai.ml.entities._mixins import RestTranslatableMixin + + +class VmSize(RestTranslatableMixin): + """Virtual Machine Size. + + :param name: The virtual machine size name. + :type name: Optional[str] + :param family: The virtual machine size family name. + :type family: Optional[str] + :param v_cp_us: The number of vCPUs supported by the virtual machine size. + :type v_cp_us: Optional[int] + :param gpus: The number of GPUs supported by the virtual machine size. + :type gpus: Optional[int] + :param os_vhd_size_mb: The OS VHD disk size, in MB, allowed by the virtual machine size. + :type os_vhd_size_mb: Optional[int] + :param max_resource_volume_mb: The resource volume size, in MB, allowed by the virtual machine + size. + :type max_resource_volume_mb: Optional[int] + :param memory_gb: The amount of memory, in GB, supported by the virtual machine size. + :type memory_gb: Optional[float] + :param low_priority_capable: Specifies if the virtual machine size supports low priority VMs. + :type low_priority_capable: Optional[bool] + :param premium_io: Specifies if the virtual machine size supports premium IO. + :type premium_io: Optional[bool] + :param estimated_vm_prices: The estimated price information for using a VM. + :type estimated_vm_prices: ~azure.mgmt.machinelearningservices.models.EstimatedVMPrices + :param supported_compute_types: Specifies the compute types supported by the virtual machine + size. + :type supported_compute_types: Optional[list[str]] + """ + + def __init__( + self, + name: Optional[str] = None, + family: Optional[str] = None, + v_cp_us: Optional[int] = None, + gpus: Optional[int] = None, + os_vhd_size_mb: Optional[int] = None, + max_resource_volume_mb: Optional[int] = None, + memory_gb: Optional[float] = None, + low_priority_capable: Optional[bool] = None, + premium_io: Optional[bool] = None, + supported_compute_types: Optional[List[str]] = None, + ) -> None: + self.name = name + self.family = family + self.v_cp_us = v_cp_us + self.gpus = gpus + self.os_vhd_size_mb = os_vhd_size_mb + self.max_resource_volume_mb = max_resource_volume_mb + self.memory_gb = memory_gb + self.low_priority_capable = low_priority_capable + self.premium_io = premium_io + self.supported_compute_types = ",".join(map(str, supported_compute_types)) if supported_compute_types else None + + @classmethod + def _from_rest_object(cls, obj: VirtualMachineSize) -> "VmSize": + result = cls() + result.__dict__.update(obj.as_dict()) + return result + + def dump(self, dest: Union[str, PathLike, IO[AnyStr]], **kwargs: Any) -> None: + """Dump the virtual machine size content into a file in yaml format. + + :param dest: The destination to receive this virtual machine size's content. + Must be either a path to a local file, or an already-open file stream. + If dest is a file path, a new file will be created, + and an exception is raised if the file exists. + If dest is an open file, the file will be written to directly, + and an exception will be raised if the file is not writable. + :type dest: Union[PathLike, str, IO[AnyStr]] + """ + + path = kwargs.pop("path", None) + yaml_serialized = self._to_dict() + dump_yaml_to_file(dest, yaml_serialized, default_flow_style=False, path=path, **kwargs) + + def _to_dict(self) -> Dict: + # pylint: disable=no-member + res: dict = VmSizeSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + @classmethod + @abstractmethod + def _load( + cls, + path: Union[PathLike, str], + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "VmSize": + pass diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/aml_compute.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/aml_compute.py new file mode 100644 index 00000000..3ec7c10f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/aml_compute.py @@ -0,0 +1,281 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access,too-many-instance-attributes + +from typing import Any, Dict, Optional + +from azure.ai.ml._restclient.v2022_12_01_preview.models import ( + AmlCompute as AmlComputeRest, +) +from azure.ai.ml._restclient.v2022_12_01_preview.models import ( + AmlComputeProperties, + ComputeResource, + ResourceId, + ScaleSettings, + UserAccountCredentials, +) +from azure.ai.ml._schema._utils.utils import get_subnet_str +from azure.ai.ml._schema.compute.aml_compute import AmlComputeSchema +from azure.ai.ml._utils.utils import ( + camel_to_snake, + snake_to_pascal, + to_iso_duration_format, +) +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE +from azure.ai.ml.constants._compute import ComputeDefaults, ComputeType +from azure.ai.ml.entities._credentials import IdentityConfiguration +from azure.ai.ml.entities._util import load_from_dict + +from .compute import Compute, NetworkSettings + + +class AmlComputeSshSettings: + """SSH settings to access a AML compute target. + + :param admin_username: SSH user name. + :type admin_username: str + :param admin_password: SSH user password. Defaults to None. + :type admin_password: str + :param ssh_key_value: The SSH RSA private key. Use "ssh-keygen -t + rsa -b 2048" to generate your SSH key pairs. Defaults to None. + :type ssh_key_value: Optional[str] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_compute.py + :start-after: [START aml_compute_ssh_settings] + :end-before: [END aml_compute_ssh_settings] + :language: python + :dedent: 8 + :caption: Configuring an AmlComputeSshSettings object. + """ + + def __init__( + self, + *, + admin_username: str, + admin_password: Optional[str] = None, + ssh_key_value: Optional[str] = None, + ) -> None: + self.admin_username = admin_username + self.admin_password = admin_password + self.ssh_key_value = ssh_key_value + + def _to_user_account_credentials(self) -> UserAccountCredentials: + return UserAccountCredentials( + admin_user_name=self.admin_username, + admin_user_password=self.admin_password, + admin_user_ssh_public_key=self.ssh_key_value, + ) + + @classmethod + def _from_user_account_credentials(cls, credentials: UserAccountCredentials) -> "AmlComputeSshSettings": + return cls( + admin_username=credentials.admin_user_name, + admin_password=credentials.admin_user_password, + ssh_key_value=credentials.admin_user_ssh_public_key, + ) + + +class AmlCompute(Compute): + """AzureML Compute resource. + + :param name: Name of the compute resource. + :type name: str + :param description: Description of the compute resource. + :type description: Optional[str] + :param size: Size of the compute. Defaults to None. + :type size: Optional[str] + :param tags: A set of tags. Contains resource tags defined as key/value pairs. + :type tags: Optional[dict[str, str]] + :param ssh_settings: SSH settings to access the AzureML compute cluster. + :type ssh_settings: Optional[~azure.ai.ml.entities.AmlComputeSshSettings] + :param network_settings: Virtual network settings for the AzureML compute cluster. + :type network_settings: Optional[~azure.ai.ml.entities.NetworkSettings] + :param idle_time_before_scale_down: Node idle time before scaling down. Defaults to None. + :type idle_time_before_scale_down: Optional[int] + :param identity: The identities that are associated with the compute cluster. + :type identity: Optional[~azure.ai.ml.entities.IdentityConfiguration] + :param tier: Virtual Machine tier. Accepted values include: "Dedicated", "LowPriority". Defaults to None. + :type tier: Optional[str] + :param min_instances: Minimum number of instances. Defaults to None. + :type min_instances: Optional[int] + :param max_instances: Maximum number of instances. Defaults to None. + :type max_instances: Optional[int] + :param ssh_public_access_enabled: State of the public SSH port. Accepted values are: + * False - Indicates that the public SSH port is closed on all nodes of the cluster. + * True - Indicates that the public SSH port is open on all nodes of the cluster. + * None - Indicates that the public SSH port is closed on all nodes of the cluster if VNet is defined, + else is open all public nodes. + It can be None only during cluster creation time. After creation it will be either True or False. + Defaults to None. + :type ssh_public_access_enabled: Optional[bool] + :param enable_node_public_ip: Enable or disable node public IP address provisioning. Accepted values are: + * True - Indicates that the compute nodes will have public IPs provisioned. + * False - Indicates that the compute nodes will have a private endpoint and no public IPs. + Defaults to True. + :type enable_node_public_ip: bool + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_compute.py + :start-after: [START amlcompute] + :end-before: [END amlcompute] + :language: python + :dedent: 8 + :caption: Creating an AmlCompute object. + """ + + def __init__( + self, + *, + name: str, + description: Optional[str] = None, + size: Optional[str] = None, + tags: Optional[dict] = None, + ssh_public_access_enabled: Optional[bool] = None, + ssh_settings: Optional[AmlComputeSshSettings] = None, + min_instances: Optional[int] = None, + max_instances: Optional[int] = None, + network_settings: Optional[NetworkSettings] = None, + idle_time_before_scale_down: Optional[int] = None, + identity: Optional[IdentityConfiguration] = None, + tier: Optional[str] = None, + enable_node_public_ip: bool = True, + **kwargs: Any, + ) -> None: + kwargs[TYPE] = ComputeType.AMLCOMPUTE + super().__init__( + name=name, + description=description, + location=kwargs.pop("location", None), + tags=tags, + **kwargs, + ) + self.size = size + self.min_instances = min_instances or 0 + self.max_instances = max_instances or 1 + self.idle_time_before_scale_down = idle_time_before_scale_down + self.identity = identity + self.ssh_public_access_enabled = ssh_public_access_enabled + self.ssh_settings = ssh_settings + self.network_settings = network_settings + self.tier = tier + self.enable_node_public_ip = enable_node_public_ip + self.subnet = None + + @classmethod + def _load_from_rest(cls, rest_obj: ComputeResource) -> "AmlCompute": + prop = rest_obj.properties + + network_settings = None + if prop.properties.subnet or (prop.properties.enable_node_public_ip is not None): + network_settings = NetworkSettings( + subnet=prop.properties.subnet.id if prop.properties.subnet else None, + ) + + ssh_settings = ( + AmlComputeSshSettings._from_user_account_credentials(prop.properties.user_account_credentials) + if prop.properties.user_account_credentials + else None + ) + + response = AmlCompute( + name=rest_obj.name, + id=rest_obj.id, + description=prop.description, + location=(prop.compute_location if prop.compute_location else rest_obj.location), + tags=rest_obj.tags if rest_obj.tags else None, + provisioning_state=prop.provisioning_state, + provisioning_errors=( + prop.provisioning_errors[0].error.code + if (prop.provisioning_errors and len(prop.provisioning_errors) > 0) + else None + ), + size=prop.properties.vm_size, + tier=camel_to_snake(prop.properties.vm_priority), + min_instances=(prop.properties.scale_settings.min_node_count if prop.properties.scale_settings else None), + max_instances=(prop.properties.scale_settings.max_node_count if prop.properties.scale_settings else None), + network_settings=network_settings or None, + ssh_settings=ssh_settings, + ssh_public_access_enabled=(prop.properties.remote_login_port_public_access == "Enabled"), + idle_time_before_scale_down=( + prop.properties.scale_settings.node_idle_time_before_scale_down.total_seconds() + if prop.properties.scale_settings and prop.properties.scale_settings.node_idle_time_before_scale_down + else None + ), + identity=( + IdentityConfiguration._from_compute_rest_object(rest_obj.identity) if rest_obj.identity else None + ), + created_on=prop.additional_properties.get("createdOn", None), + enable_node_public_ip=( + prop.properties.enable_node_public_ip if prop.properties.enable_node_public_ip is not None else True + ), + ) + return response + + def _set_full_subnet_name(self, subscription_id: str, rg: str) -> None: + if self.network_settings: + self.subnet = get_subnet_str( + self.network_settings.vnet_name, + self.network_settings.subnet, + subscription_id, + rg, + ) + + def _to_dict(self) -> Dict: + res: dict = AmlComputeSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + @classmethod + def _load_from_dict(cls, data: Dict, context: Dict, **kwargs: Any) -> "AmlCompute": + loaded_data = load_from_dict(AmlComputeSchema, data, context, **kwargs) + return AmlCompute(**loaded_data) + + def _to_rest_object(self) -> ComputeResource: + if self.network_settings and self.network_settings.subnet: + subnet_resource = ResourceId(id=self.subnet) + else: + subnet_resource = None + + # Scale settings is required when creating an AzureML compute cluster. + scale_settings = ScaleSettings( + max_node_count=self.max_instances, + min_node_count=self.min_instances, + node_idle_time_before_scale_down=( + to_iso_duration_format(int(self.idle_time_before_scale_down)) + if self.idle_time_before_scale_down + else None + ), + ) + remote_login_public_access = "Enabled" + disableLocalAuth = not (self.ssh_public_access_enabled and self.ssh_settings is not None) + if self.ssh_public_access_enabled is not None: + remote_login_public_access = "Enabled" if self.ssh_public_access_enabled else "Disabled" + + else: + remote_login_public_access = "NotSpecified" + aml_prop = AmlComputeProperties( + vm_size=self.size if self.size else ComputeDefaults.VMSIZE, + vm_priority=snake_to_pascal(self.tier), + user_account_credentials=(self.ssh_settings._to_user_account_credentials() if self.ssh_settings else None), + scale_settings=scale_settings, + subnet=subnet_resource, + remote_login_port_public_access=remote_login_public_access, + enable_node_public_ip=self.enable_node_public_ip, + ) + + aml_comp = AmlComputeRest( + description=self.description, + compute_type=self.type, + properties=aml_prop, + disable_local_auth=disableLocalAuth, + ) + return ComputeResource( + location=self.location, + properties=aml_comp, + identity=(self.identity._to_compute_rest_object() if self.identity else None), + tags=self.tags, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/compute.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/compute.py new file mode 100644 index 00000000..de18da5a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/compute.py @@ -0,0 +1,261 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from abc import abstractmethod +from os import PathLike +from pathlib import Path +from typing import IO, Any, AnyStr, Dict, Optional, Union, cast + +from azure.ai.ml._restclient.v2022_10_01_preview.models import ComputeResource +from azure.ai.ml._schema.compute.compute import ComputeSchema +from azure.ai.ml._utils.utils import dump_yaml_to_file +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY, CommonYamlFields +from azure.ai.ml.constants._compute import ComputeType +from azure.ai.ml.entities._mixins import RestTranslatableMixin +from azure.ai.ml.entities._resource import Resource +from azure.ai.ml.entities._util import find_type_in_override +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException + + +class Compute(Resource, RestTranslatableMixin): + """Base class for compute resources. + + This class should not be instantiated directly. Instead, use one of its subclasses. + + :param type: The compute type. Accepted values are "amlcompute", "computeinstance", + "virtualmachine", "kubernetes", and "synapsespark". + :type type: str + :param name: Name of the compute resource. + :type name: str + :param location: The resource location. Defaults to workspace location. + :type location: Optional[str] + :param description: Description of the resource. Defaults to None. + :type description: Optional[str] + :param resource_id: ARM resource id of the underlying compute. Defaults to None. + :type resource_id: Optional[str] + :param tags: A set of tags. Contains resource tags defined as key/value pairs. + :type tags: Optional[dict[str, str]] + """ + + def __init__( + self, + name: str, + location: Optional[str] = None, + description: Optional[str] = None, + resource_id: Optional[str] = None, + tags: Optional[Dict] = None, + **kwargs: Any, + ) -> None: + self._type: Optional[str] = kwargs.pop("type", None) + if self._type: + self._type = self._type.lower() + + self._created_on: Optional[str] = kwargs.pop("created_on", None) + self._provisioning_state: Optional[str] = kwargs.pop("provisioning_state", None) + self._provisioning_errors: Optional[str] = kwargs.pop("provisioning_errors", None) + + super().__init__(name=name, description=description, **kwargs) + self.resource_id = resource_id + self.location = location + self.tags = tags + + @property + def type(self) -> Optional[str]: + """The compute type. + + :return: The compute type. + :rtype: Optional[str] + """ + return self._type + + @property + def created_on(self) -> Optional[str]: + """The compute resource creation timestamp. + + :return: The compute resource creation timestamp. + :rtype: Optional[str] + """ + return self._created_on + + @property + def provisioning_state(self) -> Optional[str]: + """The compute resource's provisioning state. + + :return: The compute resource's provisioning state. + :rtype: Optional[str] + """ + return self._provisioning_state + + @property + def provisioning_errors(self) -> Optional[str]: + """The compute resource provisioning errors. + + :return: The compute resource provisioning errors. + :rtype: Optional[str] + """ + return self._provisioning_errors + + def _to_rest_object(self) -> ComputeResource: + pass + + @classmethod + def _from_rest_object(cls, obj: ComputeResource) -> "Compute": + from azure.ai.ml.entities import ( + AmlCompute, + ComputeInstance, + KubernetesCompute, + SynapseSparkCompute, + UnsupportedCompute, + VirtualMachineCompute, + ) + + mapping = { + ComputeType.AMLCOMPUTE.lower(): AmlCompute, + ComputeType.COMPUTEINSTANCE.lower(): ComputeInstance, + ComputeType.VIRTUALMACHINE.lower(): VirtualMachineCompute, + ComputeType.KUBERNETES.lower(): KubernetesCompute, + ComputeType.SYNAPSESPARK.lower(): SynapseSparkCompute, + } + compute_type = obj.properties.compute_type.lower() if obj.properties.compute_type else None + + class_type = cast( + Optional[Union[AmlCompute, ComputeInstance, VirtualMachineCompute, KubernetesCompute, SynapseSparkCompute]], + mapping.get(compute_type, None), # type: ignore + ) + if class_type: + return class_type._load_from_rest(obj) + _unsupported_from_rest: Compute = UnsupportedCompute._load_from_rest(obj) + return _unsupported_from_rest + + @classmethod + @abstractmethod + def _load_from_rest(cls, rest_obj: ComputeResource) -> "Compute": + pass + + def _set_full_subnet_name(self, subscription_id: str, rg: str) -> None: + pass + + def dump(self, dest: Union[str, PathLike, IO[AnyStr]], **kwargs: Any) -> None: + """Dump the compute content into a file in yaml format. + + :param dest: The destination to receive this compute's content. + Must be either a path to a local file, or an already-open file stream. + If dest is a file path, a new file will be created, + and an exception is raised if the file exists. + If dest is an open file, the file will be written to directly, + and an exception will be raised if the file is not writable.'. + :type dest: Union[PathLike, str, IO[AnyStr]] + """ + path = kwargs.pop("path", None) + yaml_serialized = self._to_dict() + dump_yaml_to_file(dest, yaml_serialized, default_flow_style=False, path=path, **kwargs) + + def _to_dict(self) -> Dict: + res: dict = ComputeSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + @classmethod + def _load( + cls, + data: Optional[Dict] = None, + yaml_path: Optional[Union[PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "Compute": + data = data or {} + params_override = params_override or [] + context = { + BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"), + PARAMS_OVERRIDE_KEY: params_override, + } + from azure.ai.ml.entities import ( + AmlCompute, + ComputeInstance, + KubernetesCompute, + SynapseSparkCompute, + VirtualMachineCompute, + ) + + type_in_override = find_type_in_override(params_override) if params_override else None + compute_type = type_in_override or data.get(CommonYamlFields.TYPE, None) # override takes the priority + if compute_type: + if compute_type.lower() == ComputeType.VIRTUALMACHINE: + _vm_load_from_dict: Compute = VirtualMachineCompute._load_from_dict(data, context, **kwargs) + return _vm_load_from_dict + if compute_type.lower() == ComputeType.AMLCOMPUTE: + _aml_load_from_dict: Compute = AmlCompute._load_from_dict(data, context, **kwargs) + return _aml_load_from_dict + if compute_type.lower() == ComputeType.COMPUTEINSTANCE: + _compute_instance_load_from_dict: Compute = ComputeInstance._load_from_dict(data, context, **kwargs) + return _compute_instance_load_from_dict + if compute_type.lower() == ComputeType.KUBERNETES: + _kub_load_from_dict: Compute = KubernetesCompute._load_from_dict(data, context, **kwargs) + return _kub_load_from_dict + if compute_type.lower() == ComputeType.SYNAPSESPARK: + _synapse_spark_load_from_dict: Compute = SynapseSparkCompute._load_from_dict(data, context, **kwargs) + return _synapse_spark_load_from_dict + msg = f"Unknown compute type: {compute_type}" + raise ValidationException( + message=msg, + target=ErrorTarget.COMPUTE, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + ) + + @classmethod + @abstractmethod + def _load_from_dict(cls, data: Dict, context: Dict, **kwargs: Any) -> "Compute": + pass + + +class NetworkSettings: + """Network settings for a compute resource. If the workspace and VNet are in different resource groups, + please provide the full URI for subnet and leave vnet_name as None. + + :param vnet_name: The virtual network name. + :type vnet_name: Optional[str] + :param subnet: The subnet name. + :type subnet: Optional[str] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_compute.py + :start-after: [START network_settings] + :end-before: [END network_settings] + :language: python + :dedent: 8 + :caption: Configuring NetworkSettings for an AmlCompute object. + """ + + def __init__( + self, + *, + vnet_name: Optional[str] = None, + subnet: Optional[str] = None, + **kwargs: Any, + ) -> None: + self.vnet_name = vnet_name + self.subnet = subnet + self._public_ip_address: str = kwargs.pop("public_ip_address", None) + self._private_ip_address: str = kwargs.pop("private_ip_address", None) + + @property + def public_ip_address(self) -> str: + """Public IP address of the compute instance. + + :return: Public IP address. + :rtype: str + """ + return self._public_ip_address + + @property + def private_ip_address(self) -> str: + """Private IP address of the compute instance. + + :return: Private IP address. + :rtype: str + """ + return self._private_ip_address diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/compute_instance.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/compute_instance.py new file mode 100644 index 00000000..9cbb2528 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/compute_instance.py @@ -0,0 +1,511 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access,too-many-instance-attributes + +import logging +import re +import warnings +from typing import Any, Dict, List, Optional + +from azure.ai.ml._restclient.v2022_10_01_preview.models import AssignedUser +from azure.ai.ml._restclient.v2023_08_01_preview.models import ComputeInstance as CIRest +from azure.ai.ml._restclient.v2023_08_01_preview.models import ComputeInstanceProperties +from azure.ai.ml._restclient.v2023_08_01_preview.models import ComputeInstanceSshSettings as CiSShSettings +from azure.ai.ml._restclient.v2023_08_01_preview.models import ( + ComputeResource, + PersonalComputeInstanceSettings, + ResourceId, +) +from azure.ai.ml._schema._utils.utils import get_subnet_str +from azure.ai.ml._schema.compute.compute_instance import ComputeInstanceSchema +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE +from azure.ai.ml.constants._compute import ComputeDefaults, ComputeType +from azure.ai.ml.entities._compute.compute import Compute, NetworkSettings +from azure.ai.ml.entities._credentials import IdentityConfiguration +from azure.ai.ml.entities._mixins import DictMixin +from azure.ai.ml.entities._util import load_from_dict + +from ._custom_applications import CustomApplications, validate_custom_applications +from ._image_metadata import ImageMetadata +from ._schedule import ComputeSchedules +from ._setup_scripts import SetupScripts + +module_logger = logging.getLogger(__name__) + + +class ComputeInstanceSshSettings: + """Credentials for an administrator user account to SSH into the compute node. + + Can only be configured if `ssh_public_access_enabled` is set to true on compute + resource. + + :param ssh_key_value: The SSH public key of the administrator user account. + :type ssh_key_value: Optional[str] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_compute.py + :start-after: [START compute_instance_ssh_settings] + :end-before: [END compute_instance_ssh_settings] + :language: python + :dedent: 8 + :caption: Configuring ComputeInstanceSshSettings object. + """ + + def __init__( + self, + *, + ssh_key_value: Optional[str] = None, + **kwargs: Any, + ) -> None: + self.ssh_key_value = ssh_key_value + self._ssh_port: str = kwargs.pop("ssh_port", None) + self._admin_username: str = kwargs.pop("admin_username", None) + + @property + def admin_username(self) -> str: + """The name of the administrator user account which can be used to SSH into nodes. + + :return: The name of the administrator user account. + :rtype: str + """ + return self._admin_username + + @property + def ssh_port(self) -> str: + """SSH port. + + :return: SSH port. + :rtype: str + """ + return self._ssh_port + + +class AssignedUserConfiguration(DictMixin): + """Settings to create a compute resource on behalf of another user. + + :param user_tenant_id: Tenant ID of the user to assign the compute target to. + :type user_tenant_id: str + :param user_object_id: Object ID of the user to assign the compute target to. + :type user_object_id: str + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_compute.py + :start-after: [START assigned_user_configuration] + :end-before: [END assigned_user_configuration] + :language: python + :dedent: 8 + :caption: Creating an AssignedUserConfiguration. + """ + + def __init__(self, *, user_tenant_id: str, user_object_id: str) -> None: + self.user_tenant_id = user_tenant_id + self.user_object_id = user_object_id + + +class ComputeInstance(Compute): + """Compute Instance resource. + + :param name: Name of the compute. + :type name: str + :param location: The resource location. + :type location: Optional[str] + :param description: Description of the resource. + :type description: Optional[str] + :param size: Compute size. + :type size: Optional[str] + :param tags: A set of tags. Contains resource tags defined as key/value pairs. + :type tags: Optional[dict[str, str]] + :param create_on_behalf_of: Configuration to create resource on behalf of another user. Defaults to None. + :type create_on_behalf_of: Optional[~azure.ai.ml.entities.AssignedUserConfiguration] + :ivar state: State of the resource. + :type state: Optional[str] + :ivar last_operation: The last operation. + :type last_operation: Optional[Dict[str, str]] + :ivar applications: Applications associated with the compute instance. + :type applications: Optional[List[Dict[str, str]]] + :param network_settings: Network settings for the compute instance. + :type network_settings: Optional[~azure.ai.ml.entities.NetworkSettings] + :param ssh_settings: SSH settings for the compute instance. + :type ssh_settings: Optional[~azure.ai.ml.entities.ComputeInstanceSshSettings] + :param ssh_public_access_enabled: State of the public SSH port. Defaults to None. + Possible values are: + + * False - Indicates that the public ssh port is closed on all nodes of the cluster. + * True - Indicates that the public ssh port is open on all nodes of the cluster. + * None -Indicates that the public ssh port is closed on all nodes of the cluster if VNet is defined, + else is open all public nodes. It can be default only during cluster creation time, after + creation it will be either True or False. + + :type ssh_public_access_enabled: Optional[bool] + :param schedules: Compute instance schedules. Defaults to None. + :type schedules: Optional[~azure.ai.ml.entities.ComputeSchedules] + :param identity: The identities that are associated with the compute cluster. + :type identity: ~azure.ai.ml.entities.IdentityConfiguration + :param idle_time_before_shutdown: Deprecated. Use the `idle_time_before_shutdown_minutes` parameter instead. + Stops compute instance after user defined period of inactivity. + Time is defined in ISO8601 format. Minimum is 15 minutes, maximum is 3 days. + :type idle_time_before_shutdown: Optional[str] + :param idle_time_before_shutdown_minutes: Stops compute instance after a user defined period of + inactivity in minutes. Minimum is 15 minutes, maximum is 3 days. + :type idle_time_before_shutdown_minutes: Optional[int] + :param enable_node_public_ip: Enable or disable node public IP address provisioning. Defaults to True. + Possible values are: + + * True - Indicates that the compute nodes will have public IPs provisioned. + * False - Indicates that the compute nodes will have a private endpoint and no public IPs. + + :type enable_node_public_ip: Optional[bool] + :param setup_scripts: Details of customized scripts to execute for setting up the cluster. + :type setup_scripts: Optional[~azure.ai.ml.entities.SetupScripts] + :param custom_applications: List of custom applications and their endpoints for the compute instance. + :type custom_applications: Optional[List[~azure.ai.ml.entities.CustomApplications]] + :param enable_sso: Enable or disable single sign-on. Defaults to True. + :type enable_sso: bool + :param enable_root_access: Enable or disable root access. Defaults to True. + :type enable_root_access: bool + :param release_quota_on_stop: Release quota on stop for the compute instance. Defaults to False. + :type release_quota_on_stop: bool + :param enable_os_patching: Enable or disable OS patching for the compute instance. Defaults to False. + :type enable_os_patching: bool + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_compute.py + :start-after: [START compute_instance] + :end-before: [END compute_instance] + :language: python + :dedent: 8 + :caption: Creating a ComputeInstance object. + """ + + def __init__( + self, + *, + name: str, + description: Optional[str] = None, + size: Optional[str] = None, + tags: Optional[dict] = None, + ssh_public_access_enabled: Optional[bool] = None, + create_on_behalf_of: Optional[AssignedUserConfiguration] = None, + network_settings: Optional[NetworkSettings] = None, + ssh_settings: Optional[ComputeInstanceSshSettings] = None, + schedules: Optional[ComputeSchedules] = None, + identity: Optional[IdentityConfiguration] = None, + idle_time_before_shutdown: Optional[str] = None, + idle_time_before_shutdown_minutes: Optional[int] = None, + setup_scripts: Optional[SetupScripts] = None, + enable_node_public_ip: bool = True, + custom_applications: Optional[List[CustomApplications]] = None, + enable_sso: bool = True, + enable_root_access: bool = True, + release_quota_on_stop: bool = False, + enable_os_patching: bool = False, + **kwargs: Any, + ) -> None: + kwargs[TYPE] = ComputeType.COMPUTEINSTANCE + self._state: str = kwargs.pop("state", None) + self._last_operation: dict = kwargs.pop("last_operation", None) + self._os_image_metadata: ImageMetadata = kwargs.pop("os_image_metadata", None) + self._services: list = kwargs.pop("services", None) + super().__init__( + name=name, + location=kwargs.pop("location", None), + resource_id=kwargs.pop("resource_id", None), + description=description, + tags=tags, + **kwargs, + ) + self.size = size + self.ssh_public_access_enabled = ssh_public_access_enabled + self.create_on_behalf_of = create_on_behalf_of + self.network_settings = network_settings + self.ssh_settings = ssh_settings + self.schedules = schedules + self.identity = identity + self.idle_time_before_shutdown = idle_time_before_shutdown + self.idle_time_before_shutdown_minutes = idle_time_before_shutdown_minutes + self.setup_scripts = setup_scripts + self.enable_node_public_ip = enable_node_public_ip + self.enable_sso = enable_sso + self.enable_root_access = enable_root_access + self.release_quota_on_stop = release_quota_on_stop + self.enable_os_patching = enable_os_patching + self.custom_applications = custom_applications + self.subnet = None + + @property + def services(self) -> List[Dict[str, str]]: + """The compute instance's services. + + :return: The compute instance's services. + :rtype: List[Dict[str, str]] + """ + return self._services + + @property + def last_operation(self) -> Dict[str, str]: + """The last operation. + + :return: The last operation. + :rtype: str + """ + return self._last_operation + + @property + def state(self) -> str: + """The state of the compute. + + :return: The state of the compute. + :rtype: str + """ + return self._state + + @property + def os_image_metadata(self) -> ImageMetadata: + """Metadata about the operating system image for this compute instance. + + :return: Operating system image metadata. + :rtype: ~azure.ai.ml.entities.ImageMetadata + """ + return self._os_image_metadata + + def _to_rest_object(self) -> ComputeResource: + if self.network_settings and self.network_settings.subnet: + subnet_resource = ResourceId(id=self.subnet) + else: + subnet_resource = None + + ssh_settings = None + if self.ssh_public_access_enabled is not None or self.ssh_settings is not None: + ssh_settings = CiSShSettings() + ssh_settings.ssh_public_access = "Enabled" if self.ssh_public_access_enabled else "Disabled" + ssh_settings.admin_public_key = ( + self.ssh_settings.ssh_key_value if self.ssh_settings and self.ssh_settings.ssh_key_value else None + ) + + personal_compute_instance_settings = None + if self.create_on_behalf_of: + personal_compute_instance_settings = PersonalComputeInstanceSettings( + assigned_user=AssignedUser( + object_id=self.create_on_behalf_of.user_object_id, + tenant_id=self.create_on_behalf_of.user_tenant_id, + ) + ) + + idle_time_before_shutdown = None + if self.idle_time_before_shutdown_minutes: + idle_time_before_shutdown = f"PT{self.idle_time_before_shutdown_minutes}M" + elif self.idle_time_before_shutdown: + warnings.warn( + """ The property 'idle_time_before_shutdown' is deprecated. + Please use'idle_time_before_shutdown_minutes' instead.""", + DeprecationWarning, + ) + idle_time_before_shutdown = self.idle_time_before_shutdown + + compute_instance_prop = ComputeInstanceProperties( + vm_size=self.size if self.size else ComputeDefaults.VMSIZE, + subnet=subnet_resource, + ssh_settings=ssh_settings, + personal_compute_instance_settings=personal_compute_instance_settings, + idle_time_before_shutdown=idle_time_before_shutdown, + enable_node_public_ip=self.enable_node_public_ip, + enable_sso=self.enable_sso, + enable_root_access=self.enable_root_access, + release_quota_on_stop=self.release_quota_on_stop, + enable_os_patching=self.enable_os_patching, + ) + compute_instance_prop.schedules = self.schedules._to_rest_object() if self.schedules else None + compute_instance_prop.setup_scripts = self.setup_scripts._to_rest_object() if self.setup_scripts else None + if self.custom_applications: + validate_custom_applications(self.custom_applications) + compute_instance_prop.custom_services = [] + for app in self.custom_applications: + compute_instance_prop.custom_services.append(app._to_rest_object()) + compute_instance = CIRest( + description=self.description, + compute_type=self.type, + properties=compute_instance_prop, + ) + return ComputeResource( + location=self.location, + properties=compute_instance, + identity=(self.identity._to_compute_rest_object() if self.identity else None), + tags=self.tags, + ) + + def _to_dict(self) -> Dict: + res: dict = ComputeInstanceSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + def _set_full_subnet_name(self, subscription_id: str, rg: str) -> None: + if self.network_settings and (self.network_settings.vnet_name or self.network_settings.subnet): + self.subnet = get_subnet_str( + self.network_settings.vnet_name, + self.network_settings.subnet, + subscription_id, + rg, + ) + + @classmethod + def _load_from_rest(cls, rest_obj: ComputeResource) -> "ComputeInstance": + prop = rest_obj.properties + create_on_behalf_of = None + if prop.properties and prop.properties.personal_compute_instance_settings: + create_on_behalf_of = AssignedUserConfiguration( + user_tenant_id=prop.properties.personal_compute_instance_settings.assigned_user.tenant_id, + user_object_id=prop.properties.personal_compute_instance_settings.assigned_user.object_id, + ) + ssh_settings = None + if prop.properties and prop.properties.ssh_settings: + ssh_settings = ComputeInstanceSshSettings( + ssh_key_value=prop.properties.ssh_settings.admin_public_key, + ssh_port=prop.properties.ssh_settings.ssh_port, + admin_username=prop.properties.ssh_settings.admin_user_name, + ) + + network_settings = None + if prop.properties and ( + prop.properties.subnet + or ( + prop.properties.connectivity_endpoints + and ( + prop.properties.connectivity_endpoints.private_ip_address + or prop.properties.connectivity_endpoints.public_ip_address + ) + ) + ): + network_settings = NetworkSettings( + subnet=prop.properties.subnet.id if prop.properties.subnet else None, + public_ip_address=( + prop.properties.connectivity_endpoints.public_ip_address + if prop.properties.connectivity_endpoints + and prop.properties.connectivity_endpoints.public_ip_address + else None + ), + private_ip_address=( + prop.properties.connectivity_endpoints.private_ip_address + if prop.properties.connectivity_endpoints + and prop.properties.connectivity_endpoints.private_ip_address + else None + ), + ) + os_image_metadata = None + if prop.properties and prop.properties.os_image_metadata: + metadata = prop.properties.os_image_metadata + os_image_metadata = ImageMetadata( + is_latest_os_image_version=( + metadata.is_latest_os_image_version if metadata.is_latest_os_image_version is not None else None + ), + current_image_version=metadata.current_image_version if metadata.current_image_version else None, + latest_image_version=metadata.latest_image_version if metadata.latest_image_version else None, + ) + + idle_time_before_shutdown = None + idle_time_before_shutdown_minutes = None + idle_time_before_shutdown_pattern = r"PT([0-9]+)M" + if prop.properties and prop.properties.idle_time_before_shutdown: + idle_time_before_shutdown = prop.properties.idle_time_before_shutdown + idle_time_match = re.match( + pattern=idle_time_before_shutdown_pattern, + string=idle_time_before_shutdown, + ) + idle_time_before_shutdown_minutes = int(idle_time_match[1]) if idle_time_match else None + custom_applications = None + if prop.properties and prop.properties.custom_services: + custom_applications = [] + for app in prop.properties.custom_services: + custom_applications.append(CustomApplications._from_rest_object(app)) + response = ComputeInstance( + name=rest_obj.name, + id=rest_obj.id, + description=prop.description, + location=rest_obj.location, + resource_id=prop.resource_id, + tags=rest_obj.tags if rest_obj.tags else None, + provisioning_state=prop.provisioning_state, + provisioning_errors=( + prop.provisioning_errors[0].error.code + if (prop.provisioning_errors and len(prop.provisioning_errors) > 0) + else None + ), + size=prop.properties.vm_size if prop.properties else None, + state=prop.properties.state if prop.properties else None, + last_operation=( + prop.properties.last_operation.as_dict() if prop.properties and prop.properties.last_operation else None + ), + services=( + [app.as_dict() for app in prop.properties.applications] + if prop.properties and prop.properties.applications + else None + ), + created_on=( + rest_obj.properties.created_on.strftime("%Y-%m-%dT%H:%M:%S.%f%z") + if rest_obj.properties and rest_obj.properties.created_on is not None + else None + ), + create_on_behalf_of=create_on_behalf_of, + network_settings=network_settings, + ssh_settings=ssh_settings, + ssh_public_access_enabled=( + _ssh_public_access_to_bool(prop.properties.ssh_settings.ssh_public_access) + if (prop.properties and prop.properties.ssh_settings and prop.properties.ssh_settings.ssh_public_access) + else None + ), + schedules=( + ComputeSchedules._from_rest_object(prop.properties.schedules) + if prop.properties and prop.properties.schedules and prop.properties.schedules.compute_start_stop + else None + ), + identity=IdentityConfiguration._from_compute_rest_object(rest_obj.identity) if rest_obj.identity else None, + setup_scripts=( + SetupScripts._from_rest_object(prop.properties.setup_scripts) + if prop.properties and prop.properties.setup_scripts + else None + ), + idle_time_before_shutdown=idle_time_before_shutdown, + idle_time_before_shutdown_minutes=idle_time_before_shutdown_minutes, + os_image_metadata=os_image_metadata, + enable_node_public_ip=( + prop.properties.enable_node_public_ip + if (prop.properties and prop.properties.enable_node_public_ip is not None) + else True + ), + custom_applications=custom_applications, + enable_sso=( + prop.properties.enable_sso if (prop.properties and prop.properties.enable_sso is not None) else True + ), + enable_root_access=( + prop.properties.enable_root_access + if (prop.properties and prop.properties.enable_root_access is not None) + else True + ), + release_quota_on_stop=( + prop.properties.release_quota_on_stop + if (prop.properties and prop.properties.release_quota_on_stop is not None) + else False + ), + enable_os_patching=( + prop.properties.enable_os_patching + if (prop.properties and prop.properties.enable_os_patching is not None) + else False + ), + ) + return response + + @classmethod + def _load_from_dict(cls, data: Dict, context: Dict, **kwargs: Any) -> "ComputeInstance": + loaded_data = load_from_dict(ComputeInstanceSchema, data, context, **kwargs) + return ComputeInstance(**loaded_data) + + +def _ssh_public_access_to_bool(value: str) -> Optional[bool]: + if value.lower() == "disabled": + return False + if value.lower() == "enabled": + return True + return None diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/kubernetes_compute.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/kubernetes_compute.py new file mode 100644 index 00000000..bc8c2c28 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/kubernetes_compute.py @@ -0,0 +1,105 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from typing import Any, Dict, Optional + +from azure.ai.ml._restclient.v2022_10_01_preview.models import ComputeResource, Kubernetes, KubernetesProperties +from azure.ai.ml._schema.compute.kubernetes_compute import KubernetesComputeSchema +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE +from azure.ai.ml.constants._compute import ComputeType +from azure.ai.ml.entities._compute.compute import Compute +from azure.ai.ml.entities._credentials import IdentityConfiguration +from azure.ai.ml.entities._util import load_from_dict + + +class KubernetesCompute(Compute): + """Kubernetes Compute resource. + + :param namespace: The namespace of the KubernetesCompute. Defaults to "default". + :type namespace: Optional[str] + :param properties: The properties of the Kubernetes compute resource. + :type properties: Optional[Dict] + :param identity: The identities that are associated with the compute cluster. + :type identity: ~azure.ai.ml.entities.IdentityConfiguration + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_compute.py + :start-after: [START kubernetes_compute] + :end-before: [END kubernetes_compute] + :language: python + :dedent: 8 + :caption: Creating a KubernetesCompute object. + """ + + def __init__( + self, + *, + namespace: str = "default", + properties: Optional[Dict[str, Any]] = None, + identity: Optional[IdentityConfiguration] = None, + **kwargs: Any, + ) -> None: + kwargs[TYPE] = ComputeType.KUBERNETES + super().__init__(**kwargs) + self.namespace = namespace + self.properties = properties if properties else {} + if "properties" in self.properties: + self.properties["properties"]["namespace"] = namespace + self.identity = identity + + @classmethod + def _load_from_rest(cls, rest_obj: ComputeResource) -> "KubernetesCompute": + prop = rest_obj.properties + return KubernetesCompute( + name=rest_obj.name, + id=rest_obj.id, + description=prop.description, + location=rest_obj.location, + resource_id=prop.resource_id, + tags=rest_obj.tags if rest_obj.tags else None, + provisioning_state=prop.provisioning_state, + provisioning_errors=( + prop.provisioning_errors[0].error.code + if (prop.provisioning_errors and len(prop.provisioning_errors) > 0) + else None + ), + created_on=prop.additional_properties.get("createdOn", None), + properties=prop.properties.as_dict() if prop.properties else None, + namespace=prop.properties.namespace, + identity=IdentityConfiguration._from_compute_rest_object(rest_obj.identity) if rest_obj.identity else None, + ) + + def _to_dict(self) -> Dict: + res: dict = KubernetesComputeSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + @classmethod + def _load_from_dict(cls, data: Dict, context: Dict, **kwargs: Any) -> "KubernetesCompute": + if not data: + data = {"namespace": "default"} + if "namespace" not in data: + data["namespace"] = "default" + + loaded_data = load_from_dict(KubernetesComputeSchema, data, context, **kwargs) + return KubernetesCompute(**loaded_data) + + def _to_rest_object(self) -> ComputeResource: + kubernetes_prop = KubernetesProperties.from_dict(self.properties) + kubernetes_prop.namespace = self.namespace + kubernetes_comp = Kubernetes( + resource_id=self.resource_id, + compute_location=self.location, + description=self.description, + properties=kubernetes_prop, + ) + return ComputeResource( + location=self.location, + properties=kubernetes_comp, + name=self.name, + identity=(self.identity._to_compute_rest_object() if self.identity else None), + tags=self.tags, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/synapsespark_compute.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/synapsespark_compute.py new file mode 100644 index 00000000..99b366cb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/synapsespark_compute.py @@ -0,0 +1,234 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from typing import Any, Dict, Optional + +from azure.ai.ml._restclient.v2022_10_01_preview.models import ( + AutoPauseProperties, + AutoScaleProperties, + ComputeResource, + SynapseSpark, +) +from azure.ai.ml._schema.compute.synapsespark_compute import SynapseSparkComputeSchema +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE +from azure.ai.ml.constants._compute import ComputeType +from azure.ai.ml.entities import Compute +from azure.ai.ml.entities._credentials import IdentityConfiguration +from azure.ai.ml.entities._util import load_from_dict + + +class AutoScaleSettings: + """Auto-scale settings for Synapse Spark compute. + + :keyword min_node_count: The minimum compute node count. + :paramtype min_node_count: Optional[int] + :keyword max_node_count: The maximum compute node count. + :paramtype max_node_count: Optional[int] + :keyword enabled: Specifies if auto-scale is enabled. + :paramtype enabled: Optional[bool] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_spark_configurations.py + :start-after: [START synapse_spark_compute_configuration] + :end-before: [END synapse_spark_compute_configuration] + :language: python + :dedent: 8 + :caption: Configuring AutoScaleSettings on SynapseSparkCompute. + """ + + def __init__( + self, + *, + min_node_count: Optional[int] = None, + max_node_count: Optional[int] = None, + enabled: Optional[bool] = None, + ) -> None: + self.min_node_count = min_node_count + self.max_node_count = max_node_count + self.auto_scale_enabled = enabled + + def _to_auto_scale_settings(self) -> AutoScaleProperties: + return AutoScaleProperties( + min_node_count=self.min_node_count, + max_node_count=self.max_node_count, + auto_scale_enabled=self.auto_scale_enabled, + ) + + @classmethod + def _from_auto_scale_settings(cls, autoscaleprops: AutoScaleProperties) -> "AutoScaleSettings": + return cls( + min_node_count=autoscaleprops.min_node_count, + max_node_count=autoscaleprops.max_node_count, + enabled=autoscaleprops.enabled, + ) + + +class AutoPauseSettings: + """Auto pause settings for Synapse Spark compute. + + :keyword delay_in_minutes: The time delay in minutes before pausing cluster. + :paramtype delay_in_minutes: Optional[int] + :keyword enabled: Specifies if auto-pause is enabled. + :paramtype enabled: Optional[bool] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_spark_configurations.py + :start-after: [START synapse_spark_compute_configuration] + :end-before: [END synapse_spark_compute_configuration] + :language: python + :dedent: 8 + :caption: Configuring AutoPauseSettings on SynapseSparkCompute. + """ + + def __init__(self, *, delay_in_minutes: Optional[int] = None, enabled: Optional[bool] = None) -> None: + self.delay_in_minutes = delay_in_minutes + self.auto_pause_enabled = enabled + + def _to_auto_pause_settings(self) -> AutoPauseProperties: + return AutoPauseProperties( + delay_in_minutes=self.delay_in_minutes, + auto_pause_enabled=self.auto_pause_enabled, + ) + + @classmethod + def _from_auto_pause_settings(cls, autopauseprops: AutoPauseProperties) -> "AutoPauseSettings": + return cls( + delay_in_minutes=autopauseprops.delay_in_minutes, + enabled=autopauseprops.enabled, + ) + + +@experimental +class SynapseSparkCompute(Compute): + """SynapseSpark Compute resource. + + :keyword name: The name of the compute. + :paramtype name: str + :keyword description: The description of the resource. Defaults to None. + :paramtype description: Optional[str] + :keyword tags: The set of resource tags defined as key/value pairs. Defaults to None. + :paramtype tags: Optional[[dict[str, str]] + :keyword node_count: The number of nodes in the compute. + :paramtype node_count: Optional[int] + :keyword node_family: The node family of the compute. + :paramtype node_family: Optional[str] + :keyword node_size: The size of the node. + :paramtype node_size: Optional[str] + :keyword spark_version: The version of Spark to use. + :paramtype spark_version: Optional[str] + :keyword identity: The configuration of identities that are associated with the compute cluster. + :paramtype identity: Optional[~azure.ai.ml.entities.IdentityConfiguration] + :keyword scale_settings: The scale settings for the compute. + :paramtype scale_settings: Optional[~azure.ai.ml.entities.AutoScaleSettings] + :keyword auto_pause_settings: The auto pause settings for the compute. + :paramtype auto_pause_settings: Optional[~azure.ai.ml.entities.AutoPauseSettings] + :keyword kwargs: Additional keyword arguments passed to the parent class. + :paramtype kwargs: Optional[dict] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_spark_configurations.py + :start-after: [START synapse_spark_compute_configuration] + :end-before: [END synapse_spark_compute_configuration] + :language: python + :dedent: 8 + :caption: Creating Synapse Spark compute. + """ + + def __init__( + self, + *, + name: str, + description: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + node_count: Optional[int] = None, + node_family: Optional[str] = None, + node_size: Optional[str] = None, + spark_version: Optional[str] = None, + identity: Optional[IdentityConfiguration] = None, + scale_settings: Optional[AutoScaleSettings] = None, + auto_pause_settings: Optional[AutoPauseSettings] = None, + **kwargs: Any, + ) -> None: + kwargs[TYPE] = ComputeType.SYNAPSESPARK + super().__init__(name=name, description=description, location=kwargs.pop("location", None), tags=tags, **kwargs) + self.identity = identity + self.node_count = node_count + self.node_family = node_family + self.node_size = node_size + self.spark_version = spark_version + self.scale_settings = scale_settings + self.auto_pause_settings = auto_pause_settings + + @classmethod + def _load_from_rest(cls, rest_obj: ComputeResource) -> "SynapseSparkCompute": + prop = rest_obj.properties + scale_settings = ( + # pylint: disable=protected-access + AutoScaleSettings._from_auto_scale_settings(prop.properties.auto_scale_properties) + if prop.properties.auto_scale_properties + else None + ) + + auto_pause_settings = ( + # pylint: disable=protected-access + AutoPauseSettings._from_auto_pause_settings(prop.properties.auto_pause_properties) + if prop.properties.auto_pause_properties + else None + ) + + return SynapseSparkCompute( + name=rest_obj.name, + id=rest_obj.id, + description=prop.description, + location=rest_obj.location, + resource_id=prop.resource_id, + tags=rest_obj.tags if rest_obj.tags else None, + created_on=prop.created_on if prop.properties else None, + node_count=prop.properties.node_count if prop.properties else None, + node_family=prop.properties.node_size_family if prop.properties else None, + node_size=prop.properties.node_size if prop.properties else None, + spark_version=prop.properties.spark_version if prop.properties else None, + # pylint: disable=protected-access + identity=IdentityConfiguration._from_compute_rest_object(rest_obj.identity) if rest_obj.identity else None, + scale_settings=scale_settings, + auto_pause_settings=auto_pause_settings, + provisioning_state=prop.provisioning_state, + provisioning_errors=( + prop.provisioning_errors[0].error.code + if (prop.provisioning_errors and len(prop.provisioning_errors) > 0) + else None + ), + ) + + def _to_dict(self) -> Dict: + res: dict = SynapseSparkComputeSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + @classmethod + def _load_from_dict(cls, data: Dict, context: Dict, **kwargs: Any) -> "SynapseSparkCompute": + loaded_data = load_from_dict(SynapseSparkComputeSchema, data, context, **kwargs) + return SynapseSparkCompute(**loaded_data) + + def _to_rest_object(self) -> ComputeResource: + synapsespark_comp = SynapseSpark( + name=self.name, + compute_type=self.type, + resource_id=self.resource_id, + description=self.description, + ) + return ComputeResource( + location=self.location, + properties=synapsespark_comp, + name=self.name, + identity=( + # pylint: disable=protected-access + self.identity._to_compute_rest_object() + if self.identity + else None + ), + tags=self.tags, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/unsupported_compute.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/unsupported_compute.py new file mode 100644 index 00000000..258fbf6b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/unsupported_compute.py @@ -0,0 +1,62 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from typing import Any, Dict + +from azure.ai.ml._restclient.v2022_10_01_preview.models import ComputeResource +from azure.ai.ml.constants._common import TYPE +from azure.ai.ml.entities._compute.compute import Compute +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException + + +class UnsupportedCompute(Compute): + """Unsupported compute resource. + + Only used for displaying compute properties for resources not fully supported in the SDK. + """ + + def __init__( + self, + **kwargs: Any, + ) -> None: + kwargs[TYPE] = "*** Unsupported Compute Type ***" + super().__init__(**kwargs) + + @classmethod + def _load_from_rest(cls, rest_obj: ComputeResource) -> "UnsupportedCompute": + prop = rest_obj.properties + if hasattr(rest_obj, "tags"): + # TODO(2294131): remove this when DataFactory object has no tags got fixed + tags = rest_obj.tags + else: + tags = None + response = UnsupportedCompute( + name=rest_obj.name, + id=rest_obj.id, + description=prop.description, + location=rest_obj.location, + resource_id=prop.resource_id, + tags=tags, + provisioning_state=prop.provisioning_state, + created_on=prop.additional_properties.get("createdOn", None), + ) + return response + + @classmethod + def _load_from_dict(cls, data: Dict, context: Dict, **kwargs: Any) -> "UnsupportedCompute": + msg = "Cannot create unsupported compute type." + raise ValidationException( + message=msg, + target=ErrorTarget.COMPUTE, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + ) + + def _to_rest_object(self) -> ComputeResource: + msg = "Cannot create unsupported compute type." + raise ValidationException( + message=msg, + target=ErrorTarget.COMPUTE, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/virtual_machine_compute.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/virtual_machine_compute.py new file mode 100644 index 00000000..90c3ec63 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/virtual_machine_compute.py @@ -0,0 +1,172 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from pathlib import Path +from typing import Any, Dict, Optional + +from azure.ai.ml._restclient.v2022_10_01_preview.models import ComputeResource +from azure.ai.ml._restclient.v2022_10_01_preview.models import VirtualMachine as VMResource +from azure.ai.ml._restclient.v2022_10_01_preview.models import ( + VirtualMachineSchemaProperties, + VirtualMachineSshCredentials, +) +from azure.ai.ml._schema.compute.virtual_machine_compute import VirtualMachineComputeSchema +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE, DefaultOpenEncoding +from azure.ai.ml.constants._compute import ComputeType +from azure.ai.ml.entities._compute.compute import Compute +from azure.ai.ml.entities._util import load_from_dict + + +class VirtualMachineSshSettings: + """SSH settings for a virtual machine. + + :param admin_username: The admin user name. Defaults to None. + :type admin_username: str + :param admin_password: The admin user password. Defaults to None. + Required if `ssh_private_key_file` is not specified. + :type admin_password: Optional[str] + :param ssh_port: The ssh port number. Default is 22. + :type ssh_port: int + :param ssh_private_key_file: Path to the file containing the SSH rsa private key. + Use "ssh-keygen -t rsa -b 2048" to generate your SSH key pairs. + Required if admin_password is not specified. + :type ssh_private_key_file: Optional[str] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_compute.py + :start-after: [START vm_ssh_settings] + :end-before: [END vm_ssh_settings] + :language: python + :dedent: 8 + :caption: Configuring a VirtualMachineSshSettings object. + """ + + def __init__( + self, + *, + admin_username: Optional[str], + admin_password: Optional[str] = None, + ssh_port: Optional[int] = 22, + ssh_private_key_file: Optional[str] = None, + ) -> None: + self.admin_username = admin_username + self.admin_password = admin_password + self.ssh_port = ssh_port + self.ssh_private_key_file = ssh_private_key_file + + +class VirtualMachineCompute(Compute): + """Virtual Machine Compute resource. + + :param name: Name of the compute resource. + :type name: str + :param description: Description of the resource. Defaults to None. + :type description: Optional[str] + :param resource_id: ARM resource ID of the underlying compute resource. + :type resource_id: str + :param tags: A set of tags. Contains resource tags defined as key/value pairs. + :type tags: Optional[dict] + :param ssh_settings: SSH settings. Defaults to None. + :type ssh_settings: Optional[~azure.ai.ml.entities.VirtualMachineSshSettings] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_compute.py + :start-after: [START vm_compute] + :end-before: [END vm_compute] + :language: python + :dedent: 8 + :caption: Configuring a VirtualMachineCompute object. + """ + + def __init__( + self, + *, + name: str, + description: Optional[str] = None, + resource_id: str, + tags: Optional[dict] = None, + ssh_settings: Optional[VirtualMachineSshSettings] = None, + **kwargs: Any, + ) -> None: + kwargs[TYPE] = ComputeType.VIRTUALMACHINE + self._public_key_data: str = kwargs.pop("public_key_data", None) + super().__init__( + name=name, + location=kwargs.pop("location", None), + description=description, + resource_id=resource_id, + tags=tags, + **kwargs, + ) + self.ssh_settings = ssh_settings + + @property + def public_key_data(self) -> str: + """Public key data. + + :return: Public key data. + :rtype: str + """ + return self._public_key_data + + @classmethod + def _load_from_rest(cls, rest_obj: ComputeResource) -> "VirtualMachineCompute": + prop = rest_obj.properties + credentials = prop.properties.administrator_account if prop.properties else None + ssh_settings_param = None + if credentials or (prop.properties and prop.properties.ssh_port): + ssh_settings_param = VirtualMachineSshSettings( + admin_username=credentials.username if credentials else None, + admin_password=credentials.password if credentials else None, + ssh_port=prop.properties.ssh_port if prop.properties else None, + ) + response = VirtualMachineCompute( + name=rest_obj.name, + id=rest_obj.id, + description=prop.description, + location=rest_obj.location, + resource_id=prop.resource_id, + tags=rest_obj.tags if rest_obj.tags else None, + public_key_data=credentials.public_key_data if credentials else None, + provisioning_state=prop.provisioning_state, + provisioning_errors=( + prop.provisioning_errors[0].error.code + if (prop.provisioning_errors and len(prop.provisioning_errors) > 0) + else None + ), + ssh_settings=ssh_settings_param, + ) + return response + + def _to_dict(self) -> Dict: + res: dict = VirtualMachineComputeSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + @classmethod + def _load_from_dict(cls, data: Dict, context: Dict, **kwargs: Any) -> "VirtualMachineCompute": + loaded_data = load_from_dict(VirtualMachineComputeSchema, data, context, **kwargs) + return VirtualMachineCompute(**loaded_data) + + def _to_rest_object(self) -> ComputeResource: + ssh_key_value = None + if self.ssh_settings and self.ssh_settings.ssh_private_key_file: + ssh_key_value = Path(self.ssh_settings.ssh_private_key_file).read_text(encoding=DefaultOpenEncoding.READ) + credentials = VirtualMachineSshCredentials( + username=self.ssh_settings.admin_username if self.ssh_settings else None, + password=self.ssh_settings.admin_password if self.ssh_settings else None, + public_key_data=self.public_key_data, + private_key_data=ssh_key_value, + ) + if self.ssh_settings is not None: + properties = VirtualMachineSchemaProperties( + ssh_port=self.ssh_settings.ssh_port, administrator_account=credentials + ) + vm_compute = VMResource( + properties=properties, # pylint: disable=possibly-used-before-assignment + resource_id=self.resource_id, + description=self.description, + ) + resource = ComputeResource(name=self.name, location=self.location, tags=self.tags, properties=vm_compute) + return resource diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_credentials.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_credentials.py new file mode 100644 index 00000000..b4d8e01d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_credentials.py @@ -0,0 +1,964 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access,redefined-builtin + +from abc import ABC +from typing import Any, Dict, List, Optional, Type, Union + +from azure.ai.ml._azure_environments import _get_active_directory_url_from_metadata +from azure.ai.ml._restclient.v2022_01_01_preview.models import Identity as RestIdentityConfiguration +from azure.ai.ml._restclient.v2022_01_01_preview.models import ManagedIdentity as RestWorkspaceConnectionManagedIdentity +from azure.ai.ml._restclient.v2022_01_01_preview.models import ( + PersonalAccessToken as RestWorkspaceConnectionPersonalAccessToken, +) +from azure.ai.ml._restclient.v2022_01_01_preview.models import ( + ServicePrincipal as RestWorkspaceConnectionServicePrincipal, +) +from azure.ai.ml._restclient.v2022_01_01_preview.models import ( + SharedAccessSignature as RestWorkspaceConnectionSharedAccessSignature, +) +from azure.ai.ml._restclient.v2022_01_01_preview.models import UserAssignedIdentity as RestUserAssignedIdentity +from azure.ai.ml._restclient.v2022_01_01_preview.models import ( + UsernamePassword as RestWorkspaceConnectionUsernamePassword, +) +from azure.ai.ml._restclient.v2022_05_01.models import ManagedServiceIdentity as RestManagedServiceIdentityConfiguration +from azure.ai.ml._restclient.v2022_05_01.models import UserAssignedIdentity as RestUserAssignedIdentityConfiguration +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + AccountKeyDatastoreCredentials as RestAccountKeyDatastoreCredentials, +) +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + AccountKeyDatastoreSecrets as RestAccountKeyDatastoreSecrets, +) +from azure.ai.ml._restclient.v2023_04_01_preview.models import AmlToken as RestAmlToken +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + CertificateDatastoreCredentials as RestCertificateDatastoreCredentials, +) +from azure.ai.ml._restclient.v2023_04_01_preview.models import CertificateDatastoreSecrets, CredentialsType +from azure.ai.ml._restclient.v2023_04_01_preview.models import IdentityConfiguration as RestJobIdentityConfiguration +from azure.ai.ml._restclient.v2023_04_01_preview.models import IdentityConfigurationType +from azure.ai.ml._restclient.v2023_04_01_preview.models import ManagedIdentity as RestJobManagedIdentity +from azure.ai.ml._restclient.v2023_04_01_preview.models import ManagedServiceIdentity as RestRegistryManagedIdentity +from azure.ai.ml._restclient.v2023_04_01_preview.models import NoneDatastoreCredentials as RestNoneDatastoreCredentials +from azure.ai.ml._restclient.v2023_04_01_preview.models import SasDatastoreCredentials as RestSasDatastoreCredentials +from azure.ai.ml._restclient.v2023_04_01_preview.models import SasDatastoreSecrets as RestSasDatastoreSecrets +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + ServicePrincipalDatastoreCredentials as RestServicePrincipalDatastoreCredentials, +) +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + ServicePrincipalDatastoreSecrets as RestServicePrincipalDatastoreSecrets, +) +from azure.ai.ml._restclient.v2023_04_01_preview.models import UserIdentity as RestUserIdentity +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + WorkspaceConnectionAccessKey as RestWorkspaceConnectionAccessKey, +) +from azure.ai.ml._restclient.v2023_06_01_preview.models import ( + WorkspaceConnectionApiKey as RestWorkspaceConnectionApiKey, +) + +# Note, this import needs to match the restclient that's imported by the +# Connection class, otherwise some unit tests will start failing +# Due to the mismatch between expected and received classes in WC rest conversions. +from azure.ai.ml._restclient.v2024_04_01_preview.models import ( + AADAuthTypeWorkspaceConnectionProperties, + AccessKeyAuthTypeWorkspaceConnectionProperties, + AccountKeyAuthTypeWorkspaceConnectionProperties, + ApiKeyAuthWorkspaceConnectionProperties, + ConnectionAuthType, + ManagedIdentityAuthTypeWorkspaceConnectionProperties, + NoneAuthTypeWorkspaceConnectionProperties, + PATAuthTypeWorkspaceConnectionProperties, + SASAuthTypeWorkspaceConnectionProperties, + ServicePrincipalAuthTypeWorkspaceConnectionProperties, + UsernamePasswordAuthTypeWorkspaceConnectionProperties, +) +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml._utils.utils import _snake_to_camel, camel_to_snake, snake_to_pascal +from azure.ai.ml.constants._common import CommonYamlFields, IdentityType +from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin, YamlTranslatableMixin +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, JobException, ValidationErrorType, ValidationException + + +class _BaseIdentityConfiguration(ABC, DictMixin, RestTranslatableMixin): + def __init__(self) -> None: + self.type: Any = None + + @classmethod + def _get_credential_class_from_rest_type(cls, auth_type: str) -> Type: + # Defined in this file instead of in constants file to avoid risking + # circular imports. This map links rest enums to the corresponding client classes. + # Enums are all lower-cased because rest enums aren't always consistent with their + # camel casing rules. + # Defined in this class because I didn't want this at the bottom of the file, + # but the classes aren't visible to the interpreter at the start of the file. + # Technically most of these classes aren't child of _BaseIdentityConfiguration, but + # I don't care. + REST_CREDENTIAL_TYPE_TO_CLIENT_CLASS_MAP = { + ConnectionAuthType.SAS.lower(): SasTokenConfiguration, + ConnectionAuthType.PAT.lower(): PatTokenConfiguration, + ConnectionAuthType.ACCESS_KEY.lower(): AccessKeyConfiguration, + ConnectionAuthType.USERNAME_PASSWORD.lower(): UsernamePasswordConfiguration, + ConnectionAuthType.SERVICE_PRINCIPAL.lower(): ServicePrincipalConfiguration, + ConnectionAuthType.MANAGED_IDENTITY.lower(): ManagedIdentityConfiguration, + ConnectionAuthType.API_KEY.lower(): ApiKeyConfiguration, + ConnectionAuthType.ACCOUNT_KEY.lower(): AccountKeyConfiguration, + ConnectionAuthType.AAD.lower(): AadCredentialConfiguration, + } + if not auth_type: + return NoneCredentialConfiguration + return REST_CREDENTIAL_TYPE_TO_CLIENT_CLASS_MAP.get( + _snake_to_camel(auth_type).lower(), NoneCredentialConfiguration + ) + + +class AccountKeyConfiguration(RestTranslatableMixin, DictMixin): + def __init__( + self, + *, + account_key: Optional[str], + ) -> None: + self.type = camel_to_snake(CredentialsType.ACCOUNT_KEY) + self.account_key = account_key + + def _to_datastore_rest_object(self) -> RestAccountKeyDatastoreCredentials: + secrets = RestAccountKeyDatastoreSecrets(key=self.account_key) + return RestAccountKeyDatastoreCredentials(secrets=secrets) + + @classmethod + def _from_datastore_rest_object(cls, obj: RestAccountKeyDatastoreCredentials) -> "AccountKeyConfiguration": + return cls(account_key=obj.secrets.key if obj.secrets else None) + + @classmethod + def _from_workspace_connection_rest_object( + cls, obj: Optional[RestWorkspaceConnectionSharedAccessSignature] + ) -> "AccountKeyConfiguration": + # As far as I can tell, account key configs use the name underlying + # rest object as sas token configs + return cls(account_key=obj.sas if obj is not None and obj.sas else None) + + def _to_workspace_connection_rest_object(self) -> RestWorkspaceConnectionSharedAccessSignature: + return RestWorkspaceConnectionSharedAccessSignature(sas=self.account_key) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, AccountKeyConfiguration): + return NotImplemented + return self.account_key == other.account_key + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + @classmethod + def _get_rest_properties_class(cls) -> Type: + return AccountKeyAuthTypeWorkspaceConnectionProperties + + +class SasTokenConfiguration(RestTranslatableMixin, DictMixin): + def __init__( + self, + *, + sas_token: Optional[str], + ) -> None: + super().__init__() + self.type = camel_to_snake(CredentialsType.SAS) + self.sas_token = sas_token + + def _to_datastore_rest_object(self) -> RestSasDatastoreCredentials: + secrets = RestSasDatastoreSecrets(sas_token=self.sas_token) + return RestSasDatastoreCredentials(secrets=secrets) + + @classmethod + def _from_datastore_rest_object(cls, obj: RestSasDatastoreCredentials) -> "SasTokenConfiguration": + return cls(sas_token=obj.secrets.sas_token if obj.secrets else None) + + def _to_workspace_connection_rest_object(self) -> RestWorkspaceConnectionSharedAccessSignature: + return RestWorkspaceConnectionSharedAccessSignature(sas=self.sas_token) + + @classmethod + def _from_workspace_connection_rest_object( + cls, obj: Optional[RestWorkspaceConnectionSharedAccessSignature] + ) -> "SasTokenConfiguration": + return cls(sas_token=obj.sas if obj is not None and obj.sas else None) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SasTokenConfiguration): + return NotImplemented + return self.sas_token == other.sas_token + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + @classmethod + def _get_rest_properties_class(cls) -> Type: + return SASAuthTypeWorkspaceConnectionProperties + + +class PatTokenConfiguration(RestTranslatableMixin, DictMixin): + """Personal access token credentials. + + :param pat: Personal access token. + :type pat: str + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START personal_access_token_configuration] + :end-before: [END personal_access_token_configuration] + :language: python + :dedent: 8 + :caption: Configuring a personal access token configuration for a WorkspaceConnection. + """ + + def __init__(self, *, pat: Optional[str]) -> None: + super().__init__() + self.type = camel_to_snake(ConnectionAuthType.PAT) + self.pat = pat + + def _to_workspace_connection_rest_object(self) -> RestWorkspaceConnectionPersonalAccessToken: + return RestWorkspaceConnectionPersonalAccessToken(pat=self.pat) + + @classmethod + def _from_workspace_connection_rest_object( + cls, obj: Optional[RestWorkspaceConnectionPersonalAccessToken] + ) -> "PatTokenConfiguration": + return cls(pat=obj.pat if obj is not None and obj.pat else None) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, PatTokenConfiguration): + return NotImplemented + return self.pat == other.pat + + @classmethod + def _get_rest_properties_class(cls) -> Type: + return PATAuthTypeWorkspaceConnectionProperties + + +class UsernamePasswordConfiguration(RestTranslatableMixin, DictMixin): + """Username and password credentials. + + :param username: The username, value should be url-encoded. + :type username: str + :param password: The password, value should be url-encoded. + :type password: str + """ + + def __init__( + self, + *, + username: Optional[str], + password: Optional[str], + ) -> None: + super().__init__() + self.type = camel_to_snake(ConnectionAuthType.USERNAME_PASSWORD) + self.username = username + self.password = password + + def _to_workspace_connection_rest_object(self) -> RestWorkspaceConnectionUsernamePassword: + return RestWorkspaceConnectionUsernamePassword(username=self.username, password=self.password) + + @classmethod + def _from_workspace_connection_rest_object( + cls, obj: Optional[RestWorkspaceConnectionUsernamePassword] + ) -> "UsernamePasswordConfiguration": + return cls( + username=obj.username if obj is not None and obj.username else None, + password=obj.password if obj is not None and obj.password else None, + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, UsernamePasswordConfiguration): + return NotImplemented + return self.username == other.username and self.password == other.password + + @classmethod + def _get_rest_properties_class(cls) -> Type: + return UsernamePasswordAuthTypeWorkspaceConnectionProperties + + +class BaseTenantCredentials(RestTranslatableMixin, DictMixin, ABC): + """Base class for tenant credentials. + + This class should not be instantiated directly. Instead, use one of its subclasses. + + :param authority_url: The authority URL. If None specified, a URL will be retrieved from the metadata in the cloud. + :type authority_url: Optional[str] + :param resource_url: The resource URL. + :type resource_url: Optional[str] + :param tenant_id: The tenant ID. + :type tenant_id: Optional[str] + :param client_id: The client ID. + :type client_id: Optional[str] + """ + + def __init__( + self, + authority_url: str = _get_active_directory_url_from_metadata(), + resource_url: Optional[str] = None, + tenant_id: Optional[str] = None, + client_id: Optional[str] = None, + ) -> None: + super().__init__() + self.authority_url = authority_url + self.resource_url = resource_url + self.tenant_id = tenant_id + self.client_id = client_id + + +class ServicePrincipalConfiguration(BaseTenantCredentials): + """Service Principal credentials configuration. + + :param client_secret: The client secret. + :type client_secret: str + :keyword kwargs: Additional arguments to pass to the parent class. + :paramtype kwargs: Optional[dict] + """ + + def __init__( + self, + *, + client_secret: Optional[str], + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.type = camel_to_snake(CredentialsType.SERVICE_PRINCIPAL) + self.client_secret = client_secret + + def _to_datastore_rest_object(self) -> RestServicePrincipalDatastoreCredentials: + secrets = RestServicePrincipalDatastoreSecrets(client_secret=self.client_secret) + return RestServicePrincipalDatastoreCredentials( + authority_url=self.authority_url, + resource_url=self.resource_url, + tenant_id=self.tenant_id, + client_id=self.client_id, + secrets=secrets, + ) + + @classmethod + def _from_datastore_rest_object( + cls, obj: RestServicePrincipalDatastoreCredentials + ) -> "ServicePrincipalConfiguration": + return cls( + authority_url=obj.authority_url, + resource_url=obj.resource_url, + tenant_id=obj.tenant_id, + client_id=obj.client_id, + client_secret=obj.secrets.client_secret if obj.secrets else None, + ) + + def _to_workspace_connection_rest_object(self) -> RestWorkspaceConnectionServicePrincipal: + return RestWorkspaceConnectionServicePrincipal( + client_id=self.client_id, + client_secret=self.client_secret, + tenant_id=self.tenant_id, + ) + + @classmethod + def _from_workspace_connection_rest_object( + cls, obj: Optional[RestWorkspaceConnectionServicePrincipal] + ) -> "ServicePrincipalConfiguration": + return cls( + client_id=obj.client_id if obj is not None and obj.client_id else None, + client_secret=obj.client_secret if obj is not None and obj.client_secret else None, + tenant_id=obj.tenant_id if obj is not None and obj.tenant_id else None, + authority_url="", + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ServicePrincipalConfiguration): + return NotImplemented + return ( + self.authority_url == other.authority_url + and self.resource_url == other.resource_url + and self.tenant_id == other.tenant_id + and self.client_id == other.client_id + and self.client_secret == other.client_secret + ) + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + @classmethod + def _get_rest_properties_class(cls) -> Type: + return ServicePrincipalAuthTypeWorkspaceConnectionProperties + + +class CertificateConfiguration(BaseTenantCredentials): + def __init__( + self, + certificate: Optional[str] = None, + thumbprint: Optional[str] = None, + **kwargs: str, + ) -> None: + super().__init__(**kwargs) + self.type = CredentialsType.CERTIFICATE + self.certificate = certificate + self.thumbprint = thumbprint + + def _to_datastore_rest_object(self) -> RestCertificateDatastoreCredentials: + secrets = CertificateDatastoreSecrets(certificate=self.certificate) + return RestCertificateDatastoreCredentials( + authority_url=self.authority_url, + resource_uri=self.resource_url, + tenant_id=self.tenant_id, + client_id=self.client_id, + thumbprint=self.thumbprint, + secrets=secrets, + ) + + @classmethod + def _from_datastore_rest_object(cls, obj: RestCertificateDatastoreCredentials) -> "CertificateConfiguration": + return cls( + authority_url=obj.authority_url, + resource_url=obj.resource_uri, + tenant_id=obj.tenant_id, + client_id=obj.client_id, + thumbprint=obj.thumbprint, + certificate=obj.secrets.certificate if obj.secrets else None, + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, CertificateConfiguration): + return NotImplemented + return ( + self.authority_url == other.authority_url + and self.resource_url == other.resource_url + and self.tenant_id == other.tenant_id + and self.client_id == other.client_id + and self.thumbprint == other.thumbprint + and self.certificate == other.certificate + ) + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + +class _BaseJobIdentityConfiguration(ABC, RestTranslatableMixin, DictMixin, YamlTranslatableMixin): + def __init__(self) -> None: + self.type = None + + @classmethod + def _from_rest_object(cls, obj: RestJobIdentityConfiguration) -> "RestIdentityConfiguration": + if obj is None: + return None + mapping = { + IdentityConfigurationType.AML_TOKEN: AmlTokenConfiguration, + IdentityConfigurationType.MANAGED: ManagedIdentityConfiguration, + IdentityConfigurationType.USER_IDENTITY: UserIdentityConfiguration, + } + + if isinstance(obj, dict): + # TODO: support data binding expression + obj = RestJobIdentityConfiguration.from_dict(obj) + + identity_class = mapping.get(obj.identity_type, None) + if identity_class: + if obj.identity_type == IdentityConfigurationType.AML_TOKEN: + return AmlTokenConfiguration._from_job_rest_object(obj) + + if obj.identity_type == IdentityConfigurationType.MANAGED: + return ManagedIdentityConfiguration._from_job_rest_object(obj) + + if obj.identity_type == IdentityConfigurationType.USER_IDENTITY: + return UserIdentityConfiguration._from_job_rest_object(obj) + + msg = f"Unknown identity type: {obj.identity_type}" + raise JobException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.IDENTITY, + error_category=ErrorCategory.SYSTEM_ERROR, + ) + + @classmethod + def _load( + cls, + data: Dict, + ) -> Union["ManagedIdentityConfiguration", "UserIdentityConfiguration", "AmlTokenConfiguration"]: + type_str = data.get(CommonYamlFields.TYPE) + if type_str == IdentityType.MANAGED_IDENTITY: + return ManagedIdentityConfiguration._load_from_dict(data) + + if type_str == IdentityType.USER_IDENTITY: + return UserIdentityConfiguration._load_from_dict(data) + + if type_str == IdentityType.AML_TOKEN: + return AmlTokenConfiguration._load_from_dict(data) + + msg = f"Unsupported identity type: {type_str}." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.IDENTITY, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + +class ManagedIdentityConfiguration(_BaseIdentityConfiguration): + """Managed Identity credential configuration. + + :keyword client_id: The client ID of the managed identity. + :paramtype client_id: Optional[str] + :keyword resource_id: The resource ID of the managed identity. + :paramtype resource_id: Optional[str] + :keyword object_id: The object ID. + :paramtype object_id: Optional[str] + :keyword principal_id: The principal ID. + :paramtype principal_id: Optional[str] + """ + + def __init__( + self, + *, + client_id: Optional[str] = None, + resource_id: Optional[str] = None, + object_id: Optional[str] = None, + principal_id: Optional[str] = None, + ) -> None: + super().__init__() + self.type = IdentityType.MANAGED_IDENTITY + self.client_id = client_id + # TODO: Check if both client_id and resource_id are required + self.resource_id = resource_id + self.object_id = object_id + self.principal_id = principal_id + + def _to_workspace_connection_rest_object(self) -> RestWorkspaceConnectionManagedIdentity: + return RestWorkspaceConnectionManagedIdentity(client_id=self.client_id, resource_id=self.resource_id) + + @classmethod + def _from_workspace_connection_rest_object( + cls, obj: Optional[RestWorkspaceConnectionManagedIdentity] + ) -> "ManagedIdentityConfiguration": + return cls( + client_id=obj.client_id if obj is not None and obj.client_id else None, + resource_id=obj.resource_id if obj is not None and obj.client_id else None, + ) + + def _to_job_rest_object(self) -> RestJobManagedIdentity: + return RestJobManagedIdentity( + client_id=self.client_id, + object_id=self.object_id, + resource_id=self.resource_id, + ) + + @classmethod + def _from_job_rest_object(cls, obj: RestJobManagedIdentity) -> "ManagedIdentityConfiguration": + return cls( + client_id=obj.client_id, + object_id=obj.client_id, + resource_id=obj.resource_id, + ) + + def _to_identity_configuration_rest_object(self) -> RestUserAssignedIdentity: + return RestUserAssignedIdentity() + + @classmethod + def _from_identity_configuration_rest_object( + cls, rest_obj: RestUserAssignedIdentity, **kwargs: Optional[str] + ) -> "ManagedIdentityConfiguration": + _rid: Optional[str] = kwargs["resource_id"] + result = cls(resource_id=_rid) + result.__dict__.update(rest_obj.as_dict()) + return result + + def _to_online_endpoint_rest_object(self) -> RestUserAssignedIdentityConfiguration: + return RestUserAssignedIdentityConfiguration() + + def _to_workspace_rest_object(self) -> RestUserAssignedIdentityConfiguration: + return RestUserAssignedIdentityConfiguration( + principal_id=self.principal_id, + client_id=self.client_id, + ) + + @classmethod + def _from_workspace_rest_object(cls, obj: RestUserAssignedIdentityConfiguration) -> "ManagedIdentityConfiguration": + return cls( + principal_id=obj.principal_id, + client_id=obj.client_id, + ) + + def _to_dict(self) -> Dict: + # pylint: disable=no-member + from azure.ai.ml._schema.job.identity import ManagedIdentitySchema + + _dict: Dict = ManagedIdentitySchema().dump(self) + return _dict + + @classmethod + def _load_from_dict(cls, data: Dict) -> "ManagedIdentityConfiguration": + # pylint: disable=no-member + from azure.ai.ml._schema.job.identity import ManagedIdentitySchema + + _data: ManagedIdentityConfiguration = ManagedIdentitySchema().load(data) + return _data + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ManagedIdentityConfiguration): + return NotImplemented + return self.client_id == other.client_id and self.resource_id == other.resource_id + + @classmethod + def _get_rest_properties_class(cls) -> Type: + return ManagedIdentityAuthTypeWorkspaceConnectionProperties + + +class UserIdentityConfiguration(_BaseIdentityConfiguration): + """User identity configuration. + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_authentication.py + :start-after: [START user_identity_configuration] + :end-before: [END user_identity_configuration] + :language: python + :dedent: 8 + :caption: Configuring a UserIdentityConfiguration for a command(). + """ + + def __init__(self) -> None: + super().__init__() + self.type = IdentityType.USER_IDENTITY + + def _to_job_rest_object(self) -> RestUserIdentity: + return RestUserIdentity() + + @classmethod + # pylint: disable=unused-argument + def _from_job_rest_object(cls, obj: RestUserIdentity) -> "RestUserIdentity": + return cls() + + def _to_dict(self) -> Dict: + # pylint: disable=no-member + from azure.ai.ml._schema.job.identity import UserIdentitySchema + + _dict: Dict = UserIdentitySchema().dump(self) + return _dict + + @classmethod + def _load_from_dict(cls, data: Dict) -> "UserIdentityConfiguration": + # pylint: disable=no-member + from azure.ai.ml._schema.job.identity import UserIdentitySchema + + _data: UserIdentityConfiguration = UserIdentitySchema().load(data) + return _data + + def __eq__(self, other: object) -> bool: + if not isinstance(other, UserIdentityConfiguration): + return NotImplemented + res: bool = self._to_job_rest_object() == other._to_job_rest_object() + return res + + +class AmlTokenConfiguration(_BaseIdentityConfiguration): + """AzureML Token identity configuration. + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_authentication.py + :start-after: [START aml_token_configuration] + :end-before: [END aml_token_configuration] + :language: python + :dedent: 8 + :caption: Configuring an AmlTokenConfiguration for a command(). + """ + + def __init__(self) -> None: + super().__init__() + self.type = IdentityType.AML_TOKEN + + def _to_job_rest_object(self) -> RestAmlToken: + return RestAmlToken() + + def _to_dict(self) -> Dict: + # pylint: disable=no-member + from azure.ai.ml._schema.job.identity import AMLTokenIdentitySchema + + _dict: Dict = AMLTokenIdentitySchema().dump(self) + return _dict + + @classmethod + def _load_from_dict(cls, data: Dict) -> "AmlTokenConfiguration": + # pylint: disable=no-member + from azure.ai.ml._schema.job.identity import AMLTokenIdentitySchema + + _data: AmlTokenConfiguration = AMLTokenIdentitySchema().load(data) + return _data + + @classmethod + # pylint: disable=unused-argument + def _from_job_rest_object(cls, obj: RestAmlToken) -> "AmlTokenConfiguration": + return cls() + + +# This class will be used to represent Identity property on compute, endpoint, and registry +class IdentityConfiguration(RestTranslatableMixin): + """Identity configuration used to represent identity property on compute, endpoint, and registry resources. + + :param type: The type of managed identity. + :type type: str + :param user_assigned_identities: A list of ManagedIdentityConfiguration objects. + :type user_assigned_identities: Optional[list[~azure.ai.ml.entities.ManagedIdentityConfiguration]] + """ + + def __init__( + self, + *, + type: str, + user_assigned_identities: Optional[List[ManagedIdentityConfiguration]] = None, + **kwargs: dict, + ) -> None: + self.type = type + self.user_assigned_identities = user_assigned_identities + self.principal_id = kwargs.pop("principal_id", None) + self.tenant_id = kwargs.pop("tenant_id", None) + + def _to_compute_rest_object(self) -> RestIdentityConfiguration: + rest_user_assigned_identities = ( + {uai.resource_id: uai._to_identity_configuration_rest_object() for uai in self.user_assigned_identities} + if self.user_assigned_identities + else None + ) + return RestIdentityConfiguration( + type=snake_to_pascal(self.type), user_assigned_identities=rest_user_assigned_identities + ) + + @classmethod + def _from_compute_rest_object(cls, obj: RestIdentityConfiguration) -> "IdentityConfiguration": + from_rest_user_assigned_identities = ( + [ + ManagedIdentityConfiguration._from_identity_configuration_rest_object(uai, resource_id=resource_id) + for (resource_id, uai) in obj.user_assigned_identities.items() + ] + if obj.user_assigned_identities + else None + ) + result = cls( + type=camel_to_snake(obj.type), + user_assigned_identities=from_rest_user_assigned_identities, + ) + result.principal_id = obj.principal_id + result.tenant_id = obj.tenant_id + return result + + def _to_online_endpoint_rest_object(self) -> RestManagedServiceIdentityConfiguration: + rest_user_assigned_identities = ( + {uai.resource_id: uai._to_online_endpoint_rest_object() for uai in self.user_assigned_identities} + if self.user_assigned_identities + else None + ) + + return RestManagedServiceIdentityConfiguration( + type=snake_to_pascal(self.type), + principal_id=self.principal_id, + tenant_id=self.tenant_id, + user_assigned_identities=rest_user_assigned_identities, + ) + + @classmethod + def _from_online_endpoint_rest_object(cls, obj: RestManagedServiceIdentityConfiguration) -> "IdentityConfiguration": + from_rest_user_assigned_identities = ( + [ + ManagedIdentityConfiguration._from_identity_configuration_rest_object(uai, resource_id=resource_id) + for (resource_id, uai) in obj.user_assigned_identities.items() + ] + if obj.user_assigned_identities + else None + ) + result = cls( + type=camel_to_snake(obj.type), + user_assigned_identities=from_rest_user_assigned_identities, + ) + result.principal_id = obj.principal_id + result.tenant_id = obj.tenant_id + return result + + @classmethod + def _from_workspace_rest_object(cls, obj: RestManagedServiceIdentityConfiguration) -> "IdentityConfiguration": + from_rest_user_assigned_identities = ( + [ + ManagedIdentityConfiguration._from_identity_configuration_rest_object(uai, resource_id=resource_id) + for (resource_id, uai) in obj.user_assigned_identities.items() + ] + if obj.user_assigned_identities + else None + ) + result = cls( + type=camel_to_snake(obj.type), + user_assigned_identities=from_rest_user_assigned_identities, + ) + result.principal_id = obj.principal_id + result.tenant_id = obj.tenant_id + return result + + def _to_workspace_rest_object(self) -> RestManagedServiceIdentityConfiguration: + rest_user_assigned_identities = ( + {uai.resource_id: uai._to_workspace_rest_object() for uai in self.user_assigned_identities} + if self.user_assigned_identities + else None + ) + return RestManagedServiceIdentityConfiguration( + type=snake_to_pascal(self.type), user_assigned_identities=rest_user_assigned_identities + ) + + def _to_rest_object(self) -> RestRegistryManagedIdentity: + return RestRegistryManagedIdentity( + type=self.type, + principal_id=self.principal_id, + tenant_id=self.tenant_id, + ) + + @classmethod + def _from_rest_object(cls, obj: RestRegistryManagedIdentity) -> "IdentityConfiguration": + result = cls( + type=obj.type, + user_assigned_identities=None, + ) + result.principal_id = obj.principal_id + result.tenant_id = obj.tenant_id + return result + + +class NoneCredentialConfiguration(RestTranslatableMixin): + """None Credential Configuration. In many uses cases, the presence of + this credential configuration indicates that the user's Entra ID will be + implicitly used instead of any other form of authentication.""" + + def __init__(self) -> None: + self.type = CredentialsType.NONE + + def _to_datastore_rest_object(self) -> RestNoneDatastoreCredentials: + return RestNoneDatastoreCredentials() + + @classmethod + # pylint: disable=unused-argument + def _from_datastore_rest_object(cls, obj: RestNoneDatastoreCredentials) -> "NoneCredentialConfiguration": + return cls() + + def _to_workspace_connection_rest_object(self) -> None: + return None + + def __eq__(self, other: object) -> bool: + if isinstance(other, NoneCredentialConfiguration): + return True + return False + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + @classmethod + def _get_rest_properties_class(cls) -> Type: + return NoneAuthTypeWorkspaceConnectionProperties + + +class AadCredentialConfiguration(RestTranslatableMixin): + """Azure Active Directory Credential Configuration""" + + def __init__(self) -> None: + self.type = camel_to_snake(ConnectionAuthType.AAD) + + def _to_datastore_rest_object(self) -> RestNoneDatastoreCredentials: + return RestNoneDatastoreCredentials() + + @classmethod + # pylint: disable=unused-argument + def _from_datastore_rest_object(cls, obj: RestNoneDatastoreCredentials) -> "AadCredentialConfiguration": + return cls() + + # Has no credential object, just a property bag class. + def _to_workspace_connection_rest_object(self) -> None: + return None + + def __eq__(self, other: object) -> bool: + if isinstance(other, AadCredentialConfiguration): + return True + return False + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + @classmethod + def _get_rest_properties_class(cls) -> Type: + return AADAuthTypeWorkspaceConnectionProperties + + +class AccessKeyConfiguration(RestTranslatableMixin, DictMixin): + """Access Key Credentials. + + :param access_key_id: The access key ID. + :type access_key_id: str + :param secret_access_key: The secret access key. + :type secret_access_key: str + """ + + def __init__( + self, + *, + access_key_id: Optional[str], + secret_access_key: Optional[str], + ) -> None: + super().__init__() + self.type = camel_to_snake(ConnectionAuthType.ACCESS_KEY) + self.access_key_id = access_key_id + self.secret_access_key = secret_access_key + + def _to_workspace_connection_rest_object(self) -> RestWorkspaceConnectionAccessKey: + return RestWorkspaceConnectionAccessKey( + access_key_id=self.access_key_id, secret_access_key=self.secret_access_key + ) + + @classmethod + def _from_workspace_connection_rest_object( + cls, obj: Optional[RestWorkspaceConnectionAccessKey] + ) -> "AccessKeyConfiguration": + return cls( + access_key_id=obj.access_key_id if obj is not None and obj.access_key_id else None, + secret_access_key=obj.secret_access_key if obj is not None and obj.secret_access_key else None, + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, AccessKeyConfiguration): + return NotImplemented + return self.access_key_id == other.access_key_id and self.secret_access_key == other.secret_access_key + + def _get_rest_properties_class(self): + return AccessKeyAuthTypeWorkspaceConnectionProperties + + +@experimental +class ApiKeyConfiguration(RestTranslatableMixin, DictMixin): + """Api Key Credentials. + + :param key: API key id + :type key: str + """ + + def __init__( + self, + *, + key: Optional[str], + ): + super().__init__() + self.type = camel_to_snake(ConnectionAuthType.API_KEY) + self.key = key + + def _to_workspace_connection_rest_object(self) -> RestWorkspaceConnectionApiKey: + return RestWorkspaceConnectionApiKey( + key=self.key, + ) + + @classmethod + def _from_workspace_connection_rest_object( + cls, obj: Optional[RestWorkspaceConnectionApiKey] + ) -> "ApiKeyConfiguration": + return cls( + key=obj.key if obj is not None and obj.key else None, + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ApiKeyConfiguration): + return NotImplemented + return bool(self.key == other.key) + + def _get_rest_properties_class(self): + return ApiKeyAuthWorkspaceConnectionProperties diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_data/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_data/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_data/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_data/mltable_metadata.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_data/mltable_metadata.py new file mode 100644 index 00000000..452b2e53 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_data/mltable_metadata.py @@ -0,0 +1,92 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from os import PathLike +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +from marshmallow import INCLUDE + +from azure.ai.ml._schema._data.mltable_metadata_schema import MLTableMetadataSchema +from azure.ai.ml._utils.utils import load_yaml +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY +from azure.ai.ml.entities._util import load_from_dict + + +class MLTableMetadataPath: + type: str # Literal["pattern", "file", "folder"] + value: Optional[str] + + def __init__(self, *, pathDict: Dict): + if pathDict.get("pattern", None): + self.type = "pattern" + self.value = pathDict.get("pattern") + if pathDict.get("file", None): + self.type = "file" + self.value = pathDict.get("file") + if pathDict.get("folder", None): + self.type = "folder" + self.value = pathDict.get("folder") + + +class MLTableMetadata: + """MLTableMetadata for data assets. + + :param paths: List of paths which the MLTableMetadata refers to. + :type paths: List[MLTableMetadataPath] + :param transformations: Any transformations to be applied to the data referenced in paths. + :type transformations: List[Any] + :param base_path: Base path to resolve relative paths from. + :type base_path: str + """ + + def __init__( + self, + *, + paths: List[MLTableMetadataPath], + transformations: Optional[List[Any]] = None, + base_path: str, + **_kwargs: Any, + ): + self.base_path = base_path + self.paths = paths + self.transformations = transformations + + @classmethod + def load( + cls, + yaml_path: Union[PathLike, str], + **kwargs: Any, + ) -> "MLTableMetadata": + """Construct an MLTable object from yaml file. + + :param yaml_path: Path to a local file as the source. + :type yaml_path: PathLike | str + + :return: Constructed MLTable object. + :rtype: MLTable + """ + yaml_dict = load_yaml(yaml_path) + return cls._load(yaml_data=yaml_dict, yaml_path=yaml_path, **kwargs) + + @classmethod + def _load( + cls, + yaml_data: Optional[Dict], + yaml_path: Optional[Union[PathLike, str]], + **kwargs: Any, + ) -> "MLTableMetadata": + yaml_data = yaml_data or {} + context = { + BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"), + } + res: MLTableMetadata = load_from_dict(MLTableMetadataSchema, yaml_data, context, "", unknown=INCLUDE, **kwargs) + return res + + def _to_dict(self) -> Dict: + res: dict = MLTableMetadataSchema(context={BASE_PATH_CONTEXT_KEY: "./"}, unknown=INCLUDE).dump(self) + return res + + def referenced_uris(self) -> List[Optional[str]]: + return [path.value for path in self.paths] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_data_import/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_data_import/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_data_import/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_data_import/data_import.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_data_import/data_import.py new file mode 100644 index 00000000..028d431c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_data_import/data_import.py @@ -0,0 +1,130 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from os import PathLike +from pathlib import Path +from typing import Any, Dict, Optional, Union + +from azure.ai.ml._restclient.v2023_06_01_preview.models import DatabaseSource as RestDatabaseSource +from azure.ai.ml._restclient.v2023_06_01_preview.models import DataImport as RestDataImport +from azure.ai.ml._restclient.v2023_06_01_preview.models import FileSystemSource as RestFileSystemSource +from azure.ai.ml._schema import DataImportSchema +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY, AssetTypes +from azure.ai.ml.data_transfer import Database, FileSystem +from azure.ai.ml.entities._assets import Data +from azure.ai.ml.entities._util import load_from_dict + + +@experimental +class DataImport(Data): + """Data asset with a creating data import job. + + :param name: Name of the asset. + :type name: str + :param path: The path to the asset being created by data import job. + :type path: str + :param source: The source of the asset data being copied from. + :type source: Union[Database, FileSystem] + :param version: Version of the resource. + :type version: str + :param description: Description of the resource. + :type description: str + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: dict[str, str] + :param properties: The asset property dictionary. + :type properties: dict[str, str] + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + """ + + def __init__( + self, + *, + name: str, + path: str, + source: Union[Database, FileSystem], + version: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[Dict] = None, + properties: Optional[Dict] = None, + **kwargs: Any, + ): + super().__init__( + name=name, + version=version, + description=description, + tags=tags, + properties=properties, + path=path, + **kwargs, + ) + self.source = source + + @classmethod + def _load( + cls, + data: Optional[Dict] = None, + yaml_path: Optional[Union[PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "DataImport": + data = data or {} + params_override = params_override or [] + context = { + BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"), + PARAMS_OVERRIDE_KEY: params_override, + } + res: DataImport = load_from_dict(DataImportSchema, data, context, **kwargs) + return res + + def _to_rest_object(self) -> RestDataImport: + if isinstance(self.source, Database): + source = RestDatabaseSource( + connection=self.source.connection, + query=self.source.query, + ) + else: + source = RestFileSystemSource( + connection=self.source.connection, + path=self.source.path, + ) + + return RestDataImport( + description=self.description, + properties=self.properties, + tags=self.tags, + data_type=self.type, + data_uri=self.path, + asset_name=self.name, + source=source, + ) + + @classmethod + def _from_rest_object(cls, data_rest_object: RestDataImport) -> "DataImport": + source: Any = None + if isinstance(data_rest_object.source, RestDatabaseSource): + source = Database( + connection=data_rest_object.source.connection, + query=data_rest_object.source.query, + ) + data_type = AssetTypes.MLTABLE + else: + source = FileSystem( + connection=data_rest_object.source.connection, + path=data_rest_object.source.path, + ) + data_type = AssetTypes.URI_FOLDER + + data_import = cls( + name=data_rest_object.asset_name, + path=data_rest_object.data_uri, + source=source, + description=data_rest_object.description, + tags=data_rest_object.tags, + properties=data_rest_object.properties, + type=data_type, + is_anonymous=data_rest_object.is_anonymous, + ) + return data_import diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_data_import/schedule.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_data_import/schedule.py new file mode 100644 index 00000000..6a51878a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_data_import/schedule.py @@ -0,0 +1,115 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint: disable=protected-access +from os import PathLike +from pathlib import Path +from typing import Any, Dict, Optional, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import ImportDataAction +from azure.ai.ml._restclient.v2023_04_01_preview.models import Schedule as RestSchedule +from azure.ai.ml._restclient.v2023_04_01_preview.models import ScheduleProperties +from azure.ai.ml._schema._data_import.schedule import ImportDataScheduleSchema +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY, ScheduleType +from azure.ai.ml.entities._data_import.data_import import DataImport +from azure.ai.ml.entities._schedule.schedule import Schedule +from azure.ai.ml.entities._schedule.trigger import CronTrigger, RecurrenceTrigger, TriggerBase +from azure.ai.ml.entities._system_data import SystemData +from azure.ai.ml.entities._util import load_from_dict + + +@experimental +class ImportDataSchedule(Schedule): + """ImportDataSchedule object. + + :param name: Name of the schedule. + :type name: str + :param trigger: Trigger of the schedule. + :type trigger: Union[CronTrigger, RecurrenceTrigger] + :param import_data: The schedule action data import definition. + :type import_data: DataImport + :param display_name: Display name of the schedule. + :type display_name: str + :param description: Description of the schedule, defaults to None + :type description: str + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: dict[str, str] + :param properties: The data import property dictionary. + :type properties: dict[str, str] + """ + + def __init__( + self, + *, + name: str, + trigger: Optional[Union[CronTrigger, RecurrenceTrigger]], + import_data: DataImport, + display_name: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[Dict] = None, + properties: Optional[Dict] = None, + **kwargs: Any, + ): + super().__init__( + name=name, + trigger=trigger, + display_name=display_name, + description=description, + tags=tags, + properties=properties, + **kwargs, + ) + self.import_data = import_data + self._type = ScheduleType.DATA_IMPORT + + @classmethod + def _load( + cls, + data: Optional[Dict] = None, + yaml_path: Optional[Union[PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "ImportDataSchedule": + data = data or {} + params_override = params_override or [] + context = { + BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"), + PARAMS_OVERRIDE_KEY: params_override, + } + return ImportDataSchedule( + base_path=context[BASE_PATH_CONTEXT_KEY], + **load_from_dict(ImportDataScheduleSchema, data, context, **kwargs), + ) + + @classmethod + def _create_schema_for_validation(cls, context: Any) -> ImportDataScheduleSchema: + return ImportDataScheduleSchema(context=context) + + @classmethod + def _from_rest_object(cls, obj: RestSchedule) -> "ImportDataSchedule": + return cls( + trigger=TriggerBase._from_rest_object(obj.properties.trigger), + import_data=DataImport._from_rest_object(obj.properties.action.data_import_definition), + name=obj.name, + display_name=obj.properties.display_name, + description=obj.properties.description, + tags=obj.properties.tags, + properties=obj.properties.properties, + provisioning_state=obj.properties.provisioning_state, + is_enabled=obj.properties.is_enabled, + creation_context=SystemData._from_rest_object(obj.system_data), + ) + + def _to_rest_object(self) -> RestSchedule: + return RestSchedule( + properties=ScheduleProperties( + description=self.description, + properties=self.properties, + tags=self.tags, + action=ImportDataAction(data_import_definition=self.import_data._to_rest_object()), + display_name=self.display_name, + is_enabled=self._is_enabled, + trigger=self.trigger._to_rest_object() if self.trigger is not None else None, + ) + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/_constants.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/_constants.py new file mode 100644 index 00000000..97a257ab --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/_constants.py @@ -0,0 +1,8 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# Miscellaneous +HTTPS = "https" +HTTP = "http" +WORKSPACE_BLOB_STORE = "workspaceblobstore" diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/_on_prem.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/_on_prem.py new file mode 100644 index 00000000..e6c0dc3f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/_on_prem.py @@ -0,0 +1,121 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from base64 import b64encode +from pathlib import Path +from typing import Any, Dict, Optional, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import Datastore as DatastoreData +from azure.ai.ml._restclient.v2023_04_01_preview.models import DatastoreType +from azure.ai.ml._restclient.v2023_04_01_preview.models import HdfsDatastore as RestHdfsDatastore +from azure.ai.ml._schema._datastore._on_prem import HdfsSchema +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE +from azure.ai.ml.entities._datastore.datastore import Datastore +from azure.ai.ml.entities._datastore.utils import _from_rest_datastore_credentials_preview +from azure.ai.ml.entities._util import load_from_dict + +from ._constants import HTTP +from ._on_prem_credentials import KerberosKeytabCredentials, KerberosPasswordCredentials + + +@experimental +class HdfsDatastore(Datastore): + """HDFS datastore that is linked to an Azure ML workspace. + + :param name: Name of the datastore. + :type name: str + :param name_node_address: IP Address or DNS HostName. + :type name_node_address: str + :param hdfs_server_certificate: The TLS cert of the HDFS server (optional). + Needs to be a local path on create and will be a base64 encoded string on get. + :type hdfs_server_certificate: str + :param protocol: http or https + :type protocol: str + :param description: Description of the resource. + :type description: str + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: dict[str, str] + :param properties: The asset property dictionary. + :type properties: dict[str, str] + :param credentials: Credentials to use for Azure ML workspace to connect to the storage. + :type credentials: Union[KerberosKeytabCredentials, KerberosPasswordCredentials] + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + """ + + def __init__( + self, + *, + name: str, + name_node_address: str, + hdfs_server_certificate: Optional[str] = None, + protocol: str = HTTP, + description: Optional[str] = None, + tags: Optional[Dict] = None, + properties: Optional[Dict] = None, + credentials: Optional[Union[KerberosKeytabCredentials, KerberosPasswordCredentials]], + **kwargs: Any + ): + kwargs[TYPE] = DatastoreType.HDFS + super().__init__( + name=name, description=description, tags=tags, properties=properties, credentials=credentials, **kwargs + ) + + self.hdfs_server_certificate = hdfs_server_certificate + self.name_node_address = name_node_address + self.protocol = protocol + + def _to_rest_object(self) -> DatastoreData: + use_this_cert = None + if self.hdfs_server_certificate: + with open(self.hdfs_server_certificate, "rb") as f: + use_this_cert = b64encode(f.read()).decode("utf-8") + hdfs_ds = RestHdfsDatastore( + credentials=self.credentials._to_rest_object(), + hdfs_server_certificate=use_this_cert, + name_node_address=self.name_node_address, + protocol=self.protocol, + description=self.description, + tags=self.tags, + ) + return DatastoreData(properties=hdfs_ds) + + @classmethod + def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "HdfsDatastore": + res: HdfsDatastore = load_from_dict(HdfsSchema, data, context, additional_message) + return res + + @classmethod + def _from_rest_object(cls, datastore_resource: DatastoreData) -> "HdfsDatastore": + properties: RestHdfsDatastore = datastore_resource.properties + return HdfsDatastore( + name=datastore_resource.name, + id=datastore_resource.id, + credentials=_from_rest_datastore_credentials_preview(properties.credentials), + hdfs_server_certificate=properties.hdfs_server_certificate, + name_node_address=properties.name_node_address, + protocol=properties.protocol, + description=properties.description, + tags=properties.tags, + ) + + def __eq__(self, other: Any) -> bool: + res: bool = ( + super().__eq__(other) + and self.hdfs_server_certificate == other.hdfs_server_certificate + and self.name_node_address == other.name_node_address + and self.protocol == other.protocol + ) + return res + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) + + def _to_dict(self) -> Dict: + context = {BASE_PATH_CONTEXT_KEY: Path(".").parent} + res: dict = HdfsSchema(context=context).dump(self) + return res diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/_on_prem_credentials.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/_on_prem_credentials.py new file mode 100644 index 00000000..b658851a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/_on_prem_credentials.py @@ -0,0 +1,128 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from base64 import b64encode +from typing import Any, Optional + +from azure.ai.ml._restclient.v2023_04_01_preview import models as model_preview +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.entities._credentials import NoneCredentialConfiguration + + +# TODO: Move classes in this file to azure.ai.ml.entities._credentials +@experimental +class BaseKerberosCredentials(NoneCredentialConfiguration): + def __init__(self, kerberos_realm: str, kerberos_kdc_address: str, kerberos_principal: str): + super().__init__() + self.kerberos_realm = kerberos_realm + self.kerberos_kdc_address = kerberos_kdc_address + self.kerberos_principal = kerberos_principal + + +@experimental +class KerberosKeytabCredentials(BaseKerberosCredentials): + def __init__( + self, + *, + kerberos_realm: str, + kerberos_kdc_address: str, + kerberos_principal: str, + kerberos_keytab: Optional[str], + **kwargs: Any, + ): + super().__init__( + kerberos_realm=kerberos_realm, + kerberos_kdc_address=kerberos_kdc_address, + kerberos_principal=kerberos_principal, + **kwargs, + ) + self.type = model_preview.CredentialsType.KERBEROS_KEYTAB + self.kerberos_keytab = kerberos_keytab + + def _to_rest_object(self) -> model_preview.KerberosKeytabCredentials: + use_this_keytab = None + if self.kerberos_keytab: + with open(self.kerberos_keytab, "rb") as f: + use_this_keytab = b64encode(f.read()).decode("utf-8") + secrets = model_preview.KerberosKeytabSecrets(kerberos_keytab=use_this_keytab) + return model_preview.KerberosKeytabCredentials( + kerberos_kdc_address=self.kerberos_kdc_address, + kerberos_principal=self.kerberos_principal, + kerberos_realm=self.kerberos_realm, + secrets=secrets, + ) + + @classmethod + def _from_rest_object(cls, obj: model_preview.KerberosKeytabCredentials) -> "KerberosKeytabCredentials": + return cls( + kerberos_kdc_address=obj.kerberos_kdc_address, + kerberos_principal=obj.kerberos_principal, + kerberos_realm=obj.kerberos_realm, + kerberos_keytab=obj.secrets.kerberos_keytab if obj.secrets else None, + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, KerberosKeytabCredentials): + return NotImplemented + return ( + self.kerberos_kdc_address == other.kerberos_kdc_address + and self.kerberos_principal == other.kerberos_principal + and self.kerberos_realm == other.kerberos_realm + and self.kerberos_keytab == other.kerberos_keytab + ) + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + +@experimental +class KerberosPasswordCredentials(BaseKerberosCredentials): + def __init__( + self, + *, + kerberos_realm: str, + kerberos_kdc_address: str, + kerberos_principal: str, + kerberos_password: Optional[str], + **kwargs: Any, + ): + super().__init__( + kerberos_realm=kerberos_realm, + kerberos_kdc_address=kerberos_kdc_address, + kerberos_principal=kerberos_principal, + **kwargs, + ) + self.type = model_preview.CredentialsType.KERBEROS_PASSWORD + self.kerberos_password = kerberos_password + + def _to_rest_object(self) -> model_preview.KerberosPasswordCredentials: + secrets = model_preview.KerberosPasswordSecrets(kerberos_password=self.kerberos_password) + return model_preview.KerberosPasswordCredentials( + kerberos_kdc_address=self.kerberos_kdc_address, + kerberos_principal=self.kerberos_principal, + kerberos_realm=self.kerberos_realm, + secrets=secrets, + ) + + @classmethod + def _from_rest_object(cls, obj: model_preview.KerberosPasswordCredentials) -> "KerberosPasswordCredentials": + return cls( + kerberos_kdc_address=obj.kerberos_kdc_address, + kerberos_principal=obj.kerberos_principal, + kerberos_realm=obj.kerberos_realm, + kerberos_password=obj.secrets.kerberos_password if obj.secrets else None, + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, KerberosPasswordCredentials): + return NotImplemented + return ( + self.kerberos_kdc_address == other.kerberos_kdc_address + and self.kerberos_principal == other.kerberos_principal + and self.kerberos_realm == other.kerberos_realm + and self.kerberos_password == other.kerberos_password + ) + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/adls_gen1.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/adls_gen1.py new file mode 100644 index 00000000..c2610703 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/adls_gen1.py @@ -0,0 +1,106 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from pathlib import Path +from typing import Any, Dict, Optional, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + AzureDataLakeGen1Datastore as RestAzureDatalakeGen1Datastore, +) +from azure.ai.ml._restclient.v2023_04_01_preview.models import Datastore as DatastoreData +from azure.ai.ml._restclient.v2023_04_01_preview.models import DatastoreType +from azure.ai.ml._schema._datastore.adls_gen1 import AzureDataLakeGen1Schema +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE +from azure.ai.ml.entities._credentials import CertificateConfiguration, ServicePrincipalConfiguration +from azure.ai.ml.entities._datastore.datastore import Datastore +from azure.ai.ml.entities._datastore.utils import from_rest_datastore_credentials +from azure.ai.ml.entities._util import load_from_dict + + +class AzureDataLakeGen1Datastore(Datastore): + """Azure Data Lake aka Gen 1 datastore that is linked to an Azure ML workspace. + + :param name: Name of the datastore. + :type name: str + :param store_name: Name of the Azure storage resource. + :type store_name: str + :param description: Description of the resource. + :type description: str + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: dict[str, str] + :param properties: The asset property dictionary. + :type properties: dict[str, str] + :param credentials: Credentials to use for Azure ML workspace to connect to the storage. + :type credentials: Union[ServicePrincipalSection, CertificateSection] + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + """ + + def __init__( + self, + *, + name: str, + store_name: str, + description: Optional[str] = None, + tags: Optional[Dict] = None, + properties: Optional[Dict] = None, + credentials: Optional[Union[CertificateConfiguration, ServicePrincipalConfiguration]] = None, + **kwargs: Any + ): + kwargs[TYPE] = DatastoreType.AZURE_DATA_LAKE_GEN1 + super().__init__( + name=name, description=description, tags=tags, properties=properties, credentials=credentials, **kwargs + ) + + self.store_name = store_name + + def _to_rest_object(self) -> DatastoreData: + gen1_ds = RestAzureDatalakeGen1Datastore( + credentials=self.credentials._to_datastore_rest_object(), + store_name=self.store_name, + description=self.description, + tags=self.tags, + ) + return DatastoreData(properties=gen1_ds) + + @classmethod + def _load_from_dict( + cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any + ) -> "AzureDataLakeGen1Datastore": + res: AzureDataLakeGen1Datastore = load_from_dict( + AzureDataLakeGen1Schema, data, context, additional_message, **kwargs + ) + return res + + @classmethod + def _from_rest_object(cls, datastore_resource: DatastoreData) -> "AzureDataLakeGen1Datastore": + properties: RestAzureDatalakeGen1Datastore = datastore_resource.properties + return AzureDataLakeGen1Datastore( + id=datastore_resource.id, + name=datastore_resource.name, + store_name=properties.store_name, + credentials=from_rest_datastore_credentials(properties.credentials), # type: ignore[arg-type] + description=properties.description, + tags=properties.tags, + ) + + def __eq__(self, other: Any) -> bool: + res: bool = ( + super().__eq__(other) + and self.name == other.name + and self.type == other.type + and self.store_name == other.store_name + and self.credentials == other.credentials + ) + return res + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) + + def _to_dict(self) -> Dict: + context = {BASE_PATH_CONTEXT_KEY: Path(".").parent} + res: dict = AzureDataLakeGen1Schema(context=context).dump(self) + return res diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/azure_storage.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/azure_storage.py new file mode 100644 index 00000000..0fff1925 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/azure_storage.py @@ -0,0 +1,337 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from pathlib import Path +from typing import Any, Dict, Optional, Union + +from azure.ai.ml._azure_environments import _get_storage_endpoint_from_metadata +from azure.ai.ml._restclient.v2023_04_01_preview.models import AzureBlobDatastore as RestAzureBlobDatastore +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + AzureDataLakeGen2Datastore as RestAzureDataLakeGen2Datastore, +) +from azure.ai.ml._restclient.v2023_04_01_preview.models import AzureFileDatastore as RestAzureFileDatastore +from azure.ai.ml._restclient.v2023_04_01_preview.models import Datastore as DatastoreData +from azure.ai.ml._restclient.v2023_04_01_preview.models import DatastoreType +from azure.ai.ml._schema._datastore import AzureBlobSchema, AzureDataLakeGen2Schema, AzureFileSchema +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE +from azure.ai.ml.entities._credentials import ( + AccountKeyConfiguration, + CertificateConfiguration, + SasTokenConfiguration, + ServicePrincipalConfiguration, +) +from azure.ai.ml.entities._datastore.datastore import Datastore +from azure.ai.ml.entities._datastore.utils import from_rest_datastore_credentials +from azure.ai.ml.entities._util import load_from_dict + +from ._constants import HTTPS + + +class AzureFileDatastore(Datastore): + """Azure file share that is linked to an Azure ML workspace. + + :param name: Name of the datastore. + :type name: str + :param account_name: Name of the Azure storage account. + :type account_name: str + :param file_share_name: Name of the file share. + :type file_share_name: str + :param description: Description of the resource. + :type description: str + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: dict[str, str] + :param endpoint: Endpoint to use to connect with the Azure storage account + :type endpoint: str + :param protocol: Protocol to use to connect with the Azure storage account + :type protocol: str + :param properties: The asset property dictionary. + :type properties: dict[str, str] + :param credentials: Credentials to use for Azure ML workspace to connect to the storage. Defaults to None. + :type credentials: Union[~azure.ai.ml.entities.AccountKeyConfiguration, + ~azure.ai.ml.entities.SasTokenConfiguration] + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + """ + + def __init__( + self, + *, + name: str, + account_name: str, + file_share_name: str, + description: Optional[str] = None, + tags: Optional[Dict] = None, + endpoint: str = _get_storage_endpoint_from_metadata(), + protocol: str = HTTPS, + properties: Optional[Dict] = None, + credentials: Optional[Union[AccountKeyConfiguration, SasTokenConfiguration]] = None, + **kwargs: Any + ): + kwargs[TYPE] = DatastoreType.AZURE_FILE + super().__init__( + name=name, description=description, tags=tags, properties=properties, credentials=credentials, **kwargs + ) + self.file_share_name = file_share_name + self.account_name = account_name + self.endpoint = endpoint + self.protocol = protocol + + def _to_rest_object(self) -> DatastoreData: + file_ds = RestAzureFileDatastore( + account_name=self.account_name, + file_share_name=self.file_share_name, + credentials=self.credentials._to_datastore_rest_object(), + endpoint=self.endpoint, + protocol=self.protocol, + description=self.description, + tags=self.tags, + ) + return DatastoreData(properties=file_ds) + + @classmethod + def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "AzureFileDatastore": + res: AzureFileDatastore = load_from_dict(AzureFileSchema, data, context, additional_message) + return res + + @classmethod + def _from_rest_object(cls, datastore_resource: DatastoreData) -> "AzureFileDatastore": + properties: RestAzureFileDatastore = datastore_resource.properties + return AzureFileDatastore( + name=datastore_resource.name, + id=datastore_resource.id, + account_name=properties.account_name, + credentials=from_rest_datastore_credentials(properties.credentials), # type: ignore[arg-type] + endpoint=properties.endpoint, + protocol=properties.protocol, + file_share_name=properties.file_share_name, + description=properties.description, + tags=properties.tags, + ) + + def __eq__(self, other: Any) -> bool: + res: bool = ( + super().__eq__(other) + and self.file_share_name == other.file_share_name + and self.account_name == other.account_name + and self.endpoint == other.endpoint + and self.protocol == other.protocol + ) + return res + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) + + def _to_dict(self) -> Dict: + context = {BASE_PATH_CONTEXT_KEY: Path(".").parent} + res: dict = AzureFileSchema(context=context).dump(self) + return res + + +class AzureBlobDatastore(Datastore): + """Azure blob storage that is linked to an Azure ML workspace. + + :param name: Name of the datastore. + :type name: str + :param account_name: Name of the Azure storage account. + :type account_name: str + :param container_name: Name of the container. + :type container_name: str + :param description: Description of the resource. + :type description: str + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: dict[str, str] + :param endpoint: Endpoint to use to connect with the Azure storage account. + :type endpoint: str + :param protocol: Protocol to use to connect with the Azure storage account. + :type protocol: str + :param properties: The asset property dictionary. + :type properties: dict[str, str] + :param credentials: Credentials to use for Azure ML workspace to connect to the storage. + :type credentials: Union[~azure.ai.ml.entities.AccountKeyConfiguration, + ~azure.ai.ml.entities.SasTokenConfiguration] + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + """ + + def __init__( + self, + *, + name: str, + account_name: str, + container_name: str, + description: Optional[str] = None, + tags: Optional[Dict] = None, + endpoint: Optional[str] = None, + protocol: str = HTTPS, + properties: Optional[Dict] = None, + credentials: Optional[Union[AccountKeyConfiguration, SasTokenConfiguration]] = None, + **kwargs: Any + ): + kwargs[TYPE] = DatastoreType.AZURE_BLOB + super().__init__( + name=name, description=description, tags=tags, properties=properties, credentials=credentials, **kwargs + ) + + self.container_name = container_name + self.account_name = account_name + self.endpoint = endpoint if endpoint else _get_storage_endpoint_from_metadata() + self.protocol = protocol + + def _to_rest_object(self) -> DatastoreData: + blob_ds = RestAzureBlobDatastore( + account_name=self.account_name, + container_name=self.container_name, + credentials=self.credentials._to_datastore_rest_object(), + endpoint=self.endpoint, + protocol=self.protocol, + tags=self.tags, + description=self.description, + ) + return DatastoreData(properties=blob_ds) + + @classmethod + def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "AzureBlobDatastore": + res: AzureBlobDatastore = load_from_dict(AzureBlobSchema, data, context, additional_message) + return res + + @classmethod + def _from_rest_object(cls, datastore_resource: DatastoreData) -> "AzureBlobDatastore": + properties: RestAzureBlobDatastore = datastore_resource.properties + return AzureBlobDatastore( + name=datastore_resource.name, + id=datastore_resource.id, + account_name=properties.account_name, + credentials=from_rest_datastore_credentials(properties.credentials), # type: ignore[arg-type] + endpoint=properties.endpoint, + protocol=properties.protocol, + container_name=properties.container_name, + description=properties.description, + tags=properties.tags, + ) + + def __eq__(self, other: Any) -> bool: + res: bool = ( + super().__eq__(other) + and self.container_name == other.container_name + and self.account_name == other.account_name + and self.endpoint == other.endpoint + and self.protocol == other.protocol + ) + return res + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) + + def _to_dict(self) -> Dict: + context = {BASE_PATH_CONTEXT_KEY: Path(".").parent} + res: dict = AzureBlobSchema(context=context).dump(self) + return res + + +class AzureDataLakeGen2Datastore(Datastore): + """Azure data lake gen 2 that is linked to an Azure ML workspace. + + :param name: Name of the datastore. + :type name: str + :param account_name: Name of the Azure storage account. + :type account_name: str + :param filesystem: The name of the Data Lake Gen2 filesystem. + :type filesystem: str + :param description: Description of the resource. + :type description: str + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: dict[str, str] + :param endpoint: Endpoint to use to connect with the Azure storage account + :type endpoint: str + :param protocol: Protocol to use to connect with the Azure storage account + :type protocol: str + :param credentials: Credentials to use for Azure ML workspace to connect to the storage. + :type credentials: Union[ + ~azure.ai.ml.entities.ServicePrincipalConfiguration, + ~azure.ai.ml.entities.CertificateConfiguration + + ] + :param properties: The asset property dictionary. + :type properties: dict[str, str] + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + """ + + def __init__( + self, + *, + name: str, + account_name: str, + filesystem: str, + description: Optional[str] = None, + tags: Optional[Dict] = None, + endpoint: str = _get_storage_endpoint_from_metadata(), + protocol: str = HTTPS, + properties: Optional[Dict] = None, + credentials: Optional[Union[ServicePrincipalConfiguration, CertificateConfiguration]] = None, + **kwargs: Any + ): + kwargs[TYPE] = DatastoreType.AZURE_DATA_LAKE_GEN2 + super().__init__( + name=name, description=description, tags=tags, properties=properties, credentials=credentials, **kwargs + ) + + self.account_name = account_name + self.filesystem = filesystem + self.endpoint = endpoint + self.protocol = protocol + + def _to_rest_object(self) -> DatastoreData: + gen2_ds = RestAzureDataLakeGen2Datastore( + account_name=self.account_name, + filesystem=self.filesystem, + credentials=self.credentials._to_datastore_rest_object(), + endpoint=self.endpoint, + protocol=self.protocol, + description=self.description, + tags=self.tags, + ) + return DatastoreData(properties=gen2_ds) + + @classmethod + def _load_from_dict( + cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any + ) -> "AzureDataLakeGen2Datastore": + res: AzureDataLakeGen2Datastore = load_from_dict(AzureDataLakeGen2Schema, data, context, additional_message) + return res + + @classmethod + def _from_rest_object(cls, datastore_resource: DatastoreData) -> "AzureDataLakeGen2Datastore": + properties: RestAzureDataLakeGen2Datastore = datastore_resource.properties + return AzureDataLakeGen2Datastore( + name=datastore_resource.name, + id=datastore_resource.id, + account_name=properties.account_name, + credentials=from_rest_datastore_credentials(properties.credentials), # type: ignore[arg-type] + endpoint=properties.endpoint, + protocol=properties.protocol, + filesystem=properties.filesystem, + description=properties.description, + tags=properties.tags, + ) + + def __eq__(self, other: Any) -> bool: + res: bool = ( + super().__eq__(other) + and self.filesystem == other.filesystem + and self.account_name == other.account_name + and self.endpoint == other.endpoint + and self.protocol == other.protocol + ) + return res + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) + + def _to_dict(self) -> Dict: + context = {BASE_PATH_CONTEXT_KEY: Path(".").parent} + res: dict = AzureDataLakeGen2Schema(context=context).dump(self) + return res diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/datastore.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/datastore.py new file mode 100644 index 00000000..bc933cfb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/datastore.py @@ -0,0 +1,221 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access,redefined-builtin,arguments-renamed + +from abc import ABC, abstractmethod +from os import PathLike +from pathlib import Path +from typing import IO, Any, AnyStr, Dict, Optional, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import Datastore as DatastoreData +from azure.ai.ml._restclient.v2023_04_01_preview.models import DatastoreType +from azure.ai.ml._utils.utils import camel_to_snake, dump_yaml_to_file +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY, CommonYamlFields +from azure.ai.ml.entities._credentials import ( + AccountKeyConfiguration, + CertificateConfiguration, + NoneCredentialConfiguration, + SasTokenConfiguration, + ServicePrincipalConfiguration, +) +from azure.ai.ml.entities._mixins import RestTranslatableMixin +from azure.ai.ml.entities._resource import Resource +from azure.ai.ml.entities._util import find_type_in_override +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + + +class Datastore(Resource, RestTranslatableMixin, ABC): + """Datastore of an Azure ML workspace, abstract class. + + :param name: Name of the datastore. + :type name: str + :param description: Description of the resource. + :type description: str + :param credentials: Credentials to use for Azure ML workspace to connect to the storage. + :type credentials: Optional[Union[ + ~azure.ai.ml.entities.ServicePrincipalConfiguration, + ~azure.ai.ml.entities.CertificateConfiguration, + ~azure.ai.ml.entities.NoneCredentialConfiguration, + ~azure.ai.ml.entities.AccountKeyConfiguration, + ~azure.ai.ml.entities.SasTokenConfiguration + + ]] + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: dict[str, str] + :param properties: The asset property dictionary. + :type properties: dict[str, str] + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + """ + + def __init__( + self, + credentials: Optional[ + Union[ + ServicePrincipalConfiguration, + CertificateConfiguration, + NoneCredentialConfiguration, + AccountKeyConfiguration, + SasTokenConfiguration, + ] + ], + name: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[Dict] = None, + properties: Optional[Dict] = None, + **kwargs: Any, + ): + self._type: str = kwargs.pop("type", None) + super().__init__( + name=name, + description=description, + tags=tags, + properties=properties, + **kwargs, + ) + + self.credentials = NoneCredentialConfiguration() if credentials is None else credentials + + @property + def type(self) -> str: + return self._type + + def dump(self, dest: Union[str, PathLike, IO[AnyStr]], **kwargs: Any) -> None: + """Dump the datastore content into a file in yaml format. + + :param dest: The destination to receive this datastore's content. + Must be either a path to a local file, or an already-open file stream. + If dest is a file path, a new file will be created, + and an exception is raised if the file exists. + If dest is an open file, the file will be written to directly, + and an exception will be raised if the file is not writable. + :type dest: Union[PathLike, str, IO[AnyStr]] + """ + yaml_serialized = self._to_dict() + dump_yaml_to_file(dest, yaml_serialized, default_flow_style=False, **kwargs) + + @abstractmethod + def _to_dict(self) -> Dict: + pass + + @classmethod + def _load( + cls, + data: Optional[Dict] = None, + yaml_path: Optional[Union[PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "Datastore": + data = data or {} + params_override = params_override or [] + + context = { + BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"), + PARAMS_OVERRIDE_KEY: params_override, + } + + from azure.ai.ml.entities import ( + AzureBlobDatastore, + AzureDataLakeGen1Datastore, + AzureDataLakeGen2Datastore, + AzureFileDatastore, + OneLakeDatastore, + ) + + # from azure.ai.ml.entities._datastore._on_prem import ( + # HdfsDatastore + # ) + + ds_type: Any = None + type_in_override = find_type_in_override(params_override) + type = type_in_override or data.get( + CommonYamlFields.TYPE, DatastoreType.AZURE_BLOB + ) # override takes the priority + + # yaml expects snake casing, while service side constants are camel casing + if type == camel_to_snake(DatastoreType.AZURE_BLOB): + ds_type = AzureBlobDatastore + elif type == camel_to_snake(DatastoreType.AZURE_FILE): + ds_type = AzureFileDatastore + elif type == camel_to_snake(DatastoreType.AZURE_DATA_LAKE_GEN1): + ds_type = AzureDataLakeGen1Datastore + elif type == camel_to_snake(DatastoreType.AZURE_DATA_LAKE_GEN2): + ds_type = AzureDataLakeGen2Datastore + elif type == camel_to_snake(DatastoreType.ONE_LAKE): + ds_type = OneLakeDatastore + # disable unless preview release + # elif type == camel_to_snake(DatastoreTypePreview.HDFS): + # ds_type = HdfsDatastore + else: + msg = f"Unsupported datastore type: {type}." + raise ValidationException( + message=msg, + error_type=ValidationErrorType.INVALID_VALUE, + target=ErrorTarget.DATASTORE, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + ) + + res: Datastore = ds_type._load_from_dict( + data=data, + context=context, + additional_message="If the datastore type is incorrect, change the 'type' property.", + **kwargs, + ) + return res + + @classmethod + def _from_rest_object(cls, datastore_resource: DatastoreData) -> "Datastore": + from azure.ai.ml.entities import ( + AzureBlobDatastore, + AzureDataLakeGen1Datastore, + AzureDataLakeGen2Datastore, + AzureFileDatastore, + OneLakeDatastore, + ) + + # from azure.ai.ml.entities._datastore._on_prem import ( + # HdfsDatastore + # ) + + datastore_type = datastore_resource.properties.datastore_type + if datastore_type == DatastoreType.AZURE_DATA_LAKE_GEN1: + res_adl_gen1: Datastore = AzureDataLakeGen1Datastore._from_rest_object(datastore_resource) + return res_adl_gen1 + if datastore_type == DatastoreType.AZURE_DATA_LAKE_GEN2: + res_adl_gen2: Datastore = AzureDataLakeGen2Datastore._from_rest_object(datastore_resource) + return res_adl_gen2 + if datastore_type == DatastoreType.AZURE_BLOB: + res_abd: Datastore = AzureBlobDatastore._from_rest_object(datastore_resource) + return res_abd + if datastore_type == DatastoreType.AZURE_FILE: + res_afd: Datastore = AzureFileDatastore._from_rest_object(datastore_resource) + return res_afd + if datastore_type == DatastoreType.ONE_LAKE: + res_old: Datastore = OneLakeDatastore._from_rest_object(datastore_resource) + return res_old + # disable unless preview release + # elif datastore_type == DatastoreTypePreview.HDFS: + # return HdfsDatastore._from_rest_object(datastore_resource) + msg = f"Unsupported datastore type {datastore_resource.properties.contents.type}" + raise ValidationException( + message=msg, + error_type=ValidationErrorType.INVALID_VALUE, + target=ErrorTarget.DATASTORE, + no_personal_data_message=msg, + error_category=ErrorCategory.SYSTEM_ERROR, + ) + + @classmethod + @abstractmethod + def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "Datastore": + pass + + def __eq__(self, other: Any) -> bool: + res: bool = self.name == other.name and self.type == other.type and self.credentials == other.credentials + return res + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/one_lake.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/one_lake.py new file mode 100644 index 00000000..9bc06d92 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/one_lake.py @@ -0,0 +1,153 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from abc import ABC +from pathlib import Path +from typing import Any, Dict, Optional, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import Datastore as DatastoreData +from azure.ai.ml._restclient.v2023_04_01_preview.models import DatastoreType +from azure.ai.ml._restclient.v2023_04_01_preview.models import LakeHouseArtifact as RestLakeHouseArtifact +from azure.ai.ml._restclient.v2023_04_01_preview.models import NoneDatastoreCredentials as RestNoneDatastoreCredentials +from azure.ai.ml._restclient.v2023_04_01_preview.models import OneLakeDatastore as RestOneLakeDatastore +from azure.ai.ml._schema._datastore.one_lake import OneLakeSchema +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE +from azure.ai.ml.entities._credentials import NoneCredentialConfiguration, ServicePrincipalConfiguration +from azure.ai.ml.entities._datastore.datastore import Datastore +from azure.ai.ml.entities._datastore.utils import from_rest_datastore_credentials +from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin +from azure.ai.ml.entities._util import load_from_dict + + +@experimental +class OneLakeArtifact(RestTranslatableMixin, DictMixin, ABC): + """OneLake artifact (data source) backing the OneLake workspace. + + :param name: OneLake artifact name/GUID. ex) 01234567-abcd-1234-5678-012345678901 + :type name: str + :param type: OneLake artifact type. Only LakeHouse artifacts are currently supported. + :type type: str + """ + + def __init__(self, *, name: str, type: Optional[str] = None): + super().__init__() + self.name = name + self.type = type + + +@experimental +class LakeHouseArtifact(OneLakeArtifact): + """LakeHouse artifact type for OneLake. + + :param artifact_name: OneLake LakeHouse artifact name/GUID. ex) 01234567-abcd-1234-5678-012345678901 + :type artifact_name: str + """ + + def __init__(self, *, name: str): + super(LakeHouseArtifact, self).__init__(name=name, type="lake_house") + + def _to_datastore_rest_object(self) -> RestLakeHouseArtifact: + return RestLakeHouseArtifact(artifact_name=self.name) + + +@experimental +class OneLakeDatastore(Datastore): + """OneLake datastore that is linked to an Azure ML workspace. + + :param name: Name of the datastore. + :type name: str + :param artifact: OneLake Artifact. Only LakeHouse artifacts are currently supported. + :type artifact: ~azure.ai.ml.entities.OneLakeArtifact + :param one_lake_workspace_name: OneLake workspace name/GUID. ex) 01234567-abcd-1234-5678-012345678901 + :type one_lake_workspace_name: str + :param endpoint: OneLake endpoint to use for the datastore. ex) https://onelake.dfs.fabric.microsoft.com + :type endpoint: str + :param description: Description of the resource. + :type description: str + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: dict[str, str] + :param properties: The asset property dictionary. + :type properties: dict[str, str] + :param credentials: Credentials to use to authenticate against OneLake. + :type credentials: Union[ + ~azure.ai.ml.entities.ServicePrincipalConfiguration, ~azure.ai.ml.entities.NoneCredentialConfiguration] + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + """ + + def __init__( + self, + *, + name: str, + artifact: OneLakeArtifact, + one_lake_workspace_name: str, + endpoint: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[Dict] = None, + properties: Optional[Dict] = None, + credentials: Optional[Union[NoneCredentialConfiguration, ServicePrincipalConfiguration]] = None, + **kwargs: Any + ): + kwargs[TYPE] = DatastoreType.ONE_LAKE + super().__init__( + name=name, description=description, tags=tags, properties=properties, credentials=credentials, **kwargs + ) + self.artifact = artifact + self.one_lake_workspace_name = one_lake_workspace_name + self.endpoint = endpoint + + def _to_rest_object(self) -> DatastoreData: + one_lake_ds = RestOneLakeDatastore( + credentials=( + RestNoneDatastoreCredentials() + if self.credentials is None + else self.credentials._to_datastore_rest_object() + ), + artifact=RestLakeHouseArtifact(artifact_name=self.artifact["name"]), + one_lake_workspace_name=self.one_lake_workspace_name, + endpoint=self.endpoint, + description=self.description, + tags=self.tags, + ) + return DatastoreData(properties=one_lake_ds) + + @classmethod + def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "OneLakeDatastore": + res: OneLakeDatastore = load_from_dict(OneLakeSchema, data, context, additional_message, **kwargs) + return res + + @classmethod + def _from_rest_object(cls, datastore_resource: DatastoreData) -> "OneLakeDatastore": + properties: RestOneLakeDatastore = datastore_resource.properties + return OneLakeDatastore( + name=datastore_resource.name, + id=datastore_resource.id, + artifact=LakeHouseArtifact(name=properties.artifact.artifact_name), + one_lake_workspace_name=properties.one_lake_workspace_name, + endpoint=properties.endpoint, + credentials=from_rest_datastore_credentials(properties.credentials), # type: ignore[arg-type] + description=properties.description, + tags=properties.tags, + ) + + def __eq__(self, other: Any) -> bool: + res: bool = ( + super().__eq__(other) + and self.one_lake_workspace_name == other.one_lake_workspace_name + and self.artifact.type == other.artifact["type"] + and self.artifact.name == other.artifact["name"] + and self.endpoint == other.endpoint + ) + return res + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) + + def _to_dict(self) -> Dict: + context = {BASE_PATH_CONTEXT_KEY: Path(".").parent} + res: dict = OneLakeSchema(context=context).dump(self) + return res diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/utils.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/utils.py new file mode 100644 index 00000000..538f9590 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/utils.py @@ -0,0 +1,70 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from typing import Any, Optional, Union, cast + +from azure.ai.ml._restclient.v2023_04_01_preview import models +from azure.ai.ml._restclient.v2024_07_01_preview import models as models2024 +from azure.ai.ml.entities._credentials import ( + AccountKeyConfiguration, + CertificateConfiguration, + NoneCredentialConfiguration, + SasTokenConfiguration, + ServicePrincipalConfiguration, +) +from azure.ai.ml.entities._datastore._on_prem_credentials import KerberosKeytabCredentials, KerberosPasswordCredentials + + +def from_rest_datastore_credentials( + rest_credentials: models.DatastoreCredentials, +) -> Union[ + AccountKeyConfiguration, + SasTokenConfiguration, + ServicePrincipalConfiguration, + CertificateConfiguration, + NoneCredentialConfiguration, +]: + config_class: Any = NoneCredentialConfiguration + + if isinstance(rest_credentials, (models.AccountKeyDatastoreCredentials, models2024.AccountKeyDatastoreCredentials)): + # we are no more using key for key base account. + # https://github.com/Azure/azure-sdk-for-python/pull/35716 + if isinstance(rest_credentials.secrets, models2024.SasDatastoreSecrets): + config_class = SasTokenConfiguration + else: + config_class = AccountKeyConfiguration + elif isinstance(rest_credentials, (models.SasDatastoreCredentials, models2024.SasDatastoreCredentials)): + config_class = SasTokenConfiguration + elif isinstance( + rest_credentials, (models.ServicePrincipalDatastoreCredentials, models2024.ServicePrincipalDatastoreCredentials) + ): + config_class = ServicePrincipalConfiguration + elif isinstance( + rest_credentials, (models.CertificateDatastoreCredentials, models2024.CertificateDatastoreCredentials) + ): + config_class = CertificateConfiguration + + return cast( + Union[ + AccountKeyConfiguration, + SasTokenConfiguration, + ServicePrincipalConfiguration, + CertificateConfiguration, + NoneCredentialConfiguration, + ], + config_class._from_datastore_rest_object(rest_credentials), + ) + + +def _from_rest_datastore_credentials_preview( + rest_credentials: models.DatastoreCredentials, +) -> Optional[Union[KerberosKeytabCredentials, KerberosPasswordCredentials]]: + if isinstance(rest_credentials, models.KerberosKeytabCredentials): + return KerberosKeytabCredentials._from_rest_object(rest_credentials) + if isinstance(rest_credentials, models.KerberosPasswordCredentials): + return KerberosPasswordCredentials._from_rest_object(rest_credentials) + + return None diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/batch_deployment.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/batch_deployment.py new file mode 100644 index 00000000..59b23eb8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/batch_deployment.py @@ -0,0 +1,356 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import logging +from os import PathLike +from pathlib import Path +from typing import Any, Dict, Optional, Union + +from azure.ai.ml._restclient.v2024_01_01_preview.models import BatchDeployment as BatchDeploymentData +from azure.ai.ml._restclient.v2024_01_01_preview.models import BatchDeploymentProperties as RestBatchDeployment +from azure.ai.ml._restclient.v2024_01_01_preview.models import BatchOutputAction +from azure.ai.ml._restclient.v2024_01_01_preview.models import CodeConfiguration as RestCodeConfiguration +from azure.ai.ml._restclient.v2024_01_01_preview.models import IdAssetReference +from azure.ai.ml._schema._deployment.batch.batch_deployment import BatchDeploymentSchema +from azure.ai.ml._utils._arm_id_utils import _parse_endpoint_name_from_deployment_id +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY +from azure.ai.ml.constants._deployment import BatchDeploymentOutputAction +from azure.ai.ml.entities._assets import Environment, Model +from azure.ai.ml.entities._deployment.deployment_settings import BatchRetrySettings +from azure.ai.ml.entities._job.resource_configuration import ResourceConfiguration +from azure.ai.ml.entities._system_data import SystemData +from azure.ai.ml.entities._util import load_from_dict +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + +from .code_configuration import CodeConfiguration +from .deployment import Deployment + +module_logger = logging.getLogger(__name__) + + +class BatchDeployment(Deployment): + """Batch endpoint deployment entity. + + :param name: the name of the batch deployment + :type name: str + :param description: Description of the resource. + :type description: str + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: dict[str, str] + :param properties: The asset property dictionary. + :type properties: dict[str, str] + :param model: Model entity for the endpoint deployment, defaults to None + :type model: Union[str, Model] + :param code_configuration: defaults to None + :type code_configuration: CodeConfiguration + :param environment: Environment entity for the endpoint deployment., defaults to None + :type environment: Union[str, Environment] + :param compute: Compute target for batch inference operation. + :type compute: str + :param output_action: Indicates how the output will be organized. Possible values include: + "summary_only", "append_row". Defaults to "append_row" + :type output_action: str or ~azure.ai.ml.constants._deployment.BatchDeploymentOutputAction + :param output_file_name: Customized output file name for append_row output action, defaults to "predictions.csv" + :type output_file_name: str + :param max_concurrency_per_instance: Indicates maximum number of parallelism per instance, defaults to 1 + :type max_concurrency_per_instance: int + :param error_threshold: Error threshold, if the error count for the entire input goes above + this value, + the batch inference will be aborted. Range is [-1, int.MaxValue] + -1 value indicates, ignore all failures during batch inference + For FileDataset count of file failures + For TabularDataset, this is the count of record failures, defaults to -1 + :type error_threshold: int + :param retry_settings: Retry settings for a batch inference operation, defaults to None + :type retry_settings: BatchRetrySettings + :param resources: Indicates compute configuration for the job. + :type resources: ~azure.mgmt.machinelearningservices.models.ResourceConfiguration + :param logging_level: Logging level for batch inference operation, defaults to "info" + :type logging_level: str + :param mini_batch_size: Size of the mini-batch passed to each batch invocation, defaults to 10 + :type mini_batch_size: int + :param environment_variables: Environment variables that will be set in deployment. + :type environment_variables: dict + :param code_path: Folder path to local code assets. Equivalent to code_configuration.code. + :type code_path: Union[str, PathLike] + :param scoring_script: Scoring script name. Equivalent to code_configuration.code.scoring_script. + :type scoring_script: Union[str, PathLike] + :param instance_count: Number of instances the interfering will run on. Equivalent to resources.instance_count. + :type instance_count: int + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if BatchDeployment cannot be successfully validated. + Details will be provided in the error message. + """ + + def __init__( + self, + *, + name: str, + endpoint_name: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[Dict[str, Any]] = None, + properties: Optional[Dict[str, str]] = None, + model: Optional[Union[str, Model]] = None, + code_configuration: Optional[CodeConfiguration] = None, + environment: Optional[Union[str, Environment]] = None, + compute: Optional[str] = None, + resources: Optional[ResourceConfiguration] = None, + output_file_name: Optional[str] = None, + output_action: Optional[Union[BatchDeploymentOutputAction, str]] = None, + error_threshold: Optional[int] = None, + retry_settings: Optional[BatchRetrySettings] = None, + logging_level: Optional[str] = None, + mini_batch_size: Optional[int] = None, + max_concurrency_per_instance: Optional[int] = None, + environment_variables: Optional[Dict[str, str]] = None, + code_path: Optional[Union[str, PathLike]] = None, # promoted property from code_configuration.code + scoring_script: Optional[ + Union[str, PathLike] + ] = None, # promoted property from code_configuration.scoring_script + instance_count: Optional[int] = None, # promoted property from resources.instance_count + **kwargs: Any, + ) -> None: + self._provisioning_state: Optional[str] = kwargs.pop("provisioning_state", None) + + super(BatchDeployment, self).__init__( + name=name, + endpoint_name=endpoint_name, + properties=properties, + tags=tags, + description=description, + model=model, + code_configuration=code_configuration, + environment=environment, + environment_variables=environment_variables, + code_path=code_path, + scoring_script=scoring_script, + **kwargs, + ) + + self.compute = compute + self.resources = resources + self.output_action = output_action + self.output_file_name = output_file_name + self.error_threshold = error_threshold + self.retry_settings = retry_settings + self.logging_level = logging_level + self.mini_batch_size = mini_batch_size + self.max_concurrency_per_instance = max_concurrency_per_instance + + if self.resources and instance_count: + msg = "Can't set instance_count when resources is provided." + raise ValidationException( + message=msg, + target=ErrorTarget.BATCH_DEPLOYMENT, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + if not self.resources and instance_count: + self.resources = ResourceConfiguration(instance_count=instance_count) + + @property + def instance_count(self) -> Optional[int]: + return self.resources.instance_count if self.resources else None + + @instance_count.setter + def instance_count(self, value: int) -> None: + if not self.resources: + self.resources = ResourceConfiguration() + + self.resources.instance_count = value + + @property + def provisioning_state(self) -> Optional[str]: + """Batch deployment provisioning state, readonly. + + :return: Batch deployment provisioning state. + :rtype: Optional[str] + """ + return self._provisioning_state + + def _to_dict(self) -> Dict: + res: dict = BatchDeploymentSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + @classmethod + def _rest_output_action_to_yaml_output_action(cls, rest_output_action: str) -> str: + output_switcher = { + BatchOutputAction.APPEND_ROW: BatchDeploymentOutputAction.APPEND_ROW, + BatchOutputAction.SUMMARY_ONLY: BatchDeploymentOutputAction.SUMMARY_ONLY, + } + + return output_switcher.get(rest_output_action, rest_output_action) + + @classmethod + def _yaml_output_action_to_rest_output_action(cls, yaml_output_action: Any) -> str: + output_switcher = { + BatchDeploymentOutputAction.APPEND_ROW: BatchOutputAction.APPEND_ROW, + BatchDeploymentOutputAction.SUMMARY_ONLY: BatchOutputAction.SUMMARY_ONLY, + } + + return output_switcher.get(yaml_output_action, yaml_output_action) + + # pylint: disable=arguments-differ + def _to_rest_object(self, location: str) -> BatchDeploymentData: # type: ignore + self._validate() + code_config = ( + RestCodeConfiguration( + code_id=self.code_configuration.code, + scoring_script=self.code_configuration.scoring_script, + ) + if self.code_configuration + else None + ) + model = IdAssetReference(asset_id=self.model) if self.model else None + environment = self.environment + + batch_deployment: RestBatchDeployment = None + if isinstance(self.output_action, str): + batch_deployment = RestBatchDeployment( + compute=self.compute, + description=self.description, + resources=self.resources._to_rest_object() if self.resources else None, + code_configuration=code_config, + environment_id=environment, + model=model, + output_file_name=self.output_file_name, + output_action=BatchDeployment._yaml_output_action_to_rest_output_action(self.output_action), + error_threshold=self.error_threshold, + retry_settings=self.retry_settings._to_rest_object() if self.retry_settings else None, + logging_level=self.logging_level, + mini_batch_size=self.mini_batch_size, + max_concurrency_per_instance=self.max_concurrency_per_instance, + environment_variables=self.environment_variables, + properties=self.properties, + ) + else: + batch_deployment = RestBatchDeployment( + compute=self.compute, + description=self.description, + resources=self.resources._to_rest_object() if self.resources else None, + code_configuration=code_config, + environment_id=environment, + model=model, + output_file_name=self.output_file_name, + output_action=None, + error_threshold=self.error_threshold, + retry_settings=self.retry_settings._to_rest_object() if self.retry_settings else None, + logging_level=self.logging_level, + mini_batch_size=self.mini_batch_size, + max_concurrency_per_instance=self.max_concurrency_per_instance, + environment_variables=self.environment_variables, + properties=self.properties, + ) + + return BatchDeploymentData(location=location, properties=batch_deployment, tags=self.tags) + + @classmethod + def _from_rest_object( # pylint: disable=arguments-renamed + cls, deployment: BatchDeploymentData + ) -> BatchDeploymentData: + modelId = deployment.properties.model.asset_id if deployment.properties.model else None + + if ( + hasattr(deployment.properties, "deployment_configuration") + and deployment.properties.deployment_configuration is not None + ): + settings = deployment.properties.deployment_configuration.settings + deployment_comp_settings = { + "deployment_configuration_type": deployment.properties.deployment_configuration.deployment_configuration_type, # pylint: disable=line-too-long + "componentDeployment.Settings.continue_on_step_failure": settings.get( + "ComponentDeployment.Settings.continue_on_step_failure", None + ), + "default_datastore": settings.get("default_datastore", None), + "default_compute": settings.get("default_compute", None), + } + properties = {} + if deployment.properties.properties: + properties.update(deployment.properties.properties) + properties.update(deployment_comp_settings) + else: + properties = deployment.properties.properties + + code_configuration = ( + CodeConfiguration._from_rest_code_configuration(deployment.properties.code_configuration) + if deployment.properties.code_configuration + else None + ) + deployment = BatchDeployment( + name=deployment.name, + description=deployment.properties.description, + id=deployment.id, + tags=deployment.tags, + model=modelId, + environment=deployment.properties.environment_id, + code_configuration=code_configuration, + output_file_name=( + deployment.properties.output_file_name + if cls._rest_output_action_to_yaml_output_action(deployment.properties.output_action) + == BatchDeploymentOutputAction.APPEND_ROW + else None + ), + output_action=cls._rest_output_action_to_yaml_output_action(deployment.properties.output_action), + error_threshold=deployment.properties.error_threshold, + retry_settings=BatchRetrySettings._from_rest_object(deployment.properties.retry_settings), + logging_level=deployment.properties.logging_level, + mini_batch_size=deployment.properties.mini_batch_size, + compute=deployment.properties.compute, + resources=ResourceConfiguration._from_rest_object(deployment.properties.resources), + environment_variables=deployment.properties.environment_variables, + max_concurrency_per_instance=deployment.properties.max_concurrency_per_instance, + endpoint_name=_parse_endpoint_name_from_deployment_id(deployment.id), + properties=properties, + creation_context=SystemData._from_rest_object(deployment.system_data), + provisioning_state=deployment.properties.provisioning_state, + ) + + return deployment + + @classmethod + def _load( + cls, + data: Optional[Dict] = None, + yaml_path: Optional[Union[PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "BatchDeployment": + data = data or {} + params_override = params_override or [] + cls._update_params(params_override) + + context = { + BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path.cwd(), + PARAMS_OVERRIDE_KEY: params_override, + } + res: BatchDeployment = load_from_dict(BatchDeploymentSchema, data, context, **kwargs) + return res + + def _validate(self) -> None: + self._validate_output_action() + + @classmethod + def _update_params(cls, params_override: Any) -> None: + for param in params_override: + endpoint_name = param.get("endpoint_name") + if isinstance(endpoint_name, str): + param["endpoint_name"] = endpoint_name.lower() + + def _validate_output_action(self) -> None: + if ( + self.output_action + and self.output_action == BatchDeploymentOutputAction.SUMMARY_ONLY + and self.output_file_name + ): + msg = "When output_action is set to {}, the output_file_name need not to be specified." + msg = msg.format(BatchDeploymentOutputAction.SUMMARY_ONLY) + raise ValidationException( + message=msg, + target=ErrorTarget.BATCH_DEPLOYMENT, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/batch_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/batch_job.py new file mode 100644 index 00000000..c078f479 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/batch_job.py @@ -0,0 +1,38 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Any, Dict + +from azure.ai.ml._restclient.v2020_09_01_dataplanepreview.models import BatchJobResource + + +class BatchJob(object): + """Batch jobs that are created with batch deployments/endpoints invocation. + + This class shouldn't be instantiated directly. Instead, it is used as the return type of batch deployment/endpoint + invocation and job listing. + """ + + def __init__(self, **kwargs: Any): + self.id = kwargs.get("id", None) + self.name = kwargs.get("name", None) + self.type = kwargs.get("type", None) + self.status = kwargs.get("status", None) + + def _to_dict(self) -> Dict: + return { + "id": self.id, + "name": self.name, + "type": self.type, + "status": self.status, + } + + @classmethod + def _from_rest_object(cls, obj: BatchJobResource) -> "BatchJob": + return cls( + id=obj.id, + name=obj.name, + type=obj.type, + status=obj.properties.status, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/code_configuration.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/code_configuration.py new file mode 100644 index 00000000..cbae647d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/code_configuration.py @@ -0,0 +1,93 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import logging +import os +from typing import Optional, Union + +from azure.ai.ml._restclient.v2022_05_01.models import CodeConfiguration as RestCodeConfiguration +from azure.ai.ml.entities._assets import Code +from azure.ai.ml.entities._mixins import DictMixin +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + +module_logger = logging.getLogger(__name__) + + +class CodeConfiguration(DictMixin): + """Code configuration for a scoring job. + + :param code: The code directory containing the scoring script. The code can be an Code object, an ARM resource ID + of an existing code asset, a local path, or "http:", "https:", or "azureml:" url pointing to a remote location. + :type code: Optional[Union[~azure.ai.ml.entities.Code, str]] + :param scoring_script: The scoring script file path relative to the code directory. + :type scoring_script: Optional[str] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START code_configuration] + :end-before: [END code_configuration] + :language: python + :dedent: 8 + :caption: Creating a CodeConfiguration for a BatchDeployment. + """ + + def __init__( + self, + code: Optional[Union[str, os.PathLike]] = None, + scoring_script: Optional[Union[str, os.PathLike]] = None, + ) -> None: + self.code: Optional[Union[str, os.PathLike]] = code + self._scoring_script: Optional[Union[str, os.PathLike]] = scoring_script + + @property + def scoring_script(self) -> Optional[Union[str, os.PathLike]]: + """The scoring script file path relative to the code directory. + + :rtype: str + """ + return self._scoring_script + + def _to_rest_code_configuration(self) -> RestCodeConfiguration: + return RestCodeConfiguration(code_id=self.code, scoring_script=self.scoring_script) + + def _validate(self) -> None: + if self.code and not self.scoring_script: + msg = "scoring script can't be empty" + raise ValidationException( + message=msg, + target=ErrorTarget.CODE, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.MISSING_FIELD, + ) + + @staticmethod + def _from_rest_code_configuration(code_configuration: RestCodeConfiguration) -> Optional["CodeConfiguration"]: + if code_configuration: + return CodeConfiguration( + code=code_configuration.code_id, + scoring_script=code_configuration.scoring_script, + ) + return None + + def __eq__(self, other: object) -> bool: + if not isinstance(other, CodeConfiguration): + return NotImplemented + if not other: + return False + # only compare mutable fields + return ( + self.scoring_script == other.scoring_script + and ( + isinstance(self.code, Code) + and isinstance(other.code, Code) + or isinstance(self.code, str) + and isinstance(other.code, str) + ) + and self.code == other.code + ) + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/container_resource_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/container_resource_settings.py new file mode 100644 index 00000000..0d0bc15d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/container_resource_settings.py @@ -0,0 +1,74 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=arguments-renamed + +import logging +from typing import Optional + +from azure.ai.ml._restclient.v2022_05_01.models import ContainerResourceSettings +from azure.ai.ml.entities._mixins import RestTranslatableMixin + +module_logger = logging.getLogger(__name__) + + +class ResourceSettings(RestTranslatableMixin): + """Resource settings for a container. + + This class uses Kubernetes Resource unit formats. For more information, see + https://kubernetes.io/docs/concepts/configuration/manage-resources-containers/. + + :param cpu: The CPU resource settings for a container. + :type cpu: Optional[str] + :param memory: The memory resource settings for a container. + :type memory: Optional[str] + :param gpu: The GPU resource settings for a container. + :type gpu: Optional[str] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START resource_requirements_configuration] + :end-before: [END resource_requirements_configuration] + :language: python + :dedent: 8 + :caption: Configuring ResourceSettings for a Kubernetes deployment. + """ + + def __init__(self, cpu: Optional[str] = None, memory: Optional[str] = None, gpu: Optional[str] = None) -> None: + self.cpu = cpu + self.memory = memory + self.gpu = gpu + + def _to_rest_object(self) -> ContainerResourceSettings: + return ContainerResourceSettings(cpu=self.cpu, memory=self.memory, gpu=self.gpu) + + @classmethod + def _from_rest_object(cls, settings: ContainerResourceSettings) -> Optional["ResourceSettings"]: + return ( + ResourceSettings( + cpu=settings.cpu, + memory=settings.memory, + gpu=settings.gpu, + ) + if settings + else None + ) + + def _merge_with(self, other: Optional["ResourceSettings"]) -> None: + if other: + self.cpu = other.cpu or self.cpu + self.memory = other.memory or self.memory + self.gpu = other.gpu or self.gpu + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ResourceSettings): + return NotImplemented + if not other: + return False + # only compare mutable fields + return self.cpu == other.cpu and self.memory == other.memory and self.gpu == other.gpu + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/data_asset.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/data_asset.py new file mode 100644 index 00000000..72d24131 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/data_asset.py @@ -0,0 +1,38 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Dict, Optional + +from azure.ai.ml._schema._deployment.online.data_asset_schema import DataAssetSchema +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY + + +@experimental +class DataAsset: + """Data Asset entity + + :keyword Optional[str] data_id: Arm id of registered data asset + :keyword Optional[str] name: Name of data asset + :keyword Optional[str] path: Path where the data asset is stored. + :keyword Optional[int] version: Version of data asset. + """ + + def __init__( + self, + *, + data_id: Optional[str] = None, + name: Optional[str] = None, + path: Optional[str] = None, + version: Optional[int] = None, + ): + self.data_id = data_id + self.name = name + self.path = path + self.version = version + + def _to_dict(self) -> Dict: + # pylint: disable=no-member + res: dict = DataAssetSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/data_collector.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/data_collector.py new file mode 100644 index 00000000..74277c61 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/data_collector.py @@ -0,0 +1,84 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint: disable=protected-access + +from typing import Any, Dict, Optional + +from azure.ai.ml._restclient.v2023_04_01_preview.models import DataCollector as RestDataCollector +from azure.ai.ml._schema._deployment.online.data_collector_schema import DataCollectorSchema +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY +from azure.ai.ml.entities._deployment.deployment_collection import DeploymentCollection +from azure.ai.ml.entities._deployment.request_logging import RequestLogging + + +@experimental +class DataCollector: + """Data Capture deployment entity. + + :param collections: Mapping dictionary of strings mapped to DeploymentCollection entities. + :type collections: Mapping[str, DeploymentCollection] + :param rolling_rate: The rolling rate of mdc files, possible values: ["minute", "hour", "day"]. + :type rolling_rate: str + :param sampling_rate: The sampling rate of mdc files, possible values: [0.0, 1.0]. + :type sampling_rate: float + :param request_logging: Logging of request payload parameters. + :type request_logging: RequestLogging + """ + + def __init__( + self, + collections: Dict[str, DeploymentCollection], + *, + rolling_rate: Optional[str] = None, + sampling_rate: Optional[float] = None, + request_logging: Optional[RequestLogging] = None, + **kwargs: Any, + ): # pylint: disable=unused-argument + self.collections = collections + self.rolling_rate = rolling_rate + self.sampling_rate = sampling_rate + self.request_logging = request_logging + + if self.sampling_rate: + for collection in self.collections.values(): + collection.sampling_rate = self.sampling_rate + + def _to_dict(self) -> Dict: + # pylint: disable=no-member + res: dict = DataCollectorSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + @classmethod + def _from_rest_object(cls, rest_obj: RestDataCollector) -> "DataCollector": + collections = {} + sampling_rate = None + for k, v in rest_obj.collections.items(): + sampling_rate = v.sampling_rate + collections[k] = DeploymentCollection._from_rest_object(v) + delattr(collections[k], "sampling_rate") + + return DataCollector( + collections=collections, + rolling_rate=rest_obj.rolling_rate, + request_logging=( + RequestLogging._from_rest_object(rest_obj.request_logging) if rest_obj.request_logging else None + ), + sampling_rate=sampling_rate, + ) + + def _to_rest_object(self) -> RestDataCollector: + rest_collections: dict = {} + for collection in self.collections.values(): + collection.sampling_rate = self.sampling_rate + delattr(self, "sampling_rate") + if self.request_logging: + self.request_logging = self.request_logging._to_rest_object() + if self.collections: + rest_collections = {} + for k, v in self.collections.items(): + rest_collections[k] = v._to_rest_object() + return RestDataCollector( + collections=rest_collections, rolling_rate=self.rolling_rate, request_logging=self.request_logging + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/deployment.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/deployment.py new file mode 100644 index 00000000..2f857cfa --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/deployment.py @@ -0,0 +1,213 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access,arguments-renamed + +import logging +from abc import abstractmethod +from os import PathLike +from typing import IO, TYPE_CHECKING, Any, AnyStr, Dict, Optional, Union + +from azure.ai.ml._restclient.v2022_02_01_preview.models import BatchDeploymentData +from azure.ai.ml._restclient.v2022_05_01.models import OnlineDeploymentData +from azure.ai.ml._utils.utils import dump_yaml_to_file +from azure.ai.ml.entities._mixins import RestTranslatableMixin +from azure.ai.ml.entities._resource import Resource +from azure.ai.ml.exceptions import ( + DeploymentException, + ErrorCategory, + ErrorTarget, + ValidationErrorType, + ValidationException, +) + +from .code_configuration import CodeConfiguration + +# avoid circular import error +if TYPE_CHECKING: + from azure.ai.ml.entities._assets._artifacts.model import Model + from azure.ai.ml.entities._assets.environment import Environment + +module_logger = logging.getLogger(__name__) + + +class Deployment(Resource, RestTranslatableMixin): + """Endpoint Deployment base class. + + :param name: Name of the deployment resource, defaults to None + :type name: typing.Optional[str] + :param endpoint_name: Name of the Endpoint resource, defaults to None + :type endpoint_name: typing.Optional[str] + :param description: Description of the deployment resource, defaults to None + :type description: typing.Optional[str] + :param tags: Tag dictionary. Tags can be added, removed, and updated, defaults to None + :type tags: typing.Optional[typing.Dict[str, typing.Any]] + :param properties: The asset property dictionary, defaults to None + :type properties: typing.Optional[typing.Dict[str, typing.Any]] + :param model: The Model entity, defaults to None + :type model: typing.Optional[typing.Union[str, ~azure.ai.ml.entities.Model]] + :param code_configuration: Code Configuration, defaults to None + :type code_configuration: typing.Optional[CodeConfiguration] + :param environment: The Environment entity, defaults to None + :type environment: typing.Optional[typing.Union[str, ~azure.ai.ml.entities.Environment]] + :param environment_variables: Environment variables that will be set in deployment, defaults to None + :type environment_variables: typing.Optional[typing.Dict[str, str]] + :param code_path: Folder path to local code assets. Equivalent to code_configuration.code.path + , defaults to None + :type code_path: typing.Optional[typing.Union[str, PathLike]] + :param scoring_script: Scoring script name. Equivalent to code_configuration.code.scoring_script + , defaults to None + :type scoring_script: typing.Optional[typing.Union[str, PathLike]] + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Deployment cannot be successfully validated. + Exception details will be provided in the error message. + """ + + def __init__( + self, + name: Optional[str] = None, + *, + endpoint_name: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[Dict[str, Any]] = None, + properties: Optional[Dict[str, Any]] = None, + model: Optional[Union[str, "Model"]] = None, + code_configuration: Optional[CodeConfiguration] = None, + environment: Optional[Union[str, "Environment"]] = None, + environment_variables: Optional[Dict[str, str]] = None, + code_path: Optional[Union[str, PathLike]] = None, + scoring_script: Optional[Union[str, PathLike]] = None, + **kwargs: Any, + ): + # MFE is case-insensitive for Name. So convert the name into lower case here. + name = name.lower() if name else None + self.endpoint_name = endpoint_name + self._type: Optional[str] = kwargs.pop("type", None) + + if code_configuration and (code_path or scoring_script): + msg = "code_path and scoring_script are not allowed if code_configuration is provided." + raise ValidationException( + message=msg, + target=ErrorTarget.DEPLOYMENT, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + super().__init__(name, description, tags, properties, **kwargs) + + self.model = model + self.code_configuration = code_configuration + if not self.code_configuration and (code_path or scoring_script): + self.code_configuration = CodeConfiguration(code=code_path, scoring_script=scoring_script) + + self.environment = environment + self.environment_variables = dict(environment_variables) if environment_variables else {} + + @property + def type(self) -> Optional[str]: + """ + Type of deployment. + + :rtype: str + """ + return self._type + + @property + def code_path(self) -> Optional[Union[str, PathLike]]: + """ + The code directory containing the scoring script. + + :rtype: Union[str, PathLike] + """ + return self.code_configuration.code if self.code_configuration and self.code_configuration.code else None + + @code_path.setter + def code_path(self, value: Union[str, PathLike]) -> None: + if not self.code_configuration: + self.code_configuration = CodeConfiguration() + + self.code_configuration.code = value + + @property + def scoring_script(self) -> Optional[Union[str, PathLike]]: + """ + The scoring script file path relative to the code directory. + + :rtype: Union[str, PathLike] + """ + return self.code_configuration.scoring_script if self.code_configuration else None + + @scoring_script.setter + def scoring_script(self, value: Union[str, PathLike]) -> None: + if not self.code_configuration: + self.code_configuration = CodeConfiguration() + + self.code_configuration.scoring_script = value # type: ignore[misc] + + def dump(self, dest: Union[str, PathLike, IO[AnyStr]], **kwargs: Any) -> None: + """Dump the deployment content into a file in yaml format. + + :param dest: The destination to receive this deployment's content. + Must be either a path to a local file, or an already-open file stream. + If dest is a file path, a new file will be created, + and an exception is raised if the file exists. + If dest is an open file, the file will be written to directly, + and an exception will be raised if the file is not writable. + :type dest: typing.Union[os.PathLike, str, typing.IO[typing.AnyStr]] + """ + path = kwargs.pop("path", None) + yaml_serialized = self._to_dict() + dump_yaml_to_file(dest, yaml_serialized, default_flow_style=False, path=path, **kwargs) + + @abstractmethod + def _to_dict(self) -> Dict: + pass + + @classmethod + def _from_rest_object( + cls, deployment_rest_object: Union[OnlineDeploymentData, BatchDeploymentData] + ) -> Union[OnlineDeploymentData, BatchDeploymentData]: + from azure.ai.ml.entities._deployment.batch_deployment import BatchDeployment + from azure.ai.ml.entities._deployment.online_deployment import OnlineDeployment + + if isinstance(deployment_rest_object, OnlineDeploymentData): + return OnlineDeployment._from_rest_object(deployment_rest_object) + if isinstance(deployment_rest_object, BatchDeploymentData): + return BatchDeployment._from_rest_object(deployment_rest_object) + + msg = f"Unsupported deployment type {type(deployment_rest_object)}" + raise DeploymentException( + message=msg, + target=ErrorTarget.DEPLOYMENT, + no_personal_data_message=msg, + error_category=ErrorCategory.SYSTEM_ERROR, + ) + + def _to_rest_object(self) -> Any: + pass + + def _merge_with(self, other: "Deployment") -> None: + if other: + if self.name != other.name: + msg = "The deployment name: {} and {} are not matched when merging." + raise ValidationException( + message=msg.format(self.name, other.name), + target=ErrorTarget.DEPLOYMENT, + no_personal_data_message=msg.format("[name1]", "[name2]"), + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + if other.tags: + self.tags: dict = {**self.tags, **other.tags} + if other.properties: + self.properties: dict = {**self.properties, **other.properties} + if other.environment_variables: + self.environment_variables = { + **self.environment_variables, + **other.environment_variables, + } + self.code_configuration = other.code_configuration or self.code_configuration + self.model = other.model or self.model + self.environment = other.environment or self.environment + self.endpoint_name = other.endpoint_name or self.endpoint_name diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/deployment_collection.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/deployment_collection.py new file mode 100644 index 00000000..c1b1c750 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/deployment_collection.py @@ -0,0 +1,62 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Any, Dict, Optional, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import Collection as RestCollection +from azure.ai.ml._schema._deployment.online.deployment_collection_schema import DeploymentCollectionSchema +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY +from .data_asset import DataAsset + + +@experimental +class DeploymentCollection: + """Collection entity + + :param enabled: Is logging for this collection enabled. Possible values include: 'true', 'false'. + :type enabled: str + :param data: Data asset id associated with collection logging. + :type data: str + :param client_id: Client ID associated with collection logging. + :type client_id: str + + """ + + def __init__( + self, + *, + enabled: Optional[str] = None, + data: Optional[Union[str, DataAsset]] = None, + client_id: Optional[str] = None, + **kwargs: Any + ): + self.enabled = enabled # maps to data_collection_mode + self.data = data # maps to data_id + self.sampling_rate = kwargs.get( + "sampling_rate", None + ) # maps to sampling_rate, but it has to be passed from the data_collector root + self.client_id = client_id + + def _to_dict(self) -> Dict: + # pylint: disable=no-member + res: dict = DeploymentCollectionSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + @classmethod + def _from_rest_object(cls, rest_obj: RestCollection) -> "DeploymentCollection": + return DeploymentCollection( + enabled="true" if rest_obj.data_collection_mode == "Enabled" else "false", + sampling_rate=rest_obj.sampling_rate, + data=rest_obj.data_id, + client_id=rest_obj.client_id, + ) + + def _to_rest_object(self) -> RestCollection: + return RestCollection( + data_collection_mode="enabled" if str(self.enabled).lower() == "true" else "disabled", + sampling_rate=self.sampling_rate, + data_id=self.data, + client_id=self.client_id, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/deployment_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/deployment_settings.py new file mode 100644 index 00000000..0dbfc8fc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/deployment_settings.py @@ -0,0 +1,200 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=arguments-renamed + +import logging +from typing import Optional + +from azure.ai.ml._restclient.v2022_05_01.models import BatchRetrySettings as RestBatchRetrySettings +from azure.ai.ml._restclient.v2022_05_01.models import OnlineRequestSettings as RestOnlineRequestSettings +from azure.ai.ml._restclient.v2022_05_01.models import ProbeSettings as RestProbeSettings +from azure.ai.ml._utils.utils import ( + from_iso_duration_format, + from_iso_duration_format_ms, + to_iso_duration_format, + to_iso_duration_format_ms, +) +from azure.ai.ml.entities._mixins import RestTranslatableMixin + +module_logger = logging.getLogger(__name__) + + +class BatchRetrySettings(RestTranslatableMixin): + """Retry settings for batch deployment. + + :param max_retries: Number of retries in failure, defaults to 3 + :type max_retries: int + :param timeout: Timeout in seconds, defaults to 30 + :type timeout: int + """ + + def __init__(self, *, max_retries: Optional[int] = None, timeout: Optional[int] = None): + self.max_retries = max_retries + self.timeout = timeout + + def _to_rest_object(self) -> RestBatchRetrySettings: + return RestBatchRetrySettings( + max_retries=self.max_retries, + timeout=to_iso_duration_format(self.timeout), + ) + + @classmethod + def _from_rest_object(cls, settings: RestBatchRetrySettings) -> Optional["BatchRetrySettings"]: + return ( + BatchRetrySettings( + max_retries=settings.max_retries, + timeout=from_iso_duration_format(settings.timeout), + ) + if settings + else None + ) + + def _merge_with(self, other: "BatchRetrySettings") -> None: + if other: + self.timeout = other.timeout or self.timeout + self.max_retries = other.max_retries or self.max_retries + + +class OnlineRequestSettings(RestTranslatableMixin): + """Request Settings entity. + + :param request_timeout_ms: defaults to 5000 + :type request_timeout_ms: int + :param max_concurrent_requests_per_instance: defaults to 1 + :type max_concurrent_requests_per_instance: int + :param max_queue_wait_ms: defaults to 500 + :type max_queue_wait_ms: int + """ + + def __init__( + self, + max_concurrent_requests_per_instance: Optional[int] = None, + request_timeout_ms: Optional[int] = None, + max_queue_wait_ms: Optional[int] = None, + ): + self.request_timeout_ms = request_timeout_ms + self.max_concurrent_requests_per_instance = max_concurrent_requests_per_instance + self.max_queue_wait_ms = max_queue_wait_ms + + def _to_rest_object(self) -> RestOnlineRequestSettings: + return RestOnlineRequestSettings( + max_queue_wait=to_iso_duration_format_ms(self.max_queue_wait_ms), + max_concurrent_requests_per_instance=self.max_concurrent_requests_per_instance, + request_timeout=to_iso_duration_format_ms(self.request_timeout_ms), + ) + + def _merge_with(self, other: Optional["OnlineRequestSettings"]) -> None: + if other: + self.max_concurrent_requests_per_instance = ( + other.max_concurrent_requests_per_instance or self.max_concurrent_requests_per_instance + ) + self.request_timeout_ms = other.request_timeout_ms or self.request_timeout_ms + self.max_queue_wait_ms = other.max_queue_wait_ms or self.max_queue_wait_ms + + @classmethod + def _from_rest_object(cls, settings: RestOnlineRequestSettings) -> Optional["OnlineRequestSettings"]: + return ( + OnlineRequestSettings( + request_timeout_ms=from_iso_duration_format_ms(settings.request_timeout), + max_concurrent_requests_per_instance=settings.max_concurrent_requests_per_instance, + max_queue_wait_ms=from_iso_duration_format_ms(settings.max_queue_wait), + ) + if settings + else None + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, OnlineRequestSettings): + return NotImplemented + if not other: + return False + # only compare mutable fields + return ( + self.max_concurrent_requests_per_instance == other.max_concurrent_requests_per_instance + and self.request_timeout_ms == other.request_timeout_ms + and self.max_queue_wait_ms == other.max_queue_wait_ms + ) + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + +class ProbeSettings(RestTranslatableMixin): + def __init__( + self, + *, + failure_threshold: Optional[int] = None, + success_threshold: Optional[int] = None, + timeout: Optional[int] = None, + period: Optional[int] = None, + initial_delay: Optional[int] = None, + ): + """Settings on how to probe an endpoint. + + :param failure_threshold: Threshold for probe failures, defaults to 30 + :type failure_threshold: int + :param success_threshold: Threshold for probe success, defaults to 1 + :type success_threshold: int + :param timeout: timeout in seconds, defaults to 2 + :type timeout: int + :param period: How often (in seconds) to perform the probe, defaults to 10 + :type period: int + :param initial_delay: How long (in seconds) to wait for the first probe, defaults to 10 + :type initial_delay: int + """ + + self.failure_threshold = failure_threshold + self.success_threshold = success_threshold + self.timeout = timeout + self.period = period + self.initial_delay = initial_delay + + def _to_rest_object(self) -> RestProbeSettings: + return RestProbeSettings( + failure_threshold=self.failure_threshold, + success_threshold=self.success_threshold, + timeout=to_iso_duration_format(self.timeout), + period=to_iso_duration_format(self.period), + initial_delay=to_iso_duration_format(self.initial_delay), + ) + + def _merge_with(self, other: Optional["ProbeSettings"]) -> None: + if other: + self.failure_threshold = other.failure_threshold or self.failure_threshold + self.success_threshold = other.success_threshold or self.success_threshold + self.timeout = other.timeout or self.timeout + self.period = other.period or self.period + self.initial_delay = other.initial_delay or self.initial_delay + + @classmethod + def _from_rest_object(cls, settings: RestProbeSettings) -> Optional["ProbeSettings"]: + return ( + ProbeSettings( + failure_threshold=settings.failure_threshold, + success_threshold=settings.success_threshold, + timeout=from_iso_duration_format(settings.timeout), + period=from_iso_duration_format(settings.period), + initial_delay=from_iso_duration_format(settings.initial_delay), + ) + if settings + else None + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ProbeSettings): + return NotImplemented + if not other: + return False + # only compare mutable fields + return ( + self.failure_threshold == other.failure_threshold + and self.success_threshold == other.success_threshold + and self.timeout == other.timeout + and self.period == other.period + and self.initial_delay == other.initial_delay + ) + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/event_hub.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/event_hub.py new file mode 100644 index 00000000..2729fa50 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/event_hub.py @@ -0,0 +1,32 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Any, Dict, Optional + +from azure.ai.ml._schema._deployment.online.event_hub_schema import EventHubSchema +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY +from azure.ai.ml.entities._deployment.oversize_data_config import OversizeDataConfig + + +class EventHub: + """Event Hub deployment entity + + :param namespace: Name space of eventhub, provided in format of "{namespace}.{name}". + :type namespace: str + :param oversize_data_config: Oversized payload body configurations. + :type oversize_data_config: OversizeDataConfig + + """ + + # pylint: disable=unused-argument + def __init__( + self, namespace: Optional[str] = None, oversize_data_config: Optional[OversizeDataConfig] = None, **kwargs: Any + ): + self.namespace = namespace + self.oversize_data_config = oversize_data_config + + def _to_dict(self) -> Dict: + # pylint: disable=no-member + res: dict = EventHubSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/job_definition.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/job_definition.py new file mode 100644 index 00000000..56bebebc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/job_definition.py @@ -0,0 +1,58 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Any, Dict, Optional, Union + +from azure.ai.ml._schema._deployment.batch.job_definition_schema import JobDefinitionSchema +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY +from azure.ai.ml.entities._component.component import Component +from azure.ai.ml.entities._job.job import Job + + +@experimental +class JobDefinition: + """Job Definition entity. + + :param type: Job definition type. Allowed value is: pipeline + :type type: str + :param name: Job name + :type name: str + :param job: Job definition + :type job: Union[Job, str] + :param component: Component definition + :type component: Union[Component, str] + :param settings: Job settings + :type settings: Dict[str, Any] + :param description: Job description. + :type description: str + :param tags: Job tags + :type tags: Dict[str, Any] + """ + + def __init__( + self, + # pylint: disable=redefined-builtin + type: str, + name: Optional[str] = None, + job: Optional[Union[Job, str]] = None, + component: Optional[Union[Component, str]] = None, + settings: Optional[Dict[str, Any]] = None, + description: Optional[str] = None, + tags: Optional[Dict[str, Any]] = None, + # pylint: disable=unused-argument + **kwargs: Any, + ): + self.type = type + self.name = name + self.job = job + self.component = component + self.settings = settings + self.tags = tags + self.description = description + + def _to_dict(self) -> Dict: + # pylint: disable=no-member + res: dict = JobDefinitionSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/model_batch_deployment.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/model_batch_deployment.py new file mode 100644 index 00000000..0ad4fd6f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/model_batch_deployment.py @@ -0,0 +1,207 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from os import PathLike +from pathlib import Path +from typing import Any, Dict, Optional, Union + +from azure.ai.ml._restclient.v2022_05_01.models import BatchDeploymentData +from azure.ai.ml._restclient.v2022_05_01.models import BatchDeploymentDetails as RestBatchDeployment +from azure.ai.ml._restclient.v2022_05_01.models import BatchOutputAction +from azure.ai.ml._restclient.v2022_05_01.models import CodeConfiguration as RestCodeConfiguration +from azure.ai.ml._restclient.v2022_05_01.models import IdAssetReference +from azure.ai.ml._schema._deployment.batch.model_batch_deployment import ModelBatchDeploymentSchema +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY +from azure.ai.ml.constants._deployment import BatchDeploymentOutputAction +from azure.ai.ml.entities._assets import Environment, Model +from azure.ai.ml.entities._deployment.batch_deployment import BatchDeployment +from azure.ai.ml.entities._deployment.deployment import Deployment +from azure.ai.ml.entities._job.resource_configuration import ResourceConfiguration +from azure.ai.ml.entities._util import load_from_dict +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + +from .code_configuration import CodeConfiguration +from .model_batch_deployment_settings import ModelBatchDeploymentSettings + + +@experimental +class ModelBatchDeployment(Deployment): + """Job Definition entity. + + :param type: Job definition type. Allowed value is: pipeline + :type type: str + :param name: Job name + :type name: str + :param job: Job definition + :type job: Union[Job, str] + :param component: Component definition + :type component: Union[Component, str] + :param settings: Job settings + :type settings: Dict[str, Any] + :param description: Job description. + :type description: str + :param tags: Job tags + :type tags: Dict[str, Any] + :param properties: The asset property dictionary. + :type properties: dict[str, str] + """ + + def __init__( + self, + *, + name: Optional[str], + endpoint_name: Optional[str] = None, + environment: Optional[Union[str, Environment]] = None, + properties: Optional[Dict[str, str]] = None, + model: Optional[Union[str, Model]] = None, + description: Optional[str] = None, + tags: Optional[Dict[str, Any]] = None, + settings: Optional[ModelBatchDeploymentSettings] = None, + resources: Optional[ResourceConfiguration] = None, + compute: Optional[str] = None, + code_configuration: Optional[CodeConfiguration] = None, + code_path: Optional[Union[str, PathLike]] = None, # promoted property from code_configuration.code + scoring_script: Optional[ + Union[str, PathLike] + ] = None, # promoted property from code_configuration.scoring_script + **kwargs: Any, + ): + self._provisioning_state: Optional[str] = kwargs.pop("provisioning_state", None) + super().__init__( + name=name, + endpoint_name=endpoint_name, + properties=properties, + code_path=code_path, + scoring_script=scoring_script, + environment=environment, + model=model, + description=description, + tags=tags, + code_configuration=code_configuration, + **kwargs, + ) + self.compute = compute + self.resources = resources + if settings is not None: + self.settings = ModelBatchDeploymentSettings( + mini_batch_size=settings.mini_batch_size, + instance_count=settings.instance_count, + max_concurrency_per_instance=settings.max_concurrency_per_instance, + output_action=settings.output_action, + output_file_name=settings.output_file_name, + retry_settings=settings.retry_settings, + environment_variables=settings.environment_variables, + error_threshold=settings.error_threshold, + logging_level=settings.logging_level, + ) + if self.resources is not None: + if self.resources.instance_count is None and settings.instance_count is not None: + self.resources.instance_count = settings.instance_count + if self.resources is None and settings.instance_count is not None: + self.resources = ResourceConfiguration(instance_count=settings.instance_count) + + # pylint: disable=arguments-differ + def _to_rest_object(self, location: str) -> BatchDeploymentData: # type: ignore + self._validate() + code_config = ( + RestCodeConfiguration( + code_id=self.code_configuration.code, + scoring_script=self.code_configuration.scoring_script, + ) + if self.code_configuration + else None + ) + deployment_settings = self.settings + model = IdAssetReference(asset_id=self.model) if self.model else None + batch_deployment = RestBatchDeployment( + description=self.description, + environment_id=self.environment, + model=model, + code_configuration=code_config, + output_file_name=deployment_settings.output_file_name, + output_action=BatchDeployment._yaml_output_action_to_rest_output_action( # pylint: disable=protected-access + deployment_settings.output_action + ), + error_threshold=deployment_settings.error_threshold, + resources=self.resources._to_rest_object() if self.resources else None, # pylint: disable=protected-access + retry_settings=( + deployment_settings.retry_settings._to_rest_object() # pylint: disable=protected-access + if deployment_settings.retry_settings + else None + ), + logging_level=deployment_settings.logging_level, + mini_batch_size=deployment_settings.mini_batch_size, + max_concurrency_per_instance=deployment_settings.max_concurrency_per_instance, + environment_variables=deployment_settings.environment_variables, + compute=self.compute, + properties=self.properties, + ) + return BatchDeploymentData(location=location, properties=batch_deployment, tags=self.tags) + + @classmethod + def _load( + cls, + data: Optional[Dict] = None, + yaml_path: Optional[Union[PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "ModelBatchDeployment": + data = data or {} + params_override = params_override or [] + cls._update_params(params_override) + + context = { + BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path.cwd(), + PARAMS_OVERRIDE_KEY: params_override, + } + res: ModelBatchDeployment = load_from_dict(ModelBatchDeploymentSchema, data, context, **kwargs) + return res + + @classmethod + def _update_params(cls, params_override: Any) -> None: + for param in params_override: + endpoint_name = param.get("endpoint_name") + if isinstance(endpoint_name, str): + param["endpoint_name"] = endpoint_name.lower() + + @classmethod + def _yaml_output_action_to_rest_output_action(cls, yaml_output_action: str) -> str: + output_switcher = { + BatchDeploymentOutputAction.APPEND_ROW: BatchOutputAction.APPEND_ROW, + BatchDeploymentOutputAction.SUMMARY_ONLY: BatchOutputAction.SUMMARY_ONLY, + } + return output_switcher.get(yaml_output_action, yaml_output_action) + + @property + def provisioning_state(self) -> Optional[str]: + """Batch deployment provisioning state, readonly. + + :return: Batch deployment provisioning state. + :rtype: Optional[str] + """ + return self._provisioning_state + + def _validate(self) -> None: + self._validate_output_action() + + def _validate_output_action(self) -> None: + if ( + self.settings.output_action + and self.settings.output_action == BatchDeploymentOutputAction.SUMMARY_ONLY + and self.settings.output_file_name + ): + msg = "When output_action is set to {}, the output_file_name need not to be specified." + msg = msg.format(BatchDeploymentOutputAction.SUMMARY_ONLY) + raise ValidationException( + message=msg, + target=ErrorTarget.BATCH_DEPLOYMENT, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + def _to_dict(self) -> Dict: + res: dict = ModelBatchDeploymentSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/model_batch_deployment_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/model_batch_deployment_settings.py new file mode 100644 index 00000000..36151019 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/model_batch_deployment_settings.py @@ -0,0 +1,81 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Any, Dict, Optional + +from azure.ai.ml._schema._deployment.batch.model_batch_deployment_settings import ModelBatchDeploymentSettingsSchema +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY +from azure.ai.ml.constants._deployment import BatchDeploymentOutputAction +from azure.ai.ml.entities._deployment.deployment_settings import BatchRetrySettings + + +@experimental +class ModelBatchDeploymentSettings: + """Model Batch Deployment Settings entity. + + :param mini_batch_size: Size of the mini-batch passed to each batch invocation, defaults to 10 + :type mini_batch_size: int + :param instance_count: Number of instances the interfering will run on. Equivalent to resources.instance_count. + :type instance_count: int + :param output_action: Indicates how the output will be organized. Possible values include: + "summary_only", "append_row". Defaults to "append_row" + :type output_action: str or ~azure.ai.ml.constants._deployment.BatchDeploymentOutputAction + :param output_file_name: Customized output file name for append_row output action, defaults to "predictions.csv" + :type output_file_name: str + :param max_concurrency_per_instance: Indicates maximum number of parallelism per instance, defaults to 1 + :type max_concurrency_per_instance: int + :param retry_settings: Retry settings for a batch inference operation, defaults to None + :type retry_settings: BatchRetrySettings + :param environment_variables: Environment variables that will be set in deployment. + :type environment_variables: dict + :param error_threshold: Error threshold, if the error count for the entire input goes above + this value, + the batch inference will be aborted. Range is [-1, int.MaxValue] + -1 value indicates, ignore all failures during batch inference + For FileDataset count of file failures + For TabularDataset, this is the count of record failures, defaults to -1 + :type error_threshold: int + :param logging_level: Logging level for batch inference operation, defaults to "info" + :type logging_level: str + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START model_batch_deployment_settings_entity_create] + :end-before: [END model_batch_deployment_settings_entity_create] + :language: python + :dedent: 8 + :caption: Creating a Model Batch Deployment Settings object. + """ + + def __init__( + self, + *, + mini_batch_size: Optional[int], + instance_count: Optional[int] = None, + max_concurrency_per_instance: Optional[int] = None, + output_action: Optional[BatchDeploymentOutputAction] = None, + output_file_name: Optional[str] = None, + retry_settings: Optional[BatchRetrySettings] = None, + environment_variables: Optional[Dict[str, str]] = None, + error_threshold: Optional[int] = None, + logging_level: Optional[str] = None, + # pylint: disable=unused-argument + **kwargs: Any, + ): + self.mini_batch_size = mini_batch_size + self.instance_count = instance_count + self.max_concurrency_per_instance = max_concurrency_per_instance + self.output_action = output_action + self.output_file_name = output_file_name + self.retry_settings = retry_settings + self.environment_variables = environment_variables + self.error_threshold = error_threshold + self.logging_level = logging_level + + def _to_dict(self) -> Dict: + # pylint: disable=no-member + res: dict = ModelBatchDeploymentSettingsSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/online_deployment.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/online_deployment.py new file mode 100644 index 00000000..131d3293 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/online_deployment.py @@ -0,0 +1,742 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access,arguments-renamed,unidiomatic-typecheck + +import logging +import os +import typing +from abc import abstractmethod +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union, cast + +from azure.ai.ml._restclient.v2023_04_01_preview.models import CodeConfiguration as RestCodeConfiguration +from azure.ai.ml._restclient.v2023_04_01_preview.models import EndpointComputeType +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + KubernetesOnlineDeployment as RestKubernetesOnlineDeployment, +) +from azure.ai.ml._restclient.v2023_04_01_preview.models import ManagedOnlineDeployment as RestManagedOnlineDeployment +from azure.ai.ml._restclient.v2023_04_01_preview.models import OnlineDeployment as RestOnlineDeploymentData +from azure.ai.ml._restclient.v2023_04_01_preview.models import OnlineDeploymentProperties as RestOnlineDeploymentDetails +from azure.ai.ml._restclient.v2023_04_01_preview.models import Sku as RestSku +from azure.ai.ml._schema._deployment.online.online_deployment import ( + KubernetesOnlineDeploymentSchema, + ManagedOnlineDeploymentSchema, +) +from azure.ai.ml._utils._arm_id_utils import _parse_endpoint_name_from_deployment_id +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY, TYPE, ArmConstants +from azure.ai.ml.constants._endpoint import EndpointYamlFields +from azure.ai.ml.entities._assets import Code +from azure.ai.ml.entities._assets._artifacts.model import Model +from azure.ai.ml.entities._assets.environment import Environment +from azure.ai.ml.entities._deployment.code_configuration import CodeConfiguration +from azure.ai.ml.entities._deployment.data_collector import DataCollector +from azure.ai.ml.entities._deployment.deployment_settings import OnlineRequestSettings, ProbeSettings +from azure.ai.ml.entities._deployment.resource_requirements_settings import ResourceRequirementsSettings +from azure.ai.ml.entities._deployment.scale_settings import ( + DefaultScaleSettings, + OnlineScaleSettings, + TargetUtilizationScaleSettings, +) +from azure.ai.ml.entities._endpoint._endpoint_helpers import validate_endpoint_or_deployment_name +from azure.ai.ml.entities._util import load_from_dict +from azure.ai.ml.exceptions import ( + DeploymentException, + ErrorCategory, + ErrorTarget, + ValidationErrorType, + ValidationException, +) + +from .deployment import Deployment + +module_logger = logging.getLogger(__name__) + + +# pylint: disable=too-many-instance-attributes +class OnlineDeployment(Deployment): + """Online endpoint deployment entity. + + :param name: Name of the deployment resource. + :type name: str + :param endpoint_name: Name of the endpoint resource, defaults to None + :type endpoint_name: typing.Optional[str] + :param tags: Tag dictionary. Tags can be added, removed, and updated, defaults to None + :type tags: typing.Optional[typing.Dict[str, typing.Any]] + :param properties: The asset property dictionary, defaults to None + :type properties: typing.Optional[typing.Dict[str, typing.Any]] + :param description: Description of the resource, defaults to None + :type description: typing.Optional[str] + :param model: Model entity for the endpoint deployment, defaults to None + :type model: typing.Optional[typing.Union[str, ~azure.ai.ml.entities.Model]] + :param data_collector: Data Collector entity for the endpoint deployment, defaults to None + :type data_collector: typing.Optional[typing.Union[str, ~azure.ai.ml.entities.DataCollector]] + :param code_configuration: Code Configuration, defaults to None + :type code_configuration: typing.Optional[~azure.ai.ml.entities.CodeConfiguration] + :param environment: Environment entity for the endpoint deployment, defaults to None + :type environment: typing.Optional[typing.Union[str, ~azure.ai.ml.entities.Environment]] + :param app_insights_enabled: Is appinsights enabled, defaults to False + :type app_insights_enabled: typing.Optional[bool] + :param scale_settings: How the online deployment will scale, defaults to None + :type scale_settings: typing.Optional[~azure.ai.ml.entities.OnlineScaleSettings] + :param request_settings: Online Request Settings, defaults to None + :type request_settings: typing.Optional[~azure.ai.ml.entities.OnlineRequestSettings] + :param liveness_probe: Liveness probe settings, defaults to None + :type liveness_probe: typing.Optional[~azure.ai.ml.entities.ProbeSettings] + :param readiness_probe: Readiness probe settings, defaults to None + :type readiness_probe: typing.Optional[~azure.ai.ml.entities.ProbeSettings] + :param environment_variables: Environment variables that will be set in deployment, defaults to None + :type environment_variables: typing.Optional[typing.Dict[str, str]] + :param instance_count: The instance count used for this deployment, defaults to None + :type instance_count: typing.Optional[int] + :param instance_type: Azure compute sku, defaults to None + :type instance_type: typing.Optional[str] + :param model_mount_path: The path to mount the model in custom container, defaults to None + :type model_mount_path: typing.Optional[str] + :param code_path: Equivalent to code_configuration.code, will be ignored if code_configuration is present + , defaults to None + :type code_path: typing.Optional[typing.Union[str, os.PathLike]] + :param scoring_script: Equivalent to code_configuration.code.scoring_script. + Will be ignored if code_configuration is present, defaults to None + :type scoring_script: typing.Optional[typing.Union[str, os.PathLike]] + """ + + def __init__( + self, + name: str, + *, + endpoint_name: Optional[str] = None, + tags: Optional[Dict[str, typing.Any]] = None, + properties: Optional[Dict[str, typing.Any]] = None, + description: Optional[str] = None, + model: Optional[Union[str, "Model"]] = None, + data_collector: Optional[DataCollector] = None, + code_configuration: Optional[CodeConfiguration] = None, + environment: Optional[Union[str, "Environment"]] = None, + app_insights_enabled: Optional[bool] = False, + scale_settings: Optional[OnlineScaleSettings] = None, + request_settings: Optional[OnlineRequestSettings] = None, + liveness_probe: Optional[ProbeSettings] = None, + readiness_probe: Optional[ProbeSettings] = None, + environment_variables: Optional[Dict[str, str]] = None, + instance_count: Optional[int] = None, + instance_type: Optional[str] = None, + model_mount_path: Optional[str] = None, + code_path: Optional[Union[str, os.PathLike]] = None, # promoted property from code_configuration.code + scoring_script: Optional[Union[str, os.PathLike]] = None, # promoted property code_configuration.scoring_script + **kwargs: typing.Any, + ): + self._provisioning_state: Optional[str] = kwargs.pop("provisioning_state", None) + + super(OnlineDeployment, self).__init__( + name=name, + endpoint_name=endpoint_name, + tags=tags, + properties=properties, + description=description, + model=model, + code_configuration=code_configuration, + environment=environment, + environment_variables=environment_variables, + code_path=code_path, + scoring_script=scoring_script, + **kwargs, + ) + + self.app_insights_enabled = app_insights_enabled + self.scale_settings = scale_settings + self.request_settings = request_settings + self.liveness_probe = liveness_probe + self.readiness_probe = readiness_probe + self.instance_count = instance_count + self._arm_type = ArmConstants.ONLINE_DEPLOYMENT_TYPE + self.model_mount_path = model_mount_path + self.instance_type = instance_type + self.data_collector: Any = data_collector + + @property + def provisioning_state(self) -> Optional[str]: + """Deployment provisioning state, readonly. + + :return: Deployment provisioning state. + :rtype: typing.Optional[str] + """ + return self._provisioning_state + + def _generate_dependencies(self) -> Tuple: + """Convert dependencies into ARM id or REST wrapper. + + :return: A 3-tuple of the code configuration, environment ID, and model ID. + :rtype: Tuple[RestCodeConfiguration, str, str] + """ + code = None + + if self.code_configuration: + self.code_configuration._validate() + if self.code_configuration.code is not None: + if isinstance(self.code_configuration.code, str): + code_id = self.code_configuration.code + elif not isinstance(self.code_configuration.code, os.PathLike): + code_id = self.code_configuration.code.id + + code = RestCodeConfiguration( + code_id=code_id, # pylint: disable=possibly-used-before-assignment + scoring_script=self.code_configuration.scoring_script, + ) + + model_id = None + if self.model: + model_id = self.model if isinstance(self.model, str) else self.model.id + + environment_id = None + if self.environment: + environment_id = self.environment if isinstance(self.environment, str) else self.environment.id + + return code, environment_id, model_id + + @abstractmethod + def _to_dict(self) -> Dict: + pass + + @abstractmethod + def _to_arm_resource_param(self, **kwargs: Any) -> Dict: + pass + + @abstractmethod + def _to_rest_object(self) -> RestOnlineDeploymentData: + pass + + @classmethod + def _from_rest_object(cls, deployment: RestOnlineDeploymentData) -> RestOnlineDeploymentDetails: + if deployment.properties.endpoint_compute_type == EndpointComputeType.KUBERNETES: + return KubernetesOnlineDeployment._from_rest_object(deployment) + if deployment.properties.endpoint_compute_type == EndpointComputeType.MANAGED: + return ManagedOnlineDeployment._from_rest_object(deployment) + + msg = f"Unsupported online endpoint type {deployment.properties.endpoint_compute_type}." + raise DeploymentException( + message=msg, + target=ErrorTarget.ONLINE_DEPLOYMENT, + no_personal_data_message=msg, + error_category=ErrorCategory.SYSTEM_ERROR, + ) + + def _get_arm_resource(self, **kwargs: Any) -> Dict: + resource: dict = super(OnlineDeployment, self)._get_arm_resource(**kwargs) + depends_on = [] + if self.environment and isinstance(self.environment, Environment): + depends_on.append(f"{self.environment._arm_type}Deployment") + if self.code_configuration and self.code_configuration.code and isinstance(self.code_configuration.code, Code): + depends_on.append(f"{self.code_configuration.code._arm_type}Deployment") + if self.model and isinstance(self.model, Model): + depends_on.append(f"{self.model._arm_type}Deployment") + resource[ArmConstants.DEPENDSON_PARAMETER_NAME] = depends_on + return resource + + def _get_arm_resource_and_params(self, **kwargs: Any) -> List: + resource_param_tuple_list = [(self._get_arm_resource(**kwargs), self._to_arm_resource_param(**kwargs))] + if self.environment and isinstance(self.environment, Environment): + resource_param_tuple_list.extend(self.environment._get_arm_resource_and_params()) + if self.code_configuration and self.code_configuration.code and isinstance(self.code_configuration.code, Code): + resource_param_tuple_list.extend(self.code_configuration.code._get_arm_resource_and_params()) + if self.model and isinstance(self.model, Model): + resource_param_tuple_list.extend(self.model._get_arm_resource_and_params()) + return resource_param_tuple_list + + def _validate_name(self) -> None: + if self.name: + validate_endpoint_or_deployment_name(self.name, is_deployment=True) + + def _merge_with(self, other: Any) -> None: + if other: + if self.name != other.name: + msg = "The deployment name: {} and {} are not matched when merging." + raise ValidationException( + message=msg.format(self.name, other.name), + target=ErrorTarget.ONLINE_DEPLOYMENT, + no_personal_data_message=msg.format("[name1]", "[name2]"), + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + super()._merge_with(other) + self.app_insights_enabled = other.app_insights_enabled or self.app_insights_enabled + # Adding noqa: Fix E721 do not compare types, use 'isinstance()' + # isinstance will include checking for subclasses, which is explicitly undesired by a logic. + if self.scale_settings and type(self.scale_settings) == type(other.scale_settings): # noqa + self.scale_settings._merge_with(other.scale_settings) + else: + self.scale_settings = other.scale_settings + if self.request_settings: + self.request_settings._merge_with(other.request_settings) + else: + self.request_settings = other.request_settings + if self.liveness_probe: + self.liveness_probe._merge_with(other.liveness_probe) + else: + self.liveness_probe = other.liveness_probe + if self.readiness_probe: + self.readiness_probe._merge_with(other.readiness_probe) + else: + self.readiness_probe = other.readiness_probe + self.instance_count = other.instance_count or self.instance_count + self.instance_type = other.instance_type or self.instance_type + + @classmethod + def _set_scale_settings(cls, data: dict) -> None: + if not hasattr(data, EndpointYamlFields.SCALE_SETTINGS): + return + + scale_settings = data[EndpointYamlFields.SCALE_SETTINGS] + keyName = TYPE + if scale_settings and scale_settings[keyName] == "default": + scale_copy = scale_settings.copy() + for key in scale_copy: + if key != keyName: + scale_settings.pop(key, None) + + @classmethod + def _load( + cls, + data: Optional[Dict] = None, + yaml_path: Optional[Union[os.PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "OnlineDeployment": + data = data or {} + params_override = params_override or [] + context = { + BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path.cwd(), + PARAMS_OVERRIDE_KEY: params_override, + } + + deployment_type = data.get("type", None) + + if deployment_type == camel_to_snake(EndpointComputeType.KUBERNETES.value): + res_kub: OnlineDeployment = load_from_dict(KubernetesOnlineDeploymentSchema, data, context, **kwargs) + return res_kub + + res_manage: OnlineDeployment = load_from_dict(ManagedOnlineDeploymentSchema, data, context, **kwargs) + return res_manage + + +class KubernetesOnlineDeployment(OnlineDeployment): + """Kubernetes Online endpoint deployment entity. + + :param name: Name of the deployment resource. + :type name: str + :param endpoint_name: Name of the endpoint resource, defaults to None + :type endpoint_name: typing.Optional[str] + :param tags: Tag dictionary. Tags can be added, removed, and updated., defaults to None + :type tags: typing.Optional[typing.Dict[str, typing.Any]] + :param properties: The asset property dictionary, defaults to None + :type properties: typing.Optional[typing.Dict[str, typing.Any]] + :param description: Description of the resource, defaults to None + :type description: typing.Optional[str] + :param model: Model entity for the endpoint deployment, defaults to None + :type model: typing.Optional[typing.Union[str, ~azure.ai.ml.entities.Model]] + :param code_configuration: Code Configuration, defaults to None + :type code_configuration: typing.Optional[~azure.ai.ml.entities.CodeConfiguration] + :param environment: Environment entity for the endpoint deployment, defaults to None + :type environment: typing.Optional[typing.Union[str, ~azure.ai.ml.entities.Environment]] + :param app_insights_enabled: Is appinsights enabled, defaults to False + :type app_insights_enabled: bool + :param scale_settings: How the online deployment will scale, defaults to None + :type scale_settings: typing.Optional[typing.Union[~azure.ai.ml.entities.DefaultScaleSettings + , ~azure.ai.ml.entities.TargetUtilizationScaleSettings]] + :param request_settings: Online Request Settings, defaults to None + :type request_settings: typing.Optional[OnlineRequestSettings] + :param liveness_probe: Liveness probe settings, defaults to None + :type liveness_probe: typing.Optional[~azure.ai.ml.entities.ProbeSettings] + :param readiness_probe: Readiness probe settings, defaults to None + :type readiness_probe: typing.Optional[~azure.ai.ml.entities.ProbeSettings] + :param environment_variables: Environment variables that will be set in deployment, defaults to None + :type environment_variables: typing.Optional[typing.Dict[str, str]] + :param resources: Resource requirements settings, defaults to None + :type resources: typing.Optional[~azure.ai.ml.entities.ResourceRequirementsSettings] + :param instance_count: The instance count used for this deployment, defaults to None + :type instance_count: typing.Optional[int] + :param instance_type: The instance type defined by K8S cluster admin, defaults to None + :type instance_type: typing.Optional[str] + :param code_path: Equivalent to code_configuration.code, will be ignored if code_configuration is present + , defaults to None + :type code_path: typing.Optional[typing.Union[str, os.PathLike]] + :param scoring_script: Equivalent to code_configuration.code.scoring_script. + Will be ignored if code_configuration is present, defaults to None + :type scoring_script: typing.Optional[typing.Union[str, os.PathLike]] + """ + + def __init__( + self, + *, + name: str, + endpoint_name: Optional[str] = None, + tags: Optional[Dict[str, typing.Any]] = None, + properties: Optional[Dict[str, typing.Any]] = None, + description: Optional[str] = None, + model: Optional[Union[str, "Model"]] = None, + code_configuration: Optional[CodeConfiguration] = None, + environment: Optional[Union[str, "Environment"]] = None, + app_insights_enabled: bool = False, + scale_settings: Optional[Union[DefaultScaleSettings, TargetUtilizationScaleSettings]] = None, + request_settings: Optional[OnlineRequestSettings] = None, + liveness_probe: Optional[ProbeSettings] = None, + readiness_probe: Optional[ProbeSettings] = None, + environment_variables: Optional[Dict[str, str]] = None, + resources: Optional[ResourceRequirementsSettings] = None, + instance_count: Optional[int] = None, + instance_type: Optional[str] = None, + code_path: Optional[Union[str, os.PathLike]] = None, # promoted property from code_configuration.code + scoring_script: Optional[ + Union[str, os.PathLike] + ] = None, # promoted property from code_configuration.scoring_script + **kwargs: Any, + ): + kwargs["type"] = EndpointComputeType.KUBERNETES.value + super(KubernetesOnlineDeployment, self).__init__( + name=name, + endpoint_name=endpoint_name, + tags=tags, + properties=properties, + description=description, + model=model, + code_configuration=code_configuration, + environment=environment, + environment_variables=environment_variables, + instance_count=instance_count, + instance_type=instance_type, + app_insights_enabled=app_insights_enabled, + scale_settings=scale_settings, + request_settings=request_settings, + liveness_probe=liveness_probe, + readiness_probe=readiness_probe, + code_path=code_path, + scoring_script=scoring_script, + **kwargs, + ) + + self.resources = resources + + def _to_dict(self) -> Dict: + res: dict = KubernetesOnlineDeploymentSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + # pylint: disable=arguments-differ + def _to_rest_object(self, location: str) -> RestOnlineDeploymentData: # type: ignore + self._validate() + code, environment, model = self._generate_dependencies() + + properties = RestKubernetesOnlineDeployment( + code_configuration=code, + environment_id=environment, + model=model, + model_mount_path=self.model_mount_path, + scale_settings=self.scale_settings._to_rest_object() if self.scale_settings else None, + properties=self.properties, + description=self.description, + environment_variables=self.environment_variables, + app_insights_enabled=self.app_insights_enabled, + request_settings=self.request_settings._to_rest_object() if self.request_settings else None, + liveness_probe=self.liveness_probe._to_rest_object() if self.liveness_probe else None, + readiness_probe=self.readiness_probe._to_rest_object() if self.readiness_probe else None, + container_resource_requirements=self.resources._to_rest_object() if self.resources else None, + instance_type=self.instance_type if self.instance_type else None, + data_collector=self.data_collector._to_rest_object() if self.data_collector else None, + ) + sku = RestSku(name="Default", capacity=self.instance_count) + + return RestOnlineDeploymentData(location=location, properties=properties, tags=self.tags, sku=sku) + + def _to_arm_resource_param(self, **kwargs: Any) -> Dict: + rest_object = self._to_rest_object(**kwargs) + properties = rest_object.properties + sku = rest_object.sku + tags = rest_object.tags + + return { + self._arm_type: { + ArmConstants.NAME: self.name, + ArmConstants.PROPERTIES_PARAMETER_NAME: self._serialize.body(properties, "K8SOnlineDeployment"), + ArmConstants.SKU: self._serialize.body(sku, "Sku"), + ArmConstants.TAGS: tags, + } + } + + def _merge_with(self, other: Any) -> None: + if other: + super()._merge_with(other) + if self.resources: + self.resources._merge_with(other.resources) + else: + self.resources = other.resources + + def _validate(self) -> None: + self._validate_name() + + @classmethod + def _from_rest_object(cls, resource: RestOnlineDeploymentData) -> "KubernetesOnlineDeployment": + deployment = resource.properties + + code_config = ( + CodeConfiguration( + code=deployment.code_configuration.code_id, + scoring_script=deployment.code_configuration.scoring_script, + ) + if deployment.code_configuration + else None + ) + + return KubernetesOnlineDeployment( + id=resource.id, + name=resource.name, + tags=resource.tags, + properties=deployment.properties, + description=deployment.description, + request_settings=OnlineRequestSettings._from_rest_object(deployment.request_settings), + model=deployment.model, + code_configuration=code_config, + environment=deployment.environment_id, + resources=ResourceRequirementsSettings._from_rest_object(deployment.container_resource_requirements), + app_insights_enabled=deployment.app_insights_enabled, + scale_settings=cast( + Optional[Union[DefaultScaleSettings, TargetUtilizationScaleSettings]], + OnlineScaleSettings._from_rest_object(deployment.scale_settings), + ), + liveness_probe=ProbeSettings._from_rest_object(deployment.liveness_probe), + readiness_probe=ProbeSettings._from_rest_object(deployment.readiness_probe), + environment_variables=deployment.environment_variables, + endpoint_name=_parse_endpoint_name_from_deployment_id(resource.id), + instance_count=resource.sku.capacity if resource.sku else None, + instance_type=deployment.instance_type, + data_collector=( + DataCollector._from_rest_object(deployment.data_collector) + if hasattr(deployment, "data_collector") and deployment.data_collector + else None + ), + provisioning_state=deployment.provisioning_state if hasattr(deployment, "provisioning_state") else None, + ) + + +class ManagedOnlineDeployment(OnlineDeployment): + """Managed Online endpoint deployment entity. + + :param name: Name of the deployment resource + :type name: str + :param endpoint_name: Name of the endpoint resource, defaults to None + :type endpoint_name: typing.Optional[str] + :param tags: Tag dictionary. Tags can be added, removed, and updated., defaults to None + :type tags: typing.Optional[typing.Dict[str, typing.Any]] + :param properties: The asset property dictionary, defaults to None + :type properties: typing.Optional[typing.Dict[str, typing.Any]] + :param description: Description of the resource, defaults to None + :type description: typing.Optional[str] + :param model: Model entity for the endpoint deployment, defaults to None + :type model: typing.Optional[typing.Union[str, ~azure.ai.ml.entities.Model]] + :param code_configuration: Code Configuration, defaults to None + :type code_configuration: typing.Optional[~azure.ai.ml.entities.CodeConfiguration] + :param environment: Environment entity for the endpoint deployment, defaults to None + :type environment: typing.Optional[typing.Union[str, ~azure.ai.ml.entities.Environment]] + :param app_insights_enabled: Is appinsights enabled, defaults to False + :type app_insights_enabled: bool + :param scale_settings: How the online deployment will scale, defaults to None + :type scale_settings: typing.Optional[typing.Union[~azure.ai.ml.entities.DefaultScaleSettings + , ~azure.ai.ml.entities.TargetUtilizationScaleSettings]] + :param request_settings: Online Request Settings, defaults to None + :type request_settings: typing.Optional[OnlineRequestSettings] + :param liveness_probe: Liveness probe settings, defaults to None + :type liveness_probe: typing.Optional[~azure.ai.ml.entities.ProbeSettings] + :param readiness_probe: Readiness probe settings, defaults to None + :type readiness_probe: typing.Optional[~azure.ai.ml.entities.ProbeSettings] + :param environment_variables: Environment variables that will be set in deployment, defaults to None + :type environment_variables: typing.Optional[typing.Dict[str, str]] + :param instance_type: Azure compute sku, defaults to None + :type instance_type: typing.Optional[str] + :param instance_count: The instance count used for this deployment, defaults to None + :type instance_count: typing.Optional[int] + :param egress_public_network_access: Whether to restrict communication between a deployment and the + Azure resources used to by the deployment. Allowed values are: "enabled", "disabled", defaults to None + :type egress_public_network_access: typing.Optional[str] + :param code_path: Equivalent to code_configuration.code, will be ignored if code_configuration is present + , defaults to None + :type code_path: typing.Optional[typing.Union[str, os.PathLike]] + :param scoring_script_path: Equivalent to code_configuration.scoring_script, will be ignored if + code_configuration is present, defaults to None + :type scoring_script_path: typing.Optional[typing.Union[str, os.PathLike]] + :param data_collector: Data collector, defaults to None + :type data_collector: typing.Optional[typing.List[~azure.ai.ml.entities.DataCollector]] + """ + + def __init__( + self, + *, + name: str, + endpoint_name: Optional[str] = None, + tags: Optional[Dict[str, typing.Any]] = None, + properties: Optional[Dict[str, typing.Any]] = None, + description: Optional[str] = None, + model: Optional[Union[str, "Model"]] = None, + code_configuration: Optional[CodeConfiguration] = None, + environment: Optional[Union[str, "Environment"]] = None, + app_insights_enabled: bool = False, + scale_settings: Optional[Union[DefaultScaleSettings, TargetUtilizationScaleSettings]] = None, + request_settings: Optional[OnlineRequestSettings] = None, + liveness_probe: Optional[ProbeSettings] = None, + readiness_probe: Optional[ProbeSettings] = None, + environment_variables: Optional[Dict[str, str]] = None, + instance_type: Optional[str] = None, + instance_count: Optional[int] = None, + egress_public_network_access: Optional[str] = None, + code_path: Optional[Union[str, os.PathLike]] = None, # promoted property from code_configuration.code + scoring_script: Optional[ + Union[str, os.PathLike] + ] = None, # promoted property from code_configuration.scoring_script + data_collector: Optional[DataCollector] = None, + **kwargs: Any, + ): + kwargs["type"] = EndpointComputeType.MANAGED.value + self.private_network_connection = kwargs.pop("private_network_connection", None) + self.package_model = kwargs.pop("package_model", False) + + super(ManagedOnlineDeployment, self).__init__( + name=name, + endpoint_name=endpoint_name, + tags=tags, + properties=properties, + description=description, + model=model, + code_configuration=code_configuration, + environment=environment, + environment_variables=environment_variables, + app_insights_enabled=app_insights_enabled, + scale_settings=scale_settings, + request_settings=request_settings, + liveness_probe=liveness_probe, + readiness_probe=readiness_probe, + instance_count=instance_count, + instance_type=instance_type, + code_path=code_path, + scoring_script=scoring_script, + data_collector=data_collector, + **kwargs, + ) + + self.readiness_probe = readiness_probe + self.egress_public_network_access = egress_public_network_access + + def _to_dict(self) -> Dict: + res: dict = ManagedOnlineDeploymentSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + # pylint: disable=arguments-differ + def _to_rest_object(self, location: str) -> RestOnlineDeploymentData: # type: ignore + self._validate() + code, environment, model = self._generate_dependencies() + properties = RestManagedOnlineDeployment( + code_configuration=code, + environment_id=environment, + model=model, + model_mount_path=self.model_mount_path, + scale_settings=self.scale_settings._to_rest_object() if self.scale_settings else None, + properties=self.properties, + description=self.description, + environment_variables=self.environment_variables, + app_insights_enabled=self.app_insights_enabled, + request_settings=self.request_settings._to_rest_object() if self.request_settings else None, + liveness_probe=self.liveness_probe._to_rest_object() if self.liveness_probe else None, + instance_type=self.instance_type, + readiness_probe=self.readiness_probe._to_rest_object() if self.readiness_probe else None, + data_collector=self.data_collector._to_rest_object() if self.data_collector else None, + ) + # TODO: SKU name is defaulted to value "Default" since service side requires it. + # Should be removed once service side defaults it. + sku = RestSku(name="Default", capacity=self.instance_count) + + # mfe is expecting private network connection to be in both the attribute level + # as well as in the properties dictionary. + if hasattr(self, "private_network_connection") and self.private_network_connection: + properties.private_network_connection = self.private_network_connection + properties.properties["private-network-connection"] = self.private_network_connection + if hasattr(self, "egress_public_network_access") and self.egress_public_network_access: + properties.egress_public_network_access = self.egress_public_network_access + return RestOnlineDeploymentData(location=location, properties=properties, tags=self.tags, sku=sku) + + def _to_arm_resource_param(self, **kwargs: Any) -> Dict: + rest_object = self._to_rest_object(**kwargs) + properties = rest_object.properties + sku = rest_object.sku + tags = rest_object.tags + + return { + self._arm_type: { + ArmConstants.NAME: self.name, + ArmConstants.PROPERTIES_PARAMETER_NAME: self._serialize.body(properties, "ManagedOnlineDeployment"), + ArmConstants.SKU: self._serialize.body(sku, "Sku"), + ArmConstants.TAGS: tags, + } + } + + @classmethod + def _from_rest_object(cls, resource: RestOnlineDeploymentData) -> "ManagedOnlineDeployment": + deployment = resource.properties + + code_config = ( + CodeConfiguration( + code=deployment.code_configuration.code_id, + scoring_script=deployment.code_configuration.scoring_script, + ) + if deployment.code_configuration + else None + ) + + return ManagedOnlineDeployment( + id=resource.id, + name=resource.name, + tags=resource.tags, + properties=deployment.properties, + description=deployment.description, + request_settings=OnlineRequestSettings._from_rest_object(deployment.request_settings), + model=(deployment.model if deployment.model else None), + code_configuration=code_config, + environment=deployment.environment_id, + app_insights_enabled=deployment.app_insights_enabled, + scale_settings=OnlineScaleSettings._from_rest_object(deployment.scale_settings), # type: ignore + liveness_probe=ProbeSettings._from_rest_object(deployment.liveness_probe), + environment_variables=deployment.environment_variables, + readiness_probe=ProbeSettings._from_rest_object(deployment.readiness_probe), + instance_type=deployment.instance_type, + endpoint_name=_parse_endpoint_name_from_deployment_id(resource.id), + instance_count=resource.sku.capacity, + private_network_connection=( + deployment.private_network_connection if hasattr(deployment, "private_network_connection") else None + ), + egress_public_network_access=deployment.egress_public_network_access, + data_collector=( + DataCollector._from_rest_object(deployment.data_collector) + if hasattr(deployment, "data_collector") and deployment.data_collector + else None + ), + provisioning_state=deployment.provisioning_state if hasattr(deployment, "provisioning_state") else None, + creation_context=resource.system_data, + ) + + def _merge_with(self, other: Any) -> None: + if other: + super()._merge_with(other) + self.instance_type = other.instance_type or self.instance_type + + def _validate(self) -> None: + self._validate_name() + self._validate_scale_settings() + + def _validate_scale_settings(self) -> None: + if self.scale_settings: + if not isinstance(self.scale_settings, DefaultScaleSettings): + msg = "ManagedOnlineEndpoint supports DefaultScaleSettings only." + raise ValidationException( + message=msg, + target=ErrorTarget.ONLINE_DEPLOYMENT, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/oversize_data_config.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/oversize_data_config.py new file mode 100644 index 00000000..80338c39 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/oversize_data_config.py @@ -0,0 +1,25 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Any, Dict, Optional + +from azure.ai.ml._schema._deployment.online.oversize_data_config_schema import OversizeDataConfigSchema +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY + + +class OversizeDataConfig: + """Oversize Data Config deployment entity. + + :param path: Blob path for Model Data Collector file. + :type path: str + """ + + # pylint: disable=unused-argument + def __init__(self, path: Optional[str] = None, **kwargs: Any): + self.path = path + + def _to_dict(self) -> Dict: + # pylint: disable=no-member + res: dict = OversizeDataConfigSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/payload_response.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/payload_response.py new file mode 100644 index 00000000..b67d46c7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/payload_response.py @@ -0,0 +1,26 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Any, Dict, Optional + +from azure.ai.ml._schema._deployment.online.payload_response_schema import PayloadResponseSchema +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY + + +class PayloadResponse: + """Response deployment entity + + :param enabled: Is response logging enabled. + :type enabled: str + + """ + + # pylint: disable=unused-argument + def __init__(self, enabled: Optional[str] = None, **kwargs: Any): + self.enabled = enabled + + def _to_dict(self) -> Dict: + # pylint: disable=no-member + res: dict = PayloadResponseSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/pipeline_component_batch_deployment.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/pipeline_component_batch_deployment.py new file mode 100644 index 00000000..730bc39e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/pipeline_component_batch_deployment.py @@ -0,0 +1,150 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from os import PathLike +from pathlib import Path +from typing import IO, Any, AnyStr, Dict, Optional, Union + +from azure.ai.ml._restclient.v2024_01_01_preview.models import BatchDeployment as RestBatchDeployment +from azure.ai.ml._restclient.v2024_01_01_preview.models import ( + BatchDeploymentProperties, + BatchPipelineComponentDeploymentConfiguration, + IdAssetReference, +) +from azure.ai.ml._schema._deployment.batch.pipeline_component_batch_deployment_schema import ( + PipelineComponentBatchDeploymentSchema, +) +from azure.ai.ml._utils._arm_id_utils import _parse_endpoint_name_from_deployment_id +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml._utils.utils import dump_yaml_to_file +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY +from azure.ai.ml.entities import PipelineComponent +from azure.ai.ml.entities._builders import BaseNode +from azure.ai.ml.entities._component.component import Component +from azure.ai.ml.entities._resource import Resource +from azure.ai.ml.entities._util import load_from_dict + + +@experimental +class PipelineComponentBatchDeployment(Resource): + """Pipeline Component Batch Deployment entity. + + :param type: Job definition type. Allowed value: "pipeline" + :type type: Optional[str] + :param name: Name of the deployment resource. + :type name: Optional[str] + :param description: Description of the deployment resource. + :type description: Optional[str] + :param component: Component definition. + :type component: Optional[Union[Component, str]] + :param settings: Run-time settings for the pipeline job. + :type settings: Optional[Dict[str, Any]] + :param tags: A set of tags. The tags which will be applied to the job. + :type tags: Optional[Dict[str, Any]] + :param job_definition: Arm ID or PipelineJob entity of an existing pipeline job. + :type job_definition: Optional[Dict[str, ~azure.ai.ml.entities._builders.BaseNode]] + :param endpoint_name: Name of the Endpoint resource, defaults to None. + :type endpoint_name: Optional[str] + """ + + def __init__( + self, + *, + name: Optional[str], + endpoint_name: Optional[str] = None, + component: Optional[Union[Component, str]] = None, + settings: Optional[Dict[str, str]] = None, + job_definition: Optional[Dict[str, BaseNode]] = None, + tags: Optional[Dict] = None, + description: Optional[str] = None, + **kwargs: Any, + ): + self._type = kwargs.pop("type", None) + super().__init__(name=name, tags=tags, description=description, **kwargs) + self.component = component + self.endpoint_name = endpoint_name + self.settings = settings + self.job_definition = job_definition + + def _to_rest_object(self, location: str) -> "RestBatchDeployment": + if isinstance(self.component, PipelineComponent): + id_asset_ref = IdAssetReference(asset_id=self.component.id) + + batch_pipeline_config = BatchPipelineComponentDeploymentConfiguration( + settings=self.settings, + tags=self.component.tags, + description=self.component.description, + component_id=id_asset_ref, + ) + else: + id_asset_ref = IdAssetReference(asset_id=self.component) + batch_pipeline_config = BatchPipelineComponentDeploymentConfiguration( + settings=self.settings, component_id=id_asset_ref + ) + return RestBatchDeployment( + location=location, + tags=self.tags, + properties=BatchDeploymentProperties( + deployment_configuration=batch_pipeline_config, + description=self.description, + ), + ) + + @classmethod + def _load( + cls, + data: Optional[Dict] = None, + yaml_path: Optional[Union[PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "PipelineComponentBatchDeployment": + data = data or {} + params_override = params_override or [] + cls._update_params(params_override) + + context = { + BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path.cwd(), + PARAMS_OVERRIDE_KEY: params_override, + } + res: PipelineComponentBatchDeployment = load_from_dict( + PipelineComponentBatchDeploymentSchema, data, context, **kwargs + ) + return res + + @classmethod + def _update_params(cls, params_override: Any) -> None: + for param in params_override: + endpoint_name = param.get("endpoint_name") + if isinstance(endpoint_name, str): + param["endpoint_name"] = endpoint_name.lower() + + @classmethod + def _from_rest_object(cls, deployment: RestBatchDeployment) -> "PipelineComponentBatchDeployment": + return PipelineComponentBatchDeployment( + name=deployment.name, + tags=deployment.tags, + component=deployment.properties.additional_properties["deploymentConfiguration"]["componentId"]["assetId"], + settings=deployment.properties.additional_properties["deploymentConfiguration"]["settings"], + endpoint_name=_parse_endpoint_name_from_deployment_id(deployment.id), + ) + + def dump(self, dest: Union[str, PathLike, IO[AnyStr]], **kwargs: Any) -> None: + """Dump the deployment content into a file in yaml format. + + :param dest: The destination to receive this deployment's content. + Must be either a path to a local file, or an already-open file stream. + If dest is a file path, a new file will be created, + and an exception is raised if the file exists. + If dest is an open file, the file will be written to directly, + and an exception will be raised if the file is not writable. + :type dest: typing.Union[os.PathLike, str, typing.IO[typing.AnyStr]] + """ + path = kwargs.pop("path", None) + yaml_serialized = self._to_dict() + dump_yaml_to_file(dest, yaml_serialized, default_flow_style=False, path=path, **kwargs) + + def _to_dict(self) -> Dict: + res: dict = PipelineComponentBatchDeploymentSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + + return res diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/request_logging.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/request_logging.py new file mode 100644 index 00000000..20cc83fe --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/request_logging.py @@ -0,0 +1,39 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Any, Dict, List, Optional + +from azure.ai.ml._restclient.v2023_04_01_preview.models import RequestLogging as RestRequestLogging +from azure.ai.ml._schema._deployment.online.request_logging_schema import RequestLoggingSchema +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY + + +@experimental +class RequestLogging: + """Request Logging deployment entity. + + :param capture_headers: Request payload header. + :type capture_headers: list[str] + """ + + def __init__( + self, + *, + capture_headers: Optional[List[str]] = None, + **kwargs: Any, + ): # pylint: disable=unused-argument + self.capture_headers = capture_headers + + @classmethod + def _from_rest_object(cls, rest_obj: RestRequestLogging) -> "RequestLogging": + return RequestLogging(capture_headers=rest_obj.capture_headers) + + def _to_dict(self) -> Dict: + # pylint: disable=no-member + res: dict = RequestLoggingSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + def _to_rest_object(self) -> RestRequestLogging: + return RestRequestLogging(capture_headers=self.capture_headers) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/resource_requirements_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/resource_requirements_settings.py new file mode 100644 index 00000000..9db61aae --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/resource_requirements_settings.py @@ -0,0 +1,84 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import logging +from typing import Optional + +from azure.ai.ml._restclient.v2022_05_01.models import ContainerResourceRequirements +from azure.ai.ml.entities._deployment.container_resource_settings import ResourceSettings +from azure.ai.ml.entities._mixins import RestTranslatableMixin + +module_logger = logging.getLogger(__name__) + + +class ResourceRequirementsSettings(RestTranslatableMixin): + """Resource requirements settings for a container. + + :param requests: The minimum resource requests for a container. + :type requests: Optional[~azure.ai.ml.entities.ResourceSettings] + :param limits: The resource limits for a container. + :type limits: Optional[~azure.ai.ml.entities.ResourceSettings] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START resource_requirements_configuration] + :end-before: [END resource_requirements_configuration] + :language: python + :dedent: 8 + :caption: Configuring ResourceRequirementSettings for a Kubernetes deployment. + """ + + def __init__( + self, + requests: Optional[ResourceSettings] = None, + limits: Optional[ResourceSettings] = None, + ) -> None: + self.requests = requests + self.limits = limits + + def _to_rest_object(self) -> ContainerResourceRequirements: + return ContainerResourceRequirements( + container_resource_requests=self.requests._to_rest_object() if self.requests else None, + container_resource_limits=self.limits._to_rest_object() if self.limits else None, + ) + + @classmethod + def _from_rest_object( # pylint: disable=arguments-renamed + cls, settings: ContainerResourceRequirements + ) -> Optional["ResourceRequirementsSettings"]: + requests = settings.container_resource_requests + limits = settings.container_resource_limits + return ( + ResourceRequirementsSettings( + requests=ResourceSettings._from_rest_object(requests), + limits=ResourceSettings._from_rest_object(limits), + ) + if settings + else None + ) + + def _merge_with(self, other: Optional["ResourceRequirementsSettings"]) -> None: + if other: + if self.requests: + self.requests._merge_with(other.requests) + else: + self.requests = other.requests + if self.limits: + self.limits._merge_with(other.limits) + else: + self.limits = other.limits + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ResourceRequirementsSettings): + return NotImplemented + if not other: + return False + # only compare mutable fields + return self.requests == other.requests and self.limits == other.limits + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/run_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/run_settings.py new file mode 100644 index 00000000..f1deac83 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/run_settings.py @@ -0,0 +1,50 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Any, Dict, Optional + +from azure.ai.ml._schema._deployment.batch.run_settings_schema import RunSettingsSchema +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY + + +@experimental +class RunSettings: + """Run Settings entity. + + :param name: Run settings name + :type name: str + :param display_name: Run settings display name + :type display_name: str + :param experiment_name: Run settings experiment name + :type experiment_name: str + :param description: Run settings description + :type description: str + :param tags: Run settings tags + :type tags: Dict[str, Any] + :param settings: Run settings - settings + :type settings: Dict[str, Any] + """ + + def __init__( + self, + name: Optional[str] = None, + display_name: Optional[str] = None, + experiment_name: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[Dict[str, Any]] = None, + settings: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ): # pylint: disable=unused-argument + self.name = name + self.display_name = display_name + self.experiment_name = experiment_name + self.description = description + self.tags = tags + self.settings = settings + + def _to_dict(self) -> Dict: + # pylint: disable=no-member + res: dict = RunSettingsSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/scale_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/scale_settings.py new file mode 100644 index 00000000..85535ca0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/scale_settings.py @@ -0,0 +1,173 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import logging +from abc import abstractmethod +from typing import Any, Optional + +from azure.ai.ml._restclient.v2023_04_01_preview.models import DefaultScaleSettings as RestDefaultScaleSettings +from azure.ai.ml._restclient.v2023_04_01_preview.models import OnlineScaleSettings as RestOnlineScaleSettings +from azure.ai.ml._restclient.v2023_04_01_preview.models import ScaleType +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + TargetUtilizationScaleSettings as RestTargetUtilizationScaleSettings, +) +from azure.ai.ml._utils.utils import camel_to_snake, from_iso_duration_format, to_iso_duration_format +from azure.ai.ml.entities._mixins import RestTranslatableMixin +from azure.ai.ml.exceptions import DeploymentException, ErrorCategory, ErrorTarget + +module_logger = logging.getLogger(__name__) + + +class OnlineScaleSettings(RestTranslatableMixin): + """Scale settings for online deployment. + + :param type: Type of the scale settings, allowed values are "default" and "target_utilization". + :type type: str + """ + + def __init__( + self, + # pylint: disable=redefined-builtin + type: str, + # pylint: disable=unused-argument + **kwargs: Any, + ): + self.type = camel_to_snake(type) + + @abstractmethod + def _to_rest_object(self) -> RestOnlineScaleSettings: + pass + + def _merge_with(self, other: Any) -> None: + if other: + self.type = other.type or self.type + + @classmethod + def _from_rest_object( # pylint: disable=arguments-renamed + cls, settings: RestOnlineScaleSettings + ) -> "OnlineScaleSettings": + if settings.scale_type == "Default": + return DefaultScaleSettings._from_rest_object(settings) + if settings.scale_type == "TargetUtilization": + return TargetUtilizationScaleSettings._from_rest_object(settings) + + msg = f"Unsupported online scale setting type {settings.scale_type}." + raise DeploymentException( + message=msg, + target=ErrorTarget.ONLINE_DEPLOYMENT, + no_personal_data_message=msg, + error_category=ErrorCategory.SYSTEM_ERROR, + ) + + +class DefaultScaleSettings(OnlineScaleSettings): + """Default scale settings. + + :ivar type: Default scale settings type. Set automatically to "default" for this class. + :vartype type: str + """ + + def __init__(self, **kwargs: Any): + super(DefaultScaleSettings, self).__init__( + type=ScaleType.DEFAULT.value, + ) + + def _to_rest_object(self) -> RestDefaultScaleSettings: + return RestDefaultScaleSettings() + + @classmethod + def _from_rest_object(cls, settings: RestDefaultScaleSettings) -> "DefaultScaleSettings": + return DefaultScaleSettings() + + def __eq__(self, other: object) -> bool: + if not isinstance(other, DefaultScaleSettings): + return NotImplemented + if not other: + return False + # only compare mutable fields + res: bool = self.type.lower() == other.type.lower() + return res + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + +class TargetUtilizationScaleSettings(OnlineScaleSettings): + """Auto scale settings. + + :param min_instances: Minimum number of the instances + :type min_instances: int + :param max_instances: Maximum number of the instances + :type max_instances: int + :param polling_interval: The polling interval in ISO 8691 format. Only supports duration with + precision as low as Seconds. + :type polling_interval: str + :param target_utilization_percentage: + :type target_utilization_percentage: int + :ivar type: Target utilization scale settings type. Set automatically to "target_utilization" for this class. + :vartype type: str + """ + + def __init__( + self, + *, + min_instances: Optional[int] = None, + max_instances: Optional[int] = None, + polling_interval: Optional[int] = None, + target_utilization_percentage: Optional[int] = None, + **kwargs: Any, + ): + super(TargetUtilizationScaleSettings, self).__init__( + type=ScaleType.TARGET_UTILIZATION.value, + ) + self.min_instances = min_instances + self.max_instances = max_instances + self.polling_interval = polling_interval + self.target_utilization_percentage = target_utilization_percentage + + def _to_rest_object(self) -> RestTargetUtilizationScaleSettings: + return RestTargetUtilizationScaleSettings( + min_instances=self.min_instances, + max_instances=self.max_instances, + polling_interval=to_iso_duration_format(self.polling_interval), + target_utilization_percentage=self.target_utilization_percentage, + ) + + def _merge_with(self, other: Optional["TargetUtilizationScaleSettings"]) -> None: + if other: + super()._merge_with(other) + self.min_instances = other.min_instances or self.min_instances + self.max_instances = other.max_instances or self.max_instances + self.polling_interval = other.polling_interval or self.polling_interval + self.target_utilization_percentage = ( + other.target_utilization_percentage or self.target_utilization_percentage + ) + + @classmethod + def _from_rest_object(cls, settings: RestTargetUtilizationScaleSettings) -> "TargetUtilizationScaleSettings": + return cls( + min_instances=settings.min_instances, + max_instances=settings.max_instances, + polling_interval=from_iso_duration_format(settings.polling_interval), + target_utilization_percentage=settings.target_utilization_percentage, + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TargetUtilizationScaleSettings): + return NotImplemented + if not other: + return False + # only compare mutable fields + return ( + self.type.lower() == other.type.lower() + and self.min_instances == other.min_instances + and self.max_instances == other.max_instances + and self.polling_interval == other.polling_interval + and self.target_utilization_percentage == other.target_utilization_percentage + ) + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/_endpoint_helpers.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/_endpoint_helpers.py new file mode 100644 index 00000000..5d62a229 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/_endpoint_helpers.py @@ -0,0 +1,62 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import re +from typing import Any, Optional + +from azure.ai.ml.constants._endpoint import EndpointConfigurations +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + + +def validate_endpoint_or_deployment_name(name: Optional[str], is_deployment: bool = False) -> None: + """Validates the name of an endpoint or a deployment + + A valid name of an endpoint or deployment: + + 1. Is between 3 and 32 characters long (inclusive of both ends of the range) + 2. Starts with a letter + 3. Is followed by 0 or more alphanumeric characters (`a-zA-Z0-9`) or hyphens (`-`) + 3. Ends with an alphanumeric character (`a-zA-Z0-9`) + + :param name: Either an endpoint or deployment name + :type name: str + :param is_deployment: Whether the name is a deployment name. Defaults to False + :type is_deployment: bool + """ + if name is None: + return + + type_str = "a deployment" if is_deployment else "an endpoint" + target = ErrorTarget.DEPLOYMENT if is_deployment else ErrorTarget.ENDPOINT + if len(name) < EndpointConfigurations.MIN_NAME_LENGTH or len(name) > EndpointConfigurations.MAX_NAME_LENGTH: + msg = f"The name for {type_str} must be at least 3 and at most 32 characters long (inclusive of both limits)." + raise ValidationException( + message=msg, + target=target, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + if not re.match(EndpointConfigurations.NAME_REGEX_PATTERN, name): + msg = f"""The name for {type_str} must start with an upper- or lowercase letter + and only consist of '-'s and alphanumeric characters.""" + raise ValidationException( + message=msg, + target=target, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + +def validate_identity_type_defined(identity: Any) -> None: + if identity and not identity.type: + msg = "Identity type not found in provided yaml file." + raise ValidationException( + message=msg, + target=ErrorTarget.ENDPOINT, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.MISSING_FIELD, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/batch_endpoint.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/batch_endpoint.py new file mode 100644 index 00000000..4883c828 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/batch_endpoint.py @@ -0,0 +1,134 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import logging +from os import PathLike +from pathlib import Path +from typing import IO, Any, AnyStr, Dict, Optional, Union + +from azure.ai.ml._restclient.v2023_10_01.models import BatchEndpoint as BatchEndpointData +from azure.ai.ml._restclient.v2023_10_01.models import BatchEndpointProperties as RestBatchEndpoint +from azure.ai.ml._schema._endpoint import BatchEndpointSchema +from azure.ai.ml._utils.utils import camel_to_snake, snake_to_camel +from azure.ai.ml.constants._common import AAD_TOKEN_YAML, BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY +from azure.ai.ml.entities._endpoint._endpoint_helpers import validate_endpoint_or_deployment_name +from azure.ai.ml.entities._util import load_from_dict + +from .endpoint import Endpoint + +module_logger = logging.getLogger(__name__) + + +class BatchEndpoint(Endpoint): + """Batch endpoint entity. + + :param name: Name of the resource. + :type name: str + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: dict[str, str] + :param properties: The asset property dictionary. + :type properties: dict[str, str] + :param auth_mode: Possible values include: "AMLToken", "Key", "AADToken", defaults to None + :type auth_mode: str + :param description: Description of the inference endpoint, defaults to None + :type description: str + :param location: defaults to None + :type location: str + :param defaults: Traffic rules on how the traffic will be routed across deployments, defaults to {} + :type defaults: Dict[str, str] + :param default_deployment_name: Equivalent to defaults.default_deployment, will be ignored if defaults is present. + :type default_deployment_name: str + :param scoring_uri: URI to use to perform a prediction, readonly. + :type scoring_uri: str + :param openapi_uri: URI to check the open API definition of the endpoint. + :type openapi_uri: str + """ + + def __init__( + self, + *, + name: Optional[str] = None, + tags: Optional[Dict] = None, + properties: Optional[Dict] = None, + auth_mode: str = AAD_TOKEN_YAML, + description: Optional[str] = None, + location: Optional[str] = None, + defaults: Optional[Dict[str, str]] = None, + default_deployment_name: Optional[str] = None, + scoring_uri: Optional[str] = None, + openapi_uri: Optional[str] = None, + **kwargs: Any, + ) -> None: + super(BatchEndpoint, self).__init__( + name=name, + tags=tags, + properties=properties, + auth_mode=auth_mode, + description=description, + location=location, + scoring_uri=scoring_uri, + openapi_uri=openapi_uri, + **kwargs, + ) + + self.defaults = defaults + + if not self.defaults and default_deployment_name: + self.defaults = {} + self.defaults["deployment_name"] = default_deployment_name + + def _to_rest_batch_endpoint(self, location: str) -> BatchEndpointData: + validate_endpoint_or_deployment_name(self.name) + batch_endpoint = RestBatchEndpoint( + description=self.description, + auth_mode=snake_to_camel(self.auth_mode), + properties=self.properties, + defaults=self.defaults, + ) + return BatchEndpointData(location=location, tags=self.tags, properties=batch_endpoint) + + @classmethod + def _from_rest_object(cls, obj: BatchEndpointData) -> "BatchEndpoint": + return BatchEndpoint( + id=obj.id, + name=obj.name, + tags=obj.tags, + properties=obj.properties.properties, + auth_mode=camel_to_snake(obj.properties.auth_mode), + description=obj.properties.description, + location=obj.location, + defaults=obj.properties.defaults, + provisioning_state=obj.properties.provisioning_state, + scoring_uri=obj.properties.scoring_uri, + openapi_uri=obj.properties.swagger_uri, + ) + + def dump( + self, + dest: Optional[Union[str, PathLike, IO[AnyStr]]] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + context = {BASE_PATH_CONTEXT_KEY: Path(".").parent} + return BatchEndpointSchema(context=context).dump(self) # type: ignore + + @classmethod + def _load( + cls, + data: Optional[Dict] = None, + yaml_path: Optional[Union[PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "BatchEndpoint": + data = data or {} + params_override = params_override or [] + context = { + BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path.cwd(), + PARAMS_OVERRIDE_KEY: params_override, + } + res: BatchEndpoint = load_from_dict(BatchEndpointSchema, data, context) + return res + + def _to_dict(self) -> Dict: + res: dict = BatchEndpointSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/endpoint.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/endpoint.py new file mode 100644 index 00000000..d878742e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/endpoint.py @@ -0,0 +1,145 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import logging +from abc import abstractmethod +from os import PathLike +from typing import IO, Any, AnyStr, Dict, Optional, Union + +from azure.ai.ml.entities._resource import Resource +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + +module_logger = logging.getLogger(__name__) + + +class Endpoint(Resource): # pylint: disable=too-many-instance-attributes + """Endpoint base class. + + :param auth_mode: The authentication mode, defaults to None + :type auth_mode: str + :param location: The location of the endpoint, defaults to None + :type location: str + :param name: Name of the resource. + :type name: str + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: typing.Optional[typing.Dict[str, str]] + :param properties: The asset property dictionary. + :type properties: typing.Optional[typing.Dict[str, str]] + :param description: Description of the resource. + :type description: typing.Optional[str] + :keyword traffic: Traffic rules on how the traffic will be routed across deployments, defaults to {} + :paramtype traffic: typing.Optional[typing.Dict[str, int]] + :keyword scoring_uri: str, Endpoint URI, readonly + :paramtype scoring_uri: typing.Optional[str] + :keyword openapi_uri: str, Endpoint Open API URI, readonly + :paramtype openapi_uri: typing.Optional[str] + :keyword provisioning_state: str, provisioning state, readonly + :paramtype provisioning_state: typing.Optional[str] + """ + + def __init__( + self, + auth_mode: Optional[str] = None, + location: Optional[str] = None, + name: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + properties: Optional[Dict[str, Any]] = None, + description: Optional[str] = None, + **kwargs: Any, + ): + """Endpoint base class. + + Constructor for Endpoint base class. + + :param auth_mode: The authentication mode, defaults to None + :type auth_mode: str + :param location: The location of the endpoint, defaults to None + :type location: str + :param name: Name of the resource. + :type name: str + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: typing.Optional[typing.Dict[str, str]] + :param properties: The asset property dictionary. + :type properties: typing.Optional[typing.Dict[str, str]] + :param description: Description of the resource. + :type description: typing.Optional[str] + :keyword traffic: Traffic rules on how the traffic will be routed across deployments, defaults to {} + :paramtype traffic: typing.Optional[typing.Dict[str, int]] + :keyword scoring_uri: str, Endpoint URI, readonly + :paramtype scoring_uri: typing.Optional[str] + :keyword openapi_uri: str, Endpoint Open API URI, readonly + :paramtype openapi_uri: typing.Optional[str] + :keyword provisioning_state: str, provisioning state, readonly + :paramtype provisioning_state: typing.Optional[str] + """ + # MFE is case-insensitive for Name. So convert the name into lower case here. + if name: + name = name.lower() + self._scoring_uri: Optional[str] = kwargs.pop("scoring_uri", None) + self._openapi_uri: Optional[str] = kwargs.pop("openapi_uri", None) + self._provisioning_state: Optional[str] = kwargs.pop("provisioning_state", None) + super().__init__(name, description, tags, properties, **kwargs) + self.auth_mode = auth_mode + self.location = location + + @property + def scoring_uri(self) -> Optional[str]: + """URI to use to perform a prediction, readonly. + + :return: The scoring URI + :rtype: typing.Optional[str] + """ + return self._scoring_uri + + @property + def openapi_uri(self) -> Optional[str]: + """URI to check the open api definition of the endpoint. + + :return: The open API URI + :rtype: typing.Optional[str] + """ + return self._openapi_uri + + @property + def provisioning_state(self) -> Optional[str]: + """Endpoint provisioning state, readonly. + + :return: Endpoint provisioning state. + :rtype: typing.Optional[str] + """ + return self._provisioning_state + + @abstractmethod + def dump(self, dest: Optional[Union[str, PathLike, IO[AnyStr]]] = None, **kwargs: Any) -> Dict: + pass + + @classmethod + @abstractmethod + def _from_rest_object(cls, obj: Any) -> Any: + pass + + def _merge_with(self, other: Any) -> None: + if other: + if self.name != other.name: + msg = "The endpoint name: {} and {} are not matched when merging." + raise ValidationException( + message=msg.format(self.name, other.name), + target=ErrorTarget.ENDPOINT, + no_personal_data_message=msg.format("[name1]", "[name2]"), + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + self.description = other.description or self.description + if other.tags: + if self.tags is not None: + self.tags = {**self.tags, **other.tags} + if other.properties: + self.properties = {**self.properties, **other.properties} + self.auth_mode = other.auth_mode or self.auth_mode + if hasattr(other, "traffic"): + self.traffic = other.traffic # pylint: disable=attribute-defined-outside-init + if hasattr(other, "mirror_traffic"): + self.mirror_traffic = other.mirror_traffic # pylint: disable=attribute-defined-outside-init + if hasattr(other, "defaults"): + self.defaults = other.defaults # pylint: disable=attribute-defined-outside-init diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/online_endpoint.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/online_endpoint.py new file mode 100644 index 00000000..cdd72536 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/online_endpoint.py @@ -0,0 +1,647 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=no-member + +import logging +from os import PathLike +from pathlib import Path +from typing import IO, Any, AnyStr, Dict, Optional, Union, cast + +from azure.ai.ml._restclient.v2022_02_01_preview.models import EndpointAuthKeys as RestEndpointAuthKeys +from azure.ai.ml._restclient.v2022_02_01_preview.models import EndpointAuthMode +from azure.ai.ml._restclient.v2022_02_01_preview.models import EndpointAuthToken as RestEndpointAuthToken +from azure.ai.ml._restclient.v2022_02_01_preview.models import OnlineEndpointData +from azure.ai.ml._restclient.v2022_02_01_preview.models import OnlineEndpointDetails as RestOnlineEndpoint +from azure.ai.ml._restclient.v2022_05_01.models import ManagedServiceIdentity as RestManagedServiceIdentityConfiguration +from azure.ai.ml._schema._endpoint import KubernetesOnlineEndpointSchema, ManagedOnlineEndpointSchema +from azure.ai.ml._utils.utils import dict_eq +from azure.ai.ml.constants._common import ( + AAD_TOKEN_YAML, + AML_TOKEN_YAML, + BASE_PATH_CONTEXT_KEY, + KEY, + PARAMS_OVERRIDE_KEY, +) +from azure.ai.ml.constants._endpoint import EndpointYamlFields +from azure.ai.ml.entities._credentials import IdentityConfiguration +from azure.ai.ml.entities._mixins import RestTranslatableMixin +from azure.ai.ml.entities._util import is_compute_in_override, load_from_dict +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException +from azure.core.credentials import AccessToken + +from ._endpoint_helpers import validate_endpoint_or_deployment_name, validate_identity_type_defined +from .endpoint import Endpoint + +module_logger = logging.getLogger(__name__) + + +class OnlineEndpoint(Endpoint): + """Online endpoint entity. + + :keyword name: Name of the resource, defaults to None + :paramtype name: typing.Optional[str] + :keyword tags: Tag dictionary. Tags can be added, removed, and updated. defaults to None + :paramtype tags: typing.Optional[typing.Dict[str, typing.Any]] + :keyword properties: The asset property dictionary, defaults to None + :paramtype properties: typing.Optional[typing.Dict[str, typing.Any]] + :keyword auth_mode: Possible values include: "aml_token", "key", defaults to KEY + :type auth_mode: typing.Optional[str] + :keyword description: Description of the inference endpoint, defaults to None + :paramtype description: typing.Optional[str] + :keyword location: Location of the resource, defaults to None + :paramtype location: typing.Optional[str] + :keyword traffic: Traffic rules on how the traffic will be routed across deployments, defaults to None + :paramtype traffic: typing.Optional[typing.Dict[str, int]] + :keyword mirror_traffic: Duplicated live traffic used to inference a single deployment, defaults to None + :paramtype mirror_traffic: typing.Optional[typing.Dict[str, int]] + :keyword identity: Identity Configuration, defaults to SystemAssigned + :paramtype identity: typing.Optional[IdentityConfiguration] + :keyword scoring_uri: Scoring URI, defaults to None + :paramtype scoring_uri: typing.Optional[str] + :keyword openapi_uri: OpenAPI URI, defaults to None + :paramtype openapi_uri: typing.Optional[str] + :keyword provisioning_state: Provisioning state of an endpoint, defaults to None + :paramtype provisioning_state: typing.Optional[str] + :keyword kind: Kind of the resource, we have two kinds: K8s and Managed online endpoints, defaults to None + :paramtype kind: typing.Optional[str] + """ + + def __init__( + self, + *, + name: Optional[str] = None, + tags: Optional[Dict[str, Any]] = None, + properties: Optional[Dict[str, Any]] = None, + auth_mode: str = KEY, + description: Optional[str] = None, + location: Optional[str] = None, + traffic: Optional[Dict[str, int]] = None, + mirror_traffic: Optional[Dict[str, int]] = None, + identity: Optional[IdentityConfiguration] = None, + scoring_uri: Optional[str] = None, + openapi_uri: Optional[str] = None, + provisioning_state: Optional[str] = None, + kind: Optional[str] = None, + **kwargs: Any, + ): + """Online endpoint entity. + + Constructor for an Online endpoint entity. + + :keyword name: Name of the resource, defaults to None + :paramtype name: typing.Optional[str] + :keyword tags: Tag dictionary. Tags can be added, removed, and updated. defaults to None + :paramtype tags: typing.Optional[typing.Dict[str, typing.Any]] + :keyword properties: The asset property dictionary, defaults to None + :paramtype properties: typing.Optional[typing.Dict[str, typing.Any]] + :keyword auth_mode: Possible values include: "aml_token", "key", defaults to KEY + :type auth_mode: typing.Optional[str] + :keyword description: Description of the inference endpoint, defaults to None + :paramtype description: typing.Optional[str] + :keyword location: Location of the resource, defaults to None + :paramtype location: typing.Optional[str] + :keyword traffic: Traffic rules on how the traffic will be routed across deployments, defaults to None + :paramtype traffic: typing.Optional[typing.Dict[str, int]] + :keyword mirror_traffic: Duplicated live traffic used to inference a single deployment, defaults to None + :paramtype mirror_traffic: typing.Optional[typing.Dict[str, int]] + :keyword identity: Identity Configuration, defaults to SystemAssigned + :paramtype identity: typing.Optional[IdentityConfiguration] + :keyword scoring_uri: Scoring URI, defaults to None + :paramtype scoring_uri: typing.Optional[str] + :keyword openapi_uri: OpenAPI URI, defaults to None + :paramtype openapi_uri: typing.Optional[str] + :keyword provisioning_state: Provisioning state of an endpoint, defaults to None + :paramtype provisioning_state: typing.Optional[str] + :keyword kind: Kind of the resource, we have two kinds: K8s and Managed online endpoints, defaults to None + :type kind: typing.Optional[str] + """ + self._provisioning_state = kwargs.pop("provisioning_state", None) + + super(OnlineEndpoint, self).__init__( + name=name, + properties=properties, + tags=tags, + auth_mode=auth_mode, + description=description, + location=location, + scoring_uri=scoring_uri, + openapi_uri=openapi_uri, + provisioning_state=provisioning_state, + **kwargs, + ) + + self.identity = identity + self.traffic: Dict = dict(traffic) if traffic else {} + self.mirror_traffic: Dict = dict(mirror_traffic) if mirror_traffic else {} + self.kind = kind + + @property + def provisioning_state(self) -> Optional[str]: + """Endpoint provisioning state, readonly. + + :return: Endpoint provisioning state. + :rtype: typing.Optional[str] + """ + return self._provisioning_state + + def _to_rest_online_endpoint(self, location: str) -> OnlineEndpointData: + # pylint: disable=protected-access + identity = ( + self.identity._to_online_endpoint_rest_object() + if self.identity + else RestManagedServiceIdentityConfiguration(type="SystemAssigned") + ) + validate_endpoint_or_deployment_name(self.name) + validate_identity_type_defined(self.identity) + properties = RestOnlineEndpoint( + description=self.description, + auth_mode=OnlineEndpoint._yaml_auth_mode_to_rest_auth_mode(self.auth_mode), + properties=self.properties, + traffic=self.traffic, + mirror_traffic=self.mirror_traffic, + ) + + if hasattr(self, "public_network_access") and self.public_network_access: + properties.public_network_access = self.public_network_access + return OnlineEndpointData( + location=location, + properties=properties, + identity=identity, + tags=self.tags, + ) + + def _to_rest_online_endpoint_traffic_update(self, location: str, no_validation: bool = False) -> OnlineEndpointData: + if not no_validation: + # validate_deployment_name_matches_traffic(self.deployments, self.traffic) + validate_identity_type_defined(self.identity) + # validate_uniqueness_of_deployment_names(self.deployments) + properties = RestOnlineEndpoint( + description=self.description, + auth_mode=OnlineEndpoint._yaml_auth_mode_to_rest_auth_mode(self.auth_mode), + endpoint=self.name, + traffic=self.traffic, + properties=self.properties, + ) + return OnlineEndpointData( + location=location, + properties=properties, + identity=self.identity, + tags=self.tags, + ) + + @classmethod + def _rest_auth_mode_to_yaml_auth_mode(cls, rest_auth_mode: str) -> str: + switcher = { + EndpointAuthMode.AML_TOKEN: AML_TOKEN_YAML, + EndpointAuthMode.AAD_TOKEN: AAD_TOKEN_YAML, + EndpointAuthMode.KEY: KEY, + } + + return switcher.get(rest_auth_mode, rest_auth_mode) + + @classmethod + def _yaml_auth_mode_to_rest_auth_mode(cls, yaml_auth_mode: Optional[str]) -> str: + if yaml_auth_mode is None: + return "" + + yaml_auth_mode = yaml_auth_mode.lower() + + switcher = { + AML_TOKEN_YAML: EndpointAuthMode.AML_TOKEN, + AAD_TOKEN_YAML: EndpointAuthMode.AAD_TOKEN, + KEY: EndpointAuthMode.KEY, + } + + return switcher.get(yaml_auth_mode, yaml_auth_mode) + + @classmethod + def _from_rest_object(cls, obj: OnlineEndpointData) -> "OnlineEndpoint": + auth_mode = cls._rest_auth_mode_to_yaml_auth_mode(obj.properties.auth_mode) + # pylint: disable=protected-access + identity = IdentityConfiguration._from_online_endpoint_rest_object(obj.identity) if obj.identity else None + + endpoint: Any = KubernetesOnlineEndpoint() + + if obj.system_data: + properties_dict = { + "createdBy": obj.system_data.created_by, + "createdAt": obj.system_data.created_at.strftime("%Y-%m-%dT%H:%M:%S.%f%z"), + "lastModifiedAt": obj.system_data.last_modified_at.strftime("%Y-%m-%dT%H:%M:%S.%f%z"), + } + properties_dict.update(obj.properties.properties) + else: + properties_dict = obj.properties.properties + + if obj.properties.compute: + endpoint = KubernetesOnlineEndpoint( + id=obj.id, + name=obj.name, + tags=obj.tags, + properties=properties_dict, + compute=obj.properties.compute, + auth_mode=auth_mode, + description=obj.properties.description, + location=obj.location, + traffic=obj.properties.traffic, + provisioning_state=obj.properties.provisioning_state, + scoring_uri=obj.properties.scoring_uri, + openapi_uri=obj.properties.swagger_uri, + identity=identity, + kind=obj.kind, + ) + else: + endpoint = ManagedOnlineEndpoint( + id=obj.id, + name=obj.name, + tags=obj.tags, + properties=properties_dict, + auth_mode=auth_mode, + description=obj.properties.description, + location=obj.location, + traffic=obj.properties.traffic, + mirror_traffic=obj.properties.mirror_traffic, + provisioning_state=obj.properties.provisioning_state, + scoring_uri=obj.properties.scoring_uri, + openapi_uri=obj.properties.swagger_uri, + identity=identity, + kind=obj.kind, + public_network_access=obj.properties.public_network_access, + ) + + return cast(OnlineEndpoint, endpoint) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, OnlineEndpoint): + return NotImplemented + if not other: + return False + if self.auth_mode is None or other.auth_mode is None: + return False + + if self.name is None and other.name is None: + return ( + self.auth_mode.lower() == other.auth_mode.lower() + and dict_eq(self.tags, other.tags) + and self.description == other.description + and dict_eq(self.traffic, other.traffic) + ) + + if self.name is not None and other.name is not None: + # only compare mutable fields + return ( + self.name.lower() == other.name.lower() + and self.auth_mode.lower() == other.auth_mode.lower() + and dict_eq(self.tags, other.tags) + and self.description == other.description + and dict_eq(self.traffic, other.traffic) + ) + + return False + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + @classmethod + def _load( + cls, + data: Optional[Dict] = None, + yaml_path: Optional[Union[PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "Endpoint": + data = data or {} + params_override = params_override or [] + context = { + BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path.cwd(), + PARAMS_OVERRIDE_KEY: params_override, + } + + if data.get(EndpointYamlFields.COMPUTE) or is_compute_in_override(params_override): + res_kub: Endpoint = load_from_dict(KubernetesOnlineEndpointSchema, data, context) + return res_kub + + res_managed: Endpoint = load_from_dict(ManagedOnlineEndpointSchema, data, context) + return res_managed + + +class KubernetesOnlineEndpoint(OnlineEndpoint): + """K8s Online endpoint entity. + + :keyword name: Name of the resource, defaults to None + :paramtype name: typing.Optional[str] + :keyword tags: Tag dictionary. Tags can be added, removed, and updated, defaults to None + :paramtype tags: typing.Optional[typing.Dict[str, typing.Any]] + :keyword properties: The asset property dictionary, defaults to None + :paramtype properties: typing.Optional[typing.Dict[str, typing.Any]] + :keyword auth_mode: Possible values include: "aml_token", "key", defaults to KEY + :type auth_mode: typing.Optional[str] + :keyword description: Description of the inference endpoint, defaults to None + :paramtype description: typing.Optional[str] + :keyword location: Location of the resource, defaults to None + :paramtype location: typing.Optional[str] + :keyword traffic: Traffic rules on how the traffic will be routed across deployments, defaults to None + :paramtype traffic: typing.Optional[typing.Dict[str, int]] + :keyword mirror_traffic: Duplicated live traffic used to inference a single deployment, defaults to None + :paramtype mirror_traffic: typing.Optional[typing.Dict[str, int]] + :keyword compute: Compute cluster id, defaults to None + :paramtype compute: typing.Optional[str] + :keyword identity: Identity Configuration, defaults to SystemAssigned + :paramtype identity: typing.Optional[IdentityConfiguration] + :keyword kind: Kind of the resource, we have two kinds: K8s and Managed online endpoints, defaults to None + :paramtype kind: typing.Optional[str] + """ + + def __init__( + self, + *, + name: Optional[str] = None, + tags: Optional[Dict[str, Any]] = None, + properties: Optional[Dict[str, Any]] = None, + auth_mode: str = KEY, + description: Optional[str] = None, + location: Optional[str] = None, + traffic: Optional[Dict[str, int]] = None, + mirror_traffic: Optional[Dict[str, int]] = None, + compute: Optional[str] = None, + identity: Optional[IdentityConfiguration] = None, + kind: Optional[str] = None, + **kwargs: Any, + ): + """K8s Online endpoint entity. + + Constructor for K8s Online endpoint entity. + + :keyword name: Name of the resource, defaults to None + :paramtype name: typing.Optional[str] + :keyword tags: Tag dictionary. Tags can be added, removed, and updated, defaults to None + :paramtype tags: typing.Optional[typing.Dict[str, typing.Any]] + :keyword properties: The asset property dictionary, defaults to None + :paramtype properties: typing.Optional[typing.Dict[str, typing.Any]] + :keyword auth_mode: Possible values include: "aml_token", "key", defaults to KEY + :type auth_mode: typing.Optional[str] + :keyword description: Description of the inference endpoint, defaults to None + :paramtype description: typing.Optional[str] + :keyword location: Location of the resource, defaults to None + :paramtype location: typing.Optional[str] + :keyword traffic: Traffic rules on how the traffic will be routed across deployments, defaults to None + :paramtype traffic: typing.Optional[typing.Dict[str, int]] + :keyword mirror_traffic: Duplicated live traffic used to inference a single deployment, defaults to None + :paramtype mirror_traffic: typing.Optional[typing.Dict[str, int]] + :keyword compute: Compute cluster id, defaults to None + :paramtype compute: typing.Optional[str] + :keyword identity: Identity Configuration, defaults to SystemAssigned + :paramtype identity: typing.Optional[IdentityConfiguration] + :keyword kind: Kind of the resource, we have two kinds: K8s and Managed online endpoints, defaults to None + :type kind: typing.Optional[str] + """ + super(KubernetesOnlineEndpoint, self).__init__( + name=name, + properties=properties, + tags=tags, + auth_mode=auth_mode, + description=description, + location=location, + traffic=traffic, + mirror_traffic=mirror_traffic, + identity=identity, + kind=kind, + **kwargs, + ) + + self.compute = compute + + def dump( + self, + dest: Optional[Union[str, PathLike, IO[AnyStr]]] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + context = {BASE_PATH_CONTEXT_KEY: Path(".").parent} + res: dict = KubernetesOnlineEndpointSchema(context=context).dump(self) + return res + + def _to_rest_online_endpoint(self, location: str) -> OnlineEndpointData: + resource = super()._to_rest_online_endpoint(location) + resource.properties.compute = self.compute + return resource + + def _to_rest_online_endpoint_traffic_update(self, location: str, no_validation: bool = False) -> OnlineEndpointData: + resource = super()._to_rest_online_endpoint_traffic_update(location, no_validation) + resource.properties.compute = self.compute + return resource + + def _merge_with(self, other: "KubernetesOnlineEndpoint") -> None: + if other: + if self.name != other.name: + msg = "The endpoint name: {} and {} are not matched when merging." + raise ValidationException( + message=msg.format(self.name, other.name), + target=ErrorTarget.ONLINE_ENDPOINT, + no_personal_data_message=msg.format("[name1]", "[name2]"), + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + super()._merge_with(other) + self.compute = other.compute or self.compute + + def _to_dict(self) -> Dict: + res: dict = KubernetesOnlineEndpointSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + +class ManagedOnlineEndpoint(OnlineEndpoint): + """Managed Online endpoint entity. + + :keyword name: Name of the resource, defaults to None + :paramtype name: typing.Optional[str] + :keyword tags: Tag dictionary. Tags can be added, removed, and updated, defaults to None + :paramtype tags: typing.Optional[typing.Dict[str, typing.Any]] + :keyword properties: The asset property dictionary, defaults to None + :paramtype properties: typing.Optional[typing.Dict[str, typing.Any]] + :keyword auth_mode: Possible values include: "aml_token", "key", defaults to KEY + :type auth_mode: str + :keyword description: Description of the inference endpoint, defaults to None + :paramtype description: typing.Optional[str] + :keyword location: Location of the resource, defaults to None + :paramtype location: typing.Optional[str] + :keyword traffic: Traffic rules on how the traffic will be routed across deployments, defaults to None + :paramtype traffic: typing.Optional[typing.Dict[str, int]] + :keyword mirror_traffic: Duplicated live traffic used to inference a single deployment, defaults to None + :paramtype mirror_traffic: typing.Optional[typing.Dict[str, int]] + :keyword identity: Identity Configuration, defaults to SystemAssigned + :paramtype identity: typing.Optional[IdentityConfiguration] + :keyword kind: Kind of the resource, we have two kinds: K8s and Managed online endpoints, defaults to None. + :paramtype kind: typing.Optional[str] + :keyword public_network_access: Whether to allow public endpoint connectivity, defaults to None + Allowed values are: "enabled", "disabled" + :type public_network_access: typing.Optional[str] + """ + + def __init__( + self, + *, + name: Optional[str] = None, + tags: Optional[Dict[str, Any]] = None, + properties: Optional[Dict[str, Any]] = None, + auth_mode: str = KEY, + description: Optional[str] = None, + location: Optional[str] = None, + traffic: Optional[Dict[str, int]] = None, + mirror_traffic: Optional[Dict[str, int]] = None, + identity: Optional[IdentityConfiguration] = None, + kind: Optional[str] = None, + public_network_access: Optional[str] = None, + **kwargs: Any, + ): + """Managed Online endpoint entity. + + Constructor for Managed Online endpoint entity. + + :keyword name: Name of the resource, defaults to None + :paramtype name: typing.Optional[str] + :keyword tags: Tag dictionary. Tags can be added, removed, and updated, defaults to None + :paramtype tags: typing.Optional[typing.Dict[str, typing.Any]] + :keyword properties: The asset property dictionary, defaults to None + :paramtype properties: typing.Optional[typing.Dict[str, typing.Any]] + :keyword auth_mode: Possible values include: "aml_token", "key", defaults to KEY + :type auth_mode: str + :keyword description: Description of the inference endpoint, defaults to None + :paramtype description: typing.Optional[str] + :keyword location: Location of the resource, defaults to None + :paramtype location: typing.Optional[str] + :keyword traffic: Traffic rules on how the traffic will be routed across deployments, defaults to None + :paramtype traffic: typing.Optional[typing.Dict[str, int]] + :keyword mirror_traffic: Duplicated live traffic used to inference a single deployment, defaults to None + :paramtype mirror_traffic: typing.Optional[typing.Dict[str, int]] + :keyword identity: Identity Configuration, defaults to SystemAssigned + :paramtype identity: typing.Optional[IdentityConfiguration] + :keyword kind: Kind of the resource, we have two kinds: K8s and Managed online endpoints, defaults to None. + :type kind: typing.Optional[str] + :keyword public_network_access: Whether to allow public endpoint connectivity, defaults to None + Allowed values are: "enabled", "disabled" + :type public_network_access: typing.Optional[str] + """ + self.public_network_access = public_network_access + + super(ManagedOnlineEndpoint, self).__init__( + name=name, + properties=properties, + tags=tags, + auth_mode=auth_mode, + description=description, + location=location, + traffic=traffic, + mirror_traffic=mirror_traffic, + identity=identity, + kind=kind, + **kwargs, + ) + + def dump( + self, + dest: Optional[Union[str, PathLike, IO[AnyStr]]] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + context = {BASE_PATH_CONTEXT_KEY: Path(".").parent} + res: dict = ManagedOnlineEndpointSchema(context=context).dump(self) + return res + + def _to_dict(self) -> Dict: + res: dict = ManagedOnlineEndpointSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + +class EndpointAuthKeys(RestTranslatableMixin): + """Keys for endpoint authentication. + + :ivar primary_key: The primary key. + :vartype primary_key: str + :ivar secondary_key: The secondary key. + :vartype secondary_key: str + """ + + def __init__(self, **kwargs: Any): + """Constructor for keys for endpoint authentication. + + :keyword primary_key: The primary key. + :paramtype primary_key: str + :keyword secondary_key: The secondary key. + :paramtype secondary_key: str + """ + self.primary_key = kwargs.get("primary_key", None) + self.secondary_key = kwargs.get("secondary_key", None) + + @classmethod + def _from_rest_object(cls, obj: RestEndpointAuthKeys) -> "EndpointAuthKeys": + return cls(primary_key=obj.primary_key, secondary_key=obj.secondary_key) + + def _to_rest_object(self) -> RestEndpointAuthKeys: + return RestEndpointAuthKeys(primary_key=self.primary_key, secondary_key=self.secondary_key) + + +class EndpointAuthToken(RestTranslatableMixin): + """Endpoint authentication token. + + :ivar access_token: Access token for endpoint authentication. + :vartype access_token: str + :ivar expiry_time_utc: Access token expiry time (UTC). + :vartype expiry_time_utc: float + :ivar refresh_after_time_utc: Refresh access token after time (UTC). + :vartype refresh_after_time_utc: float + :ivar token_type: Access token type. + :vartype token_type: str + """ + + def __init__(self, **kwargs: Any): + """Constuctor for Endpoint authentication token. + + :keyword access_token: Access token for endpoint authentication. + :paramtype access_token: str + :keyword expiry_time_utc: Access token expiry time (UTC). + :paramtype expiry_time_utc: float + :keyword refresh_after_time_utc: Refresh access token after time (UTC). + :paramtype refresh_after_time_utc: float + :keyword token_type: Access token type. + :paramtype token_type: str + """ + self.access_token = kwargs.get("access_token", None) + self.expiry_time_utc = kwargs.get("expiry_time_utc", 0) + self.refresh_after_time_utc = kwargs.get("refresh_after_time_utc", 0) + self.token_type = kwargs.get("token_type", None) + + @classmethod + def _from_rest_object(cls, obj: RestEndpointAuthToken) -> "EndpointAuthToken": + return cls( + access_token=obj.access_token, + expiry_time_utc=obj.expiry_time_utc, + refresh_after_time_utc=obj.refresh_after_time_utc, + token_type=obj.token_type, + ) + + def _to_rest_object(self) -> RestEndpointAuthToken: + return RestEndpointAuthToken( + access_token=self.access_token, + expiry_time_utc=self.expiry_time_utc, + refresh_after_time_utc=self.refresh_after_time_utc, + token_type=self.token_type, + ) + + +class EndpointAadToken: + """Endpoint aad token. + + :ivar access_token: Access token for aad authentication. + :vartype access_token: str + :ivar expiry_time_utc: Access token expiry time (UTC). + :vartype expiry_time_utc: float + """ + + def __init__(self, obj: AccessToken): + """Constructor for Endpoint aad token. + + :param obj: Access token object + :type obj: AccessToken + """ + self.access_token = obj.token + self.expiry_time_utc = obj.expires_on diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/data_availability_status.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/data_availability_status.py new file mode 100644 index 00000000..aa438f3b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/data_availability_status.py @@ -0,0 +1,15 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from enum import Enum +from azure.core import CaseInsensitiveEnumMeta + + +class DataAvailabilityStatus(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """DataAvailabilityStatus.""" + + NONE = "None" + PENDING = "Pending" + INCOMPLETE = "Incomplete" + COMPLETE = "Complete" diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/delay_metadata.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/delay_metadata.py new file mode 100644 index 00000000..66599605 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/delay_metadata.py @@ -0,0 +1,17 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + + +from typing import Any, Optional + + +class DelayMetadata(object): + def __init__( + self, *, days: Optional[int] = None, hours: Optional[int] = None, minutes: Optional[int] = None, **kwargs: Any + ): + self.days = days + self.hours = hours + self.minutes = minutes diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature.py new file mode 100644 index 00000000..2cc54815 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature.py @@ -0,0 +1,54 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from typing import Any, Dict, Optional + +from azure.ai.ml._restclient.v2023_10_01.models import Feature as RestFeature +from azure.ai.ml._restclient.v2023_10_01.models import FeatureProperties +from azure.ai.ml.entities._feature_store_entity.data_column_type import DataColumnType +from azure.ai.ml.entities._mixins import RestTranslatableMixin + + +class Feature(RestTranslatableMixin): + """Feature + + :param name: The name of the feature. + :type name: str + :param data_type: The data type of the feature. + :type data_type: ~azure.ai.ml.entities.DataColumnType + :param description: The description of the feature. Defaults to None. + :type description: Optional[str] + :param tags: Tag dictionary. Tags can be added, removed, and updated. Defaults to None. + :type tags: Optional[dict[str, str]] + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + """ + + def __init__( + self, + *, + name: str, + data_type: DataColumnType, + description: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + **kwargs: Any + ): + self.name = name + self.data_type = data_type + self.description = description + self.tags = tags + + @classmethod + def _from_rest_object(cls, obj: RestFeature) -> Optional["Feature"]: + if not obj: + return None + feature_rest_object_details: FeatureProperties = obj.properties + return Feature( + name=feature_rest_object_details.feature_name, + data_type=feature_rest_object_details.data_type, + description=feature_rest_object_details.description, + tags=feature_rest_object_details.tags, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_set_backfill_metadata.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_set_backfill_metadata.py new file mode 100644 index 00000000..652908e9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_set_backfill_metadata.py @@ -0,0 +1,39 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Any, List, Optional + +from azure.ai.ml._restclient.v2023_10_01.models import ( + FeaturesetVersionBackfillResponse as RestFeaturesetVersionBackfillResponse, +) +from azure.ai.ml.entities._mixins import RestTranslatableMixin + + +class FeatureSetBackfillMetadata(RestTranslatableMixin): + """Feature Set Backfill Metadata + + :param job_ids: A list of IDs of the backfill jobs. Defaults to None. + :type job_ids: Optional[List[str]] + :param type: The type of the backfill job. Defaults to None. + :type type: Optional[str] + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + """ + + def __init__( + self, + *, + job_ids: Optional[List[str]] = None, + type: Optional[str] = None, # pylint: disable=redefined-builtin + # pylint: disable=unused-argument + **kwargs: Any + ) -> None: + self.type = type if type else "BackfillMaterialization" + self.job_ids = job_ids + + @classmethod + def _from_rest_object(cls, obj: RestFeaturesetVersionBackfillResponse) -> Optional["FeatureSetBackfillMetadata"]: + if not obj: + return None + return FeatureSetBackfillMetadata(job_ids=obj.job_ids) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_set_backfill_request.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_set_backfill_request.py new file mode 100644 index 00000000..0baebf4c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_set_backfill_request.py @@ -0,0 +1,91 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from os import PathLike +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +from azure.ai.ml._schema._feature_set.feature_set_backfill_schema import FeatureSetBackfillSchema +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY +from azure.ai.ml.entities._feature_set.feature_window import FeatureWindow +from azure.ai.ml.entities._feature_set.materialization_compute_resource import MaterializationComputeResource +from azure.ai.ml.entities._mixins import RestTranslatableMixin +from azure.ai.ml.entities._util import load_from_dict + + +class FeatureSetBackfillRequest(RestTranslatableMixin): + """Feature Set Backfill Request + + :param name: The name of the backfill job request + :type name: str + :param version: The version of the backfill job request. + :type version: str + :param feature_window: The time window for the feature set backfill request. + :type feature_window: ~azure.ai.ml._restclient.v2023_04_01_preview.models.FeatureWindow + :param description: The description of the backfill job request. Defaults to None. + :type description: Optional[str] + :param tags: Tag dictionary. Tags can be added, removed, and updated. Defaults to None. + :type tags: Optional[dict[str, str]] + :keyword resource: The compute resource settings. Defaults to None. + :paramtype resource: Optional[~azure.ai.ml.entities.MaterializationComputeResource] + :param spark_configuration: Specifies the spark configuration. Defaults to None. + :type spark_configuration: Optional[dict[str, str]] + """ + + def __init__( + self, + *, + name: str, + version: str, + feature_window: Optional[FeatureWindow] = None, + description: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + resource: Optional[MaterializationComputeResource] = None, + spark_configuration: Optional[Dict[str, str]] = None, + data_status: Optional[List[str]] = None, + job_id: Optional[str] = None, + **kwargs: Any, + ): + self.name = name + self.version = version + self.feature_window = feature_window + self.description = description + self.resource = resource + self.tags = tags + self.spark_configuration = spark_configuration + self.data_status = data_status + self.job_id = job_id + + @classmethod + # pylint: disable=unused-argument + def _resolve_cls_and_type(cls, data: Dict, params_override: Tuple) -> Tuple: + """Resolve the class to use for deserializing the data. Return current class if no override is provided. + + :param data: Data to deserialize. + :type data: dict + :param params_override: Parameters to override, defaults to None + :type params_override: typing.Optional[list] + :return: Class to use for deserializing the data & its "type". Type will be None if no override is provided. + :rtype: tuple[class, typing.Optional[str]] + """ + return cls, None + + @classmethod + def _load( + cls, + data: Optional[Dict] = None, + yaml_path: Optional[Union[PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "FeatureSetBackfillRequest": + data = data or {} + params_override = params_override or [] + context = { + BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"), + PARAMS_OVERRIDE_KEY: params_override, + } + loaded_schema = load_from_dict(FeatureSetBackfillSchema, data, context, **kwargs) + return FeatureSetBackfillRequest(**loaded_schema) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_set_materialization_metadata.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_set_materialization_metadata.py new file mode 100644 index 00000000..afcf3fd1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_set_materialization_metadata.py @@ -0,0 +1,98 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from datetime import datetime, timedelta +from typing import Any, Dict, Optional + +from azure.ai.ml._restclient.v2023_10_01.models import JobBase as RestJobBase +from azure.ai.ml.entities._mixins import RestTranslatableMixin +from azure.ai.ml.entities._system_data import SystemData + +from .materialization_type import MaterializationType + +FeaturestoreJobTypeMap: Dict[str, MaterializationType] = { + "BackfillMaterialization": MaterializationType.BACKFILL_MATERIALIZATION, + "RecurrentMaterialization": MaterializationType.RECURRENT_MATERIALIZATION, +} + + +class FeatureSetMaterializationMetadata(RestTranslatableMixin): + """Feature Set Materialization Metadata + + :param type: The type of the materialization job. + :type type: MaterializationType + :param feature_window_start_time: The feature window start time for the feature set materialization job. + :type feature_window_start_time: Optional[datetime] + :param feature_window_end_time: The feature window end time for the feature set materialization job. + :type feature_window_end_time: Optional[datetime] + :param name: The name of the feature set materialization job. + :type name: Optional[str] + :param display_name: The display name for the feature set materialization job. + :type display_name: Optional[str] + :param creation_context: The creation context of the feature set materialization job. + :type creation_context: Optional[~azure.ai.ml.entities.SystemData] + :param duration: current time elapsed for feature set materialization job. + :type duration: Optional[~datetime.timedelta] + :param status: The status of the feature set materialization job. + :type status: Optional[str] + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: Optional[dict[str, str]] + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + """ + + def __init__( + self, + *, + # pylint: disable=redefined-builtin + type: Optional[MaterializationType], + feature_window_start_time: Optional[datetime], + feature_window_end_time: Optional[datetime], + name: Optional[str], + display_name: Optional[str], + creation_context: Optional[SystemData], + duration: Optional[timedelta], + status: Optional[str], + tags: Optional[Dict[str, str]], + # pylint: disable=unused-argument + **kwargs: Any, + ): + self.type = type + self.feature_window_start_time = feature_window_start_time + self.feature_window_end_time = feature_window_end_time + self.name = name + self.display_name = display_name + self.creation_context = creation_context + self.duration = duration + self.status = status + self.tags = tags + + @classmethod + def _from_rest_object(cls, obj: RestJobBase) -> Optional["FeatureSetMaterializationMetadata"]: + if not obj: + return None + job_properties = obj.properties + job_type = job_properties.properties.get("azureml.FeatureStoreJobType", None) + feature_window_start_time = job_properties.properties.get("azureml.FeatureWindowStart", None) + feature_window_end_time = job_properties.properties.get("azureml.FeatureWindowEnd", None) + + time_format = "%Y-%m-%dT%H:%M:%SZ" + feature_window_start_time = ( + datetime.strptime(feature_window_start_time, time_format) if feature_window_start_time else None + ) + feature_window_end_time = ( + datetime.strptime(feature_window_end_time, time_format) if feature_window_end_time else None + ) + + return FeatureSetMaterializationMetadata( + type=FeaturestoreJobTypeMap.get(job_type), + feature_window_start_time=feature_window_start_time, + feature_window_end_time=feature_window_end_time, + name=obj.name, + display_name=job_properties.display_name, + creation_context=SystemData(created_at=obj.system_data.created_at), + status=job_properties.status, + tags=job_properties.tags, + duration=None, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_set_specification.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_set_specification.py new file mode 100644 index 00000000..88ed093f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_set_specification.py @@ -0,0 +1,46 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from os import PathLike +from typing import Any, Optional, Union + +from azure.ai.ml._restclient.v2023_10_01.models import FeaturesetSpecification as RestFeaturesetSpecification +from azure.ai.ml.entities._mixins import RestTranslatableMixin + + +class FeatureSetSpecification(RestTranslatableMixin): + """Feature Set Specification + + :param path: Specifies the feature set spec path to file. Defaults to None. + :type path: Optional[str] + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_featurestore.py + :start-after: [START configure_feature_set] + :end-before: [END configure_feature_set] + :language: python + :dedent: 8 + :caption: Using Feature Set Spec to create Feature Set + """ + + def __init__( + self, *, path: Optional[Union[PathLike, str]] = None, **kwargs: Any + ): # pylint: disable=unused-argument + """ + :param path: Specifies the spec path. + :type path: str + """ + self.path = path + + def _to_rest_object(self) -> RestFeaturesetSpecification: + return RestFeaturesetSpecification(path=self.path) + + @classmethod + def _from_rest_object(cls, obj: RestFeaturesetSpecification) -> Optional["FeatureSetSpecification"]: + if not obj: + return None + return FeatureSetSpecification(path=obj.path) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_transformation_code_metadata.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_transformation_code_metadata.py new file mode 100644 index 00000000..5fd8544e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_transformation_code_metadata.py @@ -0,0 +1,13 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from typing import Any, Optional + + +class FeatureTransformationCodeMetadata(object): + def __init__(self, *, path: str, transformer_class: Optional[str] = None, **kwargs: Any): + self.path = path + self.transformer_class = transformer_class diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_window.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_window.py new file mode 100644 index 00000000..758d1ecf --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_window.py @@ -0,0 +1,34 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from datetime import datetime +from typing import Any, Optional + +from azure.ai.ml._restclient.v2023_10_01.models import FeatureWindow as RestFeatureWindow +from azure.ai.ml.entities._mixins import RestTranslatableMixin + + +class FeatureWindow(RestTranslatableMixin): + """Feature window + :keyword feature_window_end: Specifies the feature window end time. + :paramtype feature_window_end: ~datetime.datetime + :keyword feature_window_start: Specifies the feature window start time. + :paramtype feature_window_start: ~datetime.datetime + """ + + # pylint: disable=unused-argument + def __init__(self, *, feature_window_start: datetime, feature_window_end: datetime, **kwargs: Any) -> None: + self.feature_window_start = feature_window_start + self.feature_window_end = feature_window_end + + def _to_rest_object(self) -> RestFeatureWindow: + return RestFeatureWindow( + feature_window_start=self.feature_window_start, feature_window_end=self.feature_window_end + ) + + @classmethod + def _from_rest_object(cls, obj: RestFeatureWindow) -> Optional["FeatureWindow"]: + if not obj: + return None + return FeatureWindow(feature_window_start=obj.feature_window_start, feature_window_end=obj.feature_window_end) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/featureset_spec_metadata.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/featureset_spec_metadata.py new file mode 100644 index 00000000..4178b074 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/featureset_spec_metadata.py @@ -0,0 +1,101 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +from os import PathLike +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +from marshmallow import INCLUDE + +from azure.ai.ml._schema._feature_set.featureset_spec_metadata_schema import FeaturesetSpecMetadataSchema +from azure.ai.ml._schema._feature_set.featureset_spec_properties_schema import FeaturesetSpecPropertiesSchema +from azure.ai.ml._utils.utils import load_yaml +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY +from azure.ai.ml.entities._feature_store_entity.data_column import DataColumn +from azure.ai.ml.entities._util import load_from_dict +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + +from .delay_metadata import DelayMetadata +from .feature import Feature +from .feature_transformation_code_metadata import FeatureTransformationCodeMetadata +from .source_metadata import SourceMetadata + + +class FeaturesetSpecMetadata(object): + """FeaturesetSpecMetadata for feature-set.""" + + def __init__( + self, + *, + source: SourceMetadata, + feature_transformation_code: Optional[FeatureTransformationCodeMetadata] = None, + features: List[Feature], + index_columns: Optional[List[DataColumn]] = None, + source_lookback: Optional[DelayMetadata] = None, + temporal_join_lookback: Optional[DelayMetadata] = None, + **_kwargs: Any, + ): + if source.type == "featureset" and index_columns: + msg = f"You cannot provide index_columns for {source.type} feature source." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + error_type=ValidationErrorType.INVALID_VALUE, + target=ErrorTarget.FEATURE_SET, + error_category=ErrorCategory.USER_ERROR, + ) + if not index_columns and source.type != "featureset": + msg = f"You need to provide index_columns for {source.type} feature source." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + error_type=ValidationErrorType.INVALID_VALUE, + target=ErrorTarget.FEATURE_SET, + error_category=ErrorCategory.USER_ERROR, + ) + self.source = source + self.feature_transformation_code = feature_transformation_code + self.features = features + self.index_columns = index_columns + self.source_lookback = source_lookback + self.temporal_join_lookback = temporal_join_lookback + + @classmethod + def load( + cls, + yaml_path: Union[PathLike, str], + **kwargs: Any, + ) -> "FeaturesetSpecMetadata": + """Construct an FeaturesetSpecMetadata object from yaml file. + + :param yaml_path: Path to a local file as the source. + :type yaml_path: PathLike | str + + :return: Constructed FeaturesetSpecMetadata object. + :rtype: FeaturesetSpecMetadata + """ + yaml_dict = load_yaml(yaml_path) + return cls._load(yaml_data=yaml_dict, yaml_path=yaml_path, **kwargs) + + @classmethod + def _load( + cls, + yaml_data: Optional[Dict], + yaml_path: Optional[Union[PathLike, str]], + **kwargs: Any, + ) -> "FeaturesetSpecMetadata": + yaml_data = yaml_data or {} + context = { + BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"), + } + res: FeaturesetSpecMetadata = load_from_dict( + FeaturesetSpecMetadataSchema, yaml_data, context, "", unknown=INCLUDE, **kwargs + ) + + return res + + def _to_dict(self) -> Dict: + res: dict = FeaturesetSpecPropertiesSchema(context={BASE_PATH_CONTEXT_KEY: "./"}, unknown=INCLUDE).dump(self) + return res diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/materialization_compute_resource.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/materialization_compute_resource.py new file mode 100644 index 00000000..5bcff24b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/materialization_compute_resource.py @@ -0,0 +1,41 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Any, Optional + +from azure.ai.ml._restclient.v2023_10_01.models import ( + MaterializationComputeResource as RestMaterializationComputeResource, +) +from azure.ai.ml.entities._mixins import RestTranslatableMixin + + +class MaterializationComputeResource(RestTranslatableMixin): + """Materialization Compute resource + + :keyword instance_type: The compute instance type. + :paramtype instance_type: str + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_compute.py + :start-after: [START materialization_compute_resource] + :end-before: [END materialization_compute_resource] + :language: python + :dedent: 8 + :caption: Creating a MaterializationComputeResource object. + """ + + def __init__(self, *, instance_type: str, **kwargs: Any): # pylint: disable=unused-argument + self.instance_type = instance_type + + def _to_rest_object(self) -> RestMaterializationComputeResource: + return RestMaterializationComputeResource(instance_type=self.instance_type) + + @classmethod + def _from_rest_object(cls, obj: RestMaterializationComputeResource) -> Optional["MaterializationComputeResource"]: + if not obj: + return None + return MaterializationComputeResource(instance_type=obj.instance_type) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/materialization_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/materialization_settings.py new file mode 100644 index 00000000..cf6f12e0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/materialization_settings.py @@ -0,0 +1,100 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Any, Dict, Optional + +from azure.ai.ml._restclient.v2023_10_01.models import MaterializationSettings as RestMaterializationSettings +from azure.ai.ml._restclient.v2023_10_01.models import MaterializationStoreType +from azure.ai.ml.entities._feature_set.materialization_compute_resource import MaterializationComputeResource +from azure.ai.ml.entities._mixins import RestTranslatableMixin +from azure.ai.ml.entities._notification.notification import Notification +from azure.ai.ml.entities._schedule.trigger import RecurrenceTrigger + + +class MaterializationSettings(RestTranslatableMixin): + """Defines materialization settings. + + :keyword schedule: The schedule details. Defaults to None. + :paramtype schedule: Optional[~azure.ai.ml.entities.RecurrenceTrigger] + :keyword offline_enabled: Boolean that specifies if offline store is enabled. Defaults to None. + :paramtype offline_enabled: Optional[bool] + :keyword online_enabled: Boolean that specifies if online store is enabled. Defaults to None. + :paramtype online_enabled: Optional[bool] + :keyword notification: The notification details. Defaults to None. + :paramtype notification: Optional[~azure.ai.ml.entities.Notification] + :keyword resource: The compute resource settings. Defaults to None. + :paramtype resource: Optional[~azure.ai.ml.entities.MaterializationComputeResource] + :keyword spark_configuration: The spark compute settings. Defaults to None. + :paramtype spark_configuration: Optional[dict[str, str]] + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_spark_configurations.py + :start-after: [START materialization_setting_configuration] + :end-before: [END materialization_setting_configuration] + :language: python + :dedent: 8 + :caption: Configuring MaterializationSettings. + """ + + def __init__( + self, + *, + schedule: Optional[RecurrenceTrigger] = None, + offline_enabled: Optional[bool] = None, + online_enabled: Optional[bool] = None, + notification: Optional[Notification] = None, + resource: Optional[MaterializationComputeResource] = None, + spark_configuration: Optional[Dict[str, str]] = None, + # pylint: disable=unused-argument + **kwargs: Any, + ) -> None: + self.schedule = schedule + self.offline_enabled = offline_enabled + self.online_enabled = online_enabled + self.notification = notification + self.resource = resource + self.spark_configuration = spark_configuration + + def _to_rest_object(self) -> RestMaterializationSettings: + store_type = None + if self.offline_enabled and self.online_enabled: + store_type = MaterializationStoreType.ONLINE_AND_OFFLINE + elif self.offline_enabled: + store_type = MaterializationStoreType.OFFLINE + elif self.online_enabled: + store_type = MaterializationStoreType.ONLINE + else: + store_type = MaterializationStoreType.NONE + + return RestMaterializationSettings( + schedule=self.schedule._to_rest_object() if self.schedule else None, # pylint: disable=protected-access + notification=( + self.notification._to_rest_object() if self.notification else None # pylint: disable=protected-access + ), + resource=self.resource._to_rest_object() if self.resource else None, # pylint: disable=protected-access + spark_configuration=self.spark_configuration, + store_type=store_type, + ) + + @classmethod + def _from_rest_object(cls, obj: RestMaterializationSettings) -> Optional["MaterializationSettings"]: + if not obj: + return None + return MaterializationSettings( + schedule=( + RecurrenceTrigger._from_rest_object(obj.schedule) # pylint: disable=protected-access + if obj.schedule + else None + ), + notification=Notification._from_rest_object(obj.notification), # pylint: disable=protected-access + resource=MaterializationComputeResource._from_rest_object(obj.resource), # pylint: disable=protected-access + spark_configuration=obj.spark_configuration, + offline_enabled=obj.store_type + in {MaterializationStoreType.OFFLINE, MaterializationStoreType.ONLINE_AND_OFFLINE}, + online_enabled=obj.store_type + in {MaterializationStoreType.ONLINE, MaterializationStoreType.ONLINE_AND_OFFLINE}, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/materialization_type.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/materialization_type.py new file mode 100644 index 00000000..912d69fc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/materialization_type.py @@ -0,0 +1,14 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from enum import Enum + +from azure.core import CaseInsensitiveEnumMeta + + +class MaterializationType(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """Materialization Type Enum""" + + RECURRENT_MATERIALIZATION = 1 + BACKFILL_MATERIALIZATION = 2 diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/source_metadata.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/source_metadata.py new file mode 100644 index 00000000..1c9e55fe --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/source_metadata.py @@ -0,0 +1,69 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=redefined-builtin,disable=unused-argument + +from typing import Any, Dict, Optional + +from azure.ai.ml.entities._feature_set.source_process_code_metadata import SourceProcessCodeMetadata +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + +from .delay_metadata import DelayMetadata +from .timestamp_column_metadata import TimestampColumnMetadata + + +class SourceMetadata(object): + def __init__( + self, + *, + type: str, + timestamp_column: Optional[TimestampColumnMetadata] = None, + path: Optional[str] = None, + source_delay: Optional[DelayMetadata] = None, + source_process_code: Optional[SourceProcessCodeMetadata] = None, + dict: Optional[Dict] = None, + **kwargs: Any, + ): + if type == "custom": + # For custom feature source + # Required: timestamp_column, dict and source_process_code. + # Not support: path. + if path: + self.throw_exception("path", type, should_provide=False) + if not (timestamp_column and dict and source_process_code): + self.throw_exception("timestamp_column/dict/source_process_code", type, should_provide=True) + elif type == "featureset": + # For featureset feature source + # Required: path. + # Not support: timestamp_column, source_delay and source_process_code. + if timestamp_column or source_delay or source_process_code: + self.throw_exception("timestamp_column/source_delay/source_process_code", type, should_provide=False) + if not path: + self.throw_exception("path", type, should_provide=True) + else: + # For other type feature source + # Required: timestamp_column, path. + # Not support: source_process_code, dict + if dict or source_process_code: + self.throw_exception("dict/source_process_code", type, should_provide=False) + if not (timestamp_column and path): + self.throw_exception("timestamp_column/path", type, should_provide=True) + self.type = type + self.path = path + self.timestamp_column = timestamp_column + self.source_delay = source_delay + self.source_process_code = source_process_code + self.kwargs = dict + + @staticmethod + def throw_exception(property_names: str, type: str, should_provide: bool): + should_or_not = "need to" if should_provide else "cannot" + msg = f"You {should_or_not} provide {property_names} for {type} feature source." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + error_type=ValidationErrorType.INVALID_VALUE, + target=ErrorTarget.FEATURE_SET, + error_category=ErrorCategory.USER_ERROR, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/source_process_code_metadata.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/source_process_code_metadata.py new file mode 100644 index 00000000..415785da --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/source_process_code_metadata.py @@ -0,0 +1,13 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from typing import Any, Optional + + +class SourceProcessCodeMetadata(object): + def __init__(self, *, path: str, process_class: Optional[str] = None, **kwargs: Any): + self.path = path + self.process_class = process_class diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/timestamp_column_metadata.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/timestamp_column_metadata.py new file mode 100644 index 00000000..833088af --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/timestamp_column_metadata.py @@ -0,0 +1,14 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=redefined-builtin,disable=unused-argument + + +from typing import Any, Optional + + +class TimestampColumnMetadata(object): + def __init__(self, *, name: str, format: Optional[str] = None, **kwargs: Any): + self.name = name + self.format = format diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store/_constants.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store/_constants.py new file mode 100644 index 00000000..d6466401 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store/_constants.py @@ -0,0 +1,15 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +OFFLINE_STORE_CONNECTION_NAME = "OfflineStoreConnectionName" +OFFLINE_MATERIALIZATION_STORE_TYPE = "azure_data_lake_gen2" +OFFLINE_STORE_CONNECTION_CATEGORY = "ADLSGen2" +ONLINE_STORE_CONNECTION_NAME = "OnlineStoreConnectionName" +ONLINE_MATERIALIZATION_STORE_TYPE = "redis" +ONLINE_STORE_CONNECTION_CATEGORY = "Redis" +DEFAULT_SPARK_RUNTIME_VERSION = "3.4.0" +STORE_REGEX_PATTERN = ( + "^/?subscriptions/([^/]+)/resourceGroups/([^/]+)/providers/Microsoft.Storage" + "/storageAccounts/([^/]+)/blobServices/default/containers/([^/]+)" +) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store/feature_store.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store/feature_store.py new file mode 100644 index 00000000..0c41f1a3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store/feature_store.py @@ -0,0 +1,226 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + + +from os import PathLike +from pathlib import Path +from typing import Any, Dict, Optional, Union + +from azure.ai.ml._restclient.v2024_10_01_preview.models import Workspace as RestWorkspace +from azure.ai.ml._schema._feature_store.feature_store_schema import FeatureStoreSchema +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY, WorkspaceKind +from azure.ai.ml.entities._credentials import IdentityConfiguration, ManagedIdentityConfiguration +from azure.ai.ml.entities._util import load_from_dict +from azure.ai.ml.entities._workspace.compute_runtime import ComputeRuntime +from azure.ai.ml.entities._workspace.customer_managed_key import CustomerManagedKey +from azure.ai.ml.entities._workspace.feature_store_settings import FeatureStoreSettings +from azure.ai.ml.entities._workspace.networking import ManagedNetwork +from azure.ai.ml.entities._workspace.workspace import Workspace + +from ._constants import DEFAULT_SPARK_RUNTIME_VERSION +from .materialization_store import MaterializationStore + + +class FeatureStore(Workspace): + """Feature Store + + :param name: The name of the feature store. + :type name: str + :param compute_runtime: The compute runtime of the feature store. Defaults to None. + :type compute_runtime: Optional[~azure.ai.ml.entities.ComputeRuntime] + :param offline_store: The offline store for feature store. + materialization_identity is required when offline_store is passed. Defaults to None. + :type offline_store: Optional[~azure.ai.ml.entities.MaterializationStore] + :param online_store: The online store for feature store. + materialization_identity is required when online_store is passed. Defaults to None. + :type online_store: Optional[~azure.ai.ml.entities.MaterializationStore] + :param materialization_identity: The identity used for materialization. Defaults to None. + :type materialization_identity: Optional[~azure.ai.ml.entities.ManagedIdentityConfiguration] + :param description: The description of the feature store. Defaults to None. + :type description: Optional[str] + :param tags: Tags of the feature store. + :type tags: dict + :param display_name: The display name for the feature store. This is non-unique within the resource group. + Defaults to None. + :type display_name: Optional[str] + :param location: The location to create the feature store in. + If not specified, the same location as the resource group will be used. Defaults to None. + :type location: Optional[str] + :param resource_group: The name of the resource group to create the feature store in. Defaults to None. + :type resource_group: Optional[str] + :param hbi_workspace: Boolean for whether the customer data is of high business impact (HBI), + containing sensitive business information. Defaults to False. + For more information, see + https://learn.microsoft.com/azure/machine-learning/concept-data-encryption#encryption-at-rest. + :type hbi_workspace: Optional[bool] + :param storage_account: The resource ID of an existing storage account to use instead of creating a new one. + Defaults to None. + :type storage_account: Optional[str] + :param container_registry: The resource ID of an existing container registry + to use instead of creating a new one. Defaults to None. + :type container_registry: Optional[str] + :param key_vault: The resource ID of an existing key vault to use instead of creating a new one. Defaults to None. + :type key_vault: Optional[str] + :param application_insights: The resource ID of an existing application insights + to use instead of creating a new one. Defaults to None. + :type application_insights: Optional[str] + :param customer_managed_key: The key vault details for encrypting data with customer-managed keys. + If not specified, Microsoft-managed keys will be used by default. Defaults to None. + :type customer_managed_key: Optional[CustomerManagedKey] + :param image_build_compute: The name of the compute target to use for building environment + Docker images with the container registry is behind a VNet. Defaults to None. + :type image_build_compute: Optional[str] + :param public_network_access: Whether to allow public endpoint connectivity + when a workspace is private link enabled. Defaults to None. + :type public_network_access: Optional[str] + :param identity: The workspace's Managed Identity (user assigned, or system assigned). Defaults to None. + :type identity: Optional[IdentityConfiguration] + :param primary_user_assigned_identity: The workspace's primary user assigned identity. Defaults to None. + :type primary_user_assigned_identity: Optional[str] + :param managed_network: The workspace's Managed Network configuration. Defaults to None. + :type managed_network: Optional[ManagedNetwork] + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_featurestore.py + :start-after: [START create_feature_store] + :end-before: [END create_feature_store] + :language: Python + :dedent: 8 + :caption: Instantiating a Feature Store object + """ + + def __init__( + self, + *, + name: str, + compute_runtime: Optional[ComputeRuntime] = None, + offline_store: Optional[MaterializationStore] = None, + online_store: Optional[MaterializationStore] = None, + materialization_identity: Optional[ManagedIdentityConfiguration] = None, + description: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + display_name: Optional[str] = None, + location: Optional[str] = None, + resource_group: Optional[str] = None, + hbi_workspace: bool = False, + storage_account: Optional[str] = None, + container_registry: Optional[str] = None, + key_vault: Optional[str] = None, + application_insights: Optional[str] = None, + customer_managed_key: Optional[CustomerManagedKey] = None, + image_build_compute: Optional[str] = None, + public_network_access: Optional[str] = None, + identity: Optional[IdentityConfiguration] = None, + primary_user_assigned_identity: Optional[str] = None, + managed_network: Optional[ManagedNetwork] = None, + **kwargs: Any, + ) -> None: + feature_store_settings = kwargs.pop( + "feature_store_settings", + FeatureStoreSettings( + compute_runtime=( + compute_runtime + if compute_runtime + else ComputeRuntime(spark_runtime_version=DEFAULT_SPARK_RUNTIME_VERSION) + ), + ), + ) + # TODO: Refactor this so that super().__init__() is not called twice coming from _from_rest_object() + super().__init__( + name=name, + description=description, + tags=tags, + kind=WorkspaceKind.FEATURE_STORE, + display_name=display_name, + location=location, + resource_group=resource_group, + hbi_workspace=hbi_workspace, + storage_account=storage_account, + container_registry=container_registry, + key_vault=key_vault, + application_insights=application_insights, + customer_managed_key=customer_managed_key, + image_build_compute=image_build_compute, + public_network_access=public_network_access, + managed_network=managed_network, + identity=identity, + primary_user_assigned_identity=primary_user_assigned_identity, + feature_store_settings=feature_store_settings, + **kwargs, + ) + self.offline_store = offline_store + self.online_store = online_store + self.materialization_identity = materialization_identity + self.identity = identity + self.public_network_access = public_network_access + self.managed_network = managed_network + # here, compute_runtime is used instead of feature_store_settings because + # it uses default spark version if no compute_runtime is specified during update + self.compute_runtime = compute_runtime + + @classmethod + def _from_rest_object( + cls, rest_obj: RestWorkspace, v2_service_context: Optional[object] = None + ) -> Optional["FeatureStore"]: + if not rest_obj: + return None + + workspace_object = Workspace._from_rest_object(rest_obj, v2_service_context) + if workspace_object is not None: + return FeatureStore( + name=str(workspace_object.name), + id=workspace_object.id, + description=workspace_object.description, + tags=workspace_object.tags, + compute_runtime=ComputeRuntime._from_rest_object( + workspace_object._feature_store_settings.compute_runtime + if workspace_object._feature_store_settings + else None + ), + display_name=workspace_object.display_name, + discovery_url=workspace_object.discovery_url, + location=workspace_object.location, + resource_group=workspace_object.resource_group, + hbi_workspace=workspace_object.hbi_workspace, + storage_account=workspace_object.storage_account, + container_registry=workspace_object.container_registry, + key_vault=workspace_object.key_vault, + application_insights=workspace_object.application_insights, + customer_managed_key=workspace_object.customer_managed_key, + image_build_compute=workspace_object.image_build_compute, + public_network_access=workspace_object.public_network_access, + identity=workspace_object.identity, + primary_user_assigned_identity=workspace_object.primary_user_assigned_identity, + managed_network=workspace_object.managed_network, + workspace_id=rest_obj.workspace_id, + feature_store_settings=workspace_object._feature_store_settings, + ) + + return None + + @classmethod + def _load( + cls, + data: Optional[Dict] = None, + yaml_path: Optional[Union[PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "FeatureStore": + data = data or {} + params_override = params_override or [] + context = { + BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"), + PARAMS_OVERRIDE_KEY: params_override, + } + loaded_schema = load_from_dict(FeatureStoreSchema, data, context, **kwargs) + return FeatureStore(**loaded_schema) + + def _to_dict(self) -> Dict: + res: dict = FeatureStoreSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store/materialization_store.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store/materialization_store.py new file mode 100644 index 00000000..c6a7e6a7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store/materialization_store.py @@ -0,0 +1,49 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from azure.ai.ml._utils._arm_id_utils import AzureResourceId + + +class MaterializationStore: + """Materialization Store + + :param type: The type of the materialization store. + :type type: str + :param target: The ARM ID of the materialization store target. + :type target: str + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_featurestore.py + :start-after: [START configure_materialization_store] + :end-before: [END configure_materialization_store] + :language: Python + :dedent: 8 + :caption: Configuring a Materialization Store + """ + + def __init__(self, type: str, target: str) -> None: # pylint: disable=redefined-builtin + self.type = type + _ = AzureResourceId(target) + self.__target = target + + @property + def target(self) -> str: + """Get target value + + :return: returns the ID of the target + :rtype: str + """ + return self.__target + + @target.setter + def target(self, value: str) -> None: + """Set target value + + :param value: the ID of the target + :type value: str + :raises ~azure.ai.ml.exceptions.ValidationException~: Raised if the value is an invalid ARM ID. + """ + _ = AzureResourceId(value) + self.__target = value diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store_entity/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store_entity/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store_entity/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store_entity/data_column.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store_entity/data_column.py new file mode 100644 index 00000000..a4446ad4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store_entity/data_column.py @@ -0,0 +1,80 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=redefined-builtin,disable=unused-argument + +from typing import Any, Dict, Optional, Union + +from azure.ai.ml._restclient.v2023_10_01.models import FeatureDataType, IndexColumn +from azure.ai.ml.entities._mixins import RestTranslatableMixin +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + +from .data_column_type import DataColumnType + +DataColumnTypeMap: Dict[DataColumnType, FeatureDataType] = { + DataColumnType.STRING: FeatureDataType.STRING, + DataColumnType.INTEGER: FeatureDataType.INTEGER, + DataColumnType.LONG: FeatureDataType.LONG, + DataColumnType.FLOAT: FeatureDataType.FLOAT, + DataColumnType.DOUBLE: FeatureDataType.DOUBLE, + DataColumnType.BINARY: FeatureDataType.BINARY, + DataColumnType.DATETIME: FeatureDataType.DATETIME, + DataColumnType.BOOLEAN: FeatureDataType.BOOLEAN, +} + +FeatureDataTypeMap: Dict[str, DataColumnType] = { + "String": DataColumnType.STRING, + "Integer": DataColumnType.INTEGER, + "Long": DataColumnType.LONG, + "Float": DataColumnType.FLOAT, + "Double": DataColumnType.DOUBLE, + "Binary": DataColumnType.BINARY, + "Datetime": DataColumnType.DATETIME, + "Boolean": DataColumnType.BOOLEAN, +} + + +class DataColumn(RestTranslatableMixin): + """A dataframe column + + :param name: The column name + :type name: str + :param type: The column data type. Defaults to None. + :type type: Optional[union[str, ~azure.ai.ml.entities.DataColumnType]] + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + :raises ValidationException: Raised if type is specified and is not a valid DataColumnType or str. + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_featurestore.py + :start-after: [START configure_feature_store_entity] + :end-before: [END configure_feature_store_entity] + :language: Python + :dedent: 8 + :caption: Using DataColumn when creating an index column for a feature store entity + """ + + def __init__(self, *, name: str, type: Optional[Union[str, DataColumnType]] = None, **kwargs: Any): + if isinstance(type, str): + type = DataColumnType[type] + elif not isinstance(type, DataColumnType): + msg = f"Type should be DataColumnType enum string or enum type, found {type}" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + error_type=ValidationErrorType.INVALID_VALUE, + target=ErrorTarget.DATA, + error_category=ErrorCategory.USER_ERROR, + ) + + self.name = name + self.type = type + + def _to_rest_object(self) -> IndexColumn: + return IndexColumn(column_name=self.name, data_type=DataColumnTypeMap.get(self.type, None)) + + @classmethod + def _from_rest_object(cls, obj: IndexColumn) -> "DataColumn": + return DataColumn(name=obj.column_name, type=FeatureDataTypeMap.get(obj.data_type, None)) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store_entity/data_column_type.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store_entity/data_column_type.py new file mode 100644 index 00000000..0bdfa002 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store_entity/data_column_type.py @@ -0,0 +1,34 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from enum import Enum +from typing import Any + +from azure.core import CaseInsensitiveEnumMeta + + +class DataColumnType(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """Dataframe Column Type Enum + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_featurestore.py + :start-after: [START configure_feature_store_entity] + :end-before: [END configure_feature_store_entity] + :language: Python + :dedent: 8 + :caption: Using DataColumnType when instantiating a DataColumn + """ + + STRING = "string" + INTEGER = "integer" + LONG = "long" + FLOAT = "float" + DOUBLE = "double" + BINARY = "binary" + DATETIME = "datetime" + BOOLEAN = "boolean" + + def __str__(self) -> Any: + return self.value diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store_entity/feature_store_entity.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store_entity/feature_store_entity.py new file mode 100644 index 00000000..6a04bc13 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store_entity/feature_store_entity.py @@ -0,0 +1,146 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from os import PathLike +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +from azure.ai.ml._restclient.v2023_10_01.models import ( + FeaturestoreEntityContainer, + FeaturestoreEntityContainerProperties, + FeaturestoreEntityVersion, + FeaturestoreEntityVersionProperties, +) +from azure.ai.ml._schema._feature_store_entity.feature_store_entity_schema import FeatureStoreEntitySchema +from azure.ai.ml._utils._arm_id_utils import get_arm_id_object_from_id +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY +from azure.ai.ml.entities._assets.asset import Asset +from azure.ai.ml.entities._util import load_from_dict +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + +from .data_column import DataColumn + + +class FeatureStoreEntity(Asset): + """Feature Store Entity + + :param name: The name of the feature store entity resource. + :type name: str + :param version: The version of the feature store entity resource. + :type version: str + :param index_columns: Specifies index columns of the feature-store entity resource. + :type index_columns: list[~azure.ai.ml.entities.DataColumn] + :param stage: The feature store entity stage. Allowed values: Development, Production, Archived. + Defaults to "Development". + :type stage: Optional[str] + :param description: The description of the feature store entity resource. Defaults to None. + :type description: Optional[str] + :param tags: Tag dictionary. Tags can be added, removed, and updated. Defaults to None. + :type tags: Optional[dict[str, str]] + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + :raises ValidationException: Raised if stage is specified and is not valid. + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_featurestore.py + :start-after: [START configure_feature_store_entity] + :end-before: [END configure_feature_store_entity] + :language: Python + :dedent: 8 + :caption: Configuring a Feature Store Entity + """ + + def __init__( + self, + *, + name: str, + version: str, + index_columns: List[DataColumn], + stage: Optional[str] = "Development", + description: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> None: + super().__init__( + name=name, + version=version, + description=description, + tags=tags, + **kwargs, + ) + if stage and stage not in ["Development", "Production", "Archived"]: + msg = f"Stage must be Development, Production, or Archived, found {stage}" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + error_type=ValidationErrorType.INVALID_VALUE, + target=ErrorTarget.FEATURE_STORE_ENTITY, + error_category=ErrorCategory.USER_ERROR, + ) + self.index_columns = index_columns + self.version = version + self.latest_version = None + self.stage = stage + + def _to_rest_object(self) -> FeaturestoreEntityVersion: + feature_store_entity_version_properties = FeaturestoreEntityVersionProperties( + description=self.description, + index_columns=[column._to_rest_object() for column in self.index_columns], + tags=self.tags, + properties=self.properties, + stage=self.stage, + ) + return FeaturestoreEntityVersion(properties=feature_store_entity_version_properties) + + @classmethod + def _from_rest_object(cls, rest_obj: FeaturestoreEntityVersion) -> "FeatureStoreEntity": + rest_object_details: FeaturestoreEntityVersionProperties = rest_obj.properties + arm_id_object = get_arm_id_object_from_id(rest_obj.id) + featurestoreEntity = FeatureStoreEntity( + name=arm_id_object.asset_name, + version=arm_id_object.asset_version, + index_columns=[DataColumn._from_rest_object(column) for column in rest_object_details.index_columns], + stage=rest_object_details.stage, + description=rest_object_details.description, + tags=rest_object_details.tags, + ) + return featurestoreEntity + + @classmethod + def _from_container_rest_object(cls, rest_obj: FeaturestoreEntityContainer) -> "FeatureStoreEntity": + rest_object_details: FeaturestoreEntityContainerProperties = rest_obj.properties + arm_id_object = get_arm_id_object_from_id(rest_obj.id) + featurestoreEntity = FeatureStoreEntity( + name=arm_id_object.asset_name, + description=rest_object_details.description, + tags=rest_object_details.tags, + index_columns=[], + version="", + ) + featurestoreEntity.latest_version = rest_object_details.latest_version + return featurestoreEntity + + @classmethod + def _load( + cls, + data: Optional[Dict] = None, + yaml_path: Optional[Union[PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "FeatureStoreEntity": + data = data or {} + params_override = params_override or [] + context = { + BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"), + PARAMS_OVERRIDE_KEY: params_override, + } + loaded_schema = load_from_dict(FeatureStoreEntitySchema, data, context, **kwargs) + return FeatureStoreEntity(**loaded_schema) + + def _to_dict(self) -> Dict: + res: dict = FeatureStoreEntitySchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/__init__.py new file mode 100644 index 00000000..43f615c3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/__init__.py @@ -0,0 +1,16 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""AzureML Retrieval Augmented Generation (RAG) utilities.""" + +from .input._ai_search_config import AzureAISearchConfig +from .input._index_data_source import IndexDataSource, GitSource, LocalSource +from .model_config import ModelConfiguration + +__all__ = [ + "ModelConfiguration", + "AzureAISearchConfig", + "IndexDataSource", + "GitSource", + "LocalSource", +] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/data_index_func.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/data_index_func.py new file mode 100644 index 00000000..884faf82 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/data_index_func.py @@ -0,0 +1,748 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint: disable=protected-access +# pylint: disable=no-member + +import json +import re +from typing import Any, Dict, Optional, Union + +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.constants._common import AssetTypes, LegacyAssetTypes +from azure.ai.ml.entities import PipelineJob +from azure.ai.ml.entities._builders.base_node import pipeline_node_decorator +from azure.ai.ml.entities._credentials import ManagedIdentityConfiguration, UserIdentityConfiguration +from azure.ai.ml.entities._inputs_outputs import Input, Output +from azure.ai.ml.entities._job.pipeline._io import PipelineInput +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException +from azure.ai.ml.constants._common import DataIndexTypes +from azure.ai.ml.constants._component import LLMRAGComponentUri +from azure.ai.ml.entities._indexes.entities.data_index import DataIndex + +SUPPORTED_INPUTS = [ + LegacyAssetTypes.PATH, + AssetTypes.URI_FILE, + AssetTypes.URI_FOLDER, + AssetTypes.MLTABLE, +] + + +def _build_data_index(io_dict: Union[Dict, DataIndex]): + if io_dict is None: + return io_dict + if isinstance(io_dict, DataIndex): + component_io = io_dict + else: + if isinstance(io_dict, dict): + component_io = DataIndex(**io_dict) + else: + msg = "data_index only support dict and DataIndex" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.DATA, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + return component_io + + +@experimental +@pipeline_node_decorator +def index_data( + *, + data_index: DataIndex, + description: Optional[str] = None, + tags: Optional[Dict] = None, + name: Optional[str] = None, + display_name: Optional[str] = None, + experiment_name: Optional[str] = None, + compute: Optional[str] = None, + serverless_instance_type: Optional[str] = None, + ml_client: Optional[Any] = None, + identity: Optional[Union[ManagedIdentityConfiguration, UserIdentityConfiguration]] = None, + input_data_override: Optional[Input] = None, + **kwargs, +) -> PipelineJob: + """ + Create a PipelineJob object which can be used inside dsl.pipeline. + + :keyword data_index: The data index configuration. + :type data_index: DataIndex + :keyword description: Description of the job. + :type description: str + :keyword tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: dict[str, str] + :keyword name: Name of the job. + :type name: str + :keyword display_name: Display name of the job. + :type display_name: str + :keyword experiment_name: Name of the experiment the job will be created under. + :type experiment_name: str + :keyword compute: The compute resource the job runs on. + :type compute: str + :keyword serverless_instance_type: The instance type to use for serverless compute. + :type serverless_instance_type: Optional[str] + :keyword ml_client: The ml client to use for the job. + :type ml_client: Any + :keyword identity: Identity configuration for the job. + :type identity: Optional[Union[ManagedIdentityConfiguration, UserIdentityConfiguration]] + :keyword input_data_override: Input data override for the job. + Used to pipe output of step into DataIndex Job in a pipeline. + :type input_data_override: Optional[Input] + :return: A PipelineJob object. + :rtype: ~azure.ai.ml.entities.PipelineJob. + """ + data_index = _build_data_index(data_index) + + if data_index.index.type == DataIndexTypes.FAISS: + configured_component = data_index_faiss( + ml_client, + data_index, + description, + tags, + name, + display_name, + experiment_name, + compute, + serverless_instance_type, + identity, + input_data_override, + ) + elif data_index.index.type in (DataIndexTypes.ACS, DataIndexTypes.PINECONE): + if kwargs.get("incremental_update", False): + configured_component = data_index_incremental_update_hosted( + ml_client, + data_index, + description, + tags, + name, + display_name, + experiment_name, + compute, + serverless_instance_type, + identity, + input_data_override, + ) + else: + configured_component = data_index_hosted( + ml_client, + data_index, + description, + tags, + name, + display_name, + experiment_name, + compute, + serverless_instance_type, + identity, + input_data_override, + ) + else: + raise ValueError(f"Unsupported index type: {data_index.index.type}") + + configured_component.properties["azureml.mlIndexAssetName"] = data_index.name + configured_component.properties["azureml.mlIndexAssetKind"] = data_index.index.type + configured_component.properties["azureml.mlIndexAssetSource"] = "Data Asset" + + return configured_component + + +# pylint: disable=too-many-statements +def data_index_incremental_update_hosted( + ml_client: Any, + data_index: DataIndex, + description: Optional[str] = None, + tags: Optional[Dict] = None, + name: Optional[str] = None, + display_name: Optional[str] = None, + experiment_name: Optional[str] = None, + compute: Optional[str] = None, + serverless_instance_type: Optional[str] = None, + identity: Optional[Union[ManagedIdentityConfiguration, UserIdentityConfiguration]] = None, + input_data_override: Optional[Input] = None, +): + from azure.ai.ml.entities._indexes.utils import build_model_protocol, pipeline + + crack_and_chunk_and_embed_component = get_component_obj( + ml_client, LLMRAGComponentUri.LLM_RAG_CRACK_AND_CHUNK_AND_EMBED + ) + + if data_index.index.type == DataIndexTypes.ACS: + update_index_component = get_component_obj(ml_client, LLMRAGComponentUri.LLM_RAG_UPDATE_ACS_INDEX) + elif data_index.index.type == DataIndexTypes.PINECONE: + update_index_component = get_component_obj(ml_client, LLMRAGComponentUri.LLM_RAG_UPDATE_PINECONE_INDEX) + else: + raise ValueError(f"Unsupported hosted index type: {data_index.index.type}") + + register_mlindex_asset_component = get_component_obj(ml_client, LLMRAGComponentUri.LLM_RAG_REGISTER_MLINDEX_ASSET) + + @pipeline( # type: ignore [call-overload] + name=name if name else f"data_index_incremental_update_{data_index.index.type}", + description=description, + tags=tags, + display_name=( + display_name if display_name else f"LLM - Data to {data_index.index.type.upper()} (Incremental Update)" + ), + experiment_name=experiment_name, + compute=compute, + get_component=True, + ) + def data_index_pipeline( + input_data: Input, + embeddings_model: str, + index_config: str, + index_connection_id: str, + chunk_size: int = 768, + chunk_overlap: int = 0, + input_glob: str = "**/*", + citation_url: Optional[str] = None, + citation_replacement_regex: Optional[str] = None, + aoai_connection_id: Optional[str] = None, + embeddings_container: Optional[Input] = None, + ): + """ + Generate embeddings for a `input_data` source and + push them into a hosted index (such as Azure Cognitive Search and Pinecone). + + :param input_data: The input data to be indexed. + :type input_data: Input + :param embeddings_model: The embedding model to use when processing source data chunks. + :type embeddings_model: str + :param index_config: The configuration for the hosted index. + :type index_config: str + :param index_connection_id: The connection ID for the hosted index. + :type index_connection_id: str + :param chunk_size: The size of the chunks to break the input data into. + :type chunk_size: int + :param chunk_overlap: The number of tokens to overlap between chunks. + :type chunk_overlap: int + :param input_glob: The glob pattern to use when searching for input data. + :type input_glob: str + :param citation_url: The URL to use when generating citations for the input data. + :type citation_url: str + :param citation_replacement_regex: The regex to use when generating citations for the input data. + :type citation_replacement_regex: str + :param aoai_connection_id: The connection ID for the Azure Open AI service. + :type aoai_connection_id: str + :param embeddings_container: The container to use when caching embeddings. + :type embeddings_container: Input + :return: The URI of the generated Azure Cognitive Search index. + :rtype: str. + """ + crack_and_chunk_and_embed = crack_and_chunk_and_embed_component( + input_data=input_data, + input_glob=input_glob, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + citation_url=citation_url, + citation_replacement_regex=citation_replacement_regex, + embeddings_container=embeddings_container, + embeddings_model=embeddings_model, + embeddings_connection_id=aoai_connection_id, + ) + if compute is None or compute == "serverless": + use_automatic_compute(crack_and_chunk_and_embed, instance_type=serverless_instance_type) + if optional_pipeline_input_provided(embeddings_container): # type: ignore [arg-type] + crack_and_chunk_and_embed.outputs.embeddings = Output( + type="uri_folder", path=f"{embeddings_container.path}/{{name}}" # type: ignore [union-attr] + ) + if identity: + crack_and_chunk_and_embed.identity = identity + + if data_index.index.type == DataIndexTypes.ACS: + update_index = update_index_component( + embeddings=crack_and_chunk_and_embed.outputs.embeddings, acs_config=index_config + ) + update_index.environment_variables["AZUREML_WORKSPACE_CONNECTION_ID_ACS"] = index_connection_id + elif data_index.index.type == DataIndexTypes.PINECONE: + update_index = update_index_component( + embeddings=crack_and_chunk_and_embed.outputs.embeddings, pinecone_config=index_config + ) + update_index.environment_variables["AZUREML_WORKSPACE_CONNECTION_ID_PINECONE"] = index_connection_id + else: + raise ValueError(f"Unsupported hosted index type: {data_index.index.type}") + if compute is None or compute == "serverless": + use_automatic_compute(update_index, instance_type=serverless_instance_type) + if identity: + update_index.identity = identity + + register_mlindex_asset = register_mlindex_asset_component( + storage_uri=update_index.outputs.index, + asset_name=data_index.name, + ) + if compute is None or compute == "serverless": + use_automatic_compute(register_mlindex_asset, instance_type=serverless_instance_type) + if identity: + register_mlindex_asset.identity = identity + return { + "mlindex_asset_uri": update_index.outputs.index, + "mlindex_asset_id": register_mlindex_asset.outputs.asset_id, + } + + if input_data_override is not None: + input_data = input_data_override + else: + input_data = Input( + type=data_index.source.input_data.type, path=data_index.source.input_data.path # type: ignore [arg-type] + ) + + index_config = { + "index_name": data_index.index.name if data_index.index.name is not None else data_index.name, + "full_sync": True, + } + if data_index.index.config is not None: + index_config.update(data_index.index.config) + + component = data_index_pipeline( + input_data=input_data, + input_glob=data_index.source.input_glob, # type: ignore [arg-type] + chunk_size=data_index.source.chunk_size, # type: ignore [arg-type] + chunk_overlap=data_index.source.chunk_overlap, # type: ignore [arg-type] + citation_url=data_index.source.citation_url, + citation_replacement_regex=( + json.dumps(data_index.source.citation_url_replacement_regex._to_dict()) + if data_index.source.citation_url_replacement_regex + else None + ), + embeddings_model=build_model_protocol(data_index.embedding.model), + aoai_connection_id=_resolve_connection_id(ml_client, data_index.embedding.connection), + embeddings_container=( + Input(type=AssetTypes.URI_FOLDER, path=data_index.embedding.cache_path) + if data_index.embedding.cache_path + else None + ), + index_config=json.dumps(index_config), + index_connection_id=_resolve_connection_id(ml_client, data_index.index.connection), # type: ignore [arg-type] + ) + # Hack until full Component classes are implemented that can annotate the optional parameters properly + component.inputs["input_glob"]._meta.optional = True + component.inputs["chunk_size"]._meta.optional = True + component.inputs["chunk_overlap"]._meta.optional = True + component.inputs["citation_url"]._meta.optional = True + component.inputs["citation_replacement_regex"]._meta.optional = True + component.inputs["aoai_connection_id"]._meta.optional = True + component.inputs["embeddings_container"]._meta.optional = True + + if data_index.path: + component.outputs.mlindex_asset_uri = Output( # type: ignore [attr-defined] + type=AssetTypes.URI_FOLDER, path=data_index.path # type: ignore [arg-type] + ) + + return component + + +def data_index_faiss( + ml_client: Any, + data_index: DataIndex, + description: Optional[str] = None, + tags: Optional[Dict] = None, + name: Optional[str] = None, + display_name: Optional[str] = None, + experiment_name: Optional[str] = None, + compute: Optional[str] = None, + serverless_instance_type: Optional[str] = None, + identity: Optional[Union[ManagedIdentityConfiguration, UserIdentityConfiguration]] = None, + input_data_override: Optional[Input] = None, +): + from azure.ai.ml.entities._indexes.utils import build_model_protocol, pipeline + + crack_and_chunk_component = get_component_obj(ml_client, LLMRAGComponentUri.LLM_RAG_CRACK_AND_CHUNK) + generate_embeddings_component = get_component_obj(ml_client, LLMRAGComponentUri.LLM_RAG_GENERATE_EMBEDDINGS) + create_faiss_index_component = get_component_obj(ml_client, LLMRAGComponentUri.LLM_RAG_CREATE_FAISS_INDEX) + register_mlindex_asset_component = get_component_obj(ml_client, LLMRAGComponentUri.LLM_RAG_REGISTER_MLINDEX_ASSET) + + @pipeline( # type: ignore [call-overload] + name=name if name else "data_index_faiss", + description=description, + tags=tags, + display_name=display_name if display_name else "LLM - Data to Faiss", + experiment_name=experiment_name, + compute=compute, + get_component=True, + ) + def data_index_faiss_pipeline( + input_data: Input, + embeddings_model: str, + chunk_size: int = 1024, + data_source_glob: str = None, # type: ignore [assignment] + data_source_url: str = None, # type: ignore [assignment] + document_path_replacement_regex: str = None, # type: ignore [assignment] + aoai_connection_id: str = None, # type: ignore [assignment] + embeddings_container: Input = None, # type: ignore [assignment] + ): + """ + Generate embeddings for a `input_data` source and create a Faiss index from them. + + :param input_data: The input data to be indexed. + :type input_data: Input + :param embeddings_model: The embedding model to use when processing source data chunks. + :type embeddings_model: str + :param chunk_size: The size of the chunks to break the input data into. + :type chunk_size: int + :param data_source_glob: The glob pattern to use when searching for input data. + :type data_source_glob: str + :param data_source_url: The URL to use when generating citations for the input data. + :type data_source_url: str + :param document_path_replacement_regex: The regex to use when generating citations for the input data. + :type document_path_replacement_regex: str + :param aoai_connection_id: The connection ID for the Azure Open AI service. + :type aoai_connection_id: str + :param embeddings_container: The container to use when caching embeddings. + :type embeddings_container: Input + :return: The URI of the generated Faiss index. + :rtype: str. + """ + crack_and_chunk = crack_and_chunk_component( + input_data=input_data, + input_glob=data_source_glob, + chunk_size=chunk_size, + data_source_url=data_source_url, + document_path_replacement_regex=document_path_replacement_regex, + ) + if compute is None or compute == "serverless": + use_automatic_compute(crack_and_chunk, instance_type=serverless_instance_type) + if identity: + crack_and_chunk.identity = identity + + generate_embeddings = generate_embeddings_component( + chunks_source=crack_and_chunk.outputs.output_chunks, + embeddings_container=embeddings_container, + embeddings_model=embeddings_model, + ) + if compute is None or compute == "serverless": + use_automatic_compute(generate_embeddings, instance_type=serverless_instance_type) + if optional_pipeline_input_provided(aoai_connection_id): # type: ignore [arg-type] + generate_embeddings.environment_variables["AZUREML_WORKSPACE_CONNECTION_ID_AOAI"] = aoai_connection_id + if optional_pipeline_input_provided(embeddings_container): # type: ignore [arg-type] + generate_embeddings.outputs.embeddings = Output( + type="uri_folder", path=f"{embeddings_container.path}/{{name}}" + ) + if identity: + generate_embeddings.identity = identity + + create_faiss_index = create_faiss_index_component(embeddings=generate_embeddings.outputs.embeddings) + if compute is None or compute == "serverless": + use_automatic_compute(create_faiss_index, instance_type=serverless_instance_type) + if identity: + create_faiss_index.identity = identity + + register_mlindex_asset = register_mlindex_asset_component( + storage_uri=create_faiss_index.outputs.index, + asset_name=data_index.name, + ) + if compute is None or compute == "serverless": + use_automatic_compute(register_mlindex_asset, instance_type=serverless_instance_type) + if identity: + register_mlindex_asset.identity = identity + return { + "mlindex_asset_uri": create_faiss_index.outputs.index, + "mlindex_asset_id": register_mlindex_asset.outputs.asset_id, + } + + if input_data_override is not None: + input_data = input_data_override + else: + input_data = Input( + type=data_index.source.input_data.type, path=data_index.source.input_data.path # type: ignore [arg-type] + ) + + component = data_index_faiss_pipeline( + input_data=input_data, + embeddings_model=build_model_protocol(data_index.embedding.model), + chunk_size=data_index.source.chunk_size, # type: ignore [arg-type] + data_source_glob=data_index.source.input_glob, # type: ignore [arg-type] + data_source_url=data_index.source.citation_url, # type: ignore [arg-type] + document_path_replacement_regex=( + json.dumps(data_index.source.citation_url_replacement_regex._to_dict()) # type: ignore [arg-type] + if data_index.source.citation_url_replacement_regex + else None + ), + aoai_connection_id=_resolve_connection_id( + ml_client, data_index.embedding.connection + ), # type: ignore [arg-type] + embeddings_container=( + Input(type=AssetTypes.URI_FOLDER, path=data_index.embedding.cache_path) # type: ignore [arg-type] + if data_index.embedding.cache_path + else None + ), + ) + # Hack until full Component classes are implemented that can annotate the optional parameters properly + component.inputs["data_source_glob"]._meta.optional = True + component.inputs["data_source_url"]._meta.optional = True + component.inputs["document_path_replacement_regex"]._meta.optional = True + component.inputs["aoai_connection_id"]._meta.optional = True + component.inputs["embeddings_container"]._meta.optional = True + if data_index.path: + component.outputs.mlindex_asset_uri = Output( + type=AssetTypes.URI_FOLDER, path=data_index.path # type: ignore [arg-type] + ) + + return component + + +# pylint: disable=too-many-statements +def data_index_hosted( + ml_client: Any, + data_index: DataIndex, + description: Optional[str] = None, + tags: Optional[Dict] = None, + name: Optional[str] = None, + display_name: Optional[str] = None, + experiment_name: Optional[str] = None, + compute: Optional[str] = None, + serverless_instance_type: Optional[str] = None, + identity: Optional[Union[ManagedIdentityConfiguration, UserIdentityConfiguration]] = None, + input_data_override: Optional[Input] = None, +): + from azure.ai.ml.entities._indexes.utils import build_model_protocol, pipeline + + crack_and_chunk_component = get_component_obj(ml_client, LLMRAGComponentUri.LLM_RAG_CRACK_AND_CHUNK) + generate_embeddings_component = get_component_obj(ml_client, LLMRAGComponentUri.LLM_RAG_GENERATE_EMBEDDINGS) + + if data_index.index.type == DataIndexTypes.ACS: + update_index_component = get_component_obj(ml_client, LLMRAGComponentUri.LLM_RAG_UPDATE_ACS_INDEX) + elif data_index.index.type == DataIndexTypes.PINECONE: + update_index_component = get_component_obj(ml_client, LLMRAGComponentUri.LLM_RAG_UPDATE_PINECONE_INDEX) + else: + raise ValueError(f"Unsupported hosted index type: {data_index.index.type}") + + register_mlindex_asset_component = get_component_obj(ml_client, LLMRAGComponentUri.LLM_RAG_REGISTER_MLINDEX_ASSET) + + @pipeline( # type: ignore [call-overload] + name=name if name else f"data_index_{data_index.index.type}", + description=description, + tags=tags, + display_name=display_name if display_name else f"LLM - Data to {data_index.index.type.upper()}", + experiment_name=experiment_name, + compute=compute, + get_component=True, + ) + def data_index_pipeline( + input_data: Input, + embeddings_model: str, + index_config: str, + index_connection_id: str, + chunk_size: int = 1024, + data_source_glob: str = None, # type: ignore [assignment] + data_source_url: str = None, # type: ignore [assignment] + document_path_replacement_regex: str = None, # type: ignore [assignment] + aoai_connection_id: str = None, # type: ignore [assignment] + embeddings_container: Input = None, # type: ignore [assignment] + ): + """ + Generate embeddings for a `input_data` source + and push them into a hosted index (such as Azure Cognitive Search and Pinecone). + + :param input_data: The input data to be indexed. + :type input_data: Input + :param embeddings_model: The embedding model to use when processing source data chunks. + :type embeddings_model: str + :param index_config: The configuration for the hosted index. + :type index_config: str + :param index_connection_id: The connection ID for the hosted index. + :type index_connection_id: str + :param chunk_size: The size of the chunks to break the input data into. + :type chunk_size: int + :param data_source_glob: The glob pattern to use when searching for input data. + :type data_source_glob: str + :param data_source_url: The URL to use when generating citations for the input data. + :type data_source_url: str + :param document_path_replacement_regex: The regex to use when generating citations for the input data. + :type document_path_replacement_regex: str + :param aoai_connection_id: The connection ID for the Azure Open AI service. + :type aoai_connection_id: str + :param embeddings_container: The container to use when caching embeddings. + :type embeddings_container: Input + :return: The URI of the generated Azure Cognitive Search index. + :rtype: str. + """ + crack_and_chunk = crack_and_chunk_component( + input_data=input_data, + input_glob=data_source_glob, + chunk_size=chunk_size, + data_source_url=data_source_url, + document_path_replacement_regex=document_path_replacement_regex, + ) + if compute is None or compute == "serverless": + use_automatic_compute(crack_and_chunk, instance_type=serverless_instance_type) + if identity: + crack_and_chunk.identity = identity + + generate_embeddings = generate_embeddings_component( + chunks_source=crack_and_chunk.outputs.output_chunks, + embeddings_container=embeddings_container, + embeddings_model=embeddings_model, + ) + if compute is None or compute == "serverless": + use_automatic_compute(generate_embeddings, instance_type=serverless_instance_type) + if optional_pipeline_input_provided(aoai_connection_id): # type: ignore [arg-type] + generate_embeddings.environment_variables["AZUREML_WORKSPACE_CONNECTION_ID_AOAI"] = aoai_connection_id + if optional_pipeline_input_provided(embeddings_container): # type: ignore [arg-type] + generate_embeddings.outputs.embeddings = Output( + type="uri_folder", path=f"{embeddings_container.path}/{{name}}" + ) + if identity: + generate_embeddings.identity = identity + + if data_index.index.type == DataIndexTypes.ACS: + update_index = update_index_component( + embeddings=generate_embeddings.outputs.embeddings, acs_config=index_config + ) + update_index.environment_variables["AZUREML_WORKSPACE_CONNECTION_ID_ACS"] = index_connection_id + elif data_index.index.type == DataIndexTypes.PINECONE: + update_index = update_index_component( + embeddings=generate_embeddings.outputs.embeddings, pinecone_config=index_config + ) + update_index.environment_variables["AZUREML_WORKSPACE_CONNECTION_ID_PINECONE"] = index_connection_id + else: + raise ValueError(f"Unsupported hosted index type: {data_index.index.type}") + if compute is None or compute == "serverless": + use_automatic_compute(update_index, instance_type=serverless_instance_type) + if identity: + update_index.identity = identity + + register_mlindex_asset = register_mlindex_asset_component( + storage_uri=update_index.outputs.index, + asset_name=data_index.name, + ) + if compute is None or compute == "serverless": + use_automatic_compute(register_mlindex_asset, instance_type=serverless_instance_type) + if identity: + register_mlindex_asset.identity = identity + return { + "mlindex_asset_uri": update_index.outputs.index, + "mlindex_asset_id": register_mlindex_asset.outputs.asset_id, + } + + if input_data_override is not None: + input_data = input_data_override + else: + input_data = Input( + type=data_index.source.input_data.type, path=data_index.source.input_data.path # type: ignore [arg-type] + ) + + index_config = { + "index_name": data_index.index.name if data_index.index.name is not None else data_index.name, + } + if data_index.index.config is not None: + index_config.update(data_index.index.config) + + component = data_index_pipeline( + input_data=input_data, + embeddings_model=build_model_protocol(data_index.embedding.model), + index_config=json.dumps(index_config), + index_connection_id=_resolve_connection_id(ml_client, data_index.index.connection), # type: ignore [arg-type] + chunk_size=data_index.source.chunk_size, # type: ignore [arg-type] + data_source_glob=data_index.source.input_glob, # type: ignore [arg-type] + data_source_url=data_index.source.citation_url, # type: ignore [arg-type] + document_path_replacement_regex=( + json.dumps(data_index.source.citation_url_replacement_regex._to_dict()) # type: ignore [arg-type] + if data_index.source.citation_url_replacement_regex + else None + ), + aoai_connection_id=_resolve_connection_id( + ml_client, data_index.embedding.connection # type: ignore [arg-type] + ), + embeddings_container=( + Input(type=AssetTypes.URI_FOLDER, path=data_index.embedding.cache_path) # type: ignore [arg-type] + if data_index.embedding.cache_path + else None + ), + ) + # Hack until full Component classes are implemented that can annotate the optional parameters properly + component.inputs["data_source_glob"]._meta.optional = True + component.inputs["data_source_url"]._meta.optional = True + component.inputs["document_path_replacement_regex"]._meta.optional = True + component.inputs["aoai_connection_id"]._meta.optional = True + component.inputs["embeddings_container"]._meta.optional = True + + if data_index.path: + component.outputs.mlindex_asset_uri = Output( + type=AssetTypes.URI_FOLDER, path=data_index.path # type: ignore [arg-type] + ) + + return component + + +def optional_pipeline_input_provided(input: Optional[PipelineInput]): + """ + Checks if optional pipeline inputs are provided. + + :param input: The pipeline input to check. + :type input: Optional[PipelineInput] + :return: True if the input is not None and has a value, False otherwise. + :rtype: bool. + """ + return input is not None and input._data is not None + + +def use_automatic_compute(component, instance_count=1, instance_type=None): + """ + Configure input `component` to use automatic compute with `instance_count` and `instance_type`. + + This avoids the need to provision a compute cluster to run the component. + :param component: The component to configure. + :type component: Any + :param instance_count: The number of instances to use. + :type instance_count: int + :param instance_type: The type of instance to use. + :type instance_type: str + :return: The configured component. + :rtype: Any. + """ + component.set_resources( + instance_count=instance_count, + instance_type=instance_type, + properties={"compute_specification": {"automatic": True}}, + ) + return component + + +def get_component_obj(ml_client, component_uri): + from azure.ai.ml import MLClient + + if not isinstance(component_uri, str): + # Assume Component object + return component_uri + + matches = re.match( + r"azureml://registries/(?P<registry_name>.*)/components/(?P<component_name>.*)" + r"/(?P<identifier_type>.*)/(?P<identifier_name>.*)", + component_uri, + ) + if matches is None: + from azure.ai.ml import load_component + + # Assume local path to component + return load_component(source=component_uri) + + registry_name = matches.group("registry_name") + registry_client = MLClient( + subscription_id=ml_client.subscription_id, + resource_group_name=ml_client.resource_group_name, + credential=ml_client._credential, + registry_name=registry_name, + ) + component_obj = registry_client.components.get( + matches.group("component_name"), + **{matches.group("identifier_type").rstrip("s"): matches.group("identifier_name")}, + ) + return component_obj + + +def _resolve_connection_id(ml_client, connection: Optional[str] = None) -> Optional[str]: + if connection is None: + return None + + if isinstance(connection, str): + from azure.ai.ml._utils._arm_id_utils import AMLNamedArmId + + connection_name = AMLNamedArmId(connection).asset_name + + connection = ml_client.connections.get(connection_name) + if connection is None: + return None + return connection.id # type: ignore [attr-defined] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/entities/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/entities/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/entities/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/entities/data_index.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/entities/data_index.py new file mode 100644 index 00000000..094d19aa --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/entities/data_index.py @@ -0,0 +1,243 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""DataIndex entities.""" + +from typing import Dict, Optional + +from azure.ai.ml.constants._common import DataIndexTypes +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.entities._assets import Data +from azure.ai.ml.entities._inputs_outputs.utils import _remove_empty_values +from azure.ai.ml.entities._mixins import DictMixin + + +@experimental +class CitationRegex(DictMixin): + """ + :keyword match_pattern: Regex to match citation in the citation_url + input file path. + e.g. '(.*)/articles/(.*)(\\.[^.]+)$'. + :type match_pattern: str + :keyword replacement_pattern: Replacement string for citation. e.g. '\\1/\\2'. + :type replacement_pattern: str + """ + + def __init__( + self, + *, + match_pattern: str, + replacement_pattern: str, + ): + """Initialize a CitationRegex object.""" + self.match_pattern = match_pattern + self.replacement_pattern = replacement_pattern + + def _to_dict(self) -> Dict: + """Convert the Source object to a dict. + :return: The dictionary representation of the class + :rtype: Dict + """ + keys = [ + "match_pattern", + "replacement_pattern", + ] + result = {key: getattr(self, key) for key in keys} + return _remove_empty_values(result) + + +@experimental +class IndexSource(DictMixin): + """Congifuration for the destination index to write processed data to. + :keyword input_data: Input Data to index files from. MLTable type inputs will use `mode: eval_mount`. + :type input_data: Data + :keyword input_glob: Connection reference to use for embedding model information, + only needed for hosted embeddings models (such as Azure OpenAI). + :type input_glob: str, optional + :keyword chunk_size: Maximum number of tokens to put in each chunk. + :type chunk_size: int, optional + :keyword chunk_overlap: Number of tokens to overlap between chunks. + :type chunk_overlap: int, optional + :keyword citation_url: Base URL to join with file paths to create full source file URL for chunk metadata. + :type citation_url: str, optional + :keyword citation_url_replacement_regex: Regex match and replacement patterns for citation url. Useful if the paths + in `input_data` don't match the desired citation format. + :type citation_url_replacement_regex: CitationRegex, optional + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if the IndexSource object cannot be validated. + Details will be provided in the error message. + """ + + def __init__( + self, + *, + input_data: Data, + input_glob: Optional[str] = None, + chunk_size: Optional[int] = None, + chunk_overlap: Optional[int] = None, + citation_url: Optional[str] = None, + citation_url_replacement_regex: Optional[CitationRegex] = None, + ): + """Initialize a IndexSource object.""" + self.input_data = input_data + self.input_glob = input_glob + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + self.citation_url = citation_url + self.citation_url_replacement_regex = citation_url_replacement_regex + + def _to_dict(self) -> Dict: + """Convert the Source object to a dict. + :return: The dictionary representation of the class + :rtype: Dict + """ + keys = [ + "input_data", + "input_glob", + "chunk_size", + "chunk_overlap", + "citation_url", + "citation_url_replacement_regex", + ] + result = {key: getattr(self, key) for key in keys} + return _remove_empty_values(result) + + +@experimental +class Embedding(DictMixin): + """Congifuration for the destination index to write processed data to. + :keyword model: The model to use to embed data. E.g. 'hugging_face://model/sentence-transformers/all-mpnet-base-v2' + or 'azure_open_ai://deployment/{deployment_name}/model/{model_name}' + :type model: str + :keyword connection: Connection reference to use for embedding model information, + only needed for hosted embeddings models (such as Azure OpenAI). + :type connection: str, optional + :keyword cache_path: Folder containing previously generated embeddings. + Should be parent folder of the 'embeddings' output path used for for this component. + Will compare input data to existing embeddings and only embed changed/new data, reusing existing chunks. + :type cache_path: str, optional + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if the Embedding object cannot be validated. + Details will be provided in the error message. + """ + + def __init__( + self, + *, + model: str, + connection: Optional[str] = None, + cache_path: Optional[str] = None, + ): + """Initialize a Embedding object.""" + self.model = model + self.connection = connection + self.cache_path = cache_path + + def _to_dict(self) -> Dict: + """Convert the Source object to a dict. + :return: The dictionary representation of the class + :rtype: Dict + """ + keys = [ + "model", + "connection", + "cache_path", + ] + result = {key: getattr(self, key) for key in keys} + return _remove_empty_values(result) + + +@experimental +class IndexStore(DictMixin): + """Congifuration for the destination index to write processed data to. + :keyword type: The type of index to write to. Currently supported types are 'acs', 'pinecone', and 'faiss'. + :type type: str + :keyword name: Name of index to update/create, only needed for hosted indexes + (such as Azure Cognitive Search and Pinecone). + :type name: str, optional + :keyword connection: Connection reference to use for index information, + only needed for hosted indexes (such as Azure Cognitive Search and Pinecone). + :type connection: str, optional + :keyword config: Configuration for the index. Configuration for the index. + Primary use is to configure AI Search and Pinecone specific settings. + Such as custom `field_mapping` for known field types. + :type config: dict, optional + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if the IndexStore object cannot be validated. + Details will be provided in the error message. + """ + + def __init__( + self, + *, + type: str = DataIndexTypes.FAISS, + name: Optional[str] = None, + connection: Optional[str] = None, + config: Optional[Dict] = None, + ): + """Initialize a IndexStore object.""" + self.type = type + self.name = name + self.connection = connection + self.config = config + + def _to_dict(self) -> Dict: + """Convert the Source object to a dict. + :return: The dictionary representation of the class + :rtype: Dict + """ + keys = ["type", "name", "connection", "config"] + result = {key: getattr(self, key) for key in keys} + return _remove_empty_values(result) + + +@experimental +class DataIndex(Data): + """Data asset with a creating data index job. + :param name: Name of the asset. + :type name: str + :param path: The path to the asset being created by data index job. + :type path: str + :param source: The source data to be indexed. + :type source: IndexSource + :param embedding: The embedding model to use when processing source data chunks. + :type embedding: Embedding + :param index: The destination index to write processed data to. + :type index: IndexStore + :param version: Version of the asset created by running this DataIndex Job. + :type version: str + :param description: Description of the resource. + :type description: str + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: dict[str, str] + :param properties: The asset property dictionary. + :type properties: dict[str, str] + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + """ + + def __init__( + self, + *, + name: str, + source: IndexSource, + embedding: Embedding, + index: IndexStore, + incremental_update: bool = False, + path: Optional[str] = None, + version: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[Dict] = None, + properties: Optional[Dict] = None, + **kwargs, + ): + """Initialize a DataIndex object.""" + super().__init__( + name=name, + version=version, + description=description, + tags=tags, + properties=properties, + path=path, + **kwargs, + ) + self.source = source + self.embedding = embedding + self.index = index + self.incremental_update = incremental_update diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/_ai_search_config.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/_ai_search_config.py new file mode 100644 index 00000000..b2163c40 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/_ai_search_config.py @@ -0,0 +1,31 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# General todo: need to determine which args are required or optional when parsed out into groups like this. +# General todo: move these to more permanent locations? + +# Defines stuff related to the resulting created index, like the index type. + +from typing import Optional +from azure.ai.ml._utils._experimental import experimental + + +@experimental +class AzureAISearchConfig: + """Config class for creating an Azure AI Search index. + + :param index_name: The name of the Azure AI Search index. + :type index_name: Optional[str] + :param connection_id: The Azure AI Search connection ID. + :type connection_id: Optional[str] + """ + + def __init__( + self, + *, + index_name: Optional[str] = None, + connection_id: Optional[str] = None, + ) -> None: + self.index_name = index_name + self.connection_id = connection_id diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/_index_config.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/_index_config.py new file mode 100644 index 00000000..0eec691a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/_index_config.py @@ -0,0 +1,47 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Optional + + +class IndexConfig: # pylint: disable=too-many-instance-attributes + """Convenience class that contains all config values that for index creation that are + NOT specific to the index source data or the created index type. Meant for internal use only + to simplify function headers. The user-entry point is a function that + should still contain all the fields in this class as individual function parameters. + + Params omitted for brevity and to avoid maintaining duplicate docs. See index creation function + for actual parameter descriptions. + """ + + def __init__( + self, + *, + output_index_name: str, + vector_store: str, + data_source_url: Optional[str] = None, + chunk_size: Optional[int] = None, + chunk_overlap: Optional[int] = None, + input_glob: Optional[str] = None, + max_sample_files: Optional[int] = None, + chunk_prepend_summary: Optional[bool] = None, + document_path_replacement_regex: Optional[str] = None, + embeddings_container: Optional[str] = None, + embeddings_model: str, + aoai_connection_id: str, + _dry_run: bool = False + ): + self.output_index_name = output_index_name + self.vector_store = vector_store + self.data_source_url = data_source_url + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + self.input_glob = input_glob + self.max_sample_files = max_sample_files + self.chunk_prepend_summary = chunk_prepend_summary + self.document_path_replacement_regex = document_path_replacement_regex + self.embeddings_container = embeddings_container + self.embeddings_model = embeddings_model + self.aoai_connection_id = aoai_connection_id + self._dry_run = _dry_run diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/_index_data_source.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/_index_data_source.py new file mode 100644 index 00000000..92b62b6b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/_index_data_source.py @@ -0,0 +1,62 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from typing import Union + +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.entities._inputs_outputs import Input +from azure.ai.ml.constants._common import IndexInputType + + +# General todo: need to determine which args are required or optional when parsed out into groups like this. +# General todo: move these to more permanent locations? + + +# Defines stuff related to supplying inputs for an index AKA the base data. +@experimental +class IndexDataSource: + """Base class for configs that define data that will be processed into an ML index. + This class should not be instantiated directly. Use one of its child classes instead. + + :param input_type: A type enum describing the source of the index. Used to avoid + direct type checking. + :type input_type: Union[str, ~azure.ai.ml.constants._common.IndexInputType] + """ + + def __init__(self, *, input_type: Union[str, IndexInputType]): + self.input_type = input_type + + +# Field bundle for creating an index from files located in a Git repo. +# TODO Does git_url need to specifically be an SSH or HTTPS style link? +# TODO What is git connection id? +@experimental +class GitSource(IndexDataSource): + """Config class for creating an ML index from files located in a git repository. + + :param url: A link to the repository to use. + :type url: str + :param branch_name: The name of the branch to use from the target repository. + :type branch_name: str + :param connection_id: The connection ID for GitHub + :type connection_id: str + """ + + def __init__(self, *, url: str, branch_name: str, connection_id: str): + self.url = url + self.branch_name = branch_name + self.connection_id = connection_id + super().__init__(input_type=IndexInputType.GIT) + + +@experimental +class LocalSource(IndexDataSource): + """Config class for creating an ML index from a collection of local files. + + :param input_data: An input object describing the local location of index source files. + :type input_data: ~azure.ai.ml.Input + """ + + def __init__(self, *, input_data: str): # todo Make sure type of input_data is correct + self.input_data = Input(type="uri_folder", path=input_data) + super().__init__(input_type=IndexInputType.LOCAL) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/model_config.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/model_config.py new file mode 100644 index 00000000..c9e54da4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/model_config.py @@ -0,0 +1,122 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from dataclasses import dataclass +from typing import Any, Dict, Optional +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.entities._workspace.connections.workspace_connection import WorkspaceConnection +from azure.ai.ml.entities._workspace.connections.connection_subtypes import ( + AzureOpenAIConnection, + AadCredentialConfiguration, +) + + +@experimental +@dataclass +class ModelConfiguration: + """Configuration for a embedding model. + + :param api_base: The base URL for the API. + :type api_base: Optional[str] + :param api_key: The API key. + :type api_key: Optional[str] + :param api_version: The API version. + :type api_version: Optional[str] + :param model_name: The name of the model. + :type model_name: Optional[str] + :param model_name: The deployment name of the model. + :type model_name: Optional[str] + :param connection_name: The name of the workspace connection of this model. + :type connection_name: Optional[str] + :param connection_type: The type of the workspace connection of this model. + :type connection_type: Optional[str] + :param model_kwargs: Additional keyword arguments for the model. + :type model_kwargs: Dict[str, Any] + """ + + api_base: Optional[str] + api_key: Optional[str] + api_version: Optional[str] + connection_name: Optional[str] + connection_type: Optional[str] + model_name: Optional[str] + deployment_name: Optional[str] + model_kwargs: Dict[str, Any] + + def __init__( + self, + *, + api_base: Optional[str], + api_key: Optional[str], + api_version: Optional[str], + connection_name: Optional[str], + connection_type: Optional[str], + model_name: Optional[str], + deployment_name: Optional[str], + model_kwargs: Dict[str, Any] + ): + self.api_base = api_base + self.api_key = api_key + self.api_version = api_version + self.connection_name = connection_name + self.connection_type = connection_type + self.model_name = model_name + self.deployment_name = deployment_name + self.model_kwargs = model_kwargs + + @staticmethod + def from_connection( + connection: WorkspaceConnection, + model_name: Optional[str] = None, + deployment_name: Optional[str] = None, + **kwargs + ) -> "ModelConfiguration": + """Create an model configuration from a Connection. + + :param connection: The WorkspaceConnection object. + :type connection: ~azure.ai.ml.entities.WorkspaceConnection + :param model_name: The name of the model. + :type model_name: Optional[str] + :param deployment_name: The name of the deployment. + :type deployment_name: Optional[str] + :return: The model configuration. + :rtype: ~azure.ai.ml.entities._indexes.entities.ModelConfiguration + :raises TypeError: If the connection is not an AzureOpenAIConnection. + :raises ValueError: If the connection does not contain an OpenAI key. + """ + if isinstance(connection, AzureOpenAIConnection) or camel_to_snake(connection.type) == "azure_open_ai": + connection_type = "azure_open_ai" + api_version = connection.api_version # type: ignore[attr-defined] + if not model_name or not deployment_name: + raise ValueError("Please specify model_name and deployment_name.") + elif connection.type and connection.type.lower() == "serverless": + connection_type = "serverless" + api_version = None + if not connection.id: + raise TypeError("The connection id is missing from the serverless connection object.") + else: + raise TypeError("Connection object is not supported.") + + if isinstance(connection.credentials, AadCredentialConfiguration): + key = None + else: + key = connection.credentials.get("key") # type: ignore[union-attr] + if key is None and connection_type == "azure_open_ai": + import os + + if "AZURE_OPENAI_API_KEY" in os.environ: + key = os.getenv("AZURE_OPENAI_API_KEY") + else: + raise ValueError("Unable to retrieve openai key from connection object or env variable.") + + return ModelConfiguration( + api_base=connection.target, + api_key=key, + api_version=api_version, + connection_name=connection.name, + connection_type=connection_type, + model_name=model_name, + deployment_name=deployment_name, + model_kwargs=kwargs, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/__init__.py new file mode 100644 index 00000000..f65f5505 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/__init__.py @@ -0,0 +1,10 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""AzureML Retrieval Augmented Generation (RAG) utilities.""" + +from ._models import build_model_protocol +from ._open_ai_utils import build_open_ai_protocol, build_connection_id +from ._pipeline_decorator import pipeline + +__all__ = ["build_model_protocol", "build_open_ai_protocol", "build_connection_id", "pipeline"] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/_models.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/_models.py new file mode 100644 index 00000000..d3e8c952 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/_models.py @@ -0,0 +1,25 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""DataIndex embedding model helpers.""" +import re +from typing import Optional + +OPEN_AI_PROTOCOL_TEMPLATE = "azure_open_ai://deployment/{}/model/{}" +OPEN_AI_PROTOCOL_REGEX_PATTERN = OPEN_AI_PROTOCOL_TEMPLATE.format(".*", ".*") +OPEN_AI_SHORT_FORM_PROTOCOL_TEMPLATE = "azure_open_ai://deployments?/{}" +OPEN_AI_PROTOCOL_REGEX_PATTERN = OPEN_AI_SHORT_FORM_PROTOCOL_TEMPLATE.format(".*") + +HUGGINGFACE_PROTOCOL_TEMPLATE = "hugging_face://model/{}" +HUGGINGFACE_PROTOCOL_REGEX_PATTERN = HUGGINGFACE_PROTOCOL_TEMPLATE.format(".*") + + +def build_model_protocol(model: Optional[str] = None): + if not model or re.match(OPEN_AI_PROTOCOL_REGEX_PATTERN, model, re.IGNORECASE): + return model + if re.match(OPEN_AI_SHORT_FORM_PROTOCOL_TEMPLATE, model, re.IGNORECASE): + return model + if re.match(HUGGINGFACE_PROTOCOL_REGEX_PATTERN, model, re.IGNORECASE): + return model + + return OPEN_AI_PROTOCOL_TEMPLATE.format(model, model) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/_open_ai_utils.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/_open_ai_utils.py new file mode 100644 index 00000000..d38a447f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/_open_ai_utils.py @@ -0,0 +1,36 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Optional + +from azure.ai.ml._utils._arm_id_utils import is_ARM_id_for_resource +from azure.ai.ml._scope_dependent_operations import OperationScope + +OPEN_AI_PROTOCOL_TEMPLATE = "azure_open_ai://deployment/{}/model/{}" + + +def build_open_ai_protocol( + model: Optional[str] = None, + deployment: Optional[str] = None, +): + if not deployment or not model: + return None + return OPEN_AI_PROTOCOL_TEMPLATE.format(deployment, model) + + +def build_connection_id(id: Optional[str], scope: OperationScope): + if not id or not scope.subscription_id or not scope.resource_group_name or not scope.workspace_name: + return id + + if is_ARM_id_for_resource(id, "connections", True): + return id + + # pylint: disable=line-too-long + template = "/subscriptions/{subscription_id}/resourceGroups/{resource_group_name}/providers/Microsoft.MachineLearningServices/workspaces/{workspace_name}/connections/{id}" + return template.format( + subscription_id=scope.subscription_id, + resource_group_name=scope.resource_group_name, + workspace_name=scope.workspace_name, + id=id, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/_pipeline_decorator.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/_pipeline_decorator.py new file mode 100644 index 00000000..e70f97f2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/_pipeline_decorator.py @@ -0,0 +1,248 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import inspect +import logging +from functools import wraps +from pathlib import Path +from typing import Any, Callable, Dict, Optional, TypeVar, Union, overload + +from typing_extensions import ParamSpec + +from azure.ai.ml.entities import Data, Model, PipelineJob, PipelineJobSettings +from azure.ai.ml.entities._builders.pipeline import Pipeline +from azure.ai.ml.entities._inputs_outputs import Input +from azure.ai.ml.entities._job.pipeline._io import NodeOutput, PipelineInput, _GroupAttrDict +from azure.ai.ml.entities._job.pipeline._pipeline_expression import PipelineExpression +from azure.ai.ml.exceptions import UserErrorException + +from azure.ai.ml.dsl._pipeline_component_builder import PipelineComponentBuilder, _is_inside_dsl_pipeline_func +from azure.ai.ml.dsl._pipeline_decorator import _validate_args +from azure.ai.ml.dsl._settings import _dsl_settings_stack +from azure.ai.ml.dsl._utils import _resolve_source_file + +SUPPORTED_INPUT_TYPES = ( + PipelineInput, + NodeOutput, + Input, + Model, + Data, # For the case use a Data object as an input, we will convert it to Input object + Pipeline, # For the case use a pipeline node as the input, we use its only one output as the real input. + str, + bool, + int, + float, + PipelineExpression, + _GroupAttrDict, +) +module_logger = logging.getLogger(__name__) + +T = TypeVar("T") +P = ParamSpec("P") + + +# Overload the returns a decorator when func is None +@overload +def pipeline( + func: None, + *, + name: Optional[str] = None, + version: Optional[str] = None, + display_name: Optional[str] = None, + description: Optional[str] = None, + experiment_name: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + **kwargs: Any, +) -> Callable[[Callable[P, T]], Callable[P, PipelineJob]]: ... + + +# Overload the returns a decorated function when func isn't None +@overload +def pipeline( + func: Callable[P, T], + *, + name: Optional[str] = None, + version: Optional[str] = None, + display_name: Optional[str] = None, + description: Optional[str] = None, + experiment_name: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + **kwargs: Any, +) -> Callable[P, PipelineJob]: ... + + +def pipeline( + func: Optional[Callable[P, T]] = None, + *, + name: Optional[str] = None, + version: Optional[str] = None, + display_name: Optional[str] = None, + description: Optional[str] = None, + experiment_name: Optional[str] = None, + tags: Optional[Union[Dict[str, str], str]] = None, + **kwargs: Any, +) -> Union[Callable[[Callable[P, T]], Callable[P, PipelineJob]], Callable[P, PipelineJob]]: + """Build a pipeline which contains all component nodes defined in this function. + + :param func: The user pipeline function to be decorated. + :type func: types.FunctionType + :keyword name: The name of pipeline component, defaults to function name. + :paramtype name: str + :keyword version: The version of pipeline component, defaults to "1". + :paramtype version: str + :keyword display_name: The display name of pipeline component, defaults to function name. + :paramtype display_name: str + :keyword description: The description of the built pipeline. + :paramtype description: str + :keyword experiment_name: Name of the experiment the job will be created under, \ + if None is provided, experiment will be set to current directory. + :paramtype experiment_name: str + :keyword tags: The tags of pipeline component. + :paramtype tags: dict[str, str] + :return: Either + * A decorator, if `func` is None + * The decorated `func` + + :rtype: Union[ + Callable[[Callable], Callable[..., ~azure.ai.ml.entities.PipelineJob]], + Callable[P, ~azure.ai.ml.entities.PipelineJob] + + ] + + .. admonition:: Example: + + .. literalinclude:: ../../../../samples/ml_samples_pipeline_job_configurations.py + :start-after: [START configure_pipeline] + :end-before: [END configure_pipeline] + :language: python + :dedent: 8 + :caption: Shows how to create a pipeline using this decorator. + """ + + # get_component force pipeline to return Pipeline instead of PipelineJob so we can set optional argument + # need to remove get_component and rely on azure.ai.ml.dsl.pipeline + get_component = kwargs.get("get_component", False) + + def pipeline_decorator(func: Callable[P, T]) -> Callable: + if not isinstance(func, Callable): # type: ignore + raise UserErrorException(f"Dsl pipeline decorator accept only function type, got {type(func)}.") + + non_pipeline_inputs = kwargs.get("non_pipeline_inputs", []) or kwargs.get("non_pipeline_parameters", []) + # compute variable names changed from default_compute_targe -> compute -> default_compute -> none + # to support legacy usage, we support them with priority. + compute = kwargs.get("compute", None) + default_compute_target = kwargs.get("default_compute_target", None) + default_compute_target = kwargs.get("default_compute", None) or default_compute_target + continue_on_step_failure = kwargs.get("continue_on_step_failure", None) + on_init = kwargs.get("on_init", None) + on_finalize = kwargs.get("on_finalize", None) + + default_datastore = kwargs.get("default_datastore", None) + force_rerun = kwargs.get("force_rerun", None) + job_settings = { + "default_datastore": default_datastore, + "continue_on_step_failure": continue_on_step_failure, + "force_rerun": force_rerun, + "default_compute": default_compute_target, + "on_init": on_init, + "on_finalize": on_finalize, + } + func_entry_path = _resolve_source_file() + if not func_entry_path: + func_path = Path(inspect.getfile(func)) + # in notebook, func_path may be a fake path and will raise error when trying to resolve this fake path + if func_path.exists(): + func_entry_path = func_path.resolve().absolute() + + job_settings = {k: v for k, v in job_settings.items() if v is not None} + pipeline_builder = PipelineComponentBuilder( + func=func, + name=name, + version=version, + display_name=display_name, + description=description, + default_datastore=default_datastore, + tags=tags, + source_path=str(func_entry_path), + non_pipeline_inputs=non_pipeline_inputs, + ) + + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> Union[Pipeline, PipelineJob]: + # Default args will be added here. + # Node: push/pop stack here instead of put it inside build() + # Because we only want to enable dsl settings on top level pipeline + _dsl_settings_stack.push() # use this stack to track on_init/on_finalize settings + try: + # Convert args to kwargs + provided_positional_kwargs = _validate_args(func, args, kwargs, non_pipeline_inputs) + + # When pipeline supports variable params, update pipeline component to support the inputs in **kwargs. + pipeline_parameters = { + k: v for k, v in provided_positional_kwargs.items() if k not in non_pipeline_inputs + } + pipeline_builder._update_inputs(pipeline_parameters) + + non_pipeline_params_dict = { + k: v for k, v in provided_positional_kwargs.items() if k in non_pipeline_inputs + } + + # TODO: cache built pipeline component + pipeline_component = pipeline_builder.build( + user_provided_kwargs=provided_positional_kwargs, + non_pipeline_inputs_dict=non_pipeline_params_dict, + non_pipeline_inputs=non_pipeline_inputs, + ) + finally: + # use `finally` to ensure pop operation from the stack + dsl_settings = _dsl_settings_stack.pop() + + # update on_init/on_finalize settings if init/finalize job is set + if dsl_settings.init_job_set: + job_settings["on_init"] = dsl_settings.init_job_name(pipeline_component.jobs) + if dsl_settings.finalize_job_set: + job_settings["on_finalize"] = dsl_settings.finalize_job_name(pipeline_component.jobs) + + # TODO: pass compute & default_compute separately? + common_init_args: Any = { + "experiment_name": experiment_name, + "component": pipeline_component, + "inputs": pipeline_parameters, + "tags": tags, + } + built_pipeline: Any = None + if _is_inside_dsl_pipeline_func() or get_component: + # on_init/on_finalize is not supported for pipeline component + if job_settings.get("on_init") is not None or job_settings.get("on_finalize") is not None: + raise UserErrorException("On_init/on_finalize is not supported for pipeline component.") + # Build pipeline node instead of pipeline job if inside dsl. + built_pipeline = Pipeline(_from_component_func=True, **common_init_args) + if job_settings: + module_logger.warning( + ("Job settings %s on pipeline function %r are ignored when using inside PipelineJob."), + job_settings, + func.__name__, + ) + else: + built_pipeline = PipelineJob( + jobs=pipeline_component.jobs, + compute=compute, + settings=PipelineJobSettings(**job_settings), + **common_init_args, + ) + + return built_pipeline + + # Bug Item number: 2883169 + wrapper._is_dsl_func = True # type: ignore + wrapper._job_settings = job_settings # type: ignore + wrapper._pipeline_builder = pipeline_builder # type: ignore + return wrapper + + # enable use decorator without "()" if all arguments are default values + if func is not None: + return pipeline_decorator(func) + return pipeline_decorator diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/__init__.py new file mode 100644 index 00000000..90affdda --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/__init__.py @@ -0,0 +1,73 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +"""This package includes the type classes which could be used in dsl.pipeline, +command function, or any other place that requires job inputs/outputs. + +.. note:: + + The following pseudo-code shows how to create a pipeline with such classes. + + .. code-block:: python + + @pipeline() + def some_pipeline( + input_param: Input(type="uri_folder", path="xxx", mode="ro_mount"), + int_param0: Input(type="integer", default=0, min=-3, max=10), + int_param1 = 2 + str_param = 'abc', + ): + pass + + + The following pseudo-code shows how to create a command with such classes. + + .. code-block:: python + + my_command = command( + name="my_command", + display_name="my_command", + description="This is a command", + tags=dict(), + command="python train.py --input-data ${{inputs.input_data}} --lr ${{inputs.learning_rate}}", + code="./src", + compute="cpu-cluster", + environment="my-env:1", + distribution=MpiDistribution(process_count_per_instance=4), + environment_variables=dict(foo="bar"), + # Customers can still do this: + # resources=Resources(instance_count=2, instance_type="STANDARD_D2"), + # limits=Limits(timeout=300), + inputs={ + "float": Input(type="number", default=1.1, min=0, max=5), + "integer": Input(type="integer", default=2, min=-1, max=4), + "integer1": 2, + "string0": Input(type="string", default="default_str0"), + "string1": "default_str1", + "boolean": Input(type="boolean", default=False), + "uri_folder": Input(type="uri_folder", path="https://my-blob/path/to/data", mode="ro_mount"), + "uri_file": Input(type="uri_file", path="https://my-blob/path/to/data", mode="download"), + }, + outputs={"my_model": Output(type="mlflow_model")}, + ) + node = my_command() +""" + +from .enum_input import EnumInput +from .external_data import Database, FileSystem +from .group_input import GroupInput +from .input import Input +from .output import Output +from .utils import _get_param_with_standard_annotation, is_group + +__all__ = [ + "Input", + "Output", + "EnumInput", + "GroupInput", + "is_group", + "_get_param_with_standard_annotation", + "Database", + "FileSystem", +] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/base.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/base.py new file mode 100644 index 00000000..3a726b38 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/base.py @@ -0,0 +1,34 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from typing import Any + +from azure.ai.ml._schema.component.input_output import SUPPORTED_PARAM_TYPES +from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin + + +class _InputOutputBase(DictMixin, RestTranslatableMixin): + def __init__( + self, + *, + # pylint: disable=redefined-builtin + type: Any, + # pylint: disable=unused-argument + **kwargs: Any, + ) -> None: + """Base class for Input & Output class. + + This class is introduced to support literal output in the future. + + :param type: The type of the Input/Output. + :type type: str + """ + self.type = type + + def _is_literal(self) -> bool: + """Check whether input is a literal + + :return: True if this input is literal input. + :rtype: bool + """ + return self.type in SUPPORTED_PARAM_TYPES diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/enum_input.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/enum_input.py new file mode 100644 index 00000000..d6c88eef --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/enum_input.py @@ -0,0 +1,133 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from enum import EnumMeta +from typing import Any, Iterable, List, Optional, Sequence, Tuple, Union + +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + +from .input import Input + + +class EnumInput(Input): + """Enum parameter parse the value according to its enum values.""" + + def __init__( + self, + *, + enum: Optional[Union[EnumMeta, Sequence[str]]] = None, + default: Any = None, + description: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Enum parameter parse the value according to its enum values. + + :param enum: Enum values. + :type enum: Union[EnumMeta, Sequence[str]] + :param default: Default value of the parameter + :type default: Any + :param description: Description of the parameter + :type description: str + """ + enum_values = self._assert_enum_valid(enum) + self._enum_class: Optional[EnumMeta] = None + # This is used to parse enum class instead of enum str value if a enum class is provided. + if isinstance(enum, EnumMeta): + self._enum_class = enum + self._str2enum = dict(zip(enum_values, enum)) + else: + self._str2enum = {v: v for v in enum_values} + super().__init__(type="string", default=default, enum=enum_values, description=description) + + @property + def _allowed_types(self) -> Tuple: + return ( + (str,) + if not self._enum_class + else ( + self._enum_class, + str, + ) + ) + + @classmethod + def _assert_enum_valid(cls, enum: Optional[Union[EnumMeta, Sequence[str]]]) -> List: + """Check whether the enum is valid and return the values of the enum. + + :param enum: The enum to validate + :type enum: Type + :return: The enum values + :rtype: List[Any] + """ + if isinstance(enum, EnumMeta): + enum_values = [str(option.value) for option in enum] # type: ignore + elif isinstance(enum, Iterable): + enum_values = list(enum) + else: + msg = "enum must be a subclass of Enum or an iterable." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + target=ErrorTarget.PIPELINE, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + if len(enum_values) <= 0: + msg = "enum must have enum values." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + target=ErrorTarget.PIPELINE, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + if any(not isinstance(v, str) for v in enum_values): + msg = "enum values must be str type." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + target=ErrorTarget.PIPELINE, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + return enum_values + + def _parse(self, val: str) -> Any: + """Parse the enum value from a string value or the enum value. + + :param val: The string to parse + :type val: str + :return: The enum value + :rtype: Any + """ + if val is None: + return val + + if self._enum_class and isinstance(val, self._enum_class): + return val # Directly return the enum value if it is the enum. + + if val not in self._str2enum: + msg = "Not a valid enum value: '{}', valid values: {}" + raise ValidationException( + message=msg.format(val, ", ".join(self.enum)), + no_personal_data_message=msg.format("[val]", "[enum]"), + error_category=ErrorCategory.USER_ERROR, + target=ErrorTarget.PIPELINE, + error_type=ValidationErrorType.INVALID_VALUE, + ) + return self._str2enum[val] + + def _update_default(self, default_value: Any) -> None: + """Enum parameter support updating values with a string value. + + :param default_value: The default value for the input + :type default_value: Any + """ + enum_val = self._parse(default_value) + if self._enum_class and isinstance(enum_val, self._enum_class): + enum_val = enum_val.value + self.default = enum_val diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/external_data.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/external_data.py new file mode 100644 index 00000000..8a4fe21f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/external_data.py @@ -0,0 +1,207 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from inspect import Parameter +from typing import Dict, List, Optional, Union + +from azure.ai.ml.constants._component import ExternalDataType +from azure.ai.ml.entities._inputs_outputs.utils import _remove_empty_values +from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin + + +class StoredProcedureParameter(DictMixin, RestTranslatableMixin): + """Define a stored procedure parameter class for DataTransfer import database task. + + :keyword name: The name of the database stored procedure. + :paramtype name: str + :keyword value: The value of the database stored procedure. + :paramtype value: str + :keyword type: The type of the database stored procedure. + :paramtype type: str + """ + + def __init__( + self, + *, + name: Optional[str] = None, + value: Optional[str] = None, + type: Optional[str] = None, # pylint: disable=redefined-builtin + ) -> None: + self.type = type + self.name = name + self.value = value + + +class Database(DictMixin, RestTranslatableMixin): + """Define a database class for a DataTransfer Component or Job. + + :keyword query: The SQL query to retrieve data from the database. + :paramtype query: str + :keyword table_name: The name of the database table. + :paramtype table_name: str + :keyword stored_procedure: The name of the stored procedure. + :paramtype stored_procedure: str + :keyword stored_procedure_params: The parameters for the stored procedure. + :paramtype stored_procedure_params: List + :keyword connection: The connection string for the database. + The credential information should be stored in the connection. + :paramtype connection: str + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if the Database object cannot be successfully validated. + Details will be provided in the error message. + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_input_output_configurations.py + :start-after: [START configure_database] + :end-before: [END configure_database] + :language: python + :dedent: 8 + :caption: Create a database and querying a database table. + """ + + _EMPTY = Parameter.empty + + def __init__( + self, + *, + query: Optional[str] = None, + table_name: Optional[str] = None, + stored_procedure: Optional[str] = None, + stored_procedure_params: Optional[List[Dict]] = None, + connection: Optional[str] = None, + ) -> None: + # As an annotation, it is not allowed to initialize the name. + # The name will be updated by the annotated variable name. + self.name = None + self.type = ExternalDataType.DATABASE + self.connection = connection + self.query = query + self.table_name = table_name + self.stored_procedure = stored_procedure + self.stored_procedure_params = stored_procedure_params + + def _to_dict(self, remove_name: bool = True) -> Dict: + """Convert the Source object to a dict. + + :param remove_name: Whether to remove the `name` key from the dict representation. Defaults to True. + :type remove_name: bool + :return: The dictionary representation of the class + :rtype: Dict + """ + keys = [ + "name", + "type", + "query", + "stored_procedure", + "stored_procedure_params", + "connection", + "table_name", + ] + if remove_name: + keys.remove("name") + result = {key: getattr(self, key) for key in keys} + res: dict = _remove_empty_values(result) + return res + + def _to_rest_object(self) -> Dict: + # this is for component rest object when using Source as component inputs, as for job input usage, + # rest object is generated by extracting Source's properties, see details in to_rest_dataset_literal_inputs() + result = self._to_dict() + return result + + def _update_name(self, name: str) -> None: + self.name = name + + @classmethod + def _from_rest_object(cls, obj: Dict) -> "Database": + return Database(**obj) + + @property + def stored_procedure_params(self) -> Optional[List]: + """Get or set the parameters for the stored procedure. + + :return: The parameters for the stored procedure. + :rtype: List[StoredProcedureParameter] + """ + + return self._stored_procedure_params + + @stored_procedure_params.setter + def stored_procedure_params(self, value: Union[Dict[str, str], List, None]) -> None: + """Set the parameters for the stored procedure. + + :param value: The parameters for the stored procedure. + :type value: Union[Dict[str, str], StoredProcedureParameter, None] + """ + if value is None: + self._stored_procedure_params = value + else: + if not isinstance(value, list): + value = [value] + for index, item in enumerate(value): + if isinstance(item, dict): + value[index] = StoredProcedureParameter(**item) + self._stored_procedure_params = value + + +class FileSystem(DictMixin, RestTranslatableMixin): + """Define a file system class of a DataTransfer Component or Job. + + e.g. source_s3 = FileSystem(path='s3://my_bucket/my_folder', connection='azureml:my_s3_connection') + + :param path: The path to which the input is pointing. Could be pointing to the path of file system. Default is None. + :type path: str + :param connection: Connection is workspace, we didn't support storage connection here, need leverage workspace + connection to store these credential info. Default is None. + :type connection: str + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Source cannot be successfully validated. + Details will be provided in the error message. + """ + + _EMPTY = Parameter.empty + + def __init__( + self, + *, + path: Optional[str] = None, + connection: Optional[str] = None, + ) -> None: + self.type = ExternalDataType.FILE_SYSTEM + self.name: Optional[str] = None + self.connection = connection + self.path: Optional[str] = None + + if path is not None and not isinstance(path, str): + # this logic will make dsl data binding expression working in the same way as yaml + # it's written to handle InputOutputBase, but there will be loop import if we import InputOutputBase here + self.path = str(path) + else: + self.path = path + + def _to_dict(self, remove_name: bool = True) -> Dict: + """Convert the Source object to a dict. + + :param remove_name: Whether to remove the `name` key from the dict representation. Defaults to True. + :type remove_name: bool + :return: The dictionary representation of the object + :rtype: Dict + """ + keys = ["name", "path", "type", "connection"] + if remove_name: + keys.remove("name") + result = {key: getattr(self, key) for key in keys} + res: dict = _remove_empty_values(result) + return res + + def _to_rest_object(self) -> Dict: + # this is for component rest object when using Source as component inputs, as for job input usage, + # rest object is generated by extracting Source's properties, see details in to_rest_dataset_literal_inputs() + result = self._to_dict() + return result + + def _update_name(self, name: str) -> None: + self.name = name + + @classmethod + def _from_rest_object(cls, obj: Dict) -> "FileSystem": + return FileSystem(**obj) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/group_input.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/group_input.py new file mode 100644 index 00000000..e7fc565c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/group_input.py @@ -0,0 +1,251 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import copy +from enum import Enum as PyEnum +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from azure.ai.ml.constants._component import IOConstants +from azure.ai.ml.exceptions import ErrorTarget, UserErrorException, ValidationException + +from .input import Input +from .output import Output +from .utils import is_group + +# avoid circular import error +if TYPE_CHECKING: + from azure.ai.ml.entities._job.pipeline._io import _GroupAttrDict + + +class GroupInput(Input): + """Define a group input object. + + :param values: The values of the group input. + :type values: dict + :param _group_class: The class representing the group. + :type _group_class: Any + """ + + def __init__(self, values: dict, _group_class: Any) -> None: + super().__init__(type=IOConstants.GROUP_TYPE_NAME) + self.assert_group_value_valid(values) + self.values: Any = values + # Create empty default by values + # Note Output do not have default so just set a None + self.default = self._create_default() + # Save group class for init function generation + self._group_class = _group_class + + @classmethod + def _create_group_attr_dict(cls, dct: dict) -> "_GroupAttrDict": + from .._job.pipeline._io import _GroupAttrDict + + return _GroupAttrDict(dct) + + @classmethod + def _is_group_attr_dict(cls, obj: object) -> bool: + from .._job.pipeline._io import _GroupAttrDict + + return isinstance(obj, _GroupAttrDict) + + def __getattr__(self, item: Any) -> Any: + try: + # TODO: Bug Item number: 2883363 + return super().__getattr__(item) # type: ignore + except AttributeError: + # TODO: why values is not a dict in some cases? + if isinstance(self.values, dict) and item in self.values: + return self.values[item] + raise + + def _create_default(self) -> "_GroupAttrDict": + from .._job.pipeline._io import PipelineInput + + default_dict: dict = {} + # Note: no top-level group names at this time. + for k, v in self.values.items(): + # skip create default for outputs or port inputs + if isinstance(v, Output): + continue + + # Create PipelineInput object if not subgroup + if not isinstance(v, GroupInput): + default_dict[k] = PipelineInput(name=k, data=v.default, meta=v) + continue + # Copy and insert k into group names for subgroup + default_dict[k] = copy.deepcopy(v.default) + default_dict[k].insert_group_name_for_items(k) + return self._create_group_attr_dict(default_dict) + + @classmethod + def assert_group_value_valid(cls, values: Dict) -> None: + """Check if all values in the group are supported types. + + :param values: The values of the group. + :type values: dict + :raises ValueError: If a value in the group is not a supported type or if a parameter name is duplicated. + :raises UserErrorException: If a value in the group has an unsupported type. + """ + names = set() + msg = ( + f"Parameter {{!r}} with type {{!r}} is not supported in group. " + f"Supported types are: {list(IOConstants.INPUT_TYPE_COMBINATION.keys())}" + ) + for key, value in values.items(): + if not isinstance(value, (Input, Output)): + raise ValueError(msg.format(key, type(value).__name__)) + if value.type is None: + # Skip check for parameter translated from pipeline job (lost type) + continue + if value.type not in IOConstants.INPUT_TYPE_COMBINATION and not isinstance(value, GroupInput): + raise UserErrorException(msg.format(key, value.type)) + if key in names: + if not isinstance(value, Input): + raise ValueError(f"Duplicate parameter name {value.name!r} found in Group values.") + names.add(key) + + def flatten(self, group_parameter_name: str) -> Dict: + """Flatten the group and return all parameters. + + :param group_parameter_name: The name of the group parameter. + :type group_parameter_name: str + :return: A dictionary of flattened parameters. + :rtype: dict + """ + all_parameters = {} + group_parameter_name = group_parameter_name if group_parameter_name else "" + for key, value in self.values.items(): + flattened_name = ".".join([group_parameter_name, key]) + if isinstance(value, GroupInput): + all_parameters.update(value.flatten(flattened_name)) + else: + all_parameters[flattened_name] = value + return all_parameters + + def _to_dict(self) -> dict: + attr_dict = super()._to_dict() + attr_dict["values"] = {k: v._to_dict() for k, v in self.values.items()} # pylint: disable=protected-access + return attr_dict + + @staticmethod + def custom_class_value_to_attr_dict(value: Any, group_names: Optional[List] = None) -> Any: + """Convert a custom parameter group class object to GroupAttrDict. + + :param value: The value to convert. + :type value: any + :param group_names: The names of the parent groups. + :type group_names: list + :return: The converted value as a GroupAttrDict. + :rtype: GroupAttrDict or any + """ + if not is_group(value): + return value + group_definition = getattr(value, IOConstants.GROUP_ATTR_NAME) + group_names = [*group_names] if group_names else [] + attr_dict = {} + from .._job.pipeline._io import PipelineInput + + for k, v in value.__dict__.items(): + if is_group(v): + attr_dict[k] = GroupInput.custom_class_value_to_attr_dict(v, [*group_names, k]) + continue + data = v.value if isinstance(v, PyEnum) else v + if GroupInput._is_group_attr_dict(data): + attr_dict[k] = data + continue + attr_dict[k] = PipelineInput(name=k, meta=group_definition.get(k), data=data, group_names=group_names) + return GroupInput._create_group_attr_dict(attr_dict) + + @staticmethod + def validate_conflict_keys(keys: List) -> None: + """Validate conflicting keys in a flattened input dictionary, like {'a.b.c': 1, 'a.b': 1}. + + :param keys: The keys to validate. + :type keys: list + :raises ValidationException: If conflicting keys are found. + """ + conflict_msg = "Conflict parameter key '%s' and '%s'." + + def _group_count(s: str) -> int: + return len(s.split(".")) - 1 + + # Sort order by group numbers + keys = sorted(list(keys), key=_group_count) + for idx, key1 in enumerate(keys[:-1]): + for key2 in keys[idx + 1 :]: + if _group_count(key2) == 0: + continue + # Skip case a.b.c and a.b.c1 + if _group_count(key1) == _group_count(key2): + continue + if not key2.startswith(key1): + continue + # Invalid case 'a.b' in 'a.b.c' + raise ValidationException( + message=conflict_msg % (key1, key2), + no_personal_data_message=conflict_msg % ("[key1]", "[key2]"), + target=ErrorTarget.PIPELINE, + ) + + @staticmethod + def restore_flattened_inputs(inputs: Dict) -> Dict: + """Restore flattened inputs to structured groups. + + :param inputs: The flattened input dictionary. + :type inputs: dict + :return: The restored structured inputs. + :rtype: dict + """ + GroupInput.validate_conflict_keys(list(inputs.keys())) + restored_inputs = {} + group_inputs: Dict = {} + # 1. Build all group parameters dict + for name, data in inputs.items(): + # for a.b.c, group names is [a, b] + name_splits = name.split(".") + group_names, param_name = name_splits[:-1], name_splits[-1] + if not group_names: + restored_inputs[name] = data + continue + # change {'a.b.c': data} -> {'a': {'b': {'c': data}}} + target_dict = group_inputs + for group_name in group_names: + if group_name not in target_dict: + target_dict[group_name] = {} + target_dict = target_dict[group_name] + target_dict[param_name] = data + + def restore_from_dict_recursively(_data: dict) -> Union[GroupInput, "_GroupAttrDict"]: + for key, val in _data.items(): + if type(val) == dict: # pylint: disable=unidiomatic-typecheck + _data[key] = restore_from_dict_recursively(val) + # Create GroupInput for definition and _GroupAttrDict for PipelineInput + # Regard all Input class as parameter definition, as data will not appear in group now. + if all(isinstance(val, Input) for val in _data.values()): + return GroupInput(values=_data, _group_class=None) + return GroupInput._create_group_attr_dict(dct=_data) + + # 2. Rehydrate dict to GroupInput(definition) or GroupAttrDict. + for name, data in group_inputs.items(): + restored_inputs[name] = restore_from_dict_recursively(data) + return restored_inputs + + def _update_default(self, default_value: object = None) -> None: + default_cls = type(default_value) + + # Assert '__dsl_group__' must in the class of default value + if self._is_group_attr_dict(default_value): + self.default = default_value + self.optional = False + return + if default_value and not is_group(default_cls): + raise ValueError(f"Default value must be instance of parameter group, got {default_cls}.") + if hasattr(default_value, "__dict__"): + # Convert default value with customer type to _AttrDict + self.default = GroupInput.custom_class_value_to_attr_dict(default_value) + # Update item annotation + for key, annotation in self.values.items(): + if not hasattr(default_value, key): + continue + annotation._update_default(getattr(default_value, key)) # pylint: disable=protected-access + self.optional = default_value is None diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/input.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/input.py new file mode 100644 index 00000000..4a945108 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/input.py @@ -0,0 +1,547 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=redefined-builtin +# disable redefined-builtin to use type/min/max as argument name + +import math +from inspect import Parameter +from typing import Any, Dict, List, Optional, Union, overload + +from typing_extensions import Literal + +from azure.ai.ml.constants._component import ComponentParameterTypes, IOConstants +from azure.ai.ml.entities._assets.intellectual_property import IntellectualProperty +from azure.ai.ml.exceptions import ( + ErrorCategory, + ErrorTarget, + UserErrorException, + ValidationErrorType, + ValidationException, +) + +from .base import _InputOutputBase +from .utils import _get_param_with_standard_annotation, _remove_empty_values + + +class Input(_InputOutputBase): # pylint: disable=too-many-instance-attributes + """Initialize an Input object. + + :keyword type: The type of the data input. Accepted values are + 'uri_folder', 'uri_file', 'mltable', 'mlflow_model', 'custom_model', 'integer', 'number', 'string', and + 'boolean'. Defaults to 'uri_folder'. + :paramtype type: str + :keyword path: The path to the input data. Paths can be local paths, remote data uris, or a registered AzureML asset + ID. + :paramtype path: Optional[str] + :keyword mode: The access mode of the data input. Accepted values are: + * 'ro_mount': Mount the data to the compute target as read-only, + * 'download': Download the data to the compute target, + * 'direct': Pass in the URI as a string to be accessed at runtime + :paramtype mode: Optional[str] + :keyword path_on_compute: The access path of the data input for compute + :paramtype path_on_compute: Optional[str] + :keyword default: The default value of the input. If a default is set, the input data will be optional. + :paramtype default: Union[str, int, float, bool] + :keyword min: The minimum value for the input. If a value smaller than the minimum is passed to the job, the job + execution will fail. + :paramtype min: Union[int, float] + :keyword max: The maximum value for the input. If a value larger than the maximum is passed to a job, the job + execution will fail. + :paramtype max: Union[int, float] + :keyword optional: Specifies if the input is optional. + :paramtype optional: Optional[bool] + :keyword description: Description of the input + :paramtype description: Optional[str] + :keyword datastore: The datastore to upload local files to. + :paramtype datastore: str + :keyword intellectual_property: Intellectual property for the input. + :paramtype intellectual_property: IntellectualProperty + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Input cannot be successfully validated. + Details will be provided in the error message. + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START create_inputs_outputs] + :end-before: [END create_inputs_outputs] + :language: python + :dedent: 8 + :caption: Creating a CommandJob with two inputs. + """ + + _EMPTY = Parameter.empty + _IO_KEYS = [ + "path", + "type", + "mode", + "path_on_compute", + "description", + "default", + "min", + "max", + "enum", + "optional", + "datastore", + ] + + @overload + def __init__( + self, + *, + type: str, + path: Optional[str] = None, + mode: Optional[str] = None, + optional: Optional[bool] = None, + description: Optional[str] = None, + **kwargs: Any, + ) -> None: + """""" + + @overload + def __init__( + self, + *, + type: Literal["number"] = "number", + default: Optional[float] = None, + min: Optional[float] = None, + max: Optional[float] = None, + optional: Optional[bool] = None, + description: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Initialize a number input. + + :keyword type: The type of the data input. Can only be set to "number". + :paramtype type: str + :keyword default: The default value of the input. If a default is set, the input data will be optional. + :paramtype default: Union[str, int, float, bool] + :keyword min: The minimum value for the input. If a value smaller than the minimum is passed to the job, the job + execution will fail. + :paramtype min: Optional[float] + :keyword max: The maximum value for the input. If a value larger than the maximum is passed to a job, the job + execution will fail. + :paramtype max: Optional[float] + :keyword optional: Specifies if the input is optional. + :paramtype optional: bool + :keyword description: Description of the input + :paramtype description: str + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Input cannot be successfully validated. + Details will be provided in the error message. + """ + + @overload + def __init__( + self, + *, + type: Literal["integer"] = "integer", + default: Optional[int] = None, + min: Optional[int] = None, + max: Optional[int] = None, + optional: Optional[bool] = None, + description: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Initialize an integer input. + + :keyword type: The type of the data input. Can only be set to "integer". + :paramtype type: str + :keyword default: The default value of the input. If a default is set, the input data will be optional. + :paramtype default: Union[str, int, float, bool] + :keyword min: The minimum value for the input. If a value smaller than the minimum is passed to the job, the job + execution will fail. + :paramtype min: Optional[int] + :keyword max: The maximum value for the input. If a value larger than the maximum is passed to a job, the job + execution will fail. + :paramtype max: Optional[int] + :keyword optional: Specifies if the input is optional. + :paramtype optional: bool + :keyword description: Description of the input + :paramtype description: str + """ + + @overload + def __init__( + self, + *, + type: Literal["string"] = "string", + default: Optional[str] = None, + optional: Optional[bool] = None, + description: Optional[str] = None, + path: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Initialize a string input. + + :keyword type: The type of the data input. Can only be set to "string". + :paramtype type: str + :keyword default: The default value of this input. When a `default` is set, the input will be optional. + :paramtype default: str + :keyword optional: Determine if this input is optional. + :paramtype optional: bool + :keyword description: Description of the input. + :paramtype description: str + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Input cannot be successfully validated. + Details will be provided in the error message. + """ + + @overload + def __init__( + self, + *, + type: Literal["boolean"] = "boolean", + default: Optional[bool] = None, + optional: Optional[bool] = None, + description: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Initialize a bool input. + + :keyword type: The type of the data input. Can only be set to "boolean". + :paramtype type: str + :keyword path: The path to the input data. Paths can be local paths, remote data uris, or a registered AzureML + asset id. + :paramtype path: str + :keyword default: The default value of the input. If a default is set, the input data will be optional. + :paramtype default: Union[str, int, float, bool] + :keyword optional: Specifies if the input is optional. + :paramtype optional: bool + :keyword description: Description of the input + :paramtype description: str + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Input cannot be successfully validated. + Details will be provided in the error message. + """ + + def __init__( + self, + *, + type: str = "uri_folder", + path: Optional[str] = None, + mode: Optional[str] = None, + path_on_compute: Optional[str] = None, + default: Optional[Union[str, int, float, bool]] = None, + optional: Optional[bool] = None, + min: Optional[Union[int, float]] = None, + max: Optional[Union[int, float]] = None, + enum: Any = None, + description: Optional[str] = None, + datastore: Optional[str] = None, + **kwargs: Any, + ) -> None: + super(Input, self).__init__(type=type) + # As an annotation, it is not allowed to initialize the _port_name. + self._port_name = None + self.description = description + self.path: Any = None + + if path is not None and not isinstance(path, str): + # this logic will make dsl data binding expression working in the same way as yaml + # it's written to handle InputOutputBase, but there will be loop import if we import InputOutputBase here + self.path = str(path) + else: + self.path = path + self.path_on_compute = path_on_compute + self.mode = None if self._is_primitive_type else mode + self._update_default(default) + self.optional = optional + # set the flag to mark if the optional=True is inferred by us. + self._is_inferred_optional = False + self.min = min + self.max = max + self.enum = enum + self.datastore = datastore + intellectual_property = kwargs.pop("intellectual_property", None) + if intellectual_property: + self._intellectual_property = ( + intellectual_property + if isinstance(intellectual_property, IntellectualProperty) + else IntellectualProperty(**intellectual_property) + ) + # normalize properties like ["default", "min", "max", "optional"] + self._normalize_self_properties() + + self._validate_parameter_combinations() + + @property + def _allowed_types(self) -> Any: + if self._multiple_types: + return None + return IOConstants.PRIMITIVE_STR_2_TYPE.get(self.type) + + @property + def _is_primitive_type(self) -> bool: + if self._multiple_types: + # note: we suppose that no primitive type will be included when there are multiple types + return False + return self.type in IOConstants.PRIMITIVE_STR_2_TYPE + + @property + def _multiple_types(self) -> bool: + """Returns True if this input has multiple types. + + Currently, there are two scenarios that need to check this property: + 1. before `in` as it may throw exception; there will be `in` operation for validation/transformation. + 2. `str()` of list is not ideal, so we need to manually create its string result. + + :return: Whether this input has multiple types + :rtype: bool + """ + return isinstance(self.type, list) + + def _is_literal(self) -> bool: + """Whether this input is a literal + + Override this function as `self.type` can be list and not hashable for operation `in`. + + :return: Whether is a literal + :rtype: bool + """ + return not self._multiple_types and super(Input, self)._is_literal() + + def _is_enum(self) -> bool: + """Whether input is an enum + + :return: True if the input is enum. + :rtype: bool + """ + res: bool = self.type == ComponentParameterTypes.STRING and self.enum + return res + + def _to_dict(self) -> Dict: + """Convert the Input object to a dict. + + :return: Dictionary representation of Input + :rtype: Dict + """ + keys = self._IO_KEYS + result = {key: getattr(self, key) for key in keys} + res: dict = _remove_empty_values(result) + return res + + def _parse(self, val: Any) -> Union[int, float, bool, str, Any]: + """Parse value passed from command line. + + :param val: The input value + :type val: T + :return: The parsed value. + :rtype: Union[int, float, bool, str, T] + """ + if self.type == "integer": + return int(float(val)) # backend returns 10.0,for integer, parse it to float before int + if self.type == "number": + return float(val) + if self.type == "boolean": + lower_val = str(val).lower() + if lower_val not in {"true", "false"}: + msg = "Boolean parameter '{}' only accept True/False, got {}." + raise ValidationException( + message=msg.format(self._port_name, val), + no_personal_data_message=msg.format("[self._port_name]", "[val]"), + error_category=ErrorCategory.USER_ERROR, + target=ErrorTarget.PIPELINE, + error_type=ValidationErrorType.INVALID_VALUE, + ) + return lower_val == "true" + if self.type == "string": + return val if isinstance(val, str) else str(val) + return val + + def _parse_and_validate(self, val: Any) -> Union[int, float, bool, str, Any]: + """Parse the val passed from the command line and validate the value. + + :param val: The input string value from the command line. + :type val: T + :return: The parsed value, an exception will be raised if the value is invalid. + :rtype: Union[int, float, bool, str, T] + """ + if self._is_primitive_type: + val = self._parse(val) if isinstance(val, str) else val + self._validate_or_throw(val) + return val + + def _update_name(self, name: Any) -> None: + self._port_name = name + + def _update_default(self, default_value: Any) -> None: + """Update provided default values. + + :param default_value: The default value of the Input + :type default_value: Any + """ + name = "" if not self._port_name else f"{self._port_name!r} " + msg_prefix = f"Default value of Input {name}" + + if not self._is_primitive_type and default_value is not None: + msg = f"{msg_prefix}cannot be set: Non-primitive type Input has no default value." + raise UserErrorException(msg) + if isinstance(default_value, float) and not math.isfinite(default_value): + # Since nan/inf cannot be stored in the backend, just ignore them. + # logger.warning("Float default value %r is not allowed, ignored." % default_value) + return + # pylint: disable=pointless-string-statement + """Update provided default values. + Here we need to make sure the type of default value is allowed or it could be parsed.. + """ + if default_value is not None: + if type(default_value) not in IOConstants.PRIMITIVE_TYPE_2_STR: + msg = ( + f"{msg_prefix}cannot be set: type must be one of " + f"{list(IOConstants.PRIMITIVE_TYPE_2_STR.values())}, got '{type(default_value)}'." + ) + raise UserErrorException(msg) + + if not isinstance(default_value, self._allowed_types): + try: + default_value = self._parse(default_value) + # return original validation exception which is custom defined if raised by self._parse + except ValidationException as e: + raise e + except Exception as e: + msg = f"{msg_prefix}cannot be parsed, got '{default_value}', type = {type(default_value)!r}." + raise UserErrorException(msg) from e + self.default = default_value + + def _validate_or_throw(self, value: Any) -> None: + """Validate input parameter value, throw exception if not as expected. + + It will throw exception if validate failed, otherwise do nothing. + + :param value: A value to validate + :type value: Any + """ + if not self.optional and value is None: + msg = "Parameter {} cannot be None since it is not optional." + raise ValidationException( + message=msg.format(self._port_name), + no_personal_data_message=msg.format("[self._port_name]"), + error_category=ErrorCategory.USER_ERROR, + target=ErrorTarget.PIPELINE, + error_type=ValidationErrorType.INVALID_VALUE, + ) + if self._allowed_types and value is not None: + if not isinstance(value, self._allowed_types): + msg = "Unexpected data type for parameter '{}'. Expected {} but got {}." + raise ValidationException( + message=msg.format(self._port_name, self._allowed_types, type(value)), + no_personal_data_message=msg.format("[_port_name]", self._allowed_types, type(value)), + error_category=ErrorCategory.USER_ERROR, + target=ErrorTarget.PIPELINE, + error_type=ValidationErrorType.INVALID_VALUE, + ) + # for numeric values, need extra check for min max value + if not self._multiple_types and self.type in ("integer", "number"): + if self.min is not None and value < self.min: + msg = "Parameter '{}' should not be less than {}." + raise ValidationException( + message=msg.format(self._port_name, self.min), + no_personal_data_message=msg.format("[_port_name]", self.min), + error_category=ErrorCategory.USER_ERROR, + target=ErrorTarget.PIPELINE, + error_type=ValidationErrorType.INVALID_VALUE, + ) + if self.max is not None and value > self.max: + msg = "Parameter '{}' should not be greater than {}." + raise ValidationException( + message=msg.format(self._port_name, self.max), + no_personal_data_message=msg.format("[_port_name]", self.max), + error_category=ErrorCategory.USER_ERROR, + target=ErrorTarget.PIPELINE, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + def _get_python_builtin_type_str(self) -> str: + """Get python builtin type for current input in string, eg: str. + + Return yaml type if not available. + + :return: The name of the input type + :rtype: str + """ + if self._multiple_types: + return "[" + ", ".join(self.type) + "]" + if self._is_primitive_type: + res_primitive_type: str = IOConstants.PRIMITIVE_STR_2_TYPE[self.type].__name__ + return res_primitive_type + res: str = self.type + return res + + def _validate_parameter_combinations(self) -> None: + """Validate different parameter combinations according to type.""" + parameters = ["type", "path", "mode", "default", "min", "max"] + parameters_dict: dict = {key: getattr(self, key, None) for key in parameters} + type = parameters_dict.pop("type") + + # validate parameter combination + if not self._multiple_types and type in IOConstants.INPUT_TYPE_COMBINATION: + valid_parameters = IOConstants.INPUT_TYPE_COMBINATION[type] + for key, value in parameters_dict.items(): + if key not in valid_parameters and value is not None: + msg = "Invalid parameter for '{}' Input, parameter '{}' should be None but got '{}'" + raise ValidationException( + message=msg.format(type, key, value), + no_personal_data_message=msg.format("[type]", "[parameter]", "[parameter_value]"), + error_category=ErrorCategory.USER_ERROR, + target=ErrorTarget.PIPELINE, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + def _simple_parse(self, value: Any, _type: Any = None) -> Any: + if self._multiple_types: + return value + if _type is None: + _type = self.type + if _type in IOConstants.PARAM_PARSERS: + return IOConstants.PARAM_PARSERS[_type](value) + return value + + def _normalize_self_properties(self) -> None: + # parse value from string to its original type. eg: "false" -> False + for key in ["min", "max"]: + if getattr(self, key) is not None: + origin_value = getattr(self, key) + new_value = self._simple_parse(origin_value) + setattr(self, key, new_value) + if self.optional: + self.optional = self._simple_parse(getattr(self, "optional", "false"), _type="boolean") + + @classmethod + def _get_input_by_type(cls, t: type, optional: Any = None) -> Optional["Input"]: + if t in IOConstants.PRIMITIVE_TYPE_2_STR: + return cls(type=IOConstants.PRIMITIVE_TYPE_2_STR[t], optional=optional) + return None + + @classmethod + def _get_default_unknown_input(cls, optional: Optional[bool] = None) -> "Input": + # Set type as None here to avoid schema validation failed + res: Input = cls(type=None, optional=optional) # type: ignore + return res + + @classmethod + def _get_param_with_standard_annotation(cls, func: Any) -> Dict: + return _get_param_with_standard_annotation(func, is_func=True) + + def _to_rest_object(self) -> Dict: + # this is for component rest object when using Input as component inputs, as for job input usage, + # rest object is generated by extracting Input's properties, see details in to_rest_dataset_literal_inputs() + result = self._to_dict() + # parse string -> String, integer -> Integer, etc. + if result["type"] in IOConstants.TYPE_MAPPING_YAML_2_REST: + result["type"] = IOConstants.TYPE_MAPPING_YAML_2_REST[result["type"]] + return result + + @classmethod + def _map_from_rest_type(cls, _type: Union[str, List]) -> Union[str, List]: + # this is for component rest object when using Input as component inputs + reversed_data_type_mapping = {v: k for k, v in IOConstants.TYPE_MAPPING_YAML_2_REST.items()} + # parse String -> string, Integer -> integer, etc + if not isinstance(_type, list) and _type in reversed_data_type_mapping: + res: str = reversed_data_type_mapping[_type] + return res + return _type + + @classmethod + def _from_rest_object(cls, obj: Dict) -> "Input": + obj["type"] = cls._map_from_rest_type(obj["type"]) + + return cls(**obj) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/output.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/output.py new file mode 100644 index 00000000..1c4dcd06 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/output.py @@ -0,0 +1,180 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=redefined-builtin +import re +from typing import Any, Dict, Optional, overload + +from typing_extensions import Literal + +from azure.ai.ml.constants import AssetTypes +from azure.ai.ml.constants._component import IOConstants +from azure.ai.ml.entities._assets.intellectual_property import IntellectualProperty +from azure.ai.ml.exceptions import UserErrorException + +from .base import _InputOutputBase +from .utils import _remove_empty_values + + +class Output(_InputOutputBase): + _IO_KEYS = ["name", "version", "path", "path_on_compute", "type", "mode", "description", "early_available"] + + @overload + def __init__( + self, + *, + type: str, + path: Optional[str] = None, + mode: Optional[str] = None, + description: Optional[str] = None, + **kwargs: Any, + ): ... + + @overload + def __init__( + self, + type: Literal["uri_file"] = "uri_file", + path: Optional[str] = None, + mode: Optional[str] = None, + description: Optional[str] = None, + ): + """Define a URI file output. + + :keyword type: The type of the data output. Can only be set to 'uri_file'. + :paramtype type: str + :keyword path: The remote path where the output should be stored. + :paramtype path: str + :keyword mode: The access mode of the data output. Accepted values are + * 'rw_mount': Read-write mount the data, + * 'upload': Upload the data from the compute target, + * 'direct': Pass in the URI as a string + :paramtype mode: str + :keyword description: The description of the output. + :paramtype description: str + :keyword name: The name to be used to register the output as a Data or Model asset. A name can be set without + setting a version. + :paramtype name: str + :keyword version: The version used to register the output as a Data or Model asset. A version can be set only + when name is set. + :paramtype version: str + """ + + def __init__( # type: ignore[misc] + self, + *, + type: str = AssetTypes.URI_FOLDER, + path: Optional[str] = None, + mode: Optional[str] = None, + description: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Define an output. + + :keyword type: The type of the data output. Accepted values are 'uri_folder', 'uri_file', 'mltable', + 'mlflow_model', 'custom_model', and user-defined types. Defaults to 'uri_folder'. + :paramtype type: str + :keyword path: The remote path where the output should be stored. + :paramtype path: Optional[str] + :keyword mode: The access mode of the data output. Accepted values are + * 'rw_mount': Read-write mount the data + * 'upload': Upload the data from the compute target + * 'direct': Pass in the URI as a string + :paramtype mode: Optional[str] + :keyword path_on_compute: The access path of the data output for compute + :paramtype path_on_compute: Optional[str] + :keyword description: The description of the output. + :paramtype description: Optional[str] + :keyword name: The name to be used to register the output as a Data or Model asset. A name can be set without + setting a version. + :paramtype name: str + :keyword version: The version used to register the output as a Data or Model asset. A version can be set only + when name is set. + :paramtype version: str + :keyword is_control: Determine if the output is a control output. + :paramtype is_control: bool + :keyword early_available: Mark the output for early node orchestration. + :paramtype early_available: bool + :keyword intellectual_property: Intellectual property associated with the output. + It can be an instance of `IntellectualProperty` or a dictionary that will be used to create an instance. + :paramtype intellectual_property: Union[ + ~azure.ai.ml.entities._assets.intellectual_property.IntellectualProperty, dict] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START create_inputs_outputs] + :end-before: [END create_inputs_outputs] + :language: python + :dedent: 8 + :caption: Creating a CommandJob with a folder output. + """ + super(Output, self).__init__(type=type) + # As an annotation, it is not allowed to initialize the _port_name. + self._port_name = None + self.name = kwargs.pop("name", None) + self.version = kwargs.pop("version", None) + self._is_primitive_type = self.type in IOConstants.PRIMITIVE_STR_2_TYPE + self.description = description + self.path = path + self.path_on_compute = kwargs.pop("path_on_compute", None) + self.mode = mode + # use this field to mark Output for early node orchestrate, currently hide in kwargs + self.early_available = kwargs.pop("early_available", None) + self._intellectual_property = None + intellectual_property = kwargs.pop("intellectual_property", None) + if intellectual_property: + self._intellectual_property = ( + intellectual_property + if isinstance(intellectual_property, IntellectualProperty) + else IntellectualProperty(**intellectual_property) + ) + self._assert_name_and_version() + # normalize properties + self._normalize_self_properties() + + def _get_hint(self, new_line_style: bool = False) -> Optional[str]: + comment_str = self.description.replace('"', '\\"') if self.description else self.type + return '"""%s"""' % comment_str if comment_str and new_line_style else comment_str + + def _to_dict(self) -> Dict: + """Convert the Output object to a dict. + + :return: The dictionary representation of Output + :rtype: Dict + """ + keys = self._IO_KEYS + result = {key: getattr(self, key) for key in keys} + res: dict = _remove_empty_values(result) + return res + + def _to_rest_object(self) -> Dict: + # this is for component rest object when using Output as component outputs, as for job output usage, + # rest object is generated by extracting Output's properties, see details in to_rest_data_outputs() + return self._to_dict() + + def _simple_parse(self, value: Any, _type: Any = None) -> Any: + if _type is None: + _type = self.type + if _type in IOConstants.PARAM_PARSERS: + return IOConstants.PARAM_PARSERS[_type](value) + return value + + def _normalize_self_properties(self) -> None: + # parse value from string to its original type. eg: "false" -> False + if self.early_available: + self.early_available = self._simple_parse(getattr(self, "early_available", "false"), _type="boolean") + + @classmethod + def _from_rest_object(cls, obj: Dict) -> "Output": + # this is for component rest object when using Output as component outputs + return Output(**obj) + + def _assert_name_and_version(self) -> None: + if self.name and not (re.match("^[A-Za-z0-9_-]*$", self.name) and len(self.name) <= 255): + raise UserErrorException( + f"The output name {self.name} can only contain alphanumeric characters, dashes and underscores, " + f"with a limit of 255 characters." + ) + if self.version and not self.name: + raise UserErrorException("Output name is required when output version is specified.") diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/utils.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/utils.py new file mode 100644 index 00000000..bd752571 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/utils.py @@ -0,0 +1,479 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access +# enable protected access for protected helper functions + +import copy +from collections import OrderedDict +from enum import Enum as PyEnum +from enum import EnumMeta +from inspect import Parameter, getmro, signature +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast + +from typing_extensions import Annotated, Literal, TypeAlias + +from azure.ai.ml.constants._component import IOConstants +from azure.ai.ml.exceptions import UserErrorException + +# avoid circular import error +if TYPE_CHECKING: + from .input import Input + from .output import Output + +SUPPORTED_RETURN_TYPES_PRIMITIVE = list(IOConstants.PRIMITIVE_TYPE_2_STR.keys()) + +Annotation: TypeAlias = Union[str, Type, Annotated[Any, Any], None] # type: ignore + + +def is_group(obj: object) -> bool: + """Return True if obj is a group or an instance of a parameter group class. + + :param obj: The object to check. + :type obj: Any + :return: True if obj is a group or an instance, False otherwise. + :rtype: bool + """ + return hasattr(obj, IOConstants.GROUP_ATTR_NAME) + + +def _get_annotation_by_value(val: Any) -> Union["Input", Type["Input"]]: + # TODO: we'd better remove this potential recursive import + from .enum_input import EnumInput + from .input import Input + + annotation: Any = None + + def _is_dataset(data: Any) -> bool: + from azure.ai.ml.entities._job.job_io_mixin import JobIOMixin + + DATASET_TYPES = JobIOMixin + return isinstance(data, DATASET_TYPES) + + if _is_dataset(val): + annotation = Input + elif val is Parameter.empty or val is None: + # If no default value or default is None, create val as the basic parameter type, + # it could be replaced using component parameter definition. + annotation = Input._get_default_unknown_input() + elif isinstance(val, PyEnum): + # Handle enum values + annotation = EnumInput(enum=val.__class__) + else: + _new_annotation = _get_annotation_cls_by_type(type(val), raise_error=False) + if not _new_annotation: + # Fall back to default + annotation = Input._get_default_unknown_input() + else: + return _new_annotation + return cast(Union["Input", Type["Input"]], annotation) + + +def _get_annotation_cls_by_type( + t: type, raise_error: bool = False, optional: Optional[bool] = None +) -> Optional["Input"]: + # TODO: we'd better remove this potential recursive import + from .input import Input + + cls = Input._get_input_by_type(t, optional=optional) + if cls is None and raise_error: + raise UserErrorException(f"Can't convert type {t} to azure.ai.ml.Input") + return cls + + +# pylint: disable=too-many-statements +def _get_param_with_standard_annotation( + cls_or_func: Union[Callable, Type], is_func: bool = False, skip_params: Optional[List[str]] = None +) -> Dict[str, Union[Annotation, "Input", "Output"]]: + """Standardize function parameters or class fields with dsl.types annotation. + + :param cls_or_func: Either a class or a function + :type cls_or_func: Union[Callable, Type] + :param is_func: Whether `cls_or_func` is a function. Defaults to False. + :type is_func: bool + :param skip_params: + :type skip_params: Optional[List[str]] + :return: A dictionary of field annotations + :rtype: Dict[str, Union[Annotation, "Input", "Output"]] + """ + # TODO: we'd better remove this potential recursive import + from .group_input import GroupInput + from .input import Input + from .output import Output + + def _is_dsl_type_cls(t: Any) -> bool: + if type(t) is not type: # pylint: disable=unidiomatic-typecheck + return False + return issubclass(t, (Input, Output)) + + def _is_dsl_types(o: object) -> bool: + return _is_dsl_type_cls(type(o)) + + def _get_fields(annotations: Dict) -> Dict: + """Return field names to annotations mapping in class. + + :param annotations: The annotations + :type annotations: Dict[str, Union[Annotation, Input, Output]] + :return: The field dict + :rtype: Dict[str, Union[Annotation, Input, Output]] + """ + annotation_fields = OrderedDict() + for name, annotation in annotations.items(): + # Skip return type + if name == "return": + continue + # Handle EnumMeta annotation + if isinstance(annotation, EnumMeta): + from .enum_input import EnumInput + + annotation = EnumInput(type="string", enum=annotation) + # Handle Group annotation + if is_group(annotation): + _deep_copy: GroupInput = copy.deepcopy(getattr(annotation, IOConstants.GROUP_ATTR_NAME)) + annotation = _deep_copy + # Try creating annotation by type when got like 'param: int' + if not _is_dsl_type_cls(annotation) and not _is_dsl_types(annotation): + origin_annotation = annotation + annotation = cast(Input, _get_annotation_cls_by_type(annotation, raise_error=False)) + if not annotation: + msg = f"Unsupported annotation type {origin_annotation!r} for parameter {name!r}." + raise UserErrorException(msg) + annotation_fields[name] = annotation + return annotation_fields + + def _merge_field_keys( + annotation_fields: Dict[str, Union[Annotation, Input, Output]], defaults_dict: Dict[str, Any] + ) -> List[str]: + """Merge field keys from annotations and cls dict to get all fields in class. + + :param annotation_fields: The field annotations + :type annotation_fields: Dict[str, Union[Annotation, Input, Output]] + :param defaults_dict: The map of variable name to default value + :type defaults_dict: Dict[str, Any] + :return: A list of field keys + :rtype: List[str] + """ + anno_keys = list(annotation_fields.keys()) + dict_keys = defaults_dict.keys() + if not dict_keys: + return anno_keys + return [*anno_keys, *[key for key in dict_keys if key not in anno_keys]] + + def _update_annotation_with_default( + anno: Union[Annotation, Input, Output], name: str, default: Any + ) -> Union[Annotation, Input, Output]: + """Create annotation if is type class and update the default. + + :param anno: The annotation + :type anno: Union[Annotation, Input, Output] + :param name: The port name + :type name: str + :param default: The default value + :type default: Any + :return: The updated annotation + :rtype: Union[Annotation, Input, Output] + """ + # Create instance if is type class + complete_annotation = anno + if _is_dsl_type_cls(anno): + if anno is not None and not isinstance(anno, (str, Input, Output)): + complete_annotation = anno() + if complete_annotation is not None and not isinstance(complete_annotation, str): + complete_annotation._port_name = name + if default is Input._EMPTY: + return complete_annotation + if isinstance(complete_annotation, Input): + # Non-parameter Input has no default attribute + if complete_annotation._is_primitive_type and complete_annotation.default is not None: + # logger.warning( + # f"Warning: Default value of f{complete_annotation.name!r} is set twice: " + # f"{complete_annotation.default!r} and {default!r}, will use {default!r}" + # ) + pass + complete_annotation._update_default(default) + if isinstance(complete_annotation, Output) and default is not None: + msg = ( + f"Default value of Output {complete_annotation._port_name!r} cannot be set:" + f"Output has no default value." + ) + raise UserErrorException(msg) + return complete_annotation + + def _update_fields_with_default( + annotation_fields: Dict[str, Union[Annotation, Input, Output]], defaults_dict: Dict[str, Any] + ) -> Dict[str, Union[Annotation, Input, Output]]: + """Use public values in class dict to update annotations. + + :param annotation_fields: The field annotations + :type annotation_fields: Dict[str, Union[Annotation, Input, Output]] + :param defaults_dict: A map of variable name to default value + :type defaults_dict: Dict[str, Any] + :return: List of field names + :rtype: List[str] + """ + all_fields = OrderedDict() + all_filed_keys = _merge_field_keys(annotation_fields, defaults_dict) + for name in all_filed_keys: + # Get or create annotation + annotation = ( + annotation_fields[name] + if name in annotation_fields + else _get_annotation_by_value(defaults_dict.get(name, Input._EMPTY)) + ) + # Create annotation if is class type and update default + annotation = _update_annotation_with_default(annotation, name, defaults_dict.get(name, Input._EMPTY)) + all_fields[name] = annotation + return all_fields + + def _merge_and_reorder( + inherited_fields: Dict[str, Union[Annotation, Input, Output]], + cls_fields: Dict[str, Union[Annotation, Input, Output]], + ) -> Dict[str, Union[Annotation, Input, Output]]: + """Merge inherited fields with cls fields. + + The order inside each part will not be changed. Order will be: + + {inherited_no_default_fields} + {cls_no_default_fields} + {inherited_default_fields} + {cls_default_fields}. + + + :param inherited_fields: The inherited fields + :type inherited_fields: Dict[str, Union[Annotation, Input, Output]] + :param cls_fields: The class fields + :type cls_fields: Dict[str, Union[Annotation, Input, Output]] + :return: The merged fields + :rtype: Dict[str, Union[Annotation, Input, Output]] + + .. admonition:: Additional Note + + :class: note + + If cls overwrite an inherited no default field with default, it will be put in the + cls_default_fields part and deleted from inherited_no_default_fields: + + .. code-block:: python + + @dsl.group + class SubGroup: + int_param0: Integer + int_param1: int + + @dsl.group + class Group(SubGroup): + int_param3: Integer + int_param1: int = 1 + + The init function of Group will be `def __init__(self, *, int_param0, int_param3, int_param1=1)`. + """ + + def _split( + _fields: Dict[str, Union[Annotation, Input, Output]] + ) -> Tuple[Dict[str, Union[Annotation, Input, Output]], Dict[str, Union[Annotation, Input, Output]]]: + """Split fields to two parts from the first default field. + + :param _fields: The fields + :type _fields: Dict[str, Union[Annotation, Input, Output]] + :return: A 2-tuple of (fields with no defaults, fields with defaults) + :rtype: Tuple[Dict[str, Union[Annotation, Input, Output]], Dict[str, Union[Annotation, Input, Output]]] + """ + _no_defaults_fields, _defaults_fields = {}, {} + seen_default = False + for key, val in _fields.items(): + if val is not None and not isinstance(val, str): + if val.get("default", None) or seen_default: + seen_default = True + _defaults_fields[key] = val + else: + _no_defaults_fields[key] = val + return _no_defaults_fields, _defaults_fields + + inherited_no_default, inherited_default = _split(inherited_fields) + cls_no_default, cls_default = _split(cls_fields) + # Cross comparison and delete from inherited_fields if same key appeared in cls_fields + # pylint: disable=consider-iterating-dictionary + for key in cls_default.keys(): + if key in inherited_no_default.keys(): + del inherited_no_default[key] + for key in cls_no_default.keys(): + if key in inherited_default.keys(): + del inherited_default[key] + return OrderedDict( + { + **inherited_no_default, + **cls_no_default, + **inherited_default, + **cls_default, + } + ) + + def _get_inherited_fields() -> Dict[str, Union[Annotation, Input, Output]]: + """Get all fields inherited from @group decorated base classes. + + :return: The field dict + :rtype: Dict[str, Union[Annotation, Input, Output]] + """ + # Return value of _get_param_with_standard_annotation + _fields: Dict[str, Union[Annotation, Input, Output]] = OrderedDict({}) + if is_func: + return _fields + # In reversed order so that more derived classes + # override earlier field definitions in base classes. + if isinstance(cls_or_func, type): + for base in cls_or_func.__mro__[-1:0:-1]: + if is_group(base): + # merge and reorder fields from current base with previous + _fields = _merge_and_reorder( + _fields, copy.deepcopy(getattr(base, IOConstants.GROUP_ATTR_NAME).values) + ) + return _fields + + skip_params = skip_params or [] + inherited_fields = _get_inherited_fields() + # From annotations get field with type + annotations: Dict[str, Annotation] = getattr(cls_or_func, "__annotations__", {}) + annotations = {k: v for k, v in annotations.items() if k not in skip_params} + annotations = _update_io_from_mldesigner(annotations) + annotation_fields = _get_fields(annotations) + defaults_dict: Dict[str, Any] = {} + # Update fields use class field with defaults from class dict or signature(func).paramters + if not is_func: + # Only consider public fields in class dict + defaults_dict = { + key: val for key, val in cls_or_func.__dict__.items() if not key.startswith("_") and key not in skip_params + } + else: + # Infer parameter type from value if is function + defaults_dict = { + key: val.default + for key, val in signature(cls_or_func).parameters.items() + if key not in skip_params and val.kind != val.VAR_KEYWORD + } + fields = _update_fields_with_default(annotation_fields, defaults_dict) + all_fields = _merge_and_reorder(inherited_fields, fields) + return all_fields + + +def _update_io_from_mldesigner(annotations: Dict[str, Annotation]) -> Dict[str, Union[Annotation, "Input", "Output"]]: + """Translates IOBase from mldesigner package to azure.ml.entities.Input/Output. + + This function depends on: + + * `mldesigner._input_output._IOBase._to_io_entity_args_dict` to translate Input/Output instance annotations + to IO entities. + * class names of `mldesigner._input_output` to translate Input/Output class annotations + to IO entities. + + :param annotations: A map of variable names to annotations + :type annotations: Dict[str, Annotation] + :return: Dict with mldesigner IO types converted to azure-ai-ml Input/Output + :rtype: Dict[str, Union[Annotation, Input, Output]] + """ + from typing_extensions import get_args, get_origin + + from azure.ai.ml import Input, Output + + from .enum_input import EnumInput + + mldesigner_pkg = "mldesigner" + param_name = "_Param" + return_annotation_key = "return" + + def _is_primitive_type(io: type) -> bool: + """Checks whether type is a primitive type + + :param io: A type + :type io: type + :return: Return true if type is subclass of mldesigner._input_output._Param + :rtype: bool + """ + return any(io.__module__.startswith(mldesigner_pkg) and item.__name__ == param_name for item in getmro(io)) + + def _is_input_or_output_type(io: type, type_str: Literal["Input", "Output", "Meta"]) -> bool: + """Checks whether a type is an Input or Output type + + :param io: A type + :type io: type + :param type_str: The kind of type to check for + :type type_str: Literal["Input", "Output", "Meta"] + :return: Return true if type name contains type_str + :rtype: bool + """ + if isinstance(io, type) and io.__module__.startswith(mldesigner_pkg): + if type_str in io.__name__: + return True + return False + + result = {} + for key, io in annotations.items(): # pylint: disable=too-many-nested-blocks + if isinstance(io, type): + if _is_input_or_output_type(io, "Input"): + # mldesigner.Input -> entities.Input + io = Input + elif _is_input_or_output_type(io, "Output"): + # mldesigner.Output -> entities.Output + io = Output + elif _is_primitive_type(io): + io = ( + Output(type=io.TYPE_NAME) # type: ignore + if key == return_annotation_key + else Input(type=io.TYPE_NAME) # type: ignore + ) + elif hasattr(io, "_to_io_entity_args_dict"): + try: + if _is_input_or_output_type(type(io), "Input"): + # mldesigner.Input() -> entities.Input() + if io is not None: + io = Input(**io._to_io_entity_args_dict()) + elif _is_input_or_output_type(type(io), "Output"): + # mldesigner.Output() -> entities.Output() + if io is not None: + io = Output(**io._to_io_entity_args_dict()) + elif _is_primitive_type(type(io)): + if io is not None and not isinstance(io, str): + if io._is_enum(): + io = EnumInput(**io._to_io_entity_args_dict()) + else: + io = ( + Output(**io._to_io_entity_args_dict()) + if key == return_annotation_key + else Input(**io._to_io_entity_args_dict()) + ) + except BaseException as e: + raise UserErrorException(f"Failed to parse {io} to azure-ai-ml Input/Output: {str(e)}") from e + # Handle Annotated annotation + elif get_origin(io) is Annotated: + hint_type, arg, *hint_args = get_args(io) # pylint: disable=unused-variable + if hint_type in SUPPORTED_RETURN_TYPES_PRIMITIVE: + if not _is_input_or_output_type(type(arg), "Meta"): + raise UserErrorException( + f"Annotated Metadata class only support " + f"mldesigner._input_output.Meta, " + f"it is {type(arg)} now." + ) + if arg.type is not None and arg.type != hint_type: + raise UserErrorException( + f"Meta class type {arg.type} should be same as Annotated type: " f"{hint_type}" + ) + arg.type = hint_type + io = ( + Output(**arg._to_io_entity_args_dict()) + if key == return_annotation_key + else Input(**arg._to_io_entity_args_dict()) + ) + result[key] = io + return result + + +def _remove_empty_values(data: Any) -> Any: + """Recursively removes None values from a dict + + :param data: The value to remove None from + :type data: T + :return: + * `data` if `data` is not a dict + * `data` with None values recursively filtered out if data is a dict + :rtype: T + """ + if not isinstance(data, dict): + return data + return {k: _remove_empty_values(v) for k, v in data.items() if v is not None} diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/_input_output_helpers.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/_input_output_helpers.py new file mode 100644 index 00000000..1a13ab41 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/_input_output_helpers.py @@ -0,0 +1,427 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import collections.abc +import re +from typing import Any, Dict, Optional, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + CustomModelJobInput as RestCustomModelJobInput, +) +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + CustomModelJobOutput as RestCustomModelJobOutput, +) +from azure.ai.ml._restclient.v2023_04_01_preview.models import InputDeliveryMode +from azure.ai.ml._restclient.v2023_04_01_preview.models import JobInput as RestJobInput +from azure.ai.ml._restclient.v2023_04_01_preview.models import JobInputType +from azure.ai.ml._restclient.v2023_04_01_preview.models import JobOutput as RestJobOutput +from azure.ai.ml._restclient.v2023_04_01_preview.models import JobOutputType, LiteralJobInput +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + MLFlowModelJobInput as RestMLFlowModelJobInput, +) +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + MLFlowModelJobOutput as RestMLFlowModelJobOutput, +) +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + MLTableJobInput as RestMLTableJobInput, +) +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + MLTableJobOutput as RestMLTableJobOutput, +) +from azure.ai.ml._restclient.v2023_04_01_preview.models import OutputDeliveryMode +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + TritonModelJobInput as RestTritonModelJobInput, +) +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + TritonModelJobOutput as RestTritonModelJobOutput, +) +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + UriFileJobInput as RestUriFileJobInput, +) +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + UriFileJobOutput as RestUriFileJobOutput, +) +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + UriFolderJobInput as RestUriFolderJobInput, +) +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + UriFolderJobOutput as RestUriFolderJobOutput, +) +from azure.ai.ml._utils.utils import is_data_binding_expression +from azure.ai.ml.constants import AssetTypes, InputOutputModes, JobType +from azure.ai.ml.constants._component import IOConstants +from azure.ai.ml.entities._inputs_outputs import Input, Output +from azure.ai.ml.entities._job.input_output_entry import InputOutputEntry +from azure.ai.ml.entities._util import normalize_job_input_output_type +from azure.ai.ml.exceptions import ( + ErrorCategory, + ErrorTarget, + JobException, + ValidationErrorType, + ValidationException, +) + +INPUT_MOUNT_MAPPING_FROM_REST = { + InputDeliveryMode.READ_WRITE_MOUNT: InputOutputModes.RW_MOUNT, + InputDeliveryMode.READ_ONLY_MOUNT: InputOutputModes.RO_MOUNT, + InputDeliveryMode.DOWNLOAD: InputOutputModes.DOWNLOAD, + InputDeliveryMode.DIRECT: InputOutputModes.DIRECT, + InputDeliveryMode.EVAL_MOUNT: InputOutputModes.EVAL_MOUNT, + InputDeliveryMode.EVAL_DOWNLOAD: InputOutputModes.EVAL_DOWNLOAD, +} + +INPUT_MOUNT_MAPPING_TO_REST = { + InputOutputModes.MOUNT: InputDeliveryMode.READ_ONLY_MOUNT, + InputOutputModes.RW_MOUNT: InputDeliveryMode.READ_WRITE_MOUNT, + InputOutputModes.RO_MOUNT: InputDeliveryMode.READ_ONLY_MOUNT, + InputOutputModes.DOWNLOAD: InputDeliveryMode.DOWNLOAD, + InputOutputModes.EVAL_MOUNT: InputDeliveryMode.EVAL_MOUNT, + InputOutputModes.EVAL_DOWNLOAD: InputDeliveryMode.EVAL_DOWNLOAD, + InputOutputModes.DIRECT: InputDeliveryMode.DIRECT, +} + + +OUTPUT_MOUNT_MAPPING_FROM_REST = { + OutputDeliveryMode.READ_WRITE_MOUNT: InputOutputModes.RW_MOUNT, + OutputDeliveryMode.UPLOAD: InputOutputModes.UPLOAD, + OutputDeliveryMode.DIRECT: InputOutputModes.DIRECT, +} + +OUTPUT_MOUNT_MAPPING_TO_REST = { + InputOutputModes.MOUNT: OutputDeliveryMode.READ_WRITE_MOUNT, + InputOutputModes.UPLOAD: OutputDeliveryMode.UPLOAD, + InputOutputModes.RW_MOUNT: OutputDeliveryMode.READ_WRITE_MOUNT, + InputOutputModes.DIRECT: OutputDeliveryMode.DIRECT, +} + + +# TODO: Remove this as both rest type and sdk type are snake case now. +def get_output_type_mapping_from_rest() -> Dict[str, str]: + """Gets the mapping of JobOutputType to AssetType + + :return: Mapping of JobOutputType to AssetType + :rtype: Dict[str, str] + """ + return { + JobOutputType.URI_FILE: AssetTypes.URI_FILE, + JobOutputType.URI_FOLDER: AssetTypes.URI_FOLDER, + JobOutputType.MLTABLE: AssetTypes.MLTABLE, + JobOutputType.MLFLOW_MODEL: AssetTypes.MLFLOW_MODEL, + JobOutputType.CUSTOM_MODEL: AssetTypes.CUSTOM_MODEL, + JobOutputType.TRITON_MODEL: AssetTypes.TRITON_MODEL, + } + + +def get_input_rest_cls_dict() -> Dict[str, RestJobInput]: + """Gets the mapping of AssetType to RestJobInput + + :return: Map of AssetType to RestJobInput + :rtype: Dict[str, RestJobInput] + """ + return { + AssetTypes.URI_FILE: RestUriFileJobInput, + AssetTypes.URI_FOLDER: RestUriFolderJobInput, + AssetTypes.MLTABLE: RestMLTableJobInput, + AssetTypes.MLFLOW_MODEL: RestMLFlowModelJobInput, + AssetTypes.CUSTOM_MODEL: RestCustomModelJobInput, + AssetTypes.TRITON_MODEL: RestTritonModelJobInput, + } + + +def get_output_rest_cls_dict() -> Dict[str, RestJobOutput]: + """Get output rest init cls dict. + + :return: Map of AssetType to RestJobOutput + :rtype: Dict[str, RestJobOutput] + """ + return { + AssetTypes.URI_FILE: RestUriFileJobOutput, + AssetTypes.URI_FOLDER: RestUriFolderJobOutput, + AssetTypes.MLTABLE: RestMLTableJobOutput, + AssetTypes.MLFLOW_MODEL: RestMLFlowModelJobOutput, + AssetTypes.CUSTOM_MODEL: RestCustomModelJobOutput, + AssetTypes.TRITON_MODEL: RestTritonModelJobOutput, + } + + +def build_input_output( + item: Union[InputOutputEntry, Input, Output, str, bool, int, float], + inputs: bool = True, +) -> Union[InputOutputEntry, Input, Output, str, bool, int, float]: + if isinstance(item, (Input, InputOutputEntry, Output)): + # return objects constructed at yaml load or specified in sdk + return item + # parse dictionary into supported class + if isinstance(item, collections.abc.Mapping): + if item.get("data"): + return InputOutputEntry(**item) + # else default to JobInput + return Input(**item) if inputs else Output(**item) + # return literal inputs as-is + return item + + +def _validate_inputs_for(input_consumer_name: str, input_consumer: str, inputs: Optional[Dict]) -> None: + implicit_inputs = re.findall(r"\${{inputs\.([\w\.-]+)}}", input_consumer) + # optional inputs no need to validate whether they're in inputs + optional_inputs = re.findall(r"\[[\w\.\s-]*\${{inputs\.([\w\.-]+)}}]", input_consumer) + for key in implicit_inputs: + if inputs is not None and inputs.get(key, None) is None and key not in optional_inputs: + msg = "Inputs to job does not contain '{}' referenced in " + input_consumer_name + raise ValidationException( + message=msg.format(key), + no_personal_data_message=msg.format("[key]"), + target=ErrorTarget.JOB, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + +def validate_inputs_for_command(command: Optional[str], inputs: Optional[Dict]) -> None: + if command is not None: + _validate_inputs_for("command", command, inputs) + + +def validate_inputs_for_args(args: str, inputs: Optional[Dict[str, Any]]) -> None: + _validate_inputs_for("args", args, inputs) + + +def validate_key_contains_allowed_characters(key: str) -> None: + if re.match(r"^[a-zA-Z_]+[a-zA-Z0-9_]*$", key) is None: + msg = "Key name {} must be composed letters, numbers, and underscore" + raise ValidationException( + message=msg.format(key), + no_personal_data_message=msg.format("[key]"), + target=ErrorTarget.JOB, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + +def validate_pipeline_input_key_characters(key: str) -> None: + # Pipeline input allow '.' to support parameter group in key. + # Note: ([a-zA-Z_]+[a-zA-Z0-9_]*) is a valid single key, + # so a valid pipeline key is: ^{single_key}([.]{single_key})*$ + if re.match(IOConstants.VALID_KEY_PATTERN, key) is None: + msg = ( + "Pipeline input key name {} must be composed letters, numbers, and underscores with optional split by dots." + ) + raise ValidationException( + message=msg.format(key), + no_personal_data_message=msg.format("[key]"), + target=ErrorTarget.JOB, + error_category=ErrorCategory.USER_ERROR, + ) + + +def to_rest_dataset_literal_inputs( + inputs: Optional[Dict], + *, + job_type: Optional[str], +) -> Dict[str, RestJobInput]: + """Turns dataset and literal inputs into dictionary of REST JobInput. + + :param inputs: Dictionary of dataset and literal inputs to job + :type inputs: Dict[str, Union[int, str, float, bool, JobInput]] + :return: A dictionary mapping input name to a ComponentJobInput or PipelineInput + :rtype: Dict[str, Union[ComponentJobInput, PipelineInput]] + :keyword job_type: When job_type is pipeline, enable dot('.') in parameter keys to support parameter group. + TODO: Remove this after move name validation to Job's customized validate. + :paramtype job_type: str + """ + rest_inputs = {} + + if inputs is not None: + # Pack up the inputs into REST format + for input_name, input_value in inputs.items(): + if job_type == JobType.PIPELINE: + validate_pipeline_input_key_characters(input_name) + elif job_type: + # We pass job_type=None for pipeline node, and want skip this check for nodes. + validate_key_contains_allowed_characters(input_name) + if isinstance(input_value, Input): + if ( + input_value.path + and isinstance(input_value.path, str) + and is_data_binding_expression(input_value.path) + ): + input_data = LiteralJobInput(value=input_value.path) + # set mode attribute manually for binding job input + if input_value.mode: + input_data.mode = INPUT_MOUNT_MAPPING_TO_REST[input_value.mode] + if getattr(input_value, "path_on_compute", None) is not None: + input_data.pathOnCompute = input_value.path_on_compute + input_data.job_input_type = JobInputType.LITERAL + else: + target_cls_dict = get_input_rest_cls_dict() + + if input_value.type in target_cls_dict: + input_data = target_cls_dict[input_value.type]( + uri=input_value.path, + mode=(INPUT_MOUNT_MAPPING_TO_REST[input_value.mode.lower()] if input_value.mode else None), + ) + else: + msg = f"Job input type {input_value.type} is not supported as job input." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.JOB, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + elif input_value is None: + # If the input is None, we need to pass the origin None to the REST API + input_data = LiteralJobInput(value=None) + else: + # otherwise, the input is a literal input + if isinstance(input_value, dict): + input_data = LiteralJobInput(value=str(input_value["value"])) + # set mode attribute manually for binding job input + if "mode" in input_value: + input_data.mode = input_value["mode"] + else: + input_data = LiteralJobInput(value=str(input_value)) + input_data.job_input_type = JobInputType.LITERAL + # Pack up inputs into PipelineInputs or ComponentJobInputs depending on caller + rest_inputs[input_name] = input_data + return rest_inputs + + +def from_rest_inputs_to_dataset_literal(inputs: Dict[str, RestJobInput]) -> Dict: + """Turns REST dataset and literal inputs into the SDK format. + + :param inputs: Dictionary mapping input name to ComponentJobInput or PipelineInput + :type inputs: Dict[str, Union[ComponentJobInput, PipelineInput]] + :return: A dictionary mapping input name to a literal value or JobInput + :rtype: Dict[str, Union[int, str, float, bool, JobInput]] + """ + if inputs is None: + return {} + from_rest_inputs = {} + # Unpack the inputs + for input_name, input_value in inputs.items(): + # TODO:Brandon Clarify with PMs if user should be able to define null input objects + if input_value is None: + continue + + # TODO: Remove this as both rest type and sdk type are snake case now. + type_transfer_dict = get_output_type_mapping_from_rest() + # deal with invalid input type submitted by feb api + # todo: backend help convert node level input/output type + normalize_job_input_output_type(input_value) + + if input_value.job_input_type in type_transfer_dict: + if input_value.uri: + path = input_value.uri + if getattr(input_value, "pathOnCompute", None) is not None: + sourcePathOnCompute = input_value.pathOnCompute + else: + sourcePathOnCompute = None + input_data = Input( + type=type_transfer_dict[input_value.job_input_type], + path=path, + mode=(INPUT_MOUNT_MAPPING_FROM_REST[input_value.mode] if input_value.mode else None), + path_on_compute=sourcePathOnCompute, + ) + elif input_value.job_input_type in (JobInputType.LITERAL, JobInputType.LITERAL): + # otherwise, the input is a literal, so just unpack the InputData value field + input_data = input_value.value + else: + msg = f"Job input type {input_value.job_input_type} is not supported as job input." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.JOB, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + from_rest_inputs[input_name] = input_data # pylint: disable=possibly-used-before-assignment + return from_rest_inputs + + +def to_rest_data_outputs(outputs: Optional[Dict]) -> Dict[str, RestJobOutput]: + """Turns job outputs into REST format. + + :param outputs: Dictionary of dataset outputs from job + :type outputs: Dict[str, JobOutput] + :return: A dictionary mapping output name to a RestJobOutput + :rtype: Dict[str, RestJobOutput] + """ + rest_outputs = {} + if outputs is not None: + for output_name, output_value in outputs.items(): + validate_key_contains_allowed_characters(output_name) + if output_value is None: + # pipeline output could be none, default to URI folder with None mode + output_cls = RestUriFolderJobOutput + rest_outputs[output_name] = output_cls(mode=None) + else: + target_cls_dict = get_output_rest_cls_dict() + + output_value_type = output_value.type if output_value.type else AssetTypes.URI_FOLDER + if output_value_type in target_cls_dict: + output = target_cls_dict[output_value_type]( + asset_name=output_value.name, + asset_version=output_value.version, + uri=output_value.path, + mode=(OUTPUT_MOUNT_MAPPING_TO_REST[output_value.mode.lower()] if output_value.mode else None), + pathOnCompute=getattr(output_value, "path_on_compute", None), + description=output_value.description, + ) + else: + msg = "unsupported JobOutput type: {}".format(output_value.type) + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.JOB, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + rest_outputs[output_name] = output + return rest_outputs + + +def from_rest_data_outputs(outputs: Dict[str, RestJobOutput]) -> Dict[str, Output]: + """Turns REST outputs into the SDK format. + + :param outputs: Dictionary of dataset and literal inputs to job + :type outputs: Dict[str, RestJobOutput] + :return: A dictionary mapping input name to a InputOutputEntry + :rtype: Dict[str, JobOutput] + """ + output_type_mapping = get_output_type_mapping_from_rest() + from_rest_outputs = {} + if outputs is None: + return {} + for output_name, output_value in outputs.items(): + if output_value is None: + continue + # deal with invalid output type submitted by feb api + # todo: backend help convert node level input/output type + normalize_job_input_output_type(output_value) + if getattr(output_value, "pathOnCompute", None) is not None: + sourcePathOnCompute = output_value.pathOnCompute + else: + sourcePathOnCompute = None + if output_value.job_output_type in output_type_mapping: + from_rest_outputs[output_name] = Output( + type=output_type_mapping[output_value.job_output_type], + path=output_value.uri, + mode=(OUTPUT_MOUNT_MAPPING_FROM_REST[output_value.mode] if output_value.mode else None), + path_on_compute=sourcePathOnCompute, + description=output_value.description, + name=output_value.asset_name, + version=(output_value.asset_version if hasattr(output_value, "asset_version") else None), + ) + else: + msg = "unsupported JobOutput type: {}".format(output_value.job_output_type) + raise JobException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.JOB, + error_category=ErrorCategory.SYSTEM_ERROR, + ) + + return from_rest_outputs diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/_studio_url_from_job_id.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/_studio_url_from_job_id.py new file mode 100644 index 00000000..63ad6f06 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/_studio_url_from_job_id.py @@ -0,0 +1,26 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import re +from typing import Optional + +from azure.ai.ml._azure_environments import _get_aml_resource_id_from_metadata, _get_default_cloud_name + +JOB_ID_RE_PATTERN = re.compile( + ( + r"\/subscriptions\/(?P<subscription>[\w,-]+)\/resourceGroups\/(?P<resource_group>[\w,-]+)\/providers" + r"\/Microsoft\.MachineLearningServices\/workspaces\/(?P<workspace>[\w,-]+)\/jobs\/(?P<run_id>[\w,-]+)" + ) # fmt: skip +) + + +def studio_url_from_job_id(job_id: str) -> Optional[str]: + resource_id = _get_aml_resource_id_from_metadata(_get_default_cloud_name()) + m = JOB_ID_RE_PATTERN.match(job_id) + if m: + return ( + f"{resource_id}/runs/{m.group('run_id')}?wsid=/subscriptions/{m.group('subscription')}" + f"/resourcegroups/{m.group('resource_group')}/workspaces/{m.group('workspace')}" + ) # fmt: skip + return None diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/__init__.py new file mode 100644 index 00000000..e99e9321 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/__init__.py @@ -0,0 +1,16 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) + +from .search_space import SearchSpace +from .stack_ensemble_settings import StackEnsembleSettings +from .training_settings import ClassificationTrainingSettings, TrainingSettings + +__all__ = [ + "ClassificationTrainingSettings", + "TrainingSettings", + "SearchSpace", + "StackEnsembleSettings", +] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/automl_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/automl_job.py new file mode 100644 index 00000000..9e1b4d05 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/automl_job.py @@ -0,0 +1,283 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import logging +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional, Union + +from azure.ai.ml._restclient.v2024_01_01_preview.models import ( + JobBase, + MLTableJobInput, + QueueSettings, + ResourceConfiguration, + TaskType, +) +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.constants import JobType +from azure.ai.ml.constants._common import TYPE, AssetTypes +from azure.ai.ml.constants._job.automl import AutoMLConstants +from azure.ai.ml.entities._credentials import ( + AmlTokenConfiguration, + ManagedIdentityConfiguration, + UserIdentityConfiguration, +) +from azure.ai.ml.entities._inputs_outputs import Input +from azure.ai.ml.entities._job.job import Job +from azure.ai.ml.entities._job.job_io_mixin import JobIOMixin +from azure.ai.ml.entities._job.pipeline._io import AutoMLNodeIOMixin +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException + +module_logger = logging.getLogger(__name__) + + +class AutoMLJob(Job, JobIOMixin, AutoMLNodeIOMixin, ABC): + """Initialize an AutoML job entity. + + Constructor for an AutoMLJob. + + :keyword resources: Resource configuration for the AutoML job, defaults to None + :paramtype resources: typing.Optional[ResourceConfiguration] + :keyword identity: Identity that training job will use while running on compute, defaults to None + :paramtype identity: typing.Optional[ typing.Union[ManagedIdentityConfiguration, AmlTokenConfiguration + , UserIdentityConfiguration] ] + :keyword environment_id: The environment id for the AutoML job, defaults to None + :paramtype environment_id: typing.Optional[str] + :keyword environment_variables: The environment variables for the AutoML job, defaults to None + :paramtype environment_variables: typing.Optional[Dict[str, str]] + :keyword outputs: The outputs for the AutoML job, defaults to None + :paramtype outputs: typing.Optional[Dict[str, str]] + :keyword queue_settings: The queue settings for the AutoML job, defaults to None + :paramtype queue_settings: typing.Optional[QueueSettings] + :raises ValidationException: task type validation error + :raises NotImplementedError: Raises NotImplementedError + :return: An AutoML Job + :rtype: AutoMLJob + """ + + def __init__( + self, + *, + resources: Optional[ResourceConfiguration] = None, + identity: Optional[ + Union[ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration] + ] = None, + queue_settings: Optional[QueueSettings] = None, + **kwargs: Any, + ) -> None: + """Initialize an AutoML job entity. + + Constructor for an AutoMLJob. + + :keyword resources: Resource configuration for the AutoML job, defaults to None + :paramtype resources: typing.Optional[ResourceConfiguration] + :keyword identity: Identity that training job will use while running on compute, defaults to None + :paramtype identity: typing.Optional[ typing.Union[ManagedIdentityConfiguration, AmlTokenConfiguration + , UserIdentityConfiguration] ] + :keyword environment_id: The environment id for the AutoML job, defaults to None + :paramtype environment_id: typing.Optional[str] + :keyword environment_variables: The environment variables for the AutoML job, defaults to None + :paramtype environment_variables: typing.Optional[Dict[str, str]] + :keyword outputs: The outputs for the AutoML job, defaults to None + :paramtype outputs: typing.Optional[Dict[str, str]] + :keyword queue_settings: The queue settings for the AutoML job, defaults to None + :paramtype queue_settings: typing.Optional[QueueSettings] + :raises ValidationException: task type validation error + :raises NotImplementedError: Raises NotImplementedError + """ + kwargs[TYPE] = JobType.AUTOML + self.environment_id = kwargs.pop("environment_id", None) + self.environment_variables = kwargs.pop("environment_variables", None) + self.outputs = kwargs.pop("outputs", None) + + super().__init__(**kwargs) + + self.resources = resources + self.identity = identity + self.queue_settings = queue_settings + + @property + @abstractmethod + def training_data(self) -> Input: + """The training data for the AutoML job. + + :raises NotImplementedError: Raises NotImplementedError + :return: Returns the training data for the AutoML job. + :rtype: Input + """ + raise NotImplementedError() + + @training_data.setter + def training_data(self, value: Any) -> None: + self.training_data = value + + @property + @abstractmethod + def validation_data(self) -> Input: + """The validation data for the AutoML job. + + :raises NotImplementedError: Raises NotImplementedError + :return: Returns the validation data for the AutoML job. + :rtype: Input + """ + raise NotImplementedError() + + @validation_data.setter + def validation_data(self, value: Any) -> None: + self.validation_data = value + + @property + @abstractmethod + def test_data(self) -> Input: + """The test data for the AutoML job. + + :raises NotImplementedError: Raises NotImplementedError + :return: Returns the test data for the AutoML job. + :rtype: Input + """ + raise NotImplementedError() + + @test_data.setter + def test_data(self, value: Any) -> None: + self.test_data = value + + @classmethod + def _load_from_rest(cls, obj: JobBase) -> "AutoMLJob": + """Loads the rest object to a dict containing items to init the AutoMLJob objects. + + :param obj: Azure Resource Manager resource envelope. + :type obj: JobBase + :raises ValidationException: task type validation error + :return: An AutoML Job + :rtype: AutoMLJob + """ + task_type = ( + camel_to_snake(obj.properties.task_details.task_type) if obj.properties.task_details.task_type else None + ) + class_type = cls._get_task_mapping().get(task_type, None) + if class_type: + res: AutoMLJob = class_type._from_rest_object(obj) + return res + msg = f"Unsupported task type: {obj.properties.task_details.task_type}" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.AUTOML, + error_category=ErrorCategory.SYSTEM_ERROR, + ) + + @classmethod + def _load_from_dict( + cls, + data: Dict, + context: Dict, + additional_message: str, + **kwargs: Any, + ) -> "AutoMLJob": + """Loads the dictionary objects to an AutoMLJob object. + + :param data: A data dictionary. + :type data: typing.Dict + :param context: A context dictionary. + :type context: typing.Dict + :param additional_message: An additional message to be logged in the ValidationException. + :type additional_message: str + + :raises ValidationException: task type validation error + :return: An AutoML Job + :rtype: AutoMLJob + """ + task_type = data.get(AutoMLConstants.TASK_TYPE_YAML) + class_type = cls._get_task_mapping().get(task_type, None) + if class_type: + res: AutoMLJob = class_type._load_from_dict( + data, + context, + additional_message, + **kwargs, + ) + return res + msg = f"Unsupported task type: {task_type}" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.AUTOML, + error_category=ErrorCategory.USER_ERROR, + ) + + @classmethod + def _create_instance_from_schema_dict(cls, loaded_data: Dict) -> "AutoMLJob": + """Create an automl job instance from schema parsed dict. + + :param loaded_data: A loaded_data dictionary. + :type loaded_data: typing.Dict + :raises ValidationException: task type validation error + :return: An AutoML Job + :rtype: AutoMLJob + """ + task_type = loaded_data.pop(AutoMLConstants.TASK_TYPE_YAML) + class_type = cls._get_task_mapping().get(task_type, None) + if class_type: + res: AutoMLJob = class_type._create_instance_from_schema_dict(loaded_data=loaded_data) + return res + msg = f"Unsupported task type: {task_type}" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.AUTOML, + error_category=ErrorCategory.USER_ERROR, + ) + + @classmethod + def _get_task_mapping(cls) -> Dict: + """Create a mapping of task type to job class. + + :return: An AutoMLVertical object containing the task type to job class mapping. + :rtype: AutoMLVertical + """ + from .image import ( + ImageClassificationJob, + ImageClassificationMultilabelJob, + ImageInstanceSegmentationJob, + ImageObjectDetectionJob, + ) + from .nlp import TextClassificationJob, TextClassificationMultilabelJob, TextNerJob + from .tabular import ClassificationJob, ForecastingJob, RegressionJob + + # create a mapping of task type to job class + return { + camel_to_snake(TaskType.CLASSIFICATION): ClassificationJob, + camel_to_snake(TaskType.REGRESSION): RegressionJob, + camel_to_snake(TaskType.FORECASTING): ForecastingJob, + camel_to_snake(TaskType.IMAGE_CLASSIFICATION): ImageClassificationJob, + camel_to_snake(TaskType.IMAGE_CLASSIFICATION_MULTILABEL): ImageClassificationMultilabelJob, + camel_to_snake(TaskType.IMAGE_OBJECT_DETECTION): ImageObjectDetectionJob, + camel_to_snake(TaskType.IMAGE_INSTANCE_SEGMENTATION): ImageInstanceSegmentationJob, + camel_to_snake(TaskType.TEXT_NER): TextNerJob, + camel_to_snake(TaskType.TEXT_CLASSIFICATION): TextClassificationJob, + camel_to_snake(TaskType.TEXT_CLASSIFICATION_MULTILABEL): TextClassificationMultilabelJob, + } + + def _resolve_data_inputs(self, rest_job: "AutoMLJob") -> None: + """Resolve JobInputs to MLTableJobInputs within data_settings. + + :param rest_job: The rest job object. + :type rest_job: AutoMLJob + """ + if isinstance(rest_job.training_data, Input): + rest_job.training_data = MLTableJobInput(uri=rest_job.training_data.path) + if isinstance(rest_job.validation_data, Input): + rest_job.validation_data = MLTableJobInput(uri=rest_job.validation_data.path) + if hasattr(rest_job, "test_data") and isinstance(rest_job.test_data, Input): + rest_job.test_data = MLTableJobInput(uri=rest_job.test_data.path) + + def _restore_data_inputs(self) -> None: + """Restore MLTableJobInputs to JobInputs within data_settings.""" + if isinstance(self.training_data, MLTableJobInput): + self.training_data = Input(type=AssetTypes.MLTABLE, path=self.training_data.uri) + if isinstance(self.validation_data, MLTableJobInput): + self.validation_data = Input(type=AssetTypes.MLTABLE, path=self.validation_data.uri) + if hasattr(self, "test_data") and isinstance(self.test_data, MLTableJobInput): + self.test_data = Input(type=AssetTypes.MLTABLE, path=self.test_data.uri) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/automl_vertical.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/automl_vertical.py new file mode 100644 index 00000000..f11be81c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/automl_vertical.py @@ -0,0 +1,134 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from abc import abstractmethod +from typing import Any, Optional + +from azure.ai.ml import Input + +from .automl_job import AutoMLJob + + +class AutoMLVertical(AutoMLJob): + """Abstract class for AutoML verticals. + + :param task_type: The type of task to run. Possible values include: "classification", "regression", "forecasting". + :type task_type: str + :param training_data: Training data input + :type training_data: Input + :param validation_data: Validation data input + :type validation_data: Input + :param test_data: Test data input, defaults to None + :type test_data: typing.Optional[Input] + :raises ValueError: If task_type is not one of "classification", "regression", "forecasting". + :raises ValueError: If training_data is not of type Input. + :raises ValueError: If validation_data is not of type Input. + :raises ValueError: If test_data is not of type Input. + """ + + @abstractmethod + def __init__( + self, + task_type: str, + training_data: Input, + validation_data: Input, + test_data: Optional[Input] = None, + **kwargs: Any + ) -> None: + """Initialize AutoMLVertical. + + Constructor for AutoMLVertical. + + :param task_type: The type of task to run. Possible values include: "classification", "regression" + , "forecasting". + :type task_type: str + :param training_data: Training data input + :type training_data: Input + :param validation_data: Validation data input + :type validation_data: Input + :param test_data: Test data input, defaults to None + :type test_data: typing.Optional[Input] + :raises ValueError: If task_type is not one of "classification", "regression", "forecasting". + :raises ValueError: If training_data is not of type Input. + :raises ValueError: If validation_data is not of type Input. + :raises ValueError: If test_data is not of type Input. + """ + self._task_type = task_type + self.training_data = training_data + self.validation_data = validation_data + self.test_data = test_data # type: ignore + super().__init__(**kwargs) + + @property + def task_type(self) -> str: + """Get task type. + + :return: The type of task to run. Possible values include: "classification", "regression", "forecasting". + :rtype: str + """ + return self._task_type + + @task_type.setter + def task_type(self, task_type: str) -> None: + """Set task type. + + :param task_type: The type of task to run. Possible values include: "classification", "regression" + , "forecasting". + :type task_type: str + """ + self._task_type = task_type + + @property + def training_data(self) -> Input: + """Get training data. + + :return: Training data input + :rtype: Input + """ + return self._training_data + + @training_data.setter + def training_data(self, training_data: Input) -> None: + """Set training data. + + :param training_data: Training data input + :type training_data: Input + """ + self._training_data = training_data + + @property + def validation_data(self) -> Input: + """Get validation data. + + :return: Validation data input + :rtype: Input + """ + return self._validation_data + + @validation_data.setter + def validation_data(self, validation_data: Input) -> None: + """Set validation data. + + :param validation_data: Validation data input + :type validation_data: Input + """ + self._validation_data = validation_data + + @property + def test_data(self) -> Input: + """Get test data. + + :return: Test data input + :rtype: Input + """ + return self._test_data + + @test_data.setter + def test_data(self, test_data: Input) -> None: + """Set test data. + + :param test_data: Test data input + :type test_data: Input + """ + self._test_data = test_data diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/featurization_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/featurization_settings.py new file mode 100644 index 00000000..c9e73d21 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/featurization_settings.py @@ -0,0 +1,32 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Optional + +from azure.ai.ml.entities._mixins import RestTranslatableMixin + + +class FeaturizationSettings(RestTranslatableMixin): + """Base Featurization settings.""" + + def __init__( + self, + *, + dataset_language: Optional[str] = None, + ): + self.dataset_language = dataset_language + + def __eq__(self, other: object) -> bool: + if not isinstance(other, FeaturizationSettings): + return NotImplemented + + return self.dataset_language == other.dataset_language + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + +class FeaturizationSettingsType: + NLP = "nlp" + TABULAR = "tabular" diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/__init__.py new file mode 100644 index 00000000..46964086 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/__init__.py @@ -0,0 +1,35 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from .automl_image import AutoMLImage +from .image_classification_job import ImageClassificationJob +from .image_classification_multilabel_job import ImageClassificationMultilabelJob +from .image_classification_search_space import ImageClassificationSearchSpace +from .image_instance_segmentation_job import ImageInstanceSegmentationJob +from .image_limit_settings import ImageLimitSettings +from .image_model_settings import ( + ImageModelSettingsClassification, + ImageModelSettingsObjectDetection, + LogTrainingMetrics, + LogValidationLoss, +) +from .image_object_detection_job import ImageObjectDetectionJob +from .image_object_detection_search_space import ImageObjectDetectionSearchSpace +from .image_sweep_settings import ImageSweepSettings + +__all__ = [ + "AutoMLImage", + "LogTrainingMetrics", + "LogValidationLoss", + "ImageClassificationJob", + "ImageClassificationMultilabelJob", + "ImageClassificationSearchSpace", + "ImageInstanceSegmentationJob", + "ImageLimitSettings", + "ImageObjectDetectionJob", + "ImageObjectDetectionSearchSpace", + "ImageSweepSettings", + "ImageModelSettingsClassification", + "ImageModelSettingsObjectDetection", +] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/automl_image.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/automl_image.py new file mode 100644 index 00000000..a07bba4a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/automl_image.py @@ -0,0 +1,244 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from abc import ABC +from typing import Any, Dict, Optional, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import LogVerbosity, SamplingAlgorithmType +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.entities._inputs_outputs import Input +from azure.ai.ml.entities._job.automl.automl_vertical import AutoMLVertical +from azure.ai.ml.entities._job.automl.image.image_limit_settings import ImageLimitSettings +from azure.ai.ml.entities._job.automl.image.image_sweep_settings import ImageSweepSettings +from azure.ai.ml.entities._job.sweep.early_termination_policy import ( + BanditPolicy, + MedianStoppingPolicy, + TruncationSelectionPolicy, +) +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException + + +class AutoMLImage(AutoMLVertical, ABC): + """Base class for all AutoML Image jobs. + You should not instantiate this class directly. + Instead you should create classes for specific AutoML Image tasks. + + :keyword task_type: Required. Type of task to run. + Possible values include: "ImageClassification", "ImageClassificationMultilabel", + "ImageObjectDetection", "ImageInstanceSegmentation" + :paramtype task_type: str + :keyword limits: Limit settings for all AutoML Image jobs. Defaults to None. + :paramtype limits: Optional[~azure.ai.ml.automl.ImageLimitSettings] + :keyword sweep: Sweep settings for all AutoML Image jobs. Defaults to None. + :paramtype sweep: Optional[~azure.ai.ml.automl.ImageSweepSettings] + :keyword kwargs: Additional keyword arguments for AutoMLImage. + :paramtype kwargs: Dict[str, Any] + """ + + def __init__( + self, + *, + task_type: str, + limits: Optional[ImageLimitSettings] = None, + sweep: Optional[ImageSweepSettings] = None, + **kwargs: Any, + ) -> None: + self.log_verbosity = kwargs.pop("log_verbosity", LogVerbosity.INFO) + self.target_column_name = kwargs.pop("target_column_name", None) + self.validation_data_size = kwargs.pop("validation_data_size", None) + + super().__init__( + task_type=task_type, + training_data=kwargs.pop("training_data", None), + validation_data=kwargs.pop("validation_data", None), + **kwargs, + ) + + # Set default value for self._limits as it is a required property in rest object. + self._limits = limits or ImageLimitSettings() + self._sweep = sweep + + @property + def log_verbosity(self) -> LogVerbosity: + """Returns the verbosity of the logger. + + :return: The log verbosity. + :rtype: ~azure.ai.ml._restclient.v2023_04_01_preview.models.LogVerbosity + """ + return self._log_verbosity + + @log_verbosity.setter + def log_verbosity(self, value: Union[str, LogVerbosity]) -> None: + """Sets the verbosity of the logger. + + :param value: The value to set the log verbosity to. + Possible values include: "NotSet", "Debug", "Info", "Warning", "Error", "Critical". + :type value: Union[str, ~azure.ai.ml._restclient.v2023_04_01_preview.models.LogVerbosity] + """ + self._log_verbosity = None if value is None else LogVerbosity[camel_to_snake(value).upper()] + + @property + def limits(self) -> ImageLimitSettings: + """Returns the limit settings for all AutoML Image jobs. + + :return: The limit settings. + :rtype: ~azure.ai.ml.automl.ImageLimitSettings + """ + return self._limits + + @limits.setter + def limits(self, value: Union[Dict, ImageLimitSettings]) -> None: + if isinstance(value, ImageLimitSettings): + self._limits = value + else: + if not isinstance(value, dict): + msg = "Expected a dictionary for limit settings." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.AUTOML, + error_category=ErrorCategory.USER_ERROR, + ) + self.set_limits(**value) + + @property + def sweep(self) -> Optional[ImageSweepSettings]: + """Returns the sweep settings for all AutoML Image jobs. + + :return: The sweep settings. + :rtype: ~azure.ai.ml.automl.ImageSweepSettings + """ + return self._sweep + + @sweep.setter + def sweep(self, value: Union[Dict, ImageSweepSettings]) -> None: + """Sets the sweep settings for all AutoML Image jobs. + + :param value: The value to set the sweep settings to. + :type value: Union[Dict, ~azure.ai.ml.automl.ImageSweepSettings] + :raises ~azure.ai.ml.exceptions.ValidationException: If value is not a dictionary. + :return: None + """ + if isinstance(value, ImageSweepSettings): + self._sweep = value + else: + if not isinstance(value, dict): + msg = "Expected a dictionary for sweep settings." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.AUTOML, + error_category=ErrorCategory.USER_ERROR, + ) + self.set_sweep(**value) + + def set_data( + self, + *, + training_data: Input, + target_column_name: str, + validation_data: Optional[Input] = None, + validation_data_size: Optional[float] = None, + ) -> None: + """Data settings for all AutoML Image jobs. + + :keyword training_data: Required. Training data. + :type training_data: ~azure.ai.ml.entities.Input + :keyword target_column_name: Required. Target column name. + :type target_column_name: str + :keyword validation_data: Optional. Validation data. + :type validation_data: Optional[~azure.ai.ml.entities.Input] + :keyword validation_data_size: Optional. The fraction of training dataset that needs to be set aside for + validation purpose. Values should be in range (0.0 , 1.0). + Applied only when validation dataset is not provided. + :type validation_data_size: Optional[float] + :return: None + """ + self.target_column_name = self.target_column_name if target_column_name is None else target_column_name + self.training_data = self.training_data if training_data is None else training_data + self.validation_data = self.validation_data if validation_data is None else validation_data + self.validation_data_size = self.validation_data_size if validation_data_size is None else validation_data_size + + def set_limits( + self, + *, + max_concurrent_trials: Optional[int] = None, + max_trials: Optional[int] = None, + timeout_minutes: Optional[int] = None, + ) -> None: + """Limit settings for all AutoML Image Jobs. + + :keyword max_concurrent_trials: Maximum number of trials to run concurrently. + :type max_concurrent_trials: Optional[int]. Defaults to None. + :keyword max_trials: Maximum number of trials to run. Defaults to None. + :type max_trials: Optional[int] + :keyword timeout_minutes: AutoML job timeout. + :type timeout_minutes: ~datetime.timedelta + :return: None + """ + self._limits = self._limits or ImageLimitSettings() + self._limits.max_concurrent_trials = ( + max_concurrent_trials if max_concurrent_trials is not None else self._limits.max_concurrent_trials + ) + self._limits.max_trials = max_trials if max_trials is not None else self._limits.max_trials + self._limits.timeout_minutes = timeout_minutes if timeout_minutes is not None else self._limits.timeout_minutes + + def set_sweep( + self, + *, + sampling_algorithm: Union[ + str, SamplingAlgorithmType.RANDOM, SamplingAlgorithmType.GRID, SamplingAlgorithmType.BAYESIAN + ], + early_termination: Optional[Union[BanditPolicy, MedianStoppingPolicy, TruncationSelectionPolicy]] = None, + ) -> None: + """Sweep settings for all AutoML Image jobs. + + :keyword sampling_algorithm: Required. Type of the hyperparameter sampling + algorithms. Possible values include: "Grid", "Random", "Bayesian". + :type sampling_algorithm: Union[str, ~azure.mgmt.machinelearningservices.models.SamplingAlgorithmType.RANDOM, + ~azure.mgmt.machinelearningservices.models.SamplingAlgorithmType.GRID, + ~azure.mgmt.machinelearningservices.models.SamplingAlgorithmType.BAYESIAN] + :keyword early_termination: Type of early termination policy. + :type early_termination: Union[ + ~azure.mgmt.machinelearningservices.models.BanditPolicy, + ~azure.mgmt.machinelearningservices.models.MedianStoppingPolicy, + ~azure.mgmt.machinelearningservices.models.TruncationSelectionPolicy] + :return: None + """ + if self._sweep: + self._sweep.sampling_algorithm = sampling_algorithm + else: + self._sweep = ImageSweepSettings(sampling_algorithm=sampling_algorithm) + + self._sweep.early_termination = early_termination or self._sweep.early_termination + + def __eq__(self, other: object) -> bool: + """Compares two AutoMLImage objects for equality. + + :param other: The other AutoMLImage object to compare to. + :type other: ~azure.ai.ml.automl.AutoMLImage + :return: True if the two AutoMLImage objects are equal; False otherwise. + :rtype: bool + """ + if not isinstance(other, AutoMLImage): + return NotImplemented + + return ( + self.target_column_name == other.target_column_name + and self.training_data == other.training_data + and self.validation_data == other.validation_data + and self.validation_data_size == other.validation_data_size + and self._limits == other._limits + and self._sweep == other._sweep + ) + + def __ne__(self, other: object) -> bool: + """Compares two AutoMLImage objects for inequality. + + :param other: The other AutoMLImage object to compare to. + :type other: ~azure.ai.ml.automl.AutoMLImage + :return: True if the two AutoMLImage objects are not equal; False otherwise. + :rtype: bool + """ + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/automl_image_classification_base.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/automl_image_classification_base.py new file mode 100644 index 00000000..ef0c8a2d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/automl_image_classification_base.py @@ -0,0 +1,439 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from typing import Any, Dict, List, Optional, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import LearningRateScheduler, StochasticOptimizer +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.entities._job.automl.image.automl_image import AutoMLImage +from azure.ai.ml.entities._job.automl.image.image_classification_search_space import ImageClassificationSearchSpace +from azure.ai.ml.entities._job.automl.image.image_limit_settings import ImageLimitSettings +from azure.ai.ml.entities._job.automl.image.image_model_settings import ImageModelSettingsClassification +from azure.ai.ml.entities._job.automl.image.image_sweep_settings import ImageSweepSettings +from azure.ai.ml.entities._job.automl.search_space import SearchSpace +from azure.ai.ml.entities._job.automl.utils import cast_to_specific_search_space +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException + + +class AutoMLImageClassificationBase(AutoMLImage): + """Base class for AutoML Image Classification and Image Classification Multilabel tasks. + Please do not instantiate this class directly. Instantiate one of the child classes instead. + + :keyword task_type: Type of task to run. + Possible values include: "ImageClassification", "ImageClassificationMultilabel". + :paramtype task_type: str + :keyword limits: Limits for Automl image classification jobs. Defaults to None. + :paramtype limits: Optional[~azure.ai.ml.automl.ImageLimitSettings] + :keyword sweep: Sweep settings for Automl image classification jobs. Defaults to None. + :paramtype sweep: Optional[~azure.ai.ml.automl.ImageSweepSettings] + :keyword training_parameters: Training parameters for Automl image classification jobs. Defaults to None. + :paramtype training_parameters: Optional[~azure.ai.ml.automl.ImageModelSettingsClassification] + :keyword search_space: Search space for Automl image classification jobs. Defaults to None. + :paramtype search_space: Optional[List[~azure.ai.ml.automl.ImageClassificationSearchSpace]] + :keyword kwargs: Other Keyword arguments for AutoMLImageClassificationBase class. + :paramtype kwargs: Dict[str, Any] + """ + + def __init__( + self, + *, + task_type: str, + limits: Optional[ImageLimitSettings] = None, + sweep: Optional[ImageSweepSettings] = None, + training_parameters: Optional[ImageModelSettingsClassification] = None, + search_space: Optional[List[ImageClassificationSearchSpace]] = None, + **kwargs: Any, + ) -> None: + self._training_parameters: Optional[ImageModelSettingsClassification] = None + + super().__init__(task_type=task_type, limits=limits, sweep=sweep, **kwargs) + self.training_parameters = training_parameters # Assigning training_parameters through setter method. + self._search_space = search_space + + @property + def training_parameters(self) -> Optional[ImageModelSettingsClassification]: + """ + :rtype: ~azure.ai.ml.automl.ImageModelSettingsClassification + :return: Training parameters for AutoML Image Classification and Image Classification Multilabel tasks. + """ + return self._training_parameters + + @training_parameters.setter + def training_parameters(self, value: Union[Dict, ImageModelSettingsClassification]) -> None: + """Setting Image training parameters for AutoML Image Classification and Image Classification Multilabel tasks. + + :param value: Training parameters for AutoML Image Classification and Image Classification Multilabel tasks. + :type value: Union[Dict, ~azure.ai.ml.automl.ImageModelSettingsClassification] + :raises ~azure.ml.exceptions.ValidationException if value is not a dictionary or + ImageModelSettingsClassification. + :return: None + """ + if value is None: + self._training_parameters = None + elif isinstance(value, ImageModelSettingsClassification): + self._training_parameters = value + # set_training_parameters convert parameter values from snake case str to enum. + # We need to add any future enum parameters in this call to support snake case str. + self.set_training_parameters( + optimizer=value.optimizer, + learning_rate_scheduler=value.learning_rate_scheduler, + ) + else: + if not isinstance(value, dict): + msg = "Expected a dictionary for model settings." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.AUTOML, + error_category=ErrorCategory.USER_ERROR, + ) + self.set_training_parameters(**value) + + @property + def search_space(self) -> Optional[List[ImageClassificationSearchSpace]]: + """ + :rtype: List[~azure.ai.ml.automl.ImageClassificationSearchSpace] + :return: Search space for AutoML Image Classification and Image Classification Multilabel tasks. + """ + return self._search_space + + @search_space.setter + def search_space(self, value: Union[List[Dict], List[SearchSpace]]) -> None: + """Setting Image search space for AutoML Image Classification and Image Classification Multilabel tasks. + + :param value: Search space for AutoML Image Classification and Image Classification Multilabel tasks. + :type value: Union[List[Dict], List[~azure.ai.ml.automl.ImageClassificationSearchSpace]] + :raises ~azure.ml.exceptions.ValidationException if value is not a list of dictionaries or + ImageClassificationSearchSpace. + """ + if not isinstance(value, list): + msg = "Expected a list for search space." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.AUTOML, + error_category=ErrorCategory.USER_ERROR, + ) + + all_dict_type = all(isinstance(item, dict) for item in value) + all_search_space_type = all(isinstance(item, SearchSpace) for item in value) + + if all_search_space_type or all_dict_type: + self._search_space = [ + cast_to_specific_search_space(item, ImageClassificationSearchSpace, self.task_type) # type: ignore + for item in value + ] + else: + msg = "Expected all items in the list to be either dictionaries or ImageClassificationSearchSpace objects." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.AUTOML, + error_category=ErrorCategory.USER_ERROR, + ) + + # pylint: disable=too-many-locals + def set_training_parameters( + self, + *, + advanced_settings: Optional[str] = None, + ams_gradient: Optional[bool] = None, + beta1: Optional[float] = None, + beta2: Optional[float] = None, + checkpoint_frequency: Optional[int] = None, + checkpoint_run_id: Optional[str] = None, + distributed: Optional[bool] = None, + early_stopping: Optional[bool] = None, + early_stopping_delay: Optional[int] = None, + early_stopping_patience: Optional[int] = None, + enable_onnx_normalization: Optional[bool] = None, + evaluation_frequency: Optional[int] = None, + gradient_accumulation_step: Optional[int] = None, + layers_to_freeze: Optional[int] = None, + learning_rate: Optional[float] = None, + learning_rate_scheduler: Optional[Union[str, LearningRateScheduler]] = None, + model_name: Optional[str] = None, + momentum: Optional[float] = None, + nesterov: Optional[bool] = None, + number_of_epochs: Optional[int] = None, + number_of_workers: Optional[int] = None, + optimizer: Optional[Union[str, StochasticOptimizer]] = None, + random_seed: Optional[int] = None, + step_lr_gamma: Optional[float] = None, + step_lr_step_size: Optional[int] = None, + training_batch_size: Optional[int] = None, + validation_batch_size: Optional[int] = None, + warmup_cosine_lr_cycles: Optional[float] = None, + warmup_cosine_lr_warmup_epochs: Optional[int] = None, + weight_decay: Optional[float] = None, + training_crop_size: Optional[int] = None, + validation_crop_size: Optional[int] = None, + validation_resize_size: Optional[int] = None, + weighted_loss: Optional[int] = None, + ) -> None: + """Setting Image training parameters for AutoML Image Classification and Image Classification Multilabel tasks. + + :keyword advanced_settings: Settings for advanced scenarios. + :paramtype advanced_settings: str + :keyword ams_gradient: Enable AMSGrad when optimizer is 'adam' or 'adamw'. + :paramtype ams_gradient: bool + :keyword beta1: Value of 'beta1' when optimizer is 'adam' or 'adamw'. Must be a float in the + range [0, 1]. + :paramtype beta1: float + :keyword beta2: Value of 'beta2' when optimizer is 'adam' or 'adamw'. Must be a float in the + range [0, 1]. + :paramtype beta2: float + :keyword checkpoint_frequency: Frequency to store model checkpoints. Must be a positive + integer. + :paramtype checkpoint_frequency: int + :keyword checkpoint_run_id: The id of a previous run that has a pretrained checkpoint for + incremental training. + :paramtype checkpoint_run_id: str + :keyword distributed: Whether to use distributed training. + :paramtype distributed: bool + :keyword early_stopping: Enable early stopping logic during training. + :paramtype early_stopping: bool + :keyword early_stopping_delay: Minimum number of epochs or validation evaluations to wait + before primary metric improvement + is tracked for early stopping. Must be a positive integer. + :paramtype early_stopping_delay: int + :keyword early_stopping_patience: Minimum number of epochs or validation evaluations with no + primary metric improvement before + the run is stopped. Must be a positive integer. + :paramtype early_stopping_patience: int + :keyword enable_onnx_normalization: Enable normalization when exporting ONNX model. + :paramtype enable_onnx_normalization: bool + :keyword evaluation_frequency: Frequency to evaluate validation dataset to get metric scores. + Must be a positive integer. + :paramtype evaluation_frequency: int + :keyword gradient_accumulation_step: Gradient accumulation means running a configured number of + "GradAccumulationStep" steps without + updating the model weights while accumulating the gradients of those steps, and then using + the accumulated gradients to compute the weight updates. Must be a positive integer. + :paramtype gradient_accumulation_step: int + :keyword layers_to_freeze: Number of layers to freeze for the model. Must be a positive + integer. + For instance, passing 2 as value for 'seresnext' means + freezing layer0 and layer1. For a full list of models supported and details on layer freeze, + please + see: https://learn.microsoft.com/azure/machine-learning/reference-automl-images-hyperparameters#model-agnostic-hyperparameters. # pylint: disable=line-too-long + :type layers_to_freeze: int + :keyword learning_rate: Initial learning rate. Must be a float in the range [0, 1]. + :paramtype learning_rate: float + :keyword learning_rate_scheduler: Type of learning rate scheduler. Must be 'warmup_cosine' or + 'step'. Possible values include: "None", "WarmupCosine", "Step". + :type learning_rate_scheduler: str or + ~azure.mgmt.machinelearningservices.models.LearningRateScheduler + :keyword model_name: Name of the model to use for training. + For more information on the available models please visit the official documentation: + https://learn.microsoft.com/azure/machine-learning/how-to-auto-train-image-models. + :type model_name: str + :keyword momentum: Value of momentum when optimizer is 'sgd'. Must be a float in the range [0, + 1]. + :paramtype momentum: float + :keyword nesterov: Enable nesterov when optimizer is 'sgd'. + :paramtype nesterov: bool + :keyword number_of_epochs: Number of training epochs. Must be a positive integer. + :paramtype number_of_epochs: int + :keyword number_of_workers: Number of data loader workers. Must be a non-negative integer. + :paramtype number_of_workers: int + :keyword optimizer: Type of optimizer. Possible values include: "None", "Sgd", "Adam", "Adamw". + :type optimizer: str or ~azure.mgmt.machinelearningservices.models.StochasticOptimizer + :keyword random_seed: Random seed to be used when using deterministic training. + :paramtype random_seed: int + :keyword step_lr_gamma: Value of gamma when learning rate scheduler is 'step'. Must be a float + in the range [0, 1]. + :paramtype step_lr_gamma: float + :keyword step_lr_step_size: Value of step size when learning rate scheduler is 'step'. Must be + a positive integer. + :paramtype step_lr_step_size: int + :keyword training_batch_size: Training batch size. Must be a positive integer. + :paramtype training_batch_size: int + :keyword validation_batch_size: Validation batch size. Must be a positive integer. + :paramtype validation_batch_size: int + :keyword warmup_cosine_lr_cycles: Value of cosine cycle when learning rate scheduler is + 'warmup_cosine'. Must be a float in the range [0, 1]. + :paramtype warmup_cosine_lr_cycles: float + :keyword warmup_cosine_lr_warmup_epochs: Value of warmup epochs when learning rate scheduler is + 'warmup_cosine'. Must be a positive integer. + :paramtype warmup_cosine_lr_warmup_epochs: int + :keyword weight_decay: Value of weight decay when optimizer is 'sgd', 'adam', or 'adamw'. Must + be a float in the range[0, 1]. + :paramtype weight_decay: float + :keyword training_crop_size: Image crop size that is input to the neural network for the + training dataset. Must be a positive integer. + :paramtype training_crop_size: int + :keyword validation_crop_size: Image crop size that is input to the neural network for the + validation dataset. Must be a positive integer. + :paramtype validation_crop_size: int + :keyword validation_resize_size: Image size to which to resize before cropping for validation + dataset. Must be a positive integer. + :paramtype validation_resize_size: int + :keyword weighted_loss: Weighted loss. The accepted values are 0 for no weighted loss. + 1 for weighted loss with sqrt.(class_weights). 2 for weighted loss with class_weights. Must be + 0 or 1 or 2. + :paramtype weighted_loss: int + """ + self._training_parameters = self._training_parameters or ImageModelSettingsClassification() + + self._training_parameters.advanced_settings = ( + advanced_settings if advanced_settings is not None else self._training_parameters.advanced_settings + ) + self._training_parameters.ams_gradient = ( + ams_gradient if ams_gradient is not None else self._training_parameters.ams_gradient + ) + self._training_parameters.beta1 = beta1 if beta1 is not None else self._training_parameters.beta1 + self._training_parameters.beta2 = beta2 if beta2 is not None else self._training_parameters.beta2 + self._training_parameters.checkpoint_frequency = ( + checkpoint_frequency if checkpoint_frequency is not None else self._training_parameters.checkpoint_frequency + ) + self._training_parameters.checkpoint_run_id = ( + checkpoint_run_id if checkpoint_run_id is not None else self._training_parameters.checkpoint_run_id + ) + self._training_parameters.distributed = ( + distributed if distributed is not None else self._training_parameters.distributed + ) + self._training_parameters.early_stopping = ( + early_stopping if early_stopping is not None else self._training_parameters.early_stopping + ) + self._training_parameters.early_stopping_delay = ( + early_stopping_delay if early_stopping_delay is not None else self._training_parameters.early_stopping_delay + ) + self._training_parameters.early_stopping_patience = ( + early_stopping_patience + if early_stopping_patience is not None + else self._training_parameters.early_stopping_patience + ) + self._training_parameters.enable_onnx_normalization = ( + enable_onnx_normalization + if enable_onnx_normalization is not None + else self._training_parameters.enable_onnx_normalization + ) + self._training_parameters.evaluation_frequency = ( + evaluation_frequency if evaluation_frequency is not None else self._training_parameters.evaluation_frequency + ) + self._training_parameters.gradient_accumulation_step = ( + gradient_accumulation_step + if gradient_accumulation_step is not None + else self._training_parameters.gradient_accumulation_step + ) + self._training_parameters.layers_to_freeze = ( + layers_to_freeze if layers_to_freeze is not None else self._training_parameters.layers_to_freeze + ) + self._training_parameters.learning_rate = ( + learning_rate if learning_rate is not None else self._training_parameters.learning_rate + ) + self._training_parameters.learning_rate_scheduler = ( + LearningRateScheduler[camel_to_snake(learning_rate_scheduler).upper()] + if learning_rate_scheduler is not None + else self._training_parameters.learning_rate_scheduler + ) + self._training_parameters.model_name = ( + model_name if model_name is not None else self._training_parameters.model_name + ) + self._training_parameters.momentum = momentum if momentum is not None else self._training_parameters.momentum + self._training_parameters.nesterov = nesterov if nesterov is not None else self._training_parameters.nesterov + self._training_parameters.number_of_epochs = ( + number_of_epochs if number_of_epochs is not None else self._training_parameters.number_of_epochs + ) + self._training_parameters.number_of_workers = ( + number_of_workers if number_of_workers is not None else self._training_parameters.number_of_workers + ) + self._training_parameters.optimizer = ( + StochasticOptimizer[camel_to_snake(optimizer).upper()] + if optimizer is not None + else self._training_parameters.optimizer + ) + self._training_parameters.random_seed = ( + random_seed if random_seed is not None else self._training_parameters.random_seed + ) + self._training_parameters.step_lr_gamma = ( + step_lr_gamma if step_lr_gamma is not None else self._training_parameters.step_lr_gamma + ) + self._training_parameters.step_lr_step_size = ( + step_lr_step_size if step_lr_step_size is not None else self._training_parameters.step_lr_step_size + ) + self._training_parameters.training_batch_size = ( + training_batch_size if training_batch_size is not None else self._training_parameters.training_batch_size + ) + self._training_parameters.validation_batch_size = ( + validation_batch_size + if validation_batch_size is not None + else self._training_parameters.validation_batch_size + ) + self._training_parameters.warmup_cosine_lr_cycles = ( + warmup_cosine_lr_cycles + if warmup_cosine_lr_cycles is not None + else self._training_parameters.warmup_cosine_lr_cycles + ) + self._training_parameters.warmup_cosine_lr_warmup_epochs = ( + warmup_cosine_lr_warmup_epochs + if warmup_cosine_lr_warmup_epochs is not None + else self._training_parameters.warmup_cosine_lr_warmup_epochs + ) + self._training_parameters.weight_decay = ( + weight_decay if weight_decay is not None else self._training_parameters.weight_decay + ) + self._training_parameters.training_crop_size = ( + training_crop_size if training_crop_size is not None else self._training_parameters.training_crop_size + ) + self._training_parameters.validation_crop_size = ( + validation_crop_size if validation_crop_size is not None else self._training_parameters.validation_crop_size + ) + self._training_parameters.validation_resize_size = ( + validation_resize_size + if validation_resize_size is not None + else self._training_parameters.validation_resize_size + ) + self._training_parameters.weighted_loss = ( + weighted_loss if weighted_loss is not None else self._training_parameters.weighted_loss + ) + + # pylint: enable=too-many-locals + + def extend_search_space( + self, + value: Union[SearchSpace, List[SearchSpace]], + ) -> None: + """Add Search space for AutoML Image Classification and Image Classification Multilabel tasks. + + :param value: specify either an instance of ImageClassificationSearchSpace or list of + ImageClassificationSearchSpace for searching through the parameter space + :type value: Union[ImageClassificationSearchSpace, List[ImageClassificationSearchSpace]] + """ + self._search_space = self._search_space or [] + + if isinstance(value, list): + self._search_space.extend( + [ + cast_to_specific_search_space(item, ImageClassificationSearchSpace, self.task_type) # type: ignore + for item in value + ] + ) + else: + self._search_space.append( + cast_to_specific_search_space(value, ImageClassificationSearchSpace, self.task_type) # type: ignore + ) + + @classmethod + def _get_search_space_from_str(cls, search_space_str: str) -> Optional[List[ImageClassificationSearchSpace]]: + return ( + [ImageClassificationSearchSpace._from_rest_object(entry) for entry in search_space_str if entry is not None] + if search_space_str is not None + else None + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, AutoMLImageClassificationBase): + return NotImplemented + + if not super().__eq__(other): + return False + + return self._training_parameters == other._training_parameters and self._search_space == other._search_space + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/automl_image_object_detection_base.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/automl_image_object_detection_base.py new file mode 100644 index 00000000..db0c7bc6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/automl_image_object_detection_base.py @@ -0,0 +1,524 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from typing import Any, Dict, List, Optional, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + LearningRateScheduler, + LogTrainingMetrics, + LogValidationLoss, + ModelSize, + StochasticOptimizer, + ValidationMetricType, +) +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.entities._job.automl import SearchSpace +from azure.ai.ml.entities._job.automl.image.automl_image import AutoMLImage +from azure.ai.ml.entities._job.automl.image.image_limit_settings import ImageLimitSettings +from azure.ai.ml.entities._job.automl.image.image_model_settings import ImageModelSettingsObjectDetection +from azure.ai.ml.entities._job.automl.image.image_object_detection_search_space import ImageObjectDetectionSearchSpace +from azure.ai.ml.entities._job.automl.image.image_sweep_settings import ImageSweepSettings +from azure.ai.ml.entities._job.automl.utils import cast_to_specific_search_space +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException + + +class AutoMLImageObjectDetectionBase(AutoMLImage): + """Base class for AutoML Image Object Detection and Image Instance Segmentation tasks. + + :keyword task_type: Type of task to run. Possible values include: "ImageObjectDetection", + "ImageInstanceSegmentation". + :paramtype task_type: str + :keyword limits: The resource limits for the job. + :paramtype limits: Optional[~azure.ai.ml.entities._job.automl.image.image_limit_settings.ImageLimitSettings] + :keyword sweep: The sweep settings for the job. + :paramtype sweep: Optional[~azure.ai.ml.entities._job.automl.image.image_sweep_settings.ImageSweepSettings] + :keyword training_parameters: The training parameters for the job. + :paramtype training_parameters: Optional[~azure.ai.ml.automl.ImageModelSettingsObjectDetection] + :keyword search_space: The search space for the job. + :paramtype search_space: Optional[List[~azure.ai.ml.automl.ImageObjectDetectionSearchSpace]] + """ + + def __init__( + self, + *, + task_type: str, + limits: Optional[ImageLimitSettings] = None, + sweep: Optional[ImageSweepSettings] = None, + training_parameters: Optional[ImageModelSettingsObjectDetection] = None, + search_space: Optional[List[ImageObjectDetectionSearchSpace]] = None, + **kwargs: Any, + ) -> None: + self._training_parameters: Optional[ImageModelSettingsObjectDetection] = None + + super().__init__(task_type=task_type, limits=limits, sweep=sweep, **kwargs) + + self.training_parameters = training_parameters # Assigning training_parameters through setter method. + + self._search_space = search_space + + @property + def training_parameters(self) -> Optional[ImageModelSettingsObjectDetection]: + return self._training_parameters + + @training_parameters.setter + def training_parameters(self, value: Union[Dict, ImageModelSettingsObjectDetection]) -> None: + if value is None: + self._training_parameters = None + elif isinstance(value, ImageModelSettingsObjectDetection): + self._training_parameters = value + # set_training_parameters convert parameter values from snake case str to enum. + # We need to add any future enum parameters in this call to support snake case str. + self.set_training_parameters( + optimizer=value.optimizer, + learning_rate_scheduler=value.learning_rate_scheduler, + model_size=value.model_size, + validation_metric_type=value.validation_metric_type, + log_training_metrics=value.log_training_metrics, + log_validation_loss=value.log_validation_loss, + ) + elif value is None: + self._training_parameters = value + else: + if not isinstance(value, dict): + msg = "Expected a dictionary for model settings." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.AUTOML, + error_category=ErrorCategory.USER_ERROR, + ) + self.set_training_parameters(**value) + + @property + def search_space(self) -> Optional[List[ImageObjectDetectionSearchSpace]]: + return self._search_space + + @search_space.setter + def search_space(self, value: Union[List[Dict], List[SearchSpace]]) -> None: + if not isinstance(value, list): + msg = "Expected a list for search space." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.AUTOML, + error_category=ErrorCategory.USER_ERROR, + ) + + all_dict_type = all(isinstance(item, dict) for item in value) + all_search_space_type = all(isinstance(item, SearchSpace) for item in value) + + if all_search_space_type or all_dict_type: + self._search_space = [ + cast_to_specific_search_space(item, ImageObjectDetectionSearchSpace, self.task_type) # type: ignore + for item in value + ] + else: + msg = "Expected all items in the list to be either dictionaries or SearchSpace objects." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.AUTOML, + error_category=ErrorCategory.USER_ERROR, + ) + + # pylint: disable=too-many-locals + def set_training_parameters( + self, + *, + advanced_settings: Optional[str] = None, + ams_gradient: Optional[bool] = None, + beta1: Optional[float] = None, + beta2: Optional[float] = None, + checkpoint_frequency: Optional[int] = None, + checkpoint_run_id: Optional[str] = None, + distributed: Optional[bool] = None, + early_stopping: Optional[bool] = None, + early_stopping_delay: Optional[int] = None, + early_stopping_patience: Optional[int] = None, + enable_onnx_normalization: Optional[bool] = None, + evaluation_frequency: Optional[int] = None, + gradient_accumulation_step: Optional[int] = None, + layers_to_freeze: Optional[int] = None, + learning_rate: Optional[float] = None, + learning_rate_scheduler: Optional[Union[str, LearningRateScheduler]] = None, + model_name: Optional[str] = None, + momentum: Optional[float] = None, + nesterov: Optional[bool] = None, + number_of_epochs: Optional[int] = None, + number_of_workers: Optional[int] = None, + optimizer: Optional[Union[str, StochasticOptimizer]] = None, + random_seed: Optional[int] = None, + step_lr_gamma: Optional[float] = None, + step_lr_step_size: Optional[int] = None, + training_batch_size: Optional[int] = None, + validation_batch_size: Optional[int] = None, + warmup_cosine_lr_cycles: Optional[float] = None, + warmup_cosine_lr_warmup_epochs: Optional[int] = None, + weight_decay: Optional[float] = None, + box_detections_per_image: Optional[int] = None, + box_score_threshold: Optional[float] = None, + image_size: Optional[int] = None, + max_size: Optional[int] = None, + min_size: Optional[int] = None, + model_size: Optional[Union[str, ModelSize]] = None, + multi_scale: Optional[bool] = None, + nms_iou_threshold: Optional[float] = None, + tile_grid_size: Optional[str] = None, + tile_overlap_ratio: Optional[float] = None, + tile_predictions_nms_threshold: Optional[float] = None, + validation_iou_threshold: Optional[float] = None, + validation_metric_type: Optional[Union[str, ValidationMetricType]] = None, + log_training_metrics: Optional[Union[str, LogTrainingMetrics]] = None, + log_validation_loss: Optional[Union[str, LogValidationLoss]] = None, + ) -> None: + """Setting Image training parameters for for AutoML Image Object Detection and Image Instance Segmentation + tasks. + + :keyword advanced_settings: Settings for advanced scenarios. + :paramtype advanced_settings: str + :keyword ams_gradient: Enable AMSGrad when optimizer is 'adam' or 'adamw'. + :paramtype ams_gradient: bool + :keyword beta1: Value of 'beta1' when optimizer is 'adam' or 'adamw'. Must be a float in the + range [0, 1]. + :paramtype beta1: float + :keyword beta2: Value of 'beta2' when optimizer is 'adam' or 'adamw'. Must be a float in the + range [0, 1]. + :paramtype beta2: float + :keyword checkpoint_frequency: Frequency to store model checkpoints. Must be a positive + integer. + :paramtype checkpoint_frequency: int + :keyword checkpoint_run_id: The id of a previous run that has a pretrained checkpoint for + incremental training. + :paramtype checkpoint_run_id: str + :keyword distributed: Whether to use distributed training. + :paramtype distributed: bool + :keyword early_stopping: Enable early stopping logic during training. + :paramtype early_stopping: bool + :keyword early_stopping_delay: Minimum number of epochs or validation evaluations to wait + before primary metric improvement + is tracked for early stopping. Must be a positive integer. + :paramtype early_stopping_delay: int + :keyword early_stopping_patience: Minimum number of epochs or validation evaluations with no + primary metric improvement before + the run is stopped. Must be a positive integer. + :paramtype early_stopping_patience: int + :keyword enable_onnx_normalization: Enable normalization when exporting ONNX model. + :paramtype enable_onnx_normalization: bool + :keyword evaluation_frequency: Frequency to evaluate validation dataset to get metric scores. + Must be a positive integer. + :paramtype evaluation_frequency: int + :keyword gradient_accumulation_step: Gradient accumulation means running a configured number of + "GradAccumulationStep" steps without + updating the model weights while accumulating the gradients of those steps, and then using + the accumulated gradients to compute the weight updates. Must be a positive integer. + :paramtype gradient_accumulation_step: int + :keyword layers_to_freeze: Number of layers to freeze for the model. Must be a positive + integer. + For instance, passing 2 as value for 'seresnext' means + freezing layer0 and layer1. For a full list of models supported and details on layer freeze, + please + see: https://learn.microsoft.com/azure/machine-learning/reference-automl-images-hyperparameters#model-agnostic-hyperparameters. # pylint: disable=line-too-long + :type layers_to_freeze: int + :keyword learning_rate: Initial learning rate. Must be a float in the range [0, 1]. + :paramtype learning_rate: float + :keyword learning_rate_scheduler: Type of learning rate scheduler. Must be 'warmup_cosine' or + 'step'. Possible values include: "None", "WarmupCosine", "Step". + :type learning_rate_scheduler: str or + ~azure.mgmt.machinelearningservices.models.LearningRateScheduler + :keyword model_name: Name of the model to use for training. + For more information on the available models please visit the official documentation: + https://learn.microsoft.com/azure/machine-learning/how-to-auto-train-image-models. + :type model_name: str + :keyword momentum: Value of momentum when optimizer is 'sgd'. Must be a float in the range [0, + 1]. + :paramtype momentum: float + :keyword nesterov: Enable nesterov when optimizer is 'sgd'. + :paramtype nesterov: bool + :keyword number_of_epochs: Number of training epochs. Must be a positive integer. + :paramtype number_of_epochs: int + :keyword number_of_workers: Number of data loader workers. Must be a non-negative integer. + :paramtype number_of_workers: int + :keyword optimizer: Type of optimizer. Possible values include: "None", "Sgd", "Adam", "Adamw". + :type optimizer: str or ~azure.mgmt.machinelearningservices.models.StochasticOptimizer + :keyword random_seed: Random seed to be used when using deterministic training. + :paramtype random_seed: int + :keyword step_lr_gamma: Value of gamma when learning rate scheduler is 'step'. Must be a float + in the range [0, 1]. + :paramtype step_lr_gamma: float + :keyword step_lr_step_size: Value of step size when learning rate scheduler is 'step'. Must be + a positive integer. + :paramtype step_lr_step_size: int + :keyword training_batch_size: Training batch size. Must be a positive integer. + :paramtype training_batch_size: int + :keyword validation_batch_size: Validation batch size. Must be a positive integer. + :paramtype validation_batch_size: int + :keyword warmup_cosine_lr_cycles: Value of cosine cycle when learning rate scheduler is + 'warmup_cosine'. Must be a float in the range [0, 1]. + :paramtype warmup_cosine_lr_cycles: float + :keyword warmup_cosine_lr_warmup_epochs: Value of warmup epochs when learning rate scheduler is + 'warmup_cosine'. Must be a positive integer. + :paramtype warmup_cosine_lr_warmup_epochs: int + :keyword weight_decay: Value of weight decay when optimizer is 'sgd', 'adam', or 'adamw'. Must + be a float in the range[0, 1]. + :paramtype weight_decay: float + :keyword box_detections_per_image: Maximum number of detections per image, for all classes. + Must be a positive integer. + Note: This settings is not supported for the 'yolov5' algorithm. + :type box_detections_per_image: int + :keyword box_score_threshold: During inference, only return proposals with a classification + score greater than + BoxScoreThreshold. Must be a float in the range[0, 1]. + :paramtype box_score_threshold: float + :keyword image_size: Image size for training and validation. Must be a positive integer. + Note: The training run may get into CUDA OOM if the size is too big. + Note: This settings is only supported for the 'yolov5' algorithm. + :type image_size: int + :keyword max_size: Maximum size of the image to be rescaled before feeding it to the backbone. + Must be a positive integer. Note: training run may get into CUDA OOM if the size is too big. + Note: This settings is not supported for the 'yolov5' algorithm. + :type max_size: int + :keyword min_size: Minimum size of the image to be rescaled before feeding it to the backbone. + Must be a positive integer. Note: training run may get into CUDA OOM if the size is too big. + Note: This settings is not supported for the 'yolov5' algorithm. + :type min_size: int + :keyword model_size: Model size. Must be 'small', 'medium', 'large', or 'extra_large'. + Note: training run may get into CUDA OOM if the model size is too big. + Note: This settings is only supported for the 'yolov5' algorithm. + :type model_size: str or ~azure.mgmt.machinelearningservices.models.ModelSize + :keyword multi_scale: Enable multi-scale image by varying image size by +/- 50%. + Note: training run may get into CUDA OOM if no sufficient GPU memory. + Note: This settings is only supported for the 'yolov5' algorithm. + :type multi_scale: bool + :keyword nms_iou_threshold: IOU threshold used during inference in NMS post processing. Must be + float in the range [0, 1]. + :paramtype nms_iou_threshold: float + :keyword tile_grid_size: The grid size to use for tiling each image. Note: TileGridSize must + not be + None to enable small object detection logic. A string containing two integers in mxn format. + :type tile_grid_size: str + :keyword tile_overlap_ratio: Overlap ratio between adjacent tiles in each dimension. Must be + float in the range [0, 1). + :paramtype tile_overlap_ratio: float + :keyword tile_predictions_nms_threshold: The IOU threshold to use to perform NMS while merging + predictions from tiles and image. + Used in validation/ inference. Must be float in the range [0, 1]. + NMS: Non-maximum suppression. + :type tile_predictions_nms_threshold: str + :keyword validation_iou_threshold: IOU threshold to use when computing validation metric. Must + be float in the range [0, 1]. + :paramtype validation_iou_threshold: float + :keyword validation_metric_type: Metric computation method to use for validation metrics. Must + be 'none', 'coco', 'voc', or 'coco_voc'. + :paramtype validation_metric_type: str or ~azure.mgmt.machinelearningservices.models.ValidationMetricType + :keyword log_training_metrics: indicates whether or not to log training metrics. Must + be 'Enable' or 'Disable' + :paramtype log_training_metrics: str or ~azure.mgmt.machinelearningservices.models.LogTrainingMetrics + :keyword log_validation_loss: indicates whether or not to log validation loss. Must + be 'Enable' or 'Disable' + :paramtype log_validation_loss: str or ~azure.mgmt.machinelearningservices.models.LogValidationLoss + """ + self._training_parameters = self._training_parameters or ImageModelSettingsObjectDetection() + + self._training_parameters.advanced_settings = ( + advanced_settings if advanced_settings is not None else self._training_parameters.advanced_settings + ) + self._training_parameters.ams_gradient = ( + ams_gradient if ams_gradient is not None else self._training_parameters.ams_gradient + ) + self._training_parameters.beta1 = beta1 if beta1 is not None else self._training_parameters.beta1 + self._training_parameters.beta2 = beta2 if beta2 is not None else self._training_parameters.beta2 + self._training_parameters.checkpoint_frequency = ( + checkpoint_frequency if checkpoint_frequency is not None else self._training_parameters.checkpoint_frequency + ) + self._training_parameters.checkpoint_run_id = ( + checkpoint_run_id if checkpoint_run_id is not None else self._training_parameters.checkpoint_run_id + ) + self._training_parameters.distributed = ( + distributed if distributed is not None else self._training_parameters.distributed + ) + self._training_parameters.early_stopping = ( + early_stopping if early_stopping is not None else self._training_parameters.early_stopping + ) + self._training_parameters.early_stopping_delay = ( + early_stopping_delay if early_stopping_delay is not None else self._training_parameters.early_stopping_delay + ) + self._training_parameters.early_stopping_patience = ( + early_stopping_patience + if early_stopping_patience is not None + else self._training_parameters.early_stopping_patience + ) + self._training_parameters.enable_onnx_normalization = ( + enable_onnx_normalization + if enable_onnx_normalization is not None + else self._training_parameters.enable_onnx_normalization + ) + self._training_parameters.evaluation_frequency = ( + evaluation_frequency if evaluation_frequency is not None else self._training_parameters.evaluation_frequency + ) + self._training_parameters.gradient_accumulation_step = ( + gradient_accumulation_step + if gradient_accumulation_step is not None + else self._training_parameters.gradient_accumulation_step + ) + self._training_parameters.layers_to_freeze = ( + layers_to_freeze if layers_to_freeze is not None else self._training_parameters.layers_to_freeze + ) + self._training_parameters.learning_rate = ( + learning_rate if learning_rate is not None else self._training_parameters.learning_rate + ) + self._training_parameters.learning_rate_scheduler = ( + LearningRateScheduler[camel_to_snake(learning_rate_scheduler)] + if learning_rate_scheduler is not None + else self._training_parameters.learning_rate_scheduler + ) + self._training_parameters.model_name = ( + model_name if model_name is not None else self._training_parameters.model_name + ) + self._training_parameters.momentum = momentum if momentum is not None else self._training_parameters.momentum + self._training_parameters.nesterov = nesterov if nesterov is not None else self._training_parameters.nesterov + self._training_parameters.number_of_epochs = ( + number_of_epochs if number_of_epochs is not None else self._training_parameters.number_of_epochs + ) + self._training_parameters.number_of_workers = ( + number_of_workers if number_of_workers is not None else self._training_parameters.number_of_workers + ) + self._training_parameters.optimizer = ( + StochasticOptimizer[camel_to_snake(optimizer)] + if optimizer is not None + else self._training_parameters.optimizer + ) + self._training_parameters.random_seed = ( + random_seed if random_seed is not None else self._training_parameters.random_seed + ) + self._training_parameters.step_lr_gamma = ( + step_lr_gamma if step_lr_gamma is not None else self._training_parameters.step_lr_gamma + ) + self._training_parameters.step_lr_step_size = ( + step_lr_step_size if step_lr_step_size is not None else self._training_parameters.step_lr_step_size + ) + self._training_parameters.training_batch_size = ( + training_batch_size if training_batch_size is not None else self._training_parameters.training_batch_size + ) + self._training_parameters.validation_batch_size = ( + validation_batch_size + if validation_batch_size is not None + else self._training_parameters.validation_batch_size + ) + self._training_parameters.warmup_cosine_lr_cycles = ( + warmup_cosine_lr_cycles + if warmup_cosine_lr_cycles is not None + else self._training_parameters.warmup_cosine_lr_cycles + ) + self._training_parameters.warmup_cosine_lr_warmup_epochs = ( + warmup_cosine_lr_warmup_epochs + if warmup_cosine_lr_warmup_epochs is not None + else self._training_parameters.warmup_cosine_lr_warmup_epochs + ) + self._training_parameters.weight_decay = ( + weight_decay if weight_decay is not None else self._training_parameters.weight_decay + ) + self._training_parameters.box_detections_per_image = ( + box_detections_per_image + if box_detections_per_image is not None + else self._training_parameters.box_detections_per_image + ) + self._training_parameters.box_score_threshold = ( + box_score_threshold if box_score_threshold is not None else self._training_parameters.box_score_threshold + ) + self._training_parameters.image_size = ( + image_size if image_size is not None else self._training_parameters.image_size + ) + self._training_parameters.max_size = max_size if max_size is not None else self._training_parameters.max_size + self._training_parameters.min_size = min_size if min_size is not None else self._training_parameters.min_size + self._training_parameters.model_size = ( + ModelSize[camel_to_snake(model_size)] if model_size is not None else self._training_parameters.model_size + ) + self._training_parameters.multi_scale = ( + multi_scale if multi_scale is not None else self._training_parameters.multi_scale + ) + self._training_parameters.nms_iou_threshold = ( + nms_iou_threshold if nms_iou_threshold is not None else self._training_parameters.nms_iou_threshold + ) + self._training_parameters.tile_grid_size = ( + tile_grid_size if tile_grid_size is not None else self._training_parameters.tile_grid_size + ) + self._training_parameters.tile_overlap_ratio = ( + tile_overlap_ratio if tile_overlap_ratio is not None else self._training_parameters.tile_overlap_ratio + ) + self._training_parameters.tile_predictions_nms_threshold = ( + tile_predictions_nms_threshold + if tile_predictions_nms_threshold is not None + else self._training_parameters.tile_predictions_nms_threshold + ) + self._training_parameters.validation_iou_threshold = ( + validation_iou_threshold + if validation_iou_threshold is not None + else self._training_parameters.validation_iou_threshold + ) + self._training_parameters.validation_metric_type = ( + ValidationMetricType[camel_to_snake(validation_metric_type)] + if validation_metric_type is not None + else self._training_parameters.validation_metric_type + ) + self._training_parameters.log_training_metrics = ( + LogTrainingMetrics[camel_to_snake(log_training_metrics)] + if log_training_metrics is not None + else self._training_parameters.log_training_metrics + ) + self._training_parameters.log_validation_loss = ( + LogValidationLoss[camel_to_snake(log_validation_loss)] + if log_validation_loss is not None + else self._training_parameters.log_validation_loss + ) + + # pylint: enable=too-many-locals + + def extend_search_space( + self, + value: Union[SearchSpace, List[SearchSpace]], + ) -> None: + """Add search space for AutoML Image Object Detection and Image Instance Segmentation tasks. + + :param value: Search through the parameter space + :type value: Union[SearchSpace, List[SearchSpace]] + """ + self._search_space = self._search_space or [] + + if isinstance(value, list): + self._search_space.extend( + [ + cast_to_specific_search_space(item, ImageObjectDetectionSearchSpace, self.task_type) # type: ignore + for item in value + ] + ) + else: + self._search_space.append( + cast_to_specific_search_space(value, ImageObjectDetectionSearchSpace, self.task_type) # type: ignore + ) + + @classmethod + def _get_search_space_from_str(cls, search_space_str: str) -> Optional[List[ImageObjectDetectionSearchSpace]]: + return ( + [ + ImageObjectDetectionSearchSpace._from_rest_object(entry) + for entry in search_space_str + if entry is not None + ] + if search_space_str is not None + else None + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, AutoMLImageObjectDetectionBase): + return NotImplemented + + if not super().__eq__(other): + return False + + return self._training_parameters == other._training_parameters and self._search_space == other._search_space + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_classification_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_classification_job.py new file mode 100644 index 00000000..a1b9dbc3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_classification_job.py @@ -0,0 +1,244 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from typing import Any, Dict, Optional, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import AutoMLJob as RestAutoMLJob +from azure.ai.ml._restclient.v2023_04_01_preview.models import ClassificationPrimaryMetrics +from azure.ai.ml._restclient.v2023_04_01_preview.models import ImageClassification as RestImageClassification +from azure.ai.ml._restclient.v2023_04_01_preview.models import JobBase, TaskType +from azure.ai.ml._utils.utils import camel_to_snake, is_data_binding_expression +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY +from azure.ai.ml.constants._job.automl import AutoMLConstants +from azure.ai.ml.entities._credentials import _BaseJobIdentityConfiguration +from azure.ai.ml.entities._job._input_output_helpers import from_rest_data_outputs, to_rest_data_outputs +from azure.ai.ml.entities._job.automl.image.automl_image_classification_base import AutoMLImageClassificationBase +from azure.ai.ml.entities._job.automl.image.image_limit_settings import ImageLimitSettings +from azure.ai.ml.entities._job.automl.image.image_model_settings import ImageModelSettingsClassification +from azure.ai.ml.entities._job.automl.image.image_sweep_settings import ImageSweepSettings +from azure.ai.ml.entities._util import load_from_dict + + +class ImageClassificationJob(AutoMLImageClassificationBase): + """Configuration for AutoML multi-class Image Classification job. + + :param primary_metric: The primary metric to use for optimization. + :type primary_metric: Optional[str, ~azure.ai.ml.automl.ClassificationMultilabelPrimaryMetrics] + :param kwargs: Job-specific arguments. + :type kwargs: Dict[str, Any] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_automl_image.py + :start-after: [START automl.automl_image_job.image_classification_job] + :end-before: [END automl.automl_image_job.image_classification_job] + :language: python + :dedent: 8 + :caption: creating an automl image classification job + """ + + _DEFAULT_PRIMARY_METRIC = ClassificationPrimaryMetrics.ACCURACY + + def __init__( + self, + *, + primary_metric: Optional[Union[str, ClassificationPrimaryMetrics]] = None, + **kwargs: Any, + ) -> None: + + # Extract any super class init settings + limits = kwargs.pop("limits", None) + sweep = kwargs.pop("sweep", None) + training_parameters = kwargs.pop("training_parameters", None) + search_space = kwargs.pop("search_space", None) + + super().__init__( + task_type=TaskType.IMAGE_CLASSIFICATION, + limits=limits, + sweep=sweep, + training_parameters=training_parameters, + search_space=search_space, + **kwargs, + ) + + self.primary_metric = primary_metric or ImageClassificationJob._DEFAULT_PRIMARY_METRIC + + @property + def primary_metric(self) -> Optional[Union[str, ClassificationPrimaryMetrics]]: + return self._primary_metric + + @primary_metric.setter + def primary_metric(self, value: Union[str, ClassificationPrimaryMetrics]) -> None: + if is_data_binding_expression(str(value), ["parent"]): + self._primary_metric = value + return + self._primary_metric = ( + ImageClassificationJob._DEFAULT_PRIMARY_METRIC + if value is None + else ClassificationPrimaryMetrics[camel_to_snake(value).upper()] + ) + + def _to_rest_object(self) -> JobBase: + image_classification_task = RestImageClassification( + target_column_name=self.target_column_name, + training_data=self.training_data, + validation_data=self.validation_data, + validation_data_size=self.validation_data_size, + limit_settings=self._limits._to_rest_object() if self._limits else None, + sweep_settings=self._sweep._to_rest_object() if self._sweep else None, + model_settings=self._training_parameters._to_rest_object() if self._training_parameters else None, + search_space=( + [entry._to_rest_object() for entry in self._search_space if entry is not None] + if self._search_space is not None + else None + ), + primary_metric=self.primary_metric, + log_verbosity=self.log_verbosity, + ) + # resolve data inputs in rest obj + self._resolve_data_inputs(image_classification_task) + + properties = RestAutoMLJob( + display_name=self.display_name, + description=self.description, + experiment_name=self.experiment_name, + tags=self.tags, + compute_id=self.compute, + properties=self.properties, + environment_id=self.environment_id, + environment_variables=self.environment_variables, + services=self.services, + outputs=to_rest_data_outputs(self.outputs), + resources=self.resources, + task_details=image_classification_task, + identity=self.identity._to_job_rest_object() if self.identity else None, + queue_settings=self.queue_settings, + ) + + result = JobBase(properties=properties) + result.name = self.name + return result + + @classmethod + def _from_rest_object(cls, obj: JobBase) -> "ImageClassificationJob": + properties: RestAutoMLJob = obj.properties + task_details: RestImageClassification = properties.task_details + + job_args_dict = { + "id": obj.id, + "name": obj.name, + "description": properties.description, + "tags": properties.tags, + "properties": properties.properties, + "experiment_name": properties.experiment_name, + "services": properties.services, + "status": properties.status, + "creation_context": obj.system_data, + "display_name": properties.display_name, + "compute": properties.compute_id, + "outputs": from_rest_data_outputs(properties.outputs), + "resources": properties.resources, + "identity": ( + _BaseJobIdentityConfiguration._from_rest_object(properties.identity) if properties.identity else None + ), + "queue_settings": properties.queue_settings, + } + + image_classification_job = cls( + target_column_name=task_details.target_column_name, + training_data=task_details.training_data, + validation_data=task_details.validation_data, + validation_data_size=task_details.validation_data_size, + limits=( + ImageLimitSettings._from_rest_object(task_details.limit_settings) + if task_details.limit_settings + else None + ), + sweep=( + ImageSweepSettings._from_rest_object(task_details.sweep_settings) + if task_details.sweep_settings + else None + ), + training_parameters=( + ImageModelSettingsClassification._from_rest_object(task_details.model_settings) + if task_details.model_settings + else None + ), + search_space=cls._get_search_space_from_str(task_details.search_space), + primary_metric=task_details.primary_metric, + log_verbosity=task_details.log_verbosity, + **job_args_dict, + ) + + image_classification_job._restore_data_inputs() + + return image_classification_job + + @classmethod + def _load_from_dict( + cls, + data: Dict, + context: Dict, + additional_message: str, + **kwargs: Any, + ) -> "ImageClassificationJob": + from azure.ai.ml._schema.automl.image_vertical.image_classification import ImageClassificationSchema + from azure.ai.ml._schema.pipeline.automl_node import ImageClassificationMulticlassNodeSchema + + inside_pipeline = kwargs.pop("inside_pipeline", False) + if inside_pipeline: + if context.get("inside_pipeline", None) is None: + context["inside_pipeline"] = True + loaded_data = load_from_dict( + ImageClassificationMulticlassNodeSchema, + data, + context, + additional_message, + **kwargs, + ) + else: + loaded_data = load_from_dict(ImageClassificationSchema, data, context, additional_message, **kwargs) + job_instance = cls._create_instance_from_schema_dict(loaded_data) + return job_instance + + @classmethod + def _create_instance_from_schema_dict(cls, loaded_data: Dict) -> "ImageClassificationJob": + loaded_data.pop(AutoMLConstants.TASK_TYPE_YAML, None) + data_settings = { + "training_data": loaded_data.pop("training_data"), + "target_column_name": loaded_data.pop("target_column_name"), + "validation_data": loaded_data.pop("validation_data", None), + "validation_data_size": loaded_data.pop("validation_data_size", None), + } + job = ImageClassificationJob(**loaded_data) + job.set_data(**data_settings) + return job + + def _to_dict(self, inside_pipeline: bool = False) -> Dict: + from azure.ai.ml._schema.automl.image_vertical.image_classification import ImageClassificationSchema + from azure.ai.ml._schema.pipeline.automl_node import ImageClassificationMulticlassNodeSchema + + schema_dict: dict = {} + if inside_pipeline: + schema_dict = ImageClassificationMulticlassNodeSchema( + context={BASE_PATH_CONTEXT_KEY: "./", "inside_pipeline": True} + ).dump(self) + else: + schema_dict = ImageClassificationSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + + return schema_dict + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ImageClassificationJob): + return NotImplemented + + if not super().__eq__(other): + return False + + return self.primary_metric == other.primary_metric + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_classification_multilabel_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_classification_multilabel_job.py new file mode 100644 index 00000000..541f41c7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_classification_multilabel_job.py @@ -0,0 +1,252 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from typing import Any, Dict, Optional, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import AutoMLJob as RestAutoMLJob +from azure.ai.ml._restclient.v2023_04_01_preview.models import ClassificationMultilabelPrimaryMetrics +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + ImageClassificationMultilabel as RestImageClassificationMultilabel, +) +from azure.ai.ml._restclient.v2023_04_01_preview.models import JobBase, TaskType +from azure.ai.ml._utils.utils import camel_to_snake, is_data_binding_expression +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY +from azure.ai.ml.constants._job.automl import AutoMLConstants +from azure.ai.ml.entities._credentials import _BaseJobIdentityConfiguration +from azure.ai.ml.entities._job._input_output_helpers import from_rest_data_outputs, to_rest_data_outputs +from azure.ai.ml.entities._job.automl.image.automl_image_classification_base import AutoMLImageClassificationBase +from azure.ai.ml.entities._job.automl.image.image_limit_settings import ImageLimitSettings +from azure.ai.ml.entities._job.automl.image.image_model_settings import ImageModelSettingsClassification +from azure.ai.ml.entities._job.automl.image.image_sweep_settings import ImageSweepSettings +from azure.ai.ml.entities._util import load_from_dict + + +class ImageClassificationMultilabelJob(AutoMLImageClassificationBase): + """Configuration for AutoML multi-label Image Classification job. + + :param primary_metric: The primary metric to use for optimization. + :type primary_metric: Optional[str, ~azure.ai.ml.automl.ClassificationMultilabelPrimaryMetrics] + :param kwargs: Job-specific arguments. + :type kwargs: Dict[str, Any] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_automl_image.py + :start-after: [START automl.automl_image_job.image_classification_multilabel_job] + :end-before: [END automl.automl_image_job.image_classification_multilabel_job] + :language: python + :dedent: 8 + :caption: creating an automl image classification multilabel job + """ + + _DEFAULT_PRIMARY_METRIC = ClassificationMultilabelPrimaryMetrics.IOU + + def __init__( + self, + *, + primary_metric: Optional[Union[str, ClassificationMultilabelPrimaryMetrics]] = None, + **kwargs: Any, + ) -> None: + + # Extract any super class init settings + limits = kwargs.pop("limits", None) + sweep = kwargs.pop("sweep", None) + training_parameters = kwargs.pop("training_parameters", None) + search_space = kwargs.pop("search_space", None) + + super().__init__( + task_type=TaskType.IMAGE_CLASSIFICATION_MULTILABEL, + limits=limits, + sweep=sweep, + training_parameters=training_parameters, + search_space=search_space, + **kwargs, + ) + + self.primary_metric = primary_metric or ImageClassificationMultilabelJob._DEFAULT_PRIMARY_METRIC + + @property + def primary_metric(self) -> Union[str, ClassificationMultilabelPrimaryMetrics]: + return self._primary_metric + + @primary_metric.setter + def primary_metric(self, value: Union[str, ClassificationMultilabelPrimaryMetrics]) -> None: + if is_data_binding_expression(str(value), ["parent"]): + self._primary_metric = value + return + self._primary_metric = ( + ImageClassificationMultilabelJob._DEFAULT_PRIMARY_METRIC + if value is None + else ClassificationMultilabelPrimaryMetrics[camel_to_snake(value).upper()] + ) + + def _to_rest_object(self) -> JobBase: + image_classification_multilabel_task = RestImageClassificationMultilabel( + target_column_name=self.target_column_name, + training_data=self.training_data, + validation_data=self.validation_data, + validation_data_size=self.validation_data_size, + limit_settings=self._limits._to_rest_object() if self._limits else None, + sweep_settings=self._sweep._to_rest_object() if self._sweep else None, + model_settings=self._training_parameters._to_rest_object() if self._training_parameters else None, + search_space=( + [entry._to_rest_object() for entry in self._search_space if entry is not None] + if self._search_space is not None + else None + ), + primary_metric=self.primary_metric, + log_verbosity=self.log_verbosity, + ) + # resolve data inputs in rest obj + self._resolve_data_inputs(image_classification_multilabel_task) + + properties = RestAutoMLJob( + display_name=self.display_name, + description=self.description, + experiment_name=self.experiment_name, + tags=self.tags, + compute_id=self.compute, + properties=self.properties, + environment_id=self.environment_id, + environment_variables=self.environment_variables, + services=self.services, + outputs=to_rest_data_outputs(self.outputs), + resources=self.resources, + task_details=image_classification_multilabel_task, + identity=self.identity._to_job_rest_object() if self.identity else None, + queue_settings=self.queue_settings, + ) + + result = JobBase(properties=properties) + result.name = self.name + return result + + @classmethod + def _from_rest_object(cls, obj: JobBase) -> "ImageClassificationMultilabelJob": + properties: RestAutoMLJob = obj.properties + task_details: RestImageClassificationMultilabel = properties.task_details + + job_args_dict = { + "id": obj.id, + "name": obj.name, + "description": properties.description, + "tags": properties.tags, + "properties": properties.properties, + "experiment_name": properties.experiment_name, + "services": properties.services, + "status": properties.status, + "creation_context": obj.system_data, + "display_name": properties.display_name, + "compute": properties.compute_id, + "outputs": from_rest_data_outputs(properties.outputs), + "resources": properties.resources, + "identity": ( + _BaseJobIdentityConfiguration._from_rest_object(properties.identity) if properties.identity else None + ), + "queue_settings": properties.queue_settings, + } + + image_classification_multilabel_job = cls( + target_column_name=task_details.target_column_name, + training_data=task_details.training_data, + validation_data=task_details.validation_data, + validation_data_size=task_details.validation_data_size, + limits=( + ImageLimitSettings._from_rest_object(task_details.limit_settings) + if task_details.limit_settings + else None + ), + sweep=( + ImageSweepSettings._from_rest_object(task_details.sweep_settings) + if task_details.sweep_settings + else None + ), + training_parameters=( + ImageModelSettingsClassification._from_rest_object(task_details.model_settings) + if task_details.model_settings + else None + ), + search_space=cls._get_search_space_from_str(task_details.search_space), + primary_metric=task_details.primary_metric, + log_verbosity=task_details.log_verbosity, + **job_args_dict, + ) + + image_classification_multilabel_job._restore_data_inputs() + + return image_classification_multilabel_job + + @classmethod + def _load_from_dict( + cls, + data: Dict, + context: Dict, + additional_message: str, + **kwargs: Any, + ) -> "ImageClassificationMultilabelJob": + from azure.ai.ml._schema.automl.image_vertical.image_classification import ImageClassificationMultilabelSchema + from azure.ai.ml._schema.pipeline.automl_node import ImageClassificationMultilabelNodeSchema + + inside_pipeline = kwargs.pop("inside_pipeline", False) + if inside_pipeline: + if context.get("inside_pipeline", None) is None: + context["inside_pipeline"] = True + loaded_data = load_from_dict( + ImageClassificationMultilabelNodeSchema, + data, + context, + additional_message, + **kwargs, + ) + else: + loaded_data = load_from_dict( + ImageClassificationMultilabelSchema, + data, + context, + additional_message, + **kwargs, + ) + job_instance = cls._create_instance_from_schema_dict(loaded_data) + return job_instance + + @classmethod + def _create_instance_from_schema_dict(cls, loaded_data: Dict) -> "ImageClassificationMultilabelJob": + loaded_data.pop(AutoMLConstants.TASK_TYPE_YAML, None) + data_settings = { + "training_data": loaded_data.pop("training_data"), + "target_column_name": loaded_data.pop("target_column_name"), + "validation_data": loaded_data.pop("validation_data", None), + "validation_data_size": loaded_data.pop("validation_data_size", None), + } + job = ImageClassificationMultilabelJob(**loaded_data) + job.set_data(**data_settings) + return job + + def _to_dict(self, inside_pipeline: bool = False) -> Dict: + from azure.ai.ml._schema.automl.image_vertical.image_classification import ImageClassificationMultilabelSchema + from azure.ai.ml._schema.pipeline.automl_node import ImageClassificationMultilabelNodeSchema + + schema_dict: dict = {} + if inside_pipeline: + schema_dict = ImageClassificationMultilabelNodeSchema( + context={BASE_PATH_CONTEXT_KEY: "./", "inside_pipeline": True} + ).dump(self) + else: + schema_dict = ImageClassificationMultilabelSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + + return schema_dict + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ImageClassificationMultilabelJob): + return NotImplemented + + if not super().__eq__(other): + return False + + return self.primary_metric == other.primary_metric + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_classification_search_space.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_classification_search_space.py new file mode 100644 index 00000000..0691f243 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_classification_search_space.py @@ -0,0 +1,437 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=R0902,too-many-locals + +from typing import Optional, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import ImageModelDistributionSettingsClassification +from azure.ai.ml.entities._job.automl.search_space import SearchSpace +from azure.ai.ml.entities._job.automl.search_space_utils import _convert_from_rest_object, _convert_to_rest_object +from azure.ai.ml.entities._job.sweep.search_space import SweepDistribution +from azure.ai.ml.entities._mixins import RestTranslatableMixin + + +class ImageClassificationSearchSpace(RestTranslatableMixin): + """Search space for AutoML Image Classification and Image Classification + Multilabel tasks. + + :param ams_gradient: Enable AMSGrad when optimizer is 'adam' or 'adamw'. + :type ams_gradient: bool or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution + :param beta1: Value of 'beta1' when optimizer is 'adam' or 'adamw'. Must be a float in the + range [0, 1]. + :type beta1: float or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution + :param beta2: Value of 'beta2' when optimizer is 'adam' or 'adamw'. Must be a float in the + range [0, 1]. + :type beta2: float or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution + :param distributed: Whether to use distributer training. + :type distributed: bool or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution + :param early_stopping: Enable early stopping logic during training. + :type early_stopping: bool or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution + :param early_stopping_delay: Minimum number of epochs or validation evaluations to wait + before primary metric improvement + is tracked for early stopping. Must be a positive integer. + :type early_stopping_delay: int or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution + :param early_stopping_patience: Minimum number of epochs or validation evaluations with no + primary metric improvement before + the run is stopped. Must be a positive integer. + :type early_stopping_patience: int or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution + :param enable_onnx_normalization: Enable normalization when exporting ONNX model. + :type enable_onnx_normalization: bool or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution + :param evaluation_frequency: Frequency to evaluate validation dataset to get metric scores. + Must be a positive integer. + :type evaluation_frequency: int or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution + :param gradient_accumulation_step: Gradient accumulation means running a configured number of + "GradAccumulationStep" steps without + updating the model weights while accumulating the gradients of those steps, and then using + the accumulated gradients to compute the weight updates. Must be a positive integer. + :type gradient_accumulation_step: int or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution + :param layers_to_freeze: Number of layers to freeze for the model. Must be a positive + integer. + For instance, passing 2 as value for 'seresnext' means + freezing layer0 and layer1. For a full list of models supported and details on layer freeze, + please + see: https://learn.microsoft.com/azure/machine-learning/reference-automl-images-hyperparameters#model-agnostic-hyperparameters. # pylint: disable=line-too-long + :type layers_to_freeze: int or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution + :param learning_rate: Initial learning rate. Must be a float in the range [0, 1]. + :type learning_rate: float or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution + :param learning_rate_scheduler: Type of learning rate scheduler. Must be 'warmup_cosine' or + 'step'. + :type learning_rate_scheduler: str or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution + :param model_name: Name of the model to use for training. + For more information on the available models please visit the official documentation: + https://learn.microsoft.com/azure/machine-learning/how-to-auto-train-image-models. + :type model_name: str or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution + :param momentum: Value of momentum when optimizer is 'sgd'. Must be a float in the range [0, + 1]. + :type momentum: float or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution + :param nesterov: Enable nesterov when optimizer is 'sgd'. + :type nesterov: bool or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution + :param number_of_epochs: Number of training epochs. Must be a positive integer. + :type number_of_epochs: int or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution + :param number_of_workers: Number of data loader workers. Must be a non-negative integer. + :type number_of_workers: int or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution + :param optimizer: Type of optimizer. Must be either 'sgd', 'adam', or 'adamw'. + :type optimizer: str or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution + :param random_seed: Random seed to be used when using deterministic training. + :type random_seed: int or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution + :param step_lr_gamma: Value of gamma when learning rate scheduler is 'step'. Must be a float + in the range [0, 1]. + :type step_lr_gamma: float or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution + :param step_lr_step_size: Value of step size when learning rate scheduler is 'step'. Must be + a positive integer. + :type step_lr_step_size: int or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution + :param training_batch_size: Training batch size. Must be a positive integer. + :type training_batch_size: int or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution + :param validation_batch_size: Validation batch size. Must be a positive integer. + :type validation_batch_size: int or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution + :param warmup_cosine_lr_cycles: Value of cosine cycle when learning rate scheduler is + 'warmup_cosine'. Must be a float in the range [0, 1]. + :type warmup_cosine_lr_cycles: float or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution + :param warmup_cosine_lr_warmup_epochs: Value of warmup epochs when learning rate scheduler is + 'warmup_cosine'. Must be a positive integer. + :type warmup_cosine_lr_warmup_epochs: int or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution + :param weight_decay: Value of weight decay when optimizer is 'sgd', 'adam', or 'adamw'. Must + be a float in the range[0, 1]. + :type weight_decay: float or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution + :param training_crop_size: Image crop size that is input to the neural network for the + training dataset. Must be a positive integer. + :type training_crop_size: int or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution + :param validation_crop_size: Image crop size that is input to the neural network for the + validation dataset. Must be a positive integer. + :type validation_crop_size: int or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution + :param validation_resize_size: Image size to which to resize before cropping for validation + dataset. Must be a positive integer. + :type validation_resize_size: int or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution + :param weighted_loss: Weighted loss. The accepted values are 0 for no weighted loss. + 1 for weighted loss with sqrt.(class_weights). 2 for weighted loss with class_weights. Must be + 0 or 1 or 2. + :type weighted_loss: int or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_automl_image.py + :start-after: [START automl.automl_image_job.image_classification_search_space] + :end-before: [END automl.automl_image_job.image_classification_search_space] + :language: python + :dedent: 8 + :caption: Defining an automl image classification search space + """ + + def __init__( + self, + *, + ams_gradient: Optional[Union[bool, SweepDistribution]] = None, + beta1: Optional[Union[float, SweepDistribution]] = None, + beta2: Optional[Union[float, SweepDistribution]] = None, + distributed: Optional[Union[bool, SweepDistribution]] = None, + early_stopping: Optional[Union[bool, SweepDistribution]] = None, + early_stopping_delay: Optional[Union[int, SweepDistribution]] = None, + early_stopping_patience: Optional[Union[int, SweepDistribution]] = None, + enable_onnx_normalization: Optional[Union[bool, SweepDistribution]] = None, + evaluation_frequency: Optional[Union[int, SweepDistribution]] = None, + gradient_accumulation_step: Optional[Union[int, SweepDistribution]] = None, + layers_to_freeze: Optional[Union[int, SweepDistribution]] = None, + learning_rate: Optional[Union[float, SweepDistribution]] = None, + learning_rate_scheduler: Optional[Union[str, SweepDistribution]] = None, + model_name: Optional[Union[str, SweepDistribution]] = None, + momentum: Optional[Union[float, SweepDistribution]] = None, + nesterov: Optional[Union[bool, SweepDistribution]] = None, + number_of_epochs: Optional[Union[int, SweepDistribution]] = None, + number_of_workers: Optional[Union[int, SweepDistribution]] = None, + optimizer: Optional[Union[str, SweepDistribution]] = None, + random_seed: Optional[Union[int, SweepDistribution]] = None, + step_lr_gamma: Optional[Union[float, SweepDistribution]] = None, + step_lr_step_size: Optional[Union[int, SweepDistribution]] = None, + training_batch_size: Optional[Union[int, SweepDistribution]] = None, + validation_batch_size: Optional[Union[int, SweepDistribution]] = None, + warmup_cosine_lr_cycles: Optional[Union[float, SweepDistribution]] = None, + warmup_cosine_lr_warmup_epochs: Optional[Union[int, SweepDistribution]] = None, + weight_decay: Optional[Union[float, SweepDistribution]] = None, + training_crop_size: Optional[Union[int, SweepDistribution]] = None, + validation_crop_size: Optional[Union[int, SweepDistribution]] = None, + validation_resize_size: Optional[Union[int, SweepDistribution]] = None, + weighted_loss: Optional[Union[int, SweepDistribution]] = None, + ) -> None: + self.ams_gradient = ams_gradient + self.beta1 = beta1 + self.beta2 = beta2 + self.distributed = distributed + self.early_stopping = early_stopping + self.early_stopping_delay = early_stopping_delay + self.early_stopping_patience = early_stopping_patience + self.enable_onnx_normalization = enable_onnx_normalization + self.evaluation_frequency = evaluation_frequency + self.gradient_accumulation_step = gradient_accumulation_step + self.layers_to_freeze = layers_to_freeze + self.learning_rate = learning_rate + self.learning_rate_scheduler = learning_rate_scheduler + self.model_name = model_name + self.momentum = momentum + self.nesterov = nesterov + self.number_of_epochs = number_of_epochs + self.number_of_workers = number_of_workers + self.optimizer = optimizer + self.random_seed = random_seed + self.step_lr_gamma = step_lr_gamma + self.step_lr_step_size = step_lr_step_size + self.training_batch_size = training_batch_size + self.validation_batch_size = validation_batch_size + self.warmup_cosine_lr_cycles = warmup_cosine_lr_cycles + self.warmup_cosine_lr_warmup_epochs = warmup_cosine_lr_warmup_epochs + self.weight_decay = weight_decay + self.training_crop_size = training_crop_size + self.validation_crop_size = validation_crop_size + self.validation_resize_size = validation_resize_size + self.weighted_loss = weighted_loss + + def _to_rest_object(self) -> ImageModelDistributionSettingsClassification: + return ImageModelDistributionSettingsClassification( + ams_gradient=_convert_to_rest_object(self.ams_gradient) if self.ams_gradient is not None else None, + beta1=_convert_to_rest_object(self.beta1) if self.beta1 is not None else None, + beta2=_convert_to_rest_object(self.beta2) if self.beta2 is not None else None, + distributed=_convert_to_rest_object(self.distributed) if self.distributed is not None else None, + early_stopping=_convert_to_rest_object(self.early_stopping) if self.early_stopping is not None else None, + early_stopping_delay=( + _convert_to_rest_object(self.early_stopping_delay) if self.early_stopping_delay is not None else None + ), + early_stopping_patience=( + _convert_to_rest_object(self.early_stopping_patience) + if self.early_stopping_patience is not None + else None + ), + enable_onnx_normalization=( + _convert_to_rest_object(self.enable_onnx_normalization) + if self.enable_onnx_normalization is not None + else None + ), + evaluation_frequency=( + _convert_to_rest_object(self.evaluation_frequency) if self.evaluation_frequency is not None else None + ), + gradient_accumulation_step=( + _convert_to_rest_object(self.gradient_accumulation_step) + if self.gradient_accumulation_step is not None + else None + ), + layers_to_freeze=( + _convert_to_rest_object(self.layers_to_freeze) if self.layers_to_freeze is not None else None + ), + learning_rate=_convert_to_rest_object(self.learning_rate) if self.learning_rate is not None else None, + learning_rate_scheduler=( + _convert_to_rest_object(self.learning_rate_scheduler) + if self.learning_rate_scheduler is not None + else None + ), + model_name=_convert_to_rest_object(self.model_name) if self.model_name is not None else None, + momentum=_convert_to_rest_object(self.momentum) if self.momentum is not None else None, + nesterov=_convert_to_rest_object(self.nesterov) if self.nesterov is not None else None, + number_of_epochs=( + _convert_to_rest_object(self.number_of_epochs) if self.number_of_epochs is not None else None + ), + number_of_workers=( + _convert_to_rest_object(self.number_of_workers) if self.number_of_workers is not None else None + ), + optimizer=_convert_to_rest_object(self.optimizer) if self.optimizer is not None else None, + random_seed=_convert_to_rest_object(self.random_seed) if self.random_seed is not None else None, + step_lr_gamma=_convert_to_rest_object(self.step_lr_gamma) if self.step_lr_gamma is not None else None, + step_lr_step_size=( + _convert_to_rest_object(self.step_lr_step_size) if self.step_lr_step_size is not None else None + ), + training_batch_size=( + _convert_to_rest_object(self.training_batch_size) if self.training_batch_size is not None else None + ), + validation_batch_size=( + _convert_to_rest_object(self.validation_batch_size) if self.validation_batch_size is not None else None + ), + warmup_cosine_lr_cycles=( + _convert_to_rest_object(self.warmup_cosine_lr_cycles) + if self.warmup_cosine_lr_cycles is not None + else None + ), + warmup_cosine_lr_warmup_epochs=( + _convert_to_rest_object(self.warmup_cosine_lr_warmup_epochs) + if self.warmup_cosine_lr_warmup_epochs is not None + else None + ), + weight_decay=_convert_to_rest_object(self.weight_decay) if self.weight_decay is not None else None, + training_crop_size=( + _convert_to_rest_object(self.training_crop_size) if self.training_crop_size is not None else None + ), + validation_crop_size=( + _convert_to_rest_object(self.validation_crop_size) if self.validation_crop_size is not None else None + ), + validation_resize_size=( + _convert_to_rest_object(self.validation_resize_size) + if self.validation_resize_size is not None + else None + ), + weighted_loss=_convert_to_rest_object(self.weighted_loss) if self.weighted_loss is not None else None, + ) + + @classmethod + def _from_rest_object(cls, obj: ImageModelDistributionSettingsClassification) -> "ImageClassificationSearchSpace": + return cls( + ams_gradient=_convert_from_rest_object(obj.ams_gradient) if obj.ams_gradient is not None else None, + beta1=_convert_from_rest_object(obj.beta1) if obj.beta1 is not None else None, + beta2=_convert_from_rest_object(obj.beta2) if obj.beta2 is not None else None, + distributed=_convert_from_rest_object(obj.distributed) if obj.distributed is not None else None, + early_stopping=_convert_from_rest_object(obj.early_stopping) if obj.early_stopping is not None else None, + early_stopping_delay=( + _convert_from_rest_object(obj.early_stopping_delay) if obj.early_stopping_delay is not None else None + ), + early_stopping_patience=( + _convert_from_rest_object(obj.early_stopping_patience) + if obj.early_stopping_patience is not None + else None + ), + enable_onnx_normalization=( + _convert_from_rest_object(obj.enable_onnx_normalization) + if obj.enable_onnx_normalization is not None + else None + ), + evaluation_frequency=( + _convert_from_rest_object(obj.evaluation_frequency) if obj.evaluation_frequency is not None else None + ), + gradient_accumulation_step=( + _convert_from_rest_object(obj.gradient_accumulation_step) + if obj.gradient_accumulation_step is not None + else None + ), + layers_to_freeze=( + _convert_from_rest_object(obj.layers_to_freeze) if obj.layers_to_freeze is not None else None + ), + learning_rate=_convert_from_rest_object(obj.learning_rate) if obj.learning_rate is not None else None, + learning_rate_scheduler=( + _convert_from_rest_object(obj.learning_rate_scheduler) + if obj.learning_rate_scheduler is not None + else None + ), + model_name=_convert_from_rest_object(obj.model_name) if obj.model_name is not None else None, + momentum=_convert_from_rest_object(obj.momentum) if obj.momentum is not None else None, + nesterov=_convert_from_rest_object(obj.nesterov) if obj.nesterov is not None else None, + number_of_epochs=( + _convert_from_rest_object(obj.number_of_epochs) if obj.number_of_epochs is not None else None + ), + number_of_workers=( + _convert_from_rest_object(obj.number_of_workers) if obj.number_of_workers is not None else None + ), + optimizer=_convert_from_rest_object(obj.optimizer) if obj.optimizer is not None else None, + random_seed=_convert_from_rest_object(obj.random_seed) if obj.random_seed is not None else None, + step_lr_gamma=_convert_from_rest_object(obj.step_lr_gamma) if obj.step_lr_gamma is not None else None, + step_lr_step_size=( + _convert_from_rest_object(obj.step_lr_step_size) if obj.step_lr_step_size is not None else None + ), + training_batch_size=( + _convert_from_rest_object(obj.training_batch_size) if obj.training_batch_size is not None else None + ), + validation_batch_size=( + _convert_from_rest_object(obj.validation_batch_size) if obj.validation_batch_size is not None else None + ), + warmup_cosine_lr_cycles=( + _convert_from_rest_object(obj.warmup_cosine_lr_cycles) + if obj.warmup_cosine_lr_cycles is not None + else None + ), + warmup_cosine_lr_warmup_epochs=( + _convert_from_rest_object(obj.warmup_cosine_lr_warmup_epochs) + if obj.warmup_cosine_lr_warmup_epochs is not None + else None + ), + weight_decay=_convert_from_rest_object(obj.weight_decay) if obj.weight_decay is not None else None, + training_crop_size=( + _convert_from_rest_object(obj.training_crop_size) if obj.training_crop_size is not None else None + ), + validation_crop_size=( + _convert_from_rest_object(obj.validation_crop_size) if obj.validation_crop_size is not None else None + ), + validation_resize_size=( + _convert_from_rest_object(obj.validation_resize_size) + if obj.validation_resize_size is not None + else None + ), + weighted_loss=_convert_from_rest_object(obj.weighted_loss) if obj.weighted_loss is not None else None, + ) + + @classmethod + def _from_search_space_object(cls, obj: SearchSpace) -> "ImageClassificationSearchSpace": + return cls( + ams_gradient=obj.ams_gradient if hasattr(obj, "ams_gradient") else None, + beta1=obj.beta1 if hasattr(obj, "beta1") else None, + beta2=obj.beta2 if hasattr(obj, "beta2") else None, + distributed=obj.distributed if hasattr(obj, "distributed") else None, + early_stopping=obj.early_stopping if hasattr(obj, "early_stopping") else None, + early_stopping_delay=obj.early_stopping_delay if hasattr(obj, "early_stopping_delay") else None, + early_stopping_patience=obj.early_stopping_patience if hasattr(obj, "early_stopping_patience") else None, + enable_onnx_normalization=( + obj.enable_onnx_normalization if hasattr(obj, "enable_onnx_normalization") else None + ), + evaluation_frequency=obj.evaluation_frequency if hasattr(obj, "evaluation_frequency") else None, + gradient_accumulation_step=( + obj.gradient_accumulation_step if hasattr(obj, "gradient_accumulation_step") else None + ), + layers_to_freeze=obj.layers_to_freeze if hasattr(obj, "layers_to_freeze") else None, + learning_rate=obj.learning_rate if hasattr(obj, "learning_rate") else None, + learning_rate_scheduler=obj.learning_rate_scheduler if hasattr(obj, "learning_rate_scheduler") else None, + model_name=obj.model_name if hasattr(obj, "model_name") else None, + momentum=obj.momentum if hasattr(obj, "momentum") else None, + nesterov=obj.nesterov if hasattr(obj, "nesterov") else None, + number_of_epochs=obj.number_of_epochs if hasattr(obj, "number_of_epochs") else None, + number_of_workers=obj.number_of_workers if hasattr(obj, "number_of_workers") else None, + optimizer=obj.optimizer if hasattr(obj, "optimizer") else None, + random_seed=obj.random_seed if hasattr(obj, "random_seed") else None, + step_lr_gamma=obj.step_lr_gamma if hasattr(obj, "step_lr_gamma") else None, + step_lr_step_size=obj.step_lr_step_size if hasattr(obj, "step_lr_step_size") else None, + training_batch_size=obj.training_batch_size if hasattr(obj, "training_batch_size") else None, + validation_batch_size=obj.validation_batch_size if hasattr(obj, "validation_batch_size") else None, + warmup_cosine_lr_cycles=obj.warmup_cosine_lr_cycles if hasattr(obj, "warmup_cosine_lr_cycles") else None, + warmup_cosine_lr_warmup_epochs=( + obj.warmup_cosine_lr_warmup_epochs if hasattr(obj, "warmup_cosine_lr_warmup_epochs") else None + ), + weight_decay=obj.weight_decay if hasattr(obj, "weight_decay") else None, + training_crop_size=obj.training_crop_size if hasattr(obj, "training_crop_size") else None, + validation_crop_size=obj.validation_crop_size if hasattr(obj, "validation_crop_size") else None, + validation_resize_size=obj.validation_resize_size if hasattr(obj, "validation_resize_size") else None, + weighted_loss=obj.weighted_loss if hasattr(obj, "weighted_loss") else None, + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ImageClassificationSearchSpace): + return NotImplemented + + return ( + self.ams_gradient == other.ams_gradient + and self.beta1 == other.beta1 + and self.beta2 == other.beta2 + and self.distributed == other.distributed + and self.early_stopping == other.early_stopping + and self.early_stopping_delay == other.early_stopping_delay + and self.early_stopping_patience == other.early_stopping_patience + and self.enable_onnx_normalization == other.enable_onnx_normalization + and self.evaluation_frequency == other.evaluation_frequency + and self.gradient_accumulation_step == other.gradient_accumulation_step + and self.layers_to_freeze == other.layers_to_freeze + and self.learning_rate == other.learning_rate + and self.learning_rate_scheduler == other.learning_rate_scheduler + and self.model_name == other.model_name + and self.momentum == other.momentum + and self.nesterov == other.nesterov + and self.number_of_epochs == other.number_of_epochs + and self.number_of_workers == other.number_of_workers + and self.optimizer == other.optimizer + and self.random_seed == other.random_seed + and self.step_lr_gamma == other.step_lr_gamma + and self.step_lr_step_size == other.step_lr_step_size + and self.training_batch_size == other.training_batch_size + and self.validation_batch_size == other.validation_batch_size + and self.warmup_cosine_lr_cycles == other.warmup_cosine_lr_cycles + and self.warmup_cosine_lr_warmup_epochs == other.warmup_cosine_lr_warmup_epochs + and self.weight_decay == other.weight_decay + and self.training_crop_size == other.training_crop_size + and self.validation_crop_size == other.validation_crop_size + and self.validation_resize_size == other.validation_resize_size + and self.weighted_loss == other.weighted_loss + ) + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_instance_segmentation_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_instance_segmentation_job.py new file mode 100644 index 00000000..c97d3c11 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_instance_segmentation_job.py @@ -0,0 +1,249 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from typing import Any, Dict, Optional, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import AutoMLJob as RestAutoMLJob +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + ImageInstanceSegmentation as RestImageInstanceSegmentation, +) +from azure.ai.ml._restclient.v2023_04_01_preview.models import InstanceSegmentationPrimaryMetrics, JobBase, TaskType +from azure.ai.ml._utils.utils import camel_to_snake, is_data_binding_expression +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY +from azure.ai.ml.constants._job.automl import AutoMLConstants +from azure.ai.ml.entities._credentials import _BaseJobIdentityConfiguration +from azure.ai.ml.entities._job._input_output_helpers import from_rest_data_outputs, to_rest_data_outputs +from azure.ai.ml.entities._job.automl.image.automl_image_object_detection_base import AutoMLImageObjectDetectionBase +from azure.ai.ml.entities._job.automl.image.image_limit_settings import ImageLimitSettings +from azure.ai.ml.entities._job.automl.image.image_model_settings import ImageModelSettingsObjectDetection +from azure.ai.ml.entities._job.automl.image.image_sweep_settings import ImageSweepSettings +from azure.ai.ml.entities._util import load_from_dict + + +class ImageInstanceSegmentationJob(AutoMLImageObjectDetectionBase): + """Configuration for AutoML Image Instance Segmentation job. + + :keyword primary_metric: The primary metric to use for optimization. + :paramtype primary_metric: Optional[str, ~azure.ai.ml.automl.InstanceSegmentationPrimaryMetrics] + :keyword kwargs: Job-specific arguments. + :paramtype kwargs: Dict[str, Any] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_automl_image.py + :start-after: [START automl.automl_image_job.image_instance_segmentation_job] + :end-before: [END automl.automl_image_job.image_instance_segmentation_job] + :language: python + :dedent: 8 + :caption: creating an automl image instance segmentation job + """ + + _DEFAULT_PRIMARY_METRIC = InstanceSegmentationPrimaryMetrics.MEAN_AVERAGE_PRECISION + + def __init__( + self, + *, + primary_metric: Optional[Union[str, InstanceSegmentationPrimaryMetrics]] = None, + **kwargs: Any, + ) -> None: + # Extract any super class init settings + limits = kwargs.pop("limits", None) + sweep = kwargs.pop("sweep", None) + training_parameters = kwargs.pop("training_parameters", None) + search_space = kwargs.pop("search_space", None) + + super().__init__( + task_type=TaskType.IMAGE_INSTANCE_SEGMENTATION, + limits=limits, + sweep=sweep, + training_parameters=training_parameters, + search_space=search_space, + **kwargs, + ) + self.primary_metric = primary_metric or ImageInstanceSegmentationJob._DEFAULT_PRIMARY_METRIC + + @property + def primary_metric(self) -> Union[str, InstanceSegmentationPrimaryMetrics]: + return self._primary_metric + + @primary_metric.setter + def primary_metric(self, value: Union[str, InstanceSegmentationPrimaryMetrics]) -> None: + if is_data_binding_expression(str(value), ["parent"]): + self._primary_metric = value + return + self._primary_metric = ( + ImageInstanceSegmentationJob._DEFAULT_PRIMARY_METRIC + if value is None + else InstanceSegmentationPrimaryMetrics[camel_to_snake(value).upper()] + ) + + def _to_rest_object(self) -> JobBase: + image_instance_segmentation_task = RestImageInstanceSegmentation( + target_column_name=self.target_column_name, + training_data=self.training_data, + validation_data=self.validation_data, + validation_data_size=self.validation_data_size, + limit_settings=self._limits._to_rest_object() if self._limits else None, + sweep_settings=self._sweep._to_rest_object() if self._sweep else None, + model_settings=self._training_parameters._to_rest_object() if self._training_parameters else None, + search_space=( + [entry._to_rest_object() for entry in self._search_space if entry is not None] + if self._search_space is not None + else None + ), + primary_metric=self.primary_metric, + log_verbosity=self.log_verbosity, + ) + # resolve data inputs in rest obj + self._resolve_data_inputs(image_instance_segmentation_task) + + properties = RestAutoMLJob( + display_name=self.display_name, + description=self.description, + experiment_name=self.experiment_name, + tags=self.tags, + compute_id=self.compute, + properties=self.properties, + environment_id=self.environment_id, + environment_variables=self.environment_variables, + services=self.services, + outputs=to_rest_data_outputs(self.outputs), + resources=self.resources, + task_details=image_instance_segmentation_task, + identity=self.identity._to_job_rest_object() if self.identity else None, + queue_settings=self.queue_settings, + ) + + result = JobBase(properties=properties) + result.name = self.name + return result + + @classmethod + def _from_rest_object(cls, obj: JobBase) -> "ImageInstanceSegmentationJob": + properties: RestAutoMLJob = obj.properties + task_details: RestImageInstanceSegmentation = properties.task_details + + job_args_dict = { + "id": obj.id, + "name": obj.name, + "description": properties.description, + "tags": properties.tags, + "properties": properties.properties, + "experiment_name": properties.experiment_name, + "services": properties.services, + "status": properties.status, + "creation_context": obj.system_data, + "display_name": properties.display_name, + "compute": properties.compute_id, + "outputs": from_rest_data_outputs(properties.outputs), + "resources": properties.resources, + "identity": ( + _BaseJobIdentityConfiguration._from_rest_object(properties.identity) if properties.identity else None + ), + "queue_settings": properties.queue_settings, + } + + image_instance_segmentation_job = cls( + target_column_name=task_details.target_column_name, + training_data=task_details.training_data, + validation_data=task_details.validation_data, + validation_data_size=task_details.validation_data_size, + limits=( + ImageLimitSettings._from_rest_object(task_details.limit_settings) + if task_details.limit_settings + else None + ), + sweep=( + ImageSweepSettings._from_rest_object(task_details.sweep_settings) + if task_details.sweep_settings + else None + ), + training_parameters=( + ImageModelSettingsObjectDetection._from_rest_object(task_details.model_settings) + if task_details.model_settings + else None + ), + search_space=cls._get_search_space_from_str(task_details.search_space), + primary_metric=task_details.primary_metric, + log_verbosity=task_details.log_verbosity, + **job_args_dict, + ) + + image_instance_segmentation_job._restore_data_inputs() + + return image_instance_segmentation_job + + @classmethod + def _load_from_dict( + cls, + data: Dict, + context: Dict, + additional_message: str, + **kwargs: Any, + ) -> "ImageInstanceSegmentationJob": + from azure.ai.ml._schema.automl.image_vertical.image_object_detection import ImageInstanceSegmentationSchema + from azure.ai.ml._schema.pipeline.automl_node import ImageInstanceSegmentationNodeSchema + + inside_pipeline = kwargs.pop("inside_pipeline", False) + if inside_pipeline: + if context.get("inside_pipeline", None) is None: + context["inside_pipeline"] = True + loaded_data = load_from_dict( + ImageInstanceSegmentationNodeSchema, + data, + context, + additional_message, + **kwargs, + ) + else: + loaded_data = load_from_dict( + ImageInstanceSegmentationSchema, + data, + context, + additional_message, + **kwargs, + ) + job_instance = cls._create_instance_from_schema_dict(loaded_data) + return job_instance + + @classmethod + def _create_instance_from_schema_dict(cls, loaded_data: Dict) -> "ImageInstanceSegmentationJob": + loaded_data.pop(AutoMLConstants.TASK_TYPE_YAML, None) + data_settings = { + "training_data": loaded_data.pop("training_data"), + "target_column_name": loaded_data.pop("target_column_name"), + "validation_data": loaded_data.pop("validation_data", None), + "validation_data_size": loaded_data.pop("validation_data_size", None), + } + job = ImageInstanceSegmentationJob(**loaded_data) + job.set_data(**data_settings) + return job + + def _to_dict(self, inside_pipeline: bool = False) -> Dict: + from azure.ai.ml._schema.automl.image_vertical.image_object_detection import ImageInstanceSegmentationSchema + from azure.ai.ml._schema.pipeline.automl_node import ImageInstanceSegmentationNodeSchema + + schema_dict: dict = {} + if inside_pipeline: + schema_dict = ImageInstanceSegmentationNodeSchema( + context={BASE_PATH_CONTEXT_KEY: "./", "inside_pipeline": True} + ).dump(self) + else: + schema_dict = ImageInstanceSegmentationSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + + return schema_dict + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ImageInstanceSegmentationJob): + return NotImplemented + + if not super().__eq__(other): + return False + + return self.primary_metric == other.primary_metric + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_limit_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_limit_settings.py new file mode 100644 index 00000000..12ec8b57 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_limit_settings.py @@ -0,0 +1,117 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Optional + +from azure.ai.ml._restclient.v2023_04_01_preview.models import ImageLimitSettings as RestImageLimitSettings +from azure.ai.ml._utils.utils import from_iso_duration_format_mins, to_iso_duration_format_mins +from azure.ai.ml.entities._mixins import RestTranslatableMixin + + +class ImageLimitSettings(RestTranslatableMixin): + r"""Limit settings for AutoML Image Verticals. + + ImageLimitSettings is a class that contains the following parameters: max_concurrent_trials, max_trials, and \ + timeout_minutes. + + This is an optional configuration method to configure limits parameters such as timeouts etc. + + .. note:: + + The number of concurrent runs is gated on the resources available in the specified compute target. + Ensure that the compute target has the available resources for the desired concurrency. + + :keyword max_concurrent_trials: Maximum number of concurrent AutoML iterations, defaults to None. + :paramtype max_concurrent_trials: typing.Optional[int] + :keyword max_trials: Represents the maximum number of trials (children jobs). + :paramtype max_trials: typing.Optional[int] + :keyword timeout_minutes: AutoML job timeout. Defaults to None + :paramtype timeout_minutes: typing.Optional[int] + :raises ValueError: If max_concurrent_trials is not None and is not a positive integer. + :raises ValueError: If max_trials is not None and is not a positive integer. + :raises ValueError: If timeout_minutes is not None and is not a positive integer. + :return: ImageLimitSettings object. + :rtype: ImageLimitSettings + + .. tip:: + It's a good practice to match max_concurrent_trials count with the number of nodes in the cluster. + For example, if you have a cluster with 4 nodes, set max_concurrent_trials to 4. + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_automl_image.py + :start-after: [START automl.automl_image_job.image_limit_settings] + :end-before: [END automl.automl_image_job.image_limit_settings] + :language: python + :dedent: 8 + :caption: Defining the limit settings for an automl image job. + """ + + def __init__( + self, + *, + max_concurrent_trials: Optional[int] = None, + max_trials: Optional[int] = None, + timeout_minutes: Optional[int] = None, + ) -> None: + self.max_concurrent_trials = max_concurrent_trials + self.max_trials = max_trials + self.timeout_minutes = timeout_minutes + + def _to_rest_object(self) -> RestImageLimitSettings: + """Convert ImageLimitSettings objects to a rest object. + + :return: A rest object of ImageLimitSettings objects. + :rtype: RestImageLimitSettings + """ + return RestImageLimitSettings( + max_concurrent_trials=self.max_concurrent_trials, + max_trials=self.max_trials, + timeout=to_iso_duration_format_mins(self.timeout_minutes), + ) + + @classmethod + def _from_rest_object(cls, obj: RestImageLimitSettings) -> "ImageLimitSettings": + """Convert the rest object to a dict containing items to init the ImageLimitSettings objects. + + :param obj: Limit settings for the AutoML job in Rest format. + :type obj: RestImageLimitSettings + :return: Limit settings for an AutoML Image Vertical. + :rtype: ImageLimitSettings + """ + return cls( + max_concurrent_trials=obj.max_concurrent_trials, + max_trials=obj.max_trials, + timeout_minutes=from_iso_duration_format_mins(obj.timeout), + ) + + def __eq__(self, other: object) -> bool: + """Check equality between two ImageLimitSettings objects. + + This method check instances equality and returns True if both of + the instances have the same attributes with the same values. + + :param other: Any object + :type other: object + :return: True or False + :rtype: bool + """ + if not isinstance(other, ImageLimitSettings): + return NotImplemented + + return ( + self.max_concurrent_trials == other.max_concurrent_trials + and self.max_trials == other.max_trials + and self.timeout_minutes == other.timeout_minutes + ) + + def __ne__(self, other: object) -> bool: + """Check inequality between two ImageLimitSettings objects. + + :param other: Any object + :type other: object + :return: True or False + :rtype: bool + """ + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_model_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_model_settings.py new file mode 100644 index 00000000..890f987a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_model_settings.py @@ -0,0 +1,876 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Any, Optional + +# pylint: disable=R0902,too-many-locals +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + ImageModelSettingsClassification as RestImageModelSettingsClassification, +) +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + ImageModelSettingsObjectDetection as RestImageModelSettingsObjectDetection, +) +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + LearningRateScheduler, + LogTrainingMetrics, + LogValidationLoss, + ModelSize, + StochasticOptimizer, + ValidationMetricType, +) +from azure.ai.ml.entities._mixins import RestTranslatableMixin + + +class ImageModelDistributionSettings(RestTranslatableMixin): + """Model settings for all AutoML Image Verticals. + Please do not instantiate directly. Use the child classes instead. + + :param advanced_settings: Settings for advanced scenarios. + :type advanced_settings: str + :param ams_gradient: Enable AMSGrad when optimizer is 'adam' or 'adamw'. + :type ams_gradient: bool + :param beta1: Value of 'beta1' when optimizer is 'adam' or 'adamw'. Must be a float in the range + [0, 1]. + :type beta1: float + :param beta2: Value of 'beta2' when optimizer is 'adam' or 'adamw'. Must be a float in the range + [0, 1]. + :type beta2: float + :param checkpoint_frequency: Frequency to store model checkpoints. Must be a positive integer. + :type checkpoint_frequency: int + :param checkpoint_run_id: The id of a previous run that has a pretrained checkpoint for + incremental training. + :type checkpoint_run_id: str + :param distributed: Whether to use distributed training. + :type distributed: bool + :param early_stopping: Enable early stopping logic during training. + :type early_stopping: bool + :param early_stopping_delay: Minimum number of epochs or validation evaluations to wait before + primary metric improvement + is tracked for early stopping. Must be a positive integer. + :type early_stopping_delay: int + :param early_stopping_patience: Minimum number of epochs or validation evaluations with no + primary metric improvement before + the run is stopped. Must be a positive integer. + :type early_stopping_patience: int + :param enable_onnx_normalization: Enable normalization when exporting ONNX model. + :type enable_onnx_normalization: bool + :param evaluation_frequency: Frequency to evaluate validation dataset to get metric scores. Must + be a positive integer. + :type evaluation_frequency: int + :param gradient_accumulation_step: Gradient accumulation means running a configured number of + "GradAccumulationStep" steps without + updating the model weights while accumulating the gradients of those steps, and then using + the accumulated gradients to compute the weight updates. Must be a positive integer. + :type gradient_accumulation_step: int + :param layers_to_freeze: Number of layers to freeze for the model. Must be a positive integer. + For instance, passing 2 as value for 'seresnext' means + freezing layer0 and layer1. For a full list of models supported and details on layer freeze, + please + see: https://learn.microsoft.com/azure/machine-learning/how-to-auto-train-image-models. + :type layers_to_freeze: int + :param learning_rate: Initial learning rate. Must be a float in the range [0, 1]. + :type learning_rate: float + :param learning_rate_scheduler: Type of learning rate scheduler. Must be 'warmup_cosine' or + 'step'. Possible values include: "None", "WarmupCosine", "Step". + :type learning_rate_scheduler: str or + ~azure.mgmt.machinelearningservices.models.LearningRateScheduler + :param model_name: Name of the model to use for training. + For more information on the available models please visit the official documentation: + https://learn.microsoft.com/azure/machine-learning/how-to-auto-train-image-models. + :type model_name: str + :param momentum: Value of momentum when optimizer is 'sgd'. Must be a float in the range [0, 1]. + :type momentum: float + :param nesterov: Enable nesterov when optimizer is 'sgd'. + :type nesterov: bool + :param number_of_epochs: Number of training epochs. Must be a positive integer. + :type number_of_epochs: int + :param number_of_workers: Number of data loader workers. Must be a non-negative integer. + :type number_of_workers: int + :param optimizer: Type of optimizer. Possible values include: "None", "Sgd", "Adam", "Adamw". + :type optimizer: str or ~azure.mgmt.machinelearningservices.models.StochasticOptimizer + :param random_seed: Random seed to be used when using deterministic training. + :type random_seed: int + :param step_lr_gamma: Value of gamma when learning rate scheduler is 'step'. Must be a float in + the range [0, 1]. + :type step_lr_gamma: float + :param step_lr_step_size: Value of step size when learning rate scheduler is 'step'. Must be a + positive integer. + :type step_lr_step_size: int + :param training_batch_size: Training batch size. Must be a positive integer. + :type training_batch_size: int + :param validation_batch_size: Validation batch size. Must be a positive integer. + :type validation_batch_size: int + :param warmup_cosine_lr_cycles: Value of cosine cycle when learning rate scheduler is + 'warmup_cosine'. Must be a float in the range [0, 1]. + :type warmup_cosine_lr_cycles: float + :param warmup_cosine_lr_warmup_epochs: Value of warmup epochs when learning rate scheduler is + 'warmup_cosine'. Must be a positive integer. + :type warmup_cosine_lr_warmup_epochs: int + :param weight_decay: Value of weight decay when optimizer is 'sgd', 'adam', or 'adamw'. Must be + a float in the range[0, 1]. + :type weight_decay: float + """ + + def __init__( + self, + *, + advanced_settings: Optional[str] = None, + ams_gradient: Optional[bool] = None, + beta1: Optional[float] = None, + beta2: Optional[float] = None, + checkpoint_frequency: Optional[int] = None, + checkpoint_run_id: Optional[str] = None, + distributed: Optional[bool] = None, + early_stopping: Optional[bool] = None, + early_stopping_delay: Optional[int] = None, + early_stopping_patience: Optional[int] = None, + enable_onnx_normalization: Optional[bool] = None, + evaluation_frequency: Optional[int] = None, + gradient_accumulation_step: Optional[int] = None, + layers_to_freeze: Optional[int] = None, + learning_rate: Optional[float] = None, + learning_rate_scheduler: Optional[LearningRateScheduler] = None, + model_name: Optional[str] = None, + momentum: Optional[float] = None, + nesterov: Optional[bool] = None, + number_of_epochs: Optional[int] = None, + number_of_workers: Optional[int] = None, + optimizer: Optional[StochasticOptimizer] = None, + random_seed: Optional[int] = None, + step_lr_gamma: Optional[float] = None, + step_lr_step_size: Optional[int] = None, + training_batch_size: Optional[int] = None, + validation_batch_size: Optional[int] = None, + warmup_cosine_lr_cycles: Optional[float] = None, + warmup_cosine_lr_warmup_epochs: Optional[int] = None, + weight_decay: Optional[float] = None, + ): + self.advanced_settings = advanced_settings + self.ams_gradient = ams_gradient + self.beta1 = beta1 + self.beta2 = beta2 + self.checkpoint_frequency = checkpoint_frequency + self.checkpoint_run_id = checkpoint_run_id + self.distributed = distributed + self.early_stopping = early_stopping + self.early_stopping_delay = early_stopping_delay + self.early_stopping_patience = early_stopping_patience + self.enable_onnx_normalization = enable_onnx_normalization + self.evaluation_frequency = evaluation_frequency + self.gradient_accumulation_step = gradient_accumulation_step + self.layers_to_freeze = layers_to_freeze + self.learning_rate = learning_rate + self.learning_rate_scheduler = learning_rate_scheduler + self.model_name = model_name + self.momentum = momentum + self.nesterov = nesterov + self.number_of_epochs = number_of_epochs + self.number_of_workers = number_of_workers + self.optimizer = optimizer + self.random_seed = random_seed + self.step_lr_gamma = step_lr_gamma + self.step_lr_step_size = step_lr_step_size + self.training_batch_size = training_batch_size + self.validation_batch_size = validation_batch_size + self.warmup_cosine_lr_cycles = warmup_cosine_lr_cycles + self.warmup_cosine_lr_warmup_epochs = warmup_cosine_lr_warmup_epochs + self.weight_decay = weight_decay + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ImageModelDistributionSettings): + return NotImplemented + + return ( + self.advanced_settings == other.advanced_settings + and self.ams_gradient == other.ams_gradient + and self.beta1 == other.beta1 + and self.beta2 == other.beta2 + and self.checkpoint_frequency == other.checkpoint_frequency + and self.checkpoint_run_id == other.checkpoint_run_id + and self.distributed == other.distributed + and self.early_stopping == other.early_stopping + and self.early_stopping_delay == other.early_stopping_delay + and self.early_stopping_patience == other.early_stopping_patience + and self.enable_onnx_normalization == other.enable_onnx_normalization + and self.evaluation_frequency == other.evaluation_frequency + and self.gradient_accumulation_step == other.gradient_accumulation_step + and self.layers_to_freeze == other.layers_to_freeze + and self.learning_rate == other.learning_rate + and self.learning_rate_scheduler == other.learning_rate_scheduler + and self.model_name == other.model_name + and self.momentum == other.momentum + and self.nesterov == other.nesterov + and self.number_of_epochs == other.number_of_epochs + and self.number_of_workers == other.number_of_workers + and self.optimizer == other.optimizer + and self.random_seed == other.random_seed + and self.step_lr_gamma == other.step_lr_gamma + and self.step_lr_step_size == other.step_lr_step_size + and self.training_batch_size == other.training_batch_size + and self.validation_batch_size == other.validation_batch_size + and self.warmup_cosine_lr_cycles == other.warmup_cosine_lr_cycles + and self.warmup_cosine_lr_warmup_epochs == other.warmup_cosine_lr_warmup_epochs + and self.weight_decay == other.weight_decay + ) + + +class ImageModelSettingsClassification(ImageModelDistributionSettings): + """Model settings for AutoML Image Classification tasks. + + :param advanced_settings: Settings for advanced scenarios. + :type advanced_settings: str + :param ams_gradient: Enable AMSGrad when optimizer is 'adam' or 'adamw'. + :type ams_gradient: bool + :param beta1: Value of 'beta1' when optimizer is 'adam' or 'adamw'. Must be a float in the range + [0, 1]. + :type beta1: float + :param beta2: Value of 'beta2' when optimizer is 'adam' or 'adamw'. Must be a float in the range + [0, 1]. + :type beta2: float + :param checkpoint_frequency: Frequency to store model checkpoints. Must be a positive integer. + :type checkpoint_frequency: int + :param checkpoint_run_id: The id of a previous run that has a pretrained checkpoint for + incremental training. + :type checkpoint_run_id: str + :param distributed: Whether to use distributed training. + :type distributed: bool + :param early_stopping: Enable early stopping logic during training. + :type early_stopping: bool + :param early_stopping_delay: Minimum number of epochs or validation evaluations to wait before + primary metric improvement + is tracked for early stopping. Must be a positive integer. + :type early_stopping_delay: int + :param early_stopping_patience: Minimum number of epochs or validation evaluations with no + primary metric improvement before + the run is stopped. Must be a positive integer. + :type early_stopping_patience: int + :param enable_onnx_normalization: Enable normalization when exporting ONNX model. + :type enable_onnx_normalization: bool + :param evaluation_frequency: Frequency to evaluate validation dataset to get metric scores. Must + be a positive integer. + :type evaluation_frequency: int + :param gradient_accumulation_step: Gradient accumulation means running a configured number of + "GradAccumulationStep" steps without + updating the model weights while accumulating the gradients of those steps, and then using + the accumulated gradients to compute the weight updates. Must be a positive integer. + :type gradient_accumulation_step: int + :param layers_to_freeze: Number of layers to freeze for the model. Must be a positive integer. + For instance, passing 2 as value for 'seresnext' means + freezing layer0 and layer1. For a full list of models supported and details on layer freeze, + please + see: https://learn.microsoft.com/azure/machine-learning/how-to-auto-train-image-models. + :type layers_to_freeze: int + :param learning_rate: Initial learning rate. Must be a float in the range [0, 1]. + :type learning_rate: float + :param learning_rate_scheduler: Type of learning rate scheduler. Must be 'warmup_cosine' or + 'step'. Possible values include: "None", "WarmupCosine", "Step". + :type learning_rate_scheduler: str or + ~azure.mgmt.machinelearningservices.models.LearningRateScheduler + :param model_name: Name of the model to use for training. + For more information on the available models please visit the official documentation: + https://learn.microsoft.com/azure/machine-learning/how-to-auto-train-image-models. + :type model_name: str + :param momentum: Value of momentum when optimizer is 'sgd'. Must be a float in the range [0, 1]. + :type momentum: float + :param nesterov: Enable nesterov when optimizer is 'sgd'. + :type nesterov: bool + :param number_of_epochs: Number of training epochs. Must be a positive integer. + :type number_of_epochs: int + :param number_of_workers: Number of data loader workers. Must be a non-negative integer. + :type number_of_workers: int + :param optimizer: Type of optimizer. Possible values include: "None", "Sgd", "Adam", "Adamw". + :type optimizer: str or ~azure.mgmt.machinelearningservices.models.StochasticOptimizer + :param random_seed: Random seed to be used when using deterministic training. + :type random_seed: int + :param step_lr_gamma: Value of gamma when learning rate scheduler is 'step'. Must be a float in + the range [0, 1]. + :type step_lr_gamma: float + :param step_lr_step_size: Value of step size when learning rate scheduler is 'step'. Must be a + positive integer. + :type step_lr_step_size: int + :param training_batch_size: Training batch size. Must be a positive integer. + :type training_batch_size: int + :param validation_batch_size: Validation batch size. Must be a positive integer. + :type validation_batch_size: int + :param warmup_cosine_lr_cycles: Value of cosine cycle when learning rate scheduler is + 'warmup_cosine'. Must be a float in the range [0, 1]. + :type warmup_cosine_lr_cycles: float + :param warmup_cosine_lr_warmup_epochs: Value of warmup epochs when learning rate scheduler is + 'warmup_cosine'. Must be a positive integer. + :type warmup_cosine_lr_warmup_epochs: int + :param weight_decay: Value of weight decay when optimizer is 'sgd', 'adam', or 'adamw'. Must be + a float in the range[0, 1]. + :type weight_decay: float + :param training_crop_size: Image crop size that is input to the neural network for the training + dataset. Must be a positive integer. + :type training_crop_size: int + :param validation_crop_size: Image crop size that is input to the neural network for the + validation dataset. Must be a positive integer. + :type validation_crop_size: int + :param validation_resize_size: Image size to which to resize before cropping for validation + dataset. Must be a positive integer. + :type validation_resize_size: int + :param weighted_loss: Weighted loss. The accepted values are 0 for no weighted loss. + 1 for weighted loss with sqrt.(class_weights). 2 for weighted loss with class_weights. Must be + 0 or 1 or 2. + :type weighted_loss: int + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_automl_image.py + :start-after: [START automl.automl_image_job.image_classification_model_settings] + :end-before: [END automl.automl_image_job.image_classification_model_settings] + :language: python + :dedent: 8 + :caption: Defining the automl image classification model settings. + """ + + def __init__( + self, + *, + advanced_settings: Optional[str] = None, + ams_gradient: Optional[bool] = None, + beta1: Optional[float] = None, + beta2: Optional[float] = None, + checkpoint_frequency: Optional[int] = None, + checkpoint_run_id: Optional[str] = None, + distributed: Optional[bool] = None, + early_stopping: Optional[bool] = None, + early_stopping_delay: Optional[int] = None, + early_stopping_patience: Optional[int] = None, + enable_onnx_normalization: Optional[bool] = None, + evaluation_frequency: Optional[int] = None, + gradient_accumulation_step: Optional[int] = None, + layers_to_freeze: Optional[int] = None, + learning_rate: Optional[float] = None, + learning_rate_scheduler: Optional[LearningRateScheduler] = None, + model_name: Optional[str] = None, + momentum: Optional[float] = None, + nesterov: Optional[bool] = None, + number_of_epochs: Optional[int] = None, + number_of_workers: Optional[int] = None, + optimizer: Optional[StochasticOptimizer] = None, + random_seed: Optional[int] = None, + step_lr_gamma: Optional[float] = None, + step_lr_step_size: Optional[int] = None, + training_batch_size: Optional[int] = None, + validation_batch_size: Optional[int] = None, + warmup_cosine_lr_cycles: Optional[float] = None, + warmup_cosine_lr_warmup_epochs: Optional[int] = None, + weight_decay: Optional[float] = None, + training_crop_size: Optional[int] = None, + validation_crop_size: Optional[int] = None, + validation_resize_size: Optional[int] = None, + weighted_loss: Optional[int] = None, + **kwargs: Any, + ): + super(ImageModelSettingsClassification, self).__init__( + advanced_settings=advanced_settings, + ams_gradient=ams_gradient, + beta1=beta1, + beta2=beta2, + checkpoint_frequency=checkpoint_frequency, + checkpoint_run_id=checkpoint_run_id, + distributed=distributed, + early_stopping=early_stopping, + early_stopping_delay=early_stopping_delay, + early_stopping_patience=early_stopping_patience, + enable_onnx_normalization=enable_onnx_normalization, + evaluation_frequency=evaluation_frequency, + gradient_accumulation_step=gradient_accumulation_step, + layers_to_freeze=layers_to_freeze, + learning_rate=learning_rate, + learning_rate_scheduler=learning_rate_scheduler, + model_name=model_name, + momentum=momentum, + nesterov=nesterov, + number_of_epochs=number_of_epochs, + number_of_workers=number_of_workers, + optimizer=optimizer, + random_seed=random_seed, + step_lr_gamma=step_lr_gamma, + step_lr_step_size=step_lr_step_size, + training_batch_size=training_batch_size, + validation_batch_size=validation_batch_size, + warmup_cosine_lr_cycles=warmup_cosine_lr_cycles, + warmup_cosine_lr_warmup_epochs=warmup_cosine_lr_warmup_epochs, + weight_decay=weight_decay, + **kwargs, + ) + self.training_crop_size = training_crop_size + self.validation_crop_size = validation_crop_size + self.validation_resize_size = validation_resize_size + self.weighted_loss = weighted_loss + + def _to_rest_object(self) -> RestImageModelSettingsClassification: + return RestImageModelSettingsClassification( + advanced_settings=self.advanced_settings, + ams_gradient=self.ams_gradient, + beta1=self.beta1, + beta2=self.beta2, + checkpoint_frequency=self.checkpoint_frequency, + checkpoint_run_id=self.checkpoint_run_id, + distributed=self.distributed, + early_stopping=self.early_stopping, + early_stopping_delay=self.early_stopping_delay, + early_stopping_patience=self.early_stopping_patience, + enable_onnx_normalization=self.enable_onnx_normalization, + evaluation_frequency=self.evaluation_frequency, + gradient_accumulation_step=self.gradient_accumulation_step, + layers_to_freeze=self.layers_to_freeze, + learning_rate=self.learning_rate, + learning_rate_scheduler=self.learning_rate_scheduler, + model_name=self.model_name, + momentum=self.momentum, + nesterov=self.nesterov, + number_of_epochs=self.number_of_epochs, + number_of_workers=self.number_of_workers, + optimizer=self.optimizer, + random_seed=self.random_seed, + step_lr_gamma=self.step_lr_gamma, + step_lr_step_size=self.step_lr_step_size, + training_batch_size=self.training_batch_size, + validation_batch_size=self.validation_batch_size, + warmup_cosine_lr_cycles=self.warmup_cosine_lr_cycles, + warmup_cosine_lr_warmup_epochs=self.warmup_cosine_lr_warmup_epochs, + weight_decay=self.weight_decay, + training_crop_size=self.training_crop_size, + validation_crop_size=self.validation_crop_size, + validation_resize_size=self.validation_resize_size, + weighted_loss=self.weighted_loss, + ) + + @classmethod + def _from_rest_object(cls, obj: RestImageModelSettingsClassification) -> "ImageModelSettingsClassification": + return cls( + advanced_settings=obj.advanced_settings, + ams_gradient=obj.ams_gradient, + beta1=obj.beta1, + beta2=obj.beta2, + checkpoint_frequency=obj.checkpoint_frequency, + checkpoint_run_id=obj.checkpoint_run_id, + distributed=obj.distributed, + early_stopping=obj.early_stopping, + early_stopping_delay=obj.early_stopping_delay, + early_stopping_patience=obj.early_stopping_patience, + enable_onnx_normalization=obj.enable_onnx_normalization, + evaluation_frequency=obj.evaluation_frequency, + gradient_accumulation_step=obj.gradient_accumulation_step, + layers_to_freeze=obj.layers_to_freeze, + learning_rate=obj.learning_rate, + learning_rate_scheduler=obj.learning_rate_scheduler, + model_name=obj.model_name, + momentum=obj.momentum, + nesterov=obj.nesterov, + number_of_epochs=obj.number_of_epochs, + number_of_workers=obj.number_of_workers, + optimizer=obj.optimizer, + random_seed=obj.random_seed, + step_lr_gamma=obj.step_lr_gamma, + step_lr_step_size=obj.step_lr_step_size, + training_batch_size=obj.training_batch_size, + validation_batch_size=obj.validation_batch_size, + warmup_cosine_lr_cycles=obj.warmup_cosine_lr_cycles, + warmup_cosine_lr_warmup_epochs=obj.warmup_cosine_lr_warmup_epochs, + weight_decay=obj.weight_decay, + training_crop_size=obj.training_crop_size, + validation_crop_size=obj.validation_crop_size, + validation_resize_size=obj.validation_resize_size, + weighted_loss=obj.weighted_loss, + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ImageModelSettingsClassification): + return NotImplemented + + return ( + super().__eq__(other) + and self.training_crop_size == other.training_crop_size + and self.validation_crop_size == other.validation_crop_size + and self.validation_resize_size == other.validation_resize_size + and self.weighted_loss == other.weighted_loss + ) + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + +class ImageModelSettingsObjectDetection(ImageModelDistributionSettings): + """Model settings for AutoML Image Object Detection Task. + + :param advanced_settings: Settings for advanced scenarios. + :type advanced_settings: str + :param ams_gradient: Enable AMSGrad when optimizer is 'adam' or 'adamw'. + :type ams_gradient: bool + :param beta1: Value of 'beta1' when optimizer is 'adam' or 'adamw'. Must be a float in the range + [0, 1]. + :type beta1: float + :param beta2: Value of 'beta2' when optimizer is 'adam' or 'adamw'. Must be a float in the range + [0, 1]. + :type beta2: float + :param checkpoint_frequency: Frequency to store model checkpoints. Must be a positive integer. + :type checkpoint_frequency: int + :param checkpoint_run_id: The id of a previous run that has a pretrained checkpoint for + incremental training. + :type checkpoint_run_id: str + :param distributed: Whether to use distributed training. + :type distributed: bool + :param early_stopping: Enable early stopping logic during training. + :type early_stopping: bool + :param early_stopping_delay: Minimum number of epochs or validation evaluations to wait before + primary metric improvement + is tracked for early stopping. Must be a positive integer. + :type early_stopping_delay: int + :param early_stopping_patience: Minimum number of epochs or validation evaluations with no + primary metric improvement before + the run is stopped. Must be a positive integer. + :type early_stopping_patience: int + :param enable_onnx_normalization: Enable normalization when exporting ONNX model. + :type enable_onnx_normalization: bool + :param evaluation_frequency: Frequency to evaluate validation dataset to get metric scores. Must + be a positive integer. + :type evaluation_frequency: int + :param gradient_accumulation_step: Gradient accumulation means running a configured number of + "GradAccumulationStep" steps without + updating the model weights while accumulating the gradients of those steps, and then using + the accumulated gradients to compute the weight updates. Must be a positive integer. + :type gradient_accumulation_step: int + :param layers_to_freeze: Number of layers to freeze for the model. Must be a positive integer. + For instance, passing 2 as value for 'seresnext' means + freezing layer0 and layer1. For a full list of models supported and details on layer freeze, + please + see: https://learn.microsoft.com/azure/machine-learning/how-to-auto-train-image-models. + :type layers_to_freeze: int + :param learning_rate: Initial learning rate. Must be a float in the range [0, 1]. + :type learning_rate: float + :param learning_rate_scheduler: Type of learning rate scheduler. Must be 'warmup_cosine' or + 'step'. Possible values include: "None", "WarmupCosine", "Step". + :type learning_rate_scheduler: str or + ~azure.mgmt.machinelearningservices.models.LearningRateScheduler + :param model_name: Name of the model to use for training. + For more information on the available models please visit the official documentation: + https://learn.microsoft.com/azure/machine-learning/how-to-auto-train-image-models. + :type model_name: str + :param momentum: Value of momentum when optimizer is 'sgd'. Must be a float in the range [0, 1]. + :type momentum: float + :param nesterov: Enable nesterov when optimizer is 'sgd'. + :type nesterov: bool + :param number_of_epochs: Number of training epochs. Must be a positive integer. + :type number_of_epochs: int + :param number_of_workers: Number of data loader workers. Must be a non-negative integer. + :type number_of_workers: int + :param optimizer: Type of optimizer. Possible values include: "None", "Sgd", "Adam", "Adamw". + :type optimizer: str or ~azure.mgmt.machinelearningservices.models.StochasticOptimizer + :param random_seed: Random seed to be used when using deterministic training. + :type random_seed: int + :param step_lr_gamma: Value of gamma when learning rate scheduler is 'step'. Must be a float in + the range [0, 1]. + :type step_lr_gamma: float + :param step_lr_step_size: Value of step size when learning rate scheduler is 'step'. Must be a + positive integer. + :type step_lr_step_size: int + :param training_batch_size: Training batch size. Must be a positive integer. + :type training_batch_size: int + :param validation_batch_size: Validation batch size. Must be a positive integer. + :type validation_batch_size: int + :param warmup_cosine_lr_cycles: Value of cosine cycle when learning rate scheduler is + 'warmup_cosine'. Must be a float in the range [0, 1]. + :type warmup_cosine_lr_cycles: float + :param warmup_cosine_lr_warmup_epochs: Value of warmup epochs when learning rate scheduler is + 'warmup_cosine'. Must be a positive integer. + :type warmup_cosine_lr_warmup_epochs: int + :param weight_decay: Value of weight decay when optimizer is 'sgd', 'adam', or 'adamw'. Must be + a float in the range[0, 1]. + :type weight_decay: float + :param box_detections_per_image: Maximum number of detections per image, for all classes. Must + be a positive integer. + Note: This settings is not supported for the 'yolov5' algorithm. + :type box_detections_per_image: int + :param box_score_threshold: During inference, only return proposals with a classification score + greater than + BoxScoreThreshold. Must be a float in the range[0, 1]. + :type box_score_threshold: float + :param image_size: Image size for train and validation. Must be a positive integer. + Note: The training run may get into CUDA OOM if the size is too big. + Note: This settings is only supported for the 'yolov5' algorithm. + :type image_size: int + :param max_size: Maximum size of the image to be rescaled before feeding it to the backbone. + Must be a positive integer. Note: training run may get into CUDA OOM if the size is too big. + Note: This settings is not supported for the 'yolov5' algorithm. + :type max_size: int + :param min_size: Minimum size of the image to be rescaled before feeding it to the backbone. + Must be a positive integer. Note: training run may get into CUDA OOM if the size is too big. + Note: This settings is not supported for the 'yolov5' algorithm. + :type min_size: int + :param model_size: Model size. Must be 'small', 'medium', 'large'. + Note: training run may get into CUDA OOM if the model size is too big. + Note: This settings is only supported for the 'yolov5' algorithm. Possible values include: + "None", "Small", "Medium", "Large", "ExtraLarge". + :type model_size: str or ~azure.mgmt.machinelearningservices.models.ModelSize + :param multi_scale: Enable multi-scale image by varying image size by +/- 50%. + Note: training run may get into CUDA OOM if no sufficient GPU memory. + Note: This settings is only supported for the 'yolov5' algorithm. + :type multi_scale: bool + :param nms_iou_threshold: IOU threshold used during inference in NMS post processing. Must be a + float in the range [0, 1]. + :type nms_iou_threshold: float + :param tile_grid_size: The grid size to use for tiling each image. Note: TileGridSize must not + be + None to enable small object detection logic. A string containing two integers in mxn format. + Note: This settings is not supported for the 'yolov5' algorithm. + :type tile_grid_size: str + :param tile_overlap_ratio: Overlap ratio between adjacent tiles in each dimension. Must be float + in the range [0, 1). + Note: This settings is not supported for the 'yolov5' algorithm. + :type tile_overlap_ratio: float + :param tile_predictions_nms_threshold: The IOU threshold to use to perform NMS while merging + predictions from tiles and image. + Used in validation/ inference. Must be float in the range [0, 1]. + Note: This settings is not supported for the 'yolov5' algorithm. + :type tile_predictions_nms_threshold: float + :param validation_iou_threshold: IOU threshold to use when computing validation metric. Must be + float in the range [0, 1]. + :type validation_iou_threshold: float + :param validation_metric_type: Metric computation method to use for validation metrics. Possible + values include: "None", "Coco", "Voc", "CocoVoc". + :type validation_metric_type: str or + ~azure.mgmt.machinelearningservices.models.ValidationMetricType + :param log_training_metrics: indicates whether or not to log training metrics + :type log_training_metrics: str or + ~azure.mgmt.machinelearningservices.models.LogTrainingMetrics + :param log_validation_loss: indicates whether or not to log validation loss + :type log_validation_loss: str or + ~azure.mgmt.machinelearningservices.models.LogValidationLoss + + .. literalinclude:: ../samples/ml_samples_automl_image.py + :start-after: [START automl.automl_image_job.image_object_detection_model_settings] + :end-before: [END automl.automl_image_job.image_object_detection_model_settings] + :language: python + :dedent: 8 + :caption: Defining the automl image object detection or instance segmentation model settings. + """ + + def __init__( + self, + *, + advanced_settings: Optional[str] = None, + ams_gradient: Optional[bool] = None, + beta1: Optional[float] = None, + beta2: Optional[float] = None, + checkpoint_frequency: Optional[int] = None, + checkpoint_run_id: Optional[str] = None, + distributed: Optional[bool] = None, + early_stopping: Optional[bool] = None, + early_stopping_delay: Optional[int] = None, + early_stopping_patience: Optional[int] = None, + enable_onnx_normalization: Optional[bool] = None, + evaluation_frequency: Optional[int] = None, + gradient_accumulation_step: Optional[int] = None, + layers_to_freeze: Optional[int] = None, + learning_rate: Optional[float] = None, + learning_rate_scheduler: Optional[LearningRateScheduler] = None, + model_name: Optional[str] = None, + momentum: Optional[float] = None, + nesterov: Optional[bool] = None, + number_of_epochs: Optional[int] = None, + number_of_workers: Optional[int] = None, + optimizer: Optional[StochasticOptimizer] = None, + random_seed: Optional[int] = None, + step_lr_gamma: Optional[float] = None, + step_lr_step_size: Optional[int] = None, + training_batch_size: Optional[int] = None, + validation_batch_size: Optional[int] = None, + warmup_cosine_lr_cycles: Optional[float] = None, + warmup_cosine_lr_warmup_epochs: Optional[int] = None, + weight_decay: Optional[float] = None, + box_detections_per_image: Optional[int] = None, + box_score_threshold: Optional[float] = None, + image_size: Optional[int] = None, + max_size: Optional[int] = None, + min_size: Optional[int] = None, + model_size: Optional[ModelSize] = None, + multi_scale: Optional[bool] = None, + nms_iou_threshold: Optional[float] = None, + tile_grid_size: Optional[str] = None, + tile_overlap_ratio: Optional[float] = None, + tile_predictions_nms_threshold: Optional[float] = None, + validation_iou_threshold: Optional[float] = None, + validation_metric_type: Optional[ValidationMetricType] = None, + log_training_metrics: Optional[LogTrainingMetrics] = None, + log_validation_loss: Optional[LogValidationLoss] = None, + **kwargs: Any, + ): + super(ImageModelSettingsObjectDetection, self).__init__( + advanced_settings=advanced_settings, + ams_gradient=ams_gradient, + beta1=beta1, + beta2=beta2, + checkpoint_frequency=checkpoint_frequency, + checkpoint_run_id=checkpoint_run_id, + distributed=distributed, + early_stopping=early_stopping, + early_stopping_delay=early_stopping_delay, + early_stopping_patience=early_stopping_patience, + enable_onnx_normalization=enable_onnx_normalization, + evaluation_frequency=evaluation_frequency, + gradient_accumulation_step=gradient_accumulation_step, + layers_to_freeze=layers_to_freeze, + learning_rate=learning_rate, + learning_rate_scheduler=learning_rate_scheduler, + model_name=model_name, + momentum=momentum, + nesterov=nesterov, + number_of_epochs=number_of_epochs, + number_of_workers=number_of_workers, + optimizer=optimizer, + random_seed=random_seed, + step_lr_gamma=step_lr_gamma, + step_lr_step_size=step_lr_step_size, + training_batch_size=training_batch_size, + validation_batch_size=validation_batch_size, + warmup_cosine_lr_cycles=warmup_cosine_lr_cycles, + warmup_cosine_lr_warmup_epochs=warmup_cosine_lr_warmup_epochs, + weight_decay=weight_decay, + **kwargs, + ) + self.box_detections_per_image = box_detections_per_image + self.box_score_threshold = box_score_threshold + self.image_size = image_size + self.max_size = max_size + self.min_size = min_size + self.model_size = model_size + self.multi_scale = multi_scale + self.nms_iou_threshold = nms_iou_threshold + self.tile_grid_size = tile_grid_size + self.tile_overlap_ratio = tile_overlap_ratio + self.tile_predictions_nms_threshold = tile_predictions_nms_threshold + self.validation_iou_threshold = validation_iou_threshold + self.validation_metric_type = validation_metric_type + self.log_training_metrics = log_training_metrics + self.log_validation_loss = log_validation_loss + + def _to_rest_object(self) -> RestImageModelSettingsObjectDetection: + return RestImageModelSettingsObjectDetection( + advanced_settings=self.advanced_settings, + ams_gradient=self.ams_gradient, + beta1=self.beta1, + beta2=self.beta2, + checkpoint_frequency=self.checkpoint_frequency, + checkpoint_run_id=self.checkpoint_run_id, + distributed=self.distributed, + early_stopping=self.early_stopping, + early_stopping_delay=self.early_stopping_delay, + early_stopping_patience=self.early_stopping_patience, + enable_onnx_normalization=self.enable_onnx_normalization, + evaluation_frequency=self.evaluation_frequency, + gradient_accumulation_step=self.gradient_accumulation_step, + layers_to_freeze=self.layers_to_freeze, + learning_rate=self.learning_rate, + learning_rate_scheduler=self.learning_rate_scheduler, + model_name=self.model_name, + momentum=self.momentum, + nesterov=self.nesterov, + number_of_epochs=self.number_of_epochs, + number_of_workers=self.number_of_workers, + optimizer=self.optimizer, + random_seed=self.random_seed, + step_lr_gamma=self.step_lr_gamma, + step_lr_step_size=self.step_lr_step_size, + training_batch_size=self.training_batch_size, + validation_batch_size=self.validation_batch_size, + warmup_cosine_lr_cycles=self.warmup_cosine_lr_cycles, + warmup_cosine_lr_warmup_epochs=self.warmup_cosine_lr_warmup_epochs, + weight_decay=self.weight_decay, + box_detections_per_image=self.box_detections_per_image, + box_score_threshold=self.box_score_threshold, + image_size=self.image_size, + max_size=self.max_size, + min_size=self.min_size, + model_size=self.model_size, + multi_scale=self.multi_scale, + nms_iou_threshold=self.nms_iou_threshold, + tile_grid_size=self.tile_grid_size, + tile_overlap_ratio=self.tile_overlap_ratio, + tile_predictions_nms_threshold=self.tile_predictions_nms_threshold, + validation_iou_threshold=self.validation_iou_threshold, + validation_metric_type=self.validation_metric_type, + log_training_metrics=self.log_training_metrics, + log_validation_loss=self.log_validation_loss, + ) + + @classmethod + def _from_rest_object(cls, obj: RestImageModelSettingsObjectDetection) -> "ImageModelSettingsObjectDetection": + return cls( + advanced_settings=obj.advanced_settings, + ams_gradient=obj.ams_gradient, + beta1=obj.beta1, + beta2=obj.beta2, + checkpoint_frequency=obj.checkpoint_frequency, + checkpoint_run_id=obj.checkpoint_run_id, + distributed=obj.distributed, + early_stopping=obj.early_stopping, + early_stopping_delay=obj.early_stopping_delay, + early_stopping_patience=obj.early_stopping_patience, + enable_onnx_normalization=obj.enable_onnx_normalization, + evaluation_frequency=obj.evaluation_frequency, + gradient_accumulation_step=obj.gradient_accumulation_step, + layers_to_freeze=obj.layers_to_freeze, + learning_rate=obj.learning_rate, + learning_rate_scheduler=obj.learning_rate_scheduler, + model_name=obj.model_name, + momentum=obj.momentum, + nesterov=obj.nesterov, + number_of_epochs=obj.number_of_epochs, + number_of_workers=obj.number_of_workers, + optimizer=obj.optimizer, + random_seed=obj.random_seed, + step_lr_gamma=obj.step_lr_gamma, + step_lr_step_size=obj.step_lr_step_size, + training_batch_size=obj.training_batch_size, + validation_batch_size=obj.validation_batch_size, + warmup_cosine_lr_cycles=obj.warmup_cosine_lr_cycles, + warmup_cosine_lr_warmup_epochs=obj.warmup_cosine_lr_warmup_epochs, + weight_decay=obj.weight_decay, + box_detections_per_image=obj.box_detections_per_image, + box_score_threshold=obj.box_score_threshold, + image_size=obj.image_size, + max_size=obj.max_size, + min_size=obj.min_size, + model_size=obj.model_size, + multi_scale=obj.multi_scale, + nms_iou_threshold=obj.nms_iou_threshold, + tile_grid_size=obj.tile_grid_size, + tile_overlap_ratio=obj.tile_overlap_ratio, + tile_predictions_nms_threshold=obj.tile_predictions_nms_threshold, + validation_iou_threshold=obj.validation_iou_threshold, + validation_metric_type=obj.validation_metric_type, + log_training_metrics=obj.log_training_metrics, + log_validation_loss=obj.log_validation_loss, + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ImageModelSettingsObjectDetection): + return NotImplemented + + return ( + super().__eq__(other) + and self.box_detections_per_image == other.box_detections_per_image + and self.box_score_threshold == other.box_score_threshold + and self.image_size == other.image_size + and self.max_size == other.max_size + and self.min_size == other.min_size + and self.model_size == other.model_size + and self.multi_scale == other.multi_scale + and self.nms_iou_threshold == other.nms_iou_threshold + and self.tile_grid_size == other.tile_grid_size + and self.tile_overlap_ratio == other.tile_overlap_ratio + and self.tile_predictions_nms_threshold == other.tile_predictions_nms_threshold + and self.validation_iou_threshold == other.validation_iou_threshold + and self.validation_metric_type == other.validation_metric_type + and self.log_training_metrics == other.log_training_metrics + and self.log_validation_loss == other.log_validation_loss + ) + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_object_detection_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_object_detection_job.py new file mode 100644 index 00000000..f8d070d2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_object_detection_job.py @@ -0,0 +1,240 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from typing import Any, Dict, Optional, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import AutoMLJob as RestAutoMLJob +from azure.ai.ml._restclient.v2023_04_01_preview.models import ImageObjectDetection as RestImageObjectDetection +from azure.ai.ml._restclient.v2023_04_01_preview.models import JobBase, ObjectDetectionPrimaryMetrics, TaskType +from azure.ai.ml._utils.utils import camel_to_snake, is_data_binding_expression +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY +from azure.ai.ml.constants._job.automl import AutoMLConstants +from azure.ai.ml.entities._credentials import _BaseJobIdentityConfiguration +from azure.ai.ml.entities._job._input_output_helpers import from_rest_data_outputs, to_rest_data_outputs +from azure.ai.ml.entities._job.automl.image.automl_image_object_detection_base import AutoMLImageObjectDetectionBase +from azure.ai.ml.entities._job.automl.image.image_limit_settings import ImageLimitSettings +from azure.ai.ml.entities._job.automl.image.image_model_settings import ImageModelSettingsObjectDetection +from azure.ai.ml.entities._job.automl.image.image_sweep_settings import ImageSweepSettings +from azure.ai.ml.entities._util import load_from_dict + + +class ImageObjectDetectionJob(AutoMLImageObjectDetectionBase): + """Configuration for AutoML Image Object Detection job. + + :keyword primary_metric: The primary metric to use for optimization. + :paramtype primary_metric: Optional[str, ~azure.ai.ml.ObjectDetectionPrimaryMetrics] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_automl_image.py + :start-after: [START automl.automl_image_job.image_object_detection_job] + :end-before: [END automl.automl_image_job.image_object_detection_job] + :language: python + :dedent: 8 + :caption: creating an automl image object detection job + """ + + _DEFAULT_PRIMARY_METRIC = ObjectDetectionPrimaryMetrics.MEAN_AVERAGE_PRECISION + + def __init__( + self, + *, + primary_metric: Optional[Union[str, ObjectDetectionPrimaryMetrics]] = None, + **kwargs: Any, + ) -> None: + + # Extract any super class init settings + limits = kwargs.pop("limits", None) + sweep = kwargs.pop("sweep", None) + training_parameters = kwargs.pop("training_parameters", None) + search_space = kwargs.pop("search_space", None) + + super().__init__( + task_type=TaskType.IMAGE_OBJECT_DETECTION, + limits=limits, + sweep=sweep, + training_parameters=training_parameters, + search_space=search_space, + **kwargs, + ) + + self.primary_metric = primary_metric or ImageObjectDetectionJob._DEFAULT_PRIMARY_METRIC + + @property + def primary_metric(self) -> Union[str, ObjectDetectionPrimaryMetrics]: + return self._primary_metric + + @primary_metric.setter + def primary_metric(self, value: Union[str, ObjectDetectionPrimaryMetrics]) -> None: + if is_data_binding_expression(str(value), ["parent"]): + self._primary_metric = value + return + self._primary_metric = ( + ImageObjectDetectionJob._DEFAULT_PRIMARY_METRIC + if value is None + else ObjectDetectionPrimaryMetrics[camel_to_snake(value).upper()] + ) + + def _to_rest_object(self) -> JobBase: + image_object_detection_task = RestImageObjectDetection( + target_column_name=self.target_column_name, + training_data=self.training_data, + validation_data=self.validation_data, + validation_data_size=self.validation_data_size, + limit_settings=self._limits._to_rest_object() if self._limits else None, + sweep_settings=self._sweep._to_rest_object() if self._sweep else None, + model_settings=self._training_parameters._to_rest_object() if self._training_parameters else None, + search_space=( + [entry._to_rest_object() for entry in self._search_space if entry is not None] + if self._search_space is not None + else None + ), + primary_metric=self.primary_metric, + log_verbosity=self.log_verbosity, + ) + # resolve data inputs in rest object + self._resolve_data_inputs(image_object_detection_task) + + properties = RestAutoMLJob( + display_name=self.display_name, + description=self.description, + experiment_name=self.experiment_name, + tags=self.tags, + compute_id=self.compute, + properties=self.properties, + environment_id=self.environment_id, + environment_variables=self.environment_variables, + services=self.services, + outputs=to_rest_data_outputs(self.outputs), + resources=self.resources, + task_details=image_object_detection_task, + identity=self.identity._to_job_rest_object() if self.identity else None, + queue_settings=self.queue_settings, + ) + + result = JobBase(properties=properties) + result.name = self.name + return result + + @classmethod + def _from_rest_object(cls, obj: JobBase) -> "ImageObjectDetectionJob": + properties: RestAutoMLJob = obj.properties + task_details: RestImageObjectDetection = properties.task_details + + job_args_dict = { + "id": obj.id, + "name": obj.name, + "description": properties.description, + "tags": properties.tags, + "properties": properties.properties, + "experiment_name": properties.experiment_name, + "services": properties.services, + "status": properties.status, + "creation_context": obj.system_data, + "display_name": properties.display_name, + "compute": properties.compute_id, + "outputs": from_rest_data_outputs(properties.outputs), + "resources": properties.resources, + "identity": ( + _BaseJobIdentityConfiguration._from_rest_object(properties.identity) if properties.identity else None + ), + "queue_settings": properties.queue_settings, + } + + image_object_detection_job = cls( + target_column_name=task_details.target_column_name, + training_data=task_details.training_data, + validation_data=task_details.validation_data, + validation_data_size=task_details.validation_data_size, + limits=( + ImageLimitSettings._from_rest_object(task_details.limit_settings) + if task_details.limit_settings + else None + ), + sweep=( + ImageSweepSettings._from_rest_object(task_details.sweep_settings) + if task_details.sweep_settings + else None + ), + training_parameters=( + ImageModelSettingsObjectDetection._from_rest_object(task_details.model_settings) + if task_details.model_settings + else None + ), + search_space=cls._get_search_space_from_str(task_details.search_space), + primary_metric=task_details.primary_metric, + log_verbosity=task_details.log_verbosity, + **job_args_dict, + ) + + image_object_detection_job._restore_data_inputs() + + return image_object_detection_job + + @classmethod + def _load_from_dict( + cls, + data: Dict, + context: Dict, + additional_message: str, + **kwargs: Any, + ) -> "ImageObjectDetectionJob": + from azure.ai.ml._schema.automl.image_vertical.image_object_detection import ImageObjectDetectionSchema + from azure.ai.ml._schema.pipeline.automl_node import ImageObjectDetectionNodeSchema + + if kwargs.pop("inside_pipeline", False): + if context.get("inside_pipeline", None) is None: + context["inside_pipeline"] = True + loaded_data = load_from_dict( + ImageObjectDetectionNodeSchema, + data, + context, + additional_message, + **kwargs, + ) + else: + loaded_data = load_from_dict(ImageObjectDetectionSchema, data, context, additional_message, **kwargs) + job_instance = cls._create_instance_from_schema_dict(loaded_data) + return job_instance + + @classmethod + def _create_instance_from_schema_dict(cls, loaded_data: Dict) -> "ImageObjectDetectionJob": + loaded_data.pop(AutoMLConstants.TASK_TYPE_YAML, None) + data_settings = { + "training_data": loaded_data.pop("training_data"), + "target_column_name": loaded_data.pop("target_column_name"), + "validation_data": loaded_data.pop("validation_data", None), + "validation_data_size": loaded_data.pop("validation_data_size", None), + } + job = ImageObjectDetectionJob(**loaded_data) + job.set_data(**data_settings) + return job + + def _to_dict(self, inside_pipeline: bool = False) -> Dict: + from azure.ai.ml._schema.automl.image_vertical.image_object_detection import ImageObjectDetectionSchema + from azure.ai.ml._schema.pipeline.automl_node import ImageObjectDetectionNodeSchema + + schema_dict: dict = {} + if inside_pipeline: + schema_dict = ImageObjectDetectionNodeSchema( + context={BASE_PATH_CONTEXT_KEY: "./", "inside_pipeline": True} + ).dump(self) + else: + schema_dict = ImageObjectDetectionSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + + return schema_dict + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ImageObjectDetectionJob): + return NotImplemented + + if not super().__eq__(other): + return False + + return self.primary_metric == other.primary_metric + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_object_detection_search_space.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_object_detection_search_space.py new file mode 100644 index 00000000..a9004d1e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_object_detection_search_space.py @@ -0,0 +1,899 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=R0902,too-many-locals + +from typing import Optional, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import ImageModelDistributionSettingsObjectDetection +from azure.ai.ml.entities._job.automl.search_space import SearchSpace +from azure.ai.ml.entities._job.automl.search_space_utils import _convert_from_rest_object, _convert_to_rest_object +from azure.ai.ml.entities._mixins import RestTranslatableMixin +from azure.ai.ml.sweep import ( + Choice, + LogNormal, + LogUniform, + Normal, + QLogNormal, + QLogUniform, + QNormal, + QUniform, + Randint, + Uniform, +) + + +class ImageObjectDetectionSearchSpace(RestTranslatableMixin): + """Search space for AutoML Image Object Detection and Image Instance Segmentation tasks. + + :param ams_gradient: Enable AMSGrad when optimizer is 'adam' or 'adamw'. + :type ams_gradient: bool or ~azure.ai.ml.entities.SweepDistribution + :param beta1: Value of 'beta1' when optimizer is 'adam' or 'adamw'. Must be a float in the + range [0, 1]. + :type beta1: float or ~azure.ai.ml.entities.SweepDistribution + :param beta2: Value of 'beta2' when optimizer is 'adam' or 'adamw'. Must be a float in the + range [0, 1]. + :type beta2: float or ~azure.ai.ml.entities.SweepDistribution + :param distributed: Whether to use distributer training. + :type distributed: bool or ~azure.ai.ml.entities.SweepDistribution + :param early_stopping: Enable early stopping logic during training. + :type early_stopping: bool or ~azure.ai.ml.entities.SweepDistribution + :param early_stopping_delay: Minimum number of epochs or validation evaluations to wait + before primary metric improvement + is tracked for early stopping. Must be a positive integer. + :type early_stopping_delay: int or ~azure.ai.ml.entities.SweepDistribution + :param early_stopping_patience: Minimum number of epochs or validation evaluations with no + primary metric improvement before the run is stopped. Must be a positive integer. + :type early_stopping_patience: int or ~azure.ai.ml.entities.SweepDistribution + :param enable_onnx_normalization: Enable normalization when exporting ONNX model. + :type enable_onnx_normalization: bool or ~azure.ai.ml.entities.SweepDistribution + :param evaluation_frequency: Frequency to evaluate validation dataset to get metric scores. + Must be a positive integer. + :type evaluation_frequency: int or ~azure.ai.ml.entities.SweepDistribution + :param gradient_accumulation_step: Gradient accumulation means running a configured number of + "GradAccumulationStep" steps without updating the model weights while accumulating the gradients of those steps, + and then using the accumulated gradients to compute the weight updates. Must be a positive integer. + :type gradient_accumulation_step: int or ~azure.ai.ml.entities.SweepDistribution + :param layers_to_freeze: Number of layers to freeze for the model. Must be a positive + integer. For instance, passing 2 as value for 'seresnext' means freezing layer0 and layer1. + For a full list of models supported and details on layer freeze, please + see: https://learn.microsoft.com/azure/machine-learning/reference-automl-images-hyperparameters#model-agnostic-hyperparameters. # pylint: disable=line-too-long + :type layers_to_freeze: int or ~azure.ai.ml.entities.SweepDistribution + :param learning_rate: Initial learning rate. Must be a float in the range [0, 1]. + :type learning_rate: float or ~azure.ai.ml.entities.SweepDistribution + :param learning_rate_scheduler: Type of learning rate scheduler. Must be 'warmup_cosine' or + 'step'. + :type learning_rate_scheduler: str or ~azure.ai.ml.entities.SweepDistribution + :param model_name: Name of the model to use for training. + For more information on the available models please visit the official documentation: + https://learn.microsoft.com/azure/machine-learning/how-to-auto-train-image-models. + :type model_name: str or ~azure.ai.ml.entities.SweepDistribution + :param momentum: Value of momentum when optimizer is 'sgd'. Must be a float in the range [0, + 1]. + :type momentum: float or ~azure.ai.ml.entities.SweepDistribution + :param nesterov: Enable nesterov when optimizer is 'sgd'. + :type nesterov: bool or ~azure.ai.ml.entities.SweepDistribution + :param number_of_epochs: Number of training epochs. Must be a positive integer. + :type number_of_epochs: int or ~azure.ai.ml.entities.SweepDistribution + :param number_of_workers: Number of data loader workers. Must be a non-negative integer. + :type number_of_workers: int or ~azure.ai.ml.entities.SweepDistribution + :param optimizer: Type of optimizer. Must be either 'sgd', 'adam', or 'adamw'. + :type optimizer: str or ~azure.ai.ml.entities.SweepDistribution + :param random_seed: Random seed to be used when using deterministic training. + :type random_seed: int or ~azure.ai.ml.entities.SweepDistribution + :param step_lr_gamma: Value of gamma when learning rate scheduler is 'step'. Must be a float + in the range [0, 1]. + :type step_lr_gamma: float or ~azure.ai.ml.entities.SweepDistribution + :param step_lr_step_size: Value of step size when learning rate scheduler is 'step'. Must be + a positive integer. + :type step_lr_step_size: int or ~azure.ai.ml.entities.SweepDistribution + :param training_batch_size: Training batch size. Must be a positive integer. + :type training_batch_size: int or ~azure.ai.ml.entities.SweepDistribution + :param validation_batch_size: Validation batch size. Must be a positive integer. + :type validation_batch_size: int or ~azure.ai.ml.entities.SweepDistribution + :param warmup_cosine_lr_cycles: Value of cosine cycle when learning rate scheduler is + 'warmup_cosine'. Must be a float in the range [0, 1]. + :type warmup_cosine_lr_cycles: float or ~azure.ai.ml.entities.SweepDistribution + :param warmup_cosine_lr_warmup_epochs: Value of warmup epochs when learning rate scheduler is + 'warmup_cosine'. Must be a positive integer. + :type warmup_cosine_lr_warmup_epochs: int or ~azure.ai.ml.entities.SweepDistribution + :param weight_decay: Value of weight decay when optimizer is 'sgd', 'adam', or 'adamw'. Must + be a float in the range[0, 1]. + :type weight_decay: int or ~azure.ai.ml.entities.SweepDistribution + :param box_detections_per_image: Maximum number of detections per image, for all classes. + Must be a positive integer. Note: This settings is not supported for the 'yolov5' algorithm. + :type box_detections_per_image: int or ~azure.ai.ml.entities.SweepDistribution + :param box_score_threshold: During inference, only return proposals with a classification + score greater than BoxScoreThreshold. Must be a float in the range[0, 1]. + :type box_score_threshold: float or ~azure.ai.ml.entities.SweepDistribution + :param image_size: Image size for train and validation. Must be a positive integer. + Note: The training run may get into CUDA OOM if the size is too big. + Note: This settings is only supported for the 'yolov5' algorithm. + :type image_size: int or ~azure.ai.ml.entities.SweepDistribution + :param max_size: Maximum size of the image to be rescaled before feeding it to the backbone. + Must be a positive integer. Note: training run may get into CUDA OOM if the size is too big. + Note: This settings is not supported for the 'yolov5' algorithm. + :type max_size: int or ~azure.ai.ml.entities.SweepDistribution + :param min_size: Minimum size of the image to be rescaled before feeding it to the backbone. + Must be a positive integer. Note: training run may get into CUDA OOM if the size is too big. + Note: This settings is not supported for the 'yolov5' algorithm. + :type min_size: int or ~azure.ai.ml.entities.SweepDistribution + :param model_size: Model size. Must be 'small', 'medium', 'large', or 'extra_large'. + Note: training run may get into CUDA OOM if the model size is too big. + Note: This settings is only supported for the 'yolov5' algorithm. + :type model_size: str or ~azure.ai.ml.entities.SweepDistribution + :param multi_scale: Enable multi-scale image by varying image size by +/- 50%. + Note: training run may get into CUDA OOM if no sufficient GPU memory. + Note: This settings is only supported for the 'yolov5' algorithm. + :type multi_scale: bool or ~azure.ai.ml.entities.SweepDistribution + :param nms_iou_threshold: IOU threshold used during inference in NMS post processing. Must be + float in the range [0, 1]. + :type nms_iou_threshold: float or ~azure.ai.ml.entities.SweepDistribution + :param tile_grid_size: The grid size to use for tiling each image. Note: TileGridSize must + not be None to enable small object detection logic. A string containing two integers in mxn format. + :type tile_grid_size: str or ~azure.ai.ml.entities.SweepDistribution + :param tile_overlap_ratio: Overlap ratio between adjacent tiles in each dimension. Must be + float in the range [0, 1). + :type tile_overlap_ratio: float or ~azure.ai.ml.entities.SweepDistribution + :param tile_predictions_nms_threshold: The IOU threshold to use to perform NMS while merging + predictions from tiles and image. Used in validation/ inference. Must be float in the range [0, 1]. + NMS: Non-maximum suppression. + :type tile_predictions_nms_threshold: float or ~azure.ai.ml.entities.SweepDistribution + :param validation_iou_threshold: IOU threshold to use when computing validation metric. Must + be float in the range [0, 1]. + :type validation_iou_threshold: float or ~azure.ai.ml.entities.SweepDistribution + :param validation_metric_type: Metric computation method to use for validation metrics. Must + be 'none', 'coco', 'voc', or 'coco_voc'. + :type validation_metric_type: str or ~azure.ai.ml.entities.SweepDistribution + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_automl_image.py + :start-after: [START automl.automl_image_job.image_object_detection_search_space] + :end-before: [END automl.automl_image_job.image_object_detection_search_space] + :language: python + :dedent: 8 + :caption: Defining an automl image object detection or instance segmentation search space + """ + + def __init__( + self, + *, + ams_gradient: Optional[ + Union[ + bool, + Choice, + LogNormal, + LogUniform, + Normal, + QLogNormal, + QLogUniform, + QNormal, + QUniform, + Randint, + Uniform, + ] + ] = None, + beta1: Optional[ + Union[ + float, + Choice, + LogNormal, + LogUniform, + Normal, + QLogNormal, + QLogUniform, + QNormal, + QUniform, + Randint, + Uniform, + ] + ] = None, + beta2: Optional[ + Union[ + float, + Choice, + LogNormal, + LogUniform, + Normal, + QLogNormal, + QLogUniform, + QNormal, + QUniform, + Randint, + Uniform, + ] + ] = None, + distributed: Optional[ + Union[ + bool, + Choice, + LogNormal, + LogUniform, + Normal, + QLogNormal, + QLogUniform, + QNormal, + QUniform, + Randint, + Uniform, + ] + ] = None, + early_stopping: Optional[ + Union[ + bool, + Choice, + LogNormal, + LogUniform, + Normal, + QLogNormal, + QLogUniform, + QNormal, + QUniform, + Randint, + Uniform, + ] + ] = None, + early_stopping_delay: Optional[ + Union[ + int, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform + ] + ] = None, + early_stopping_patience: Optional[ + Union[ + int, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform + ] + ] = None, + enable_onnx_normalization: Optional[ + Union[ + bool, + Choice, + LogNormal, + LogUniform, + Normal, + QLogNormal, + QLogUniform, + QNormal, + QUniform, + Randint, + Uniform, + ] + ] = None, + evaluation_frequency: Optional[ + Union[ + int, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform + ] + ] = None, + gradient_accumulation_step: Optional[ + Union[ + int, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform + ] + ] = None, + layers_to_freeze: Optional[ + Union[ + int, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform + ] + ] = None, + learning_rate: Optional[ + Union[ + float, + Choice, + LogNormal, + LogUniform, + Normal, + QLogNormal, + QLogUniform, + QNormal, + QUniform, + Randint, + Uniform, + ] + ] = None, + learning_rate_scheduler: Optional[ + Union[ + str, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform + ] + ] = None, + model_name: Optional[ + Union[ + str, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform + ] + ] = None, + momentum: Optional[ + Union[ + float, + Choice, + LogNormal, + LogUniform, + Normal, + QLogNormal, + QLogUniform, + QNormal, + QUniform, + Randint, + Uniform, + ] + ] = None, + nesterov: Optional[ + Union[ + bool, + Choice, + LogNormal, + LogUniform, + Normal, + QLogNormal, + QLogUniform, + QNormal, + QUniform, + Randint, + Uniform, + ] + ] = None, + number_of_epochs: Optional[ + Union[ + int, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform + ] + ] = None, + number_of_workers: Optional[ + Union[ + int, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform + ] + ] = None, + optimizer: Optional[ + Union[ + str, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform + ] + ] = None, + random_seed: Optional[ + Union[ + int, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform + ] + ] = None, + step_lr_gamma: Optional[ + Union[ + float, + Choice, + LogNormal, + LogUniform, + Normal, + QLogNormal, + QLogUniform, + QNormal, + QUniform, + Randint, + Uniform, + ] + ] = None, + step_lr_step_size: Optional[ + Union[ + int, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform + ] + ] = None, + training_batch_size: Optional[ + Union[ + int, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform + ] + ] = None, + validation_batch_size: Optional[ + Union[ + int, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform + ] + ] = None, + warmup_cosine_lr_cycles: Optional[ + Union[ + float, + Choice, + LogNormal, + LogUniform, + Normal, + QLogNormal, + QLogUniform, + QNormal, + QUniform, + Randint, + Uniform, + ] + ] = None, + warmup_cosine_lr_warmup_epochs: Optional[ + Union[ + int, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform + ] + ] = None, + weight_decay: Optional[ + Union[ + float, + Choice, + LogNormal, + LogUniform, + Normal, + QLogNormal, + QLogUniform, + QNormal, + QUniform, + Randint, + Uniform, + ] + ] = None, + box_detections_per_image: Optional[ + Union[ + int, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform + ] + ] = None, + box_score_threshold: Optional[ + Union[ + float, + Choice, + LogNormal, + LogUniform, + Normal, + QLogNormal, + QLogUniform, + QNormal, + QUniform, + Randint, + Uniform, + ] + ] = None, + image_size: Optional[ + Union[ + int, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform + ] + ] = None, + max_size: Optional[ + Union[ + int, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform + ] + ] = None, + min_size: Optional[ + Union[ + int, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform + ] + ] = None, + model_size: Optional[ + Union[ + str, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform + ] + ] = None, + multi_scale: Optional[ + Union[ + bool, + Choice, + LogNormal, + LogUniform, + Normal, + QLogNormal, + QLogUniform, + QNormal, + QUniform, + Randint, + Uniform, + ] + ] = None, + nms_iou_threshold: Optional[ + Union[ + float, + Choice, + LogNormal, + LogUniform, + Normal, + QLogNormal, + QLogUniform, + QNormal, + QUniform, + Randint, + Uniform, + ] + ] = None, + tile_grid_size: Optional[ + Union[ + str, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform + ] + ] = None, + tile_overlap_ratio: Optional[ + Union[ + float, + Choice, + LogNormal, + LogUniform, + Normal, + QLogNormal, + QLogUniform, + QNormal, + QUniform, + Randint, + Uniform, + ] + ] = None, + tile_predictions_nms_threshold: Optional[ + Union[ + float, + Choice, + LogNormal, + LogUniform, + Normal, + QLogNormal, + QLogUniform, + QNormal, + QUniform, + Randint, + Uniform, + ] + ] = None, + validation_iou_threshold: Optional[ + Union[ + float, + Choice, + LogNormal, + LogUniform, + Normal, + QLogNormal, + QLogUniform, + QNormal, + QUniform, + Randint, + Uniform, + ] + ] = None, + validation_metric_type: Optional[ + Union[ + str, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform + ] + ] = None, + ) -> None: + self.ams_gradient = ams_gradient + self.beta1 = beta1 + self.beta2 = beta2 + self.distributed = distributed + self.early_stopping = early_stopping + self.early_stopping_delay = early_stopping_delay + self.early_stopping_patience = early_stopping_patience + self.enable_onnx_normalization = enable_onnx_normalization + self.evaluation_frequency = evaluation_frequency + self.gradient_accumulation_step = gradient_accumulation_step + self.layers_to_freeze = layers_to_freeze + self.learning_rate = learning_rate + self.learning_rate_scheduler = learning_rate_scheduler + self.model_name = model_name + self.momentum = momentum + self.nesterov = nesterov + self.number_of_epochs = number_of_epochs + self.number_of_workers = number_of_workers + self.optimizer = optimizer + self.random_seed = random_seed + self.step_lr_gamma = step_lr_gamma + self.step_lr_step_size = step_lr_step_size + self.training_batch_size = training_batch_size + self.validation_batch_size = validation_batch_size + self.warmup_cosine_lr_cycles = warmup_cosine_lr_cycles + self.warmup_cosine_lr_warmup_epochs = warmup_cosine_lr_warmup_epochs + self.weight_decay = weight_decay + self.box_detections_per_image = box_detections_per_image + self.box_score_threshold = box_score_threshold + self.image_size = image_size + self.max_size = max_size + self.min_size = min_size + self.model_size = model_size + self.multi_scale = multi_scale + self.nms_iou_threshold = nms_iou_threshold + self.tile_grid_size = tile_grid_size + self.tile_overlap_ratio = tile_overlap_ratio + self.tile_predictions_nms_threshold = tile_predictions_nms_threshold + self.validation_iou_threshold = validation_iou_threshold + self.validation_metric_type = validation_metric_type + + def _to_rest_object(self) -> ImageModelDistributionSettingsObjectDetection: + return ImageModelDistributionSettingsObjectDetection( + ams_gradient=_convert_to_rest_object(self.ams_gradient) if self.ams_gradient is not None else None, + beta1=_convert_to_rest_object(self.beta1) if self.beta1 is not None else None, + beta2=_convert_to_rest_object(self.beta2) if self.beta2 is not None else None, + distributed=_convert_to_rest_object(self.distributed) if self.distributed is not None else None, + early_stopping=_convert_to_rest_object(self.early_stopping) if self.early_stopping is not None else None, + early_stopping_delay=( + _convert_to_rest_object(self.early_stopping_delay) if self.early_stopping_delay is not None else None + ), + early_stopping_patience=( + _convert_to_rest_object(self.early_stopping_patience) + if self.early_stopping_patience is not None + else None + ), + enable_onnx_normalization=( + _convert_to_rest_object(self.enable_onnx_normalization) + if self.enable_onnx_normalization is not None + else None + ), + evaluation_frequency=( + _convert_to_rest_object(self.evaluation_frequency) if self.evaluation_frequency is not None else None + ), + gradient_accumulation_step=( + _convert_to_rest_object(self.gradient_accumulation_step) + if self.gradient_accumulation_step is not None + else None + ), + layers_to_freeze=( + _convert_to_rest_object(self.layers_to_freeze) if self.layers_to_freeze is not None else None + ), + learning_rate=_convert_to_rest_object(self.learning_rate) if self.learning_rate is not None else None, + learning_rate_scheduler=( + _convert_to_rest_object(self.learning_rate_scheduler) + if self.learning_rate_scheduler is not None + else None + ), + model_name=_convert_to_rest_object(self.model_name) if self.model_name is not None else None, + momentum=_convert_to_rest_object(self.momentum) if self.momentum is not None else None, + nesterov=_convert_to_rest_object(self.nesterov) if self.nesterov is not None else None, + number_of_epochs=( + _convert_to_rest_object(self.number_of_epochs) if self.number_of_epochs is not None else None + ), + number_of_workers=( + _convert_to_rest_object(self.number_of_workers) if self.number_of_workers is not None else None + ), + optimizer=_convert_to_rest_object(self.optimizer) if self.optimizer is not None else None, + random_seed=_convert_to_rest_object(self.random_seed) if self.random_seed is not None else None, + step_lr_gamma=_convert_to_rest_object(self.step_lr_gamma) if self.step_lr_gamma is not None else None, + step_lr_step_size=( + _convert_to_rest_object(self.step_lr_step_size) if self.step_lr_step_size is not None else None + ), + training_batch_size=( + _convert_to_rest_object(self.training_batch_size) if self.training_batch_size is not None else None + ), + validation_batch_size=( + _convert_to_rest_object(self.validation_batch_size) if self.validation_batch_size is not None else None + ), + warmup_cosine_lr_cycles=( + _convert_to_rest_object(self.warmup_cosine_lr_cycles) + if self.warmup_cosine_lr_cycles is not None + else None + ), + warmup_cosine_lr_warmup_epochs=( + _convert_to_rest_object(self.warmup_cosine_lr_warmup_epochs) + if self.warmup_cosine_lr_warmup_epochs is not None + else None + ), + weight_decay=_convert_to_rest_object(self.weight_decay) if self.weight_decay is not None else None, + box_detections_per_image=( + _convert_to_rest_object(self.box_detections_per_image) + if self.box_detections_per_image is not None + else None + ), + box_score_threshold=( + _convert_to_rest_object(self.box_score_threshold) if self.box_score_threshold is not None else None + ), + image_size=_convert_to_rest_object(self.image_size) if self.image_size is not None else None, + max_size=_convert_to_rest_object(self.max_size) if self.max_size is not None else None, + min_size=_convert_to_rest_object(self.min_size) if self.min_size is not None else None, + model_size=_convert_to_rest_object(self.model_size) if self.model_size is not None else None, + multi_scale=_convert_to_rest_object(self.multi_scale) if self.multi_scale is not None else None, + nms_iou_threshold=( + _convert_to_rest_object(self.nms_iou_threshold) if self.nms_iou_threshold is not None else None + ), + tile_grid_size=_convert_to_rest_object(self.tile_grid_size) if self.tile_grid_size is not None else None, + tile_overlap_ratio=( + _convert_to_rest_object(self.tile_overlap_ratio) if self.tile_overlap_ratio is not None else None + ), + tile_predictions_nms_threshold=( + _convert_to_rest_object(self.tile_predictions_nms_threshold) + if self.tile_predictions_nms_threshold is not None + else None + ), + validation_iou_threshold=( + _convert_to_rest_object(self.validation_iou_threshold) + if self.validation_iou_threshold is not None + else None + ), + validation_metric_type=( + _convert_to_rest_object(self.validation_metric_type) + if self.validation_metric_type is not None + else None + ), + ) + + @classmethod + def _from_rest_object(cls, obj: ImageModelDistributionSettingsObjectDetection) -> "ImageObjectDetectionSearchSpace": + return cls( + ams_gradient=_convert_from_rest_object(obj.ams_gradient) if obj.ams_gradient is not None else None, + beta1=_convert_from_rest_object(obj.beta1) if obj.beta1 is not None else None, + beta2=_convert_from_rest_object(obj.beta2) if obj.beta2 is not None else None, + distributed=_convert_from_rest_object(obj.distributed) if obj.distributed is not None else None, + early_stopping=_convert_from_rest_object(obj.early_stopping) if obj.early_stopping is not None else None, + early_stopping_delay=( + _convert_from_rest_object(obj.early_stopping_delay) if obj.early_stopping_delay is not None else None + ), + early_stopping_patience=( + _convert_from_rest_object(obj.early_stopping_patience) + if obj.early_stopping_patience is not None + else None + ), + enable_onnx_normalization=( + _convert_from_rest_object(obj.enable_onnx_normalization) + if obj.enable_onnx_normalization is not None + else None + ), + evaluation_frequency=( + _convert_from_rest_object(obj.evaluation_frequency) if obj.evaluation_frequency is not None else None + ), + gradient_accumulation_step=( + _convert_from_rest_object(obj.gradient_accumulation_step) + if obj.gradient_accumulation_step is not None + else None + ), + layers_to_freeze=( + _convert_from_rest_object(obj.layers_to_freeze) if obj.layers_to_freeze is not None else None + ), + learning_rate=_convert_from_rest_object(obj.learning_rate) if obj.learning_rate is not None else None, + learning_rate_scheduler=( + _convert_from_rest_object(obj.learning_rate_scheduler) + if obj.learning_rate_scheduler is not None + else None + ), + model_name=_convert_from_rest_object(obj.model_name) if obj.model_name is not None else None, + momentum=_convert_from_rest_object(obj.momentum) if obj.momentum is not None else None, + nesterov=_convert_from_rest_object(obj.nesterov) if obj.nesterov is not None else None, + number_of_epochs=( + _convert_from_rest_object(obj.number_of_epochs) if obj.number_of_epochs is not None else None + ), + number_of_workers=( + _convert_from_rest_object(obj.number_of_workers) if obj.number_of_workers is not None else None + ), + optimizer=_convert_from_rest_object(obj.optimizer) if obj.optimizer is not None else None, + random_seed=_convert_from_rest_object(obj.random_seed) if obj.random_seed is not None else None, + step_lr_gamma=_convert_from_rest_object(obj.step_lr_gamma) if obj.step_lr_gamma is not None else None, + step_lr_step_size=( + _convert_from_rest_object(obj.step_lr_step_size) if obj.step_lr_step_size is not None else None + ), + training_batch_size=( + _convert_from_rest_object(obj.training_batch_size) if obj.training_batch_size is not None else None + ), + validation_batch_size=( + _convert_from_rest_object(obj.validation_batch_size) if obj.validation_batch_size is not None else None + ), + warmup_cosine_lr_cycles=( + _convert_from_rest_object(obj.warmup_cosine_lr_cycles) + if obj.warmup_cosine_lr_cycles is not None + else None + ), + warmup_cosine_lr_warmup_epochs=( + _convert_from_rest_object(obj.warmup_cosine_lr_warmup_epochs) + if obj.warmup_cosine_lr_warmup_epochs is not None + else None + ), + weight_decay=_convert_from_rest_object(obj.weight_decay) if obj.weight_decay is not None else None, + box_detections_per_image=( + _convert_from_rest_object(obj.box_detections_per_image) + if obj.box_detections_per_image is not None + else None + ), + box_score_threshold=( + _convert_from_rest_object(obj.box_score_threshold) if obj.box_score_threshold is not None else None + ), + image_size=_convert_from_rest_object(obj.image_size) if obj.image_size is not None else None, + max_size=_convert_from_rest_object(obj.max_size) if obj.max_size is not None else None, + min_size=_convert_from_rest_object(obj.min_size) if obj.min_size is not None else None, + model_size=_convert_from_rest_object(obj.model_size) if obj.model_size is not None else None, + multi_scale=_convert_from_rest_object(obj.multi_scale) if obj.multi_scale is not None else None, + nms_iou_threshold=( + _convert_from_rest_object(obj.nms_iou_threshold) if obj.nms_iou_threshold is not None else None + ), + tile_grid_size=_convert_from_rest_object(obj.tile_grid_size) if obj.tile_grid_size is not None else None, + tile_overlap_ratio=( + _convert_from_rest_object(obj.tile_overlap_ratio) if obj.tile_overlap_ratio is not None else None + ), + tile_predictions_nms_threshold=( + _convert_from_rest_object(obj.tile_predictions_nms_threshold) + if obj.tile_predictions_nms_threshold is not None + else None + ), + validation_iou_threshold=( + _convert_from_rest_object(obj.validation_iou_threshold) + if obj.validation_iou_threshold is not None + else None + ), + validation_metric_type=( + _convert_from_rest_object(obj.validation_metric_type) + if obj.validation_metric_type is not None + else None + ), + ) + + @classmethod + def _from_search_space_object(cls, obj: SearchSpace) -> "ImageObjectDetectionSearchSpace": + return cls( + ams_gradient=obj.ams_gradient if hasattr(obj, "ams_gradient") else None, + beta1=obj.beta1 if hasattr(obj, "beta1") else None, + beta2=obj.beta2 if hasattr(obj, "beta2") else None, + distributed=obj.distributed if hasattr(obj, "distributed") else None, + early_stopping=obj.early_stopping if hasattr(obj, "early_stopping") else None, + early_stopping_delay=obj.early_stopping_delay if hasattr(obj, "early_stopping_delay") else None, + early_stopping_patience=obj.early_stopping_patience if hasattr(obj, "early_stopping_patience") else None, + enable_onnx_normalization=( + obj.enable_onnx_normalization if hasattr(obj, "enable_onnx_normalization") else None + ), + evaluation_frequency=obj.evaluation_frequency if hasattr(obj, "evaluation_frequency") else None, + gradient_accumulation_step=( + obj.gradient_accumulation_step if hasattr(obj, "gradient_accumulation_step") else None + ), + layers_to_freeze=obj.layers_to_freeze if hasattr(obj, "layers_to_freeze") else None, + learning_rate=obj.learning_rate if hasattr(obj, "learning_rate") else None, + learning_rate_scheduler=obj.learning_rate_scheduler if hasattr(obj, "learning_rate_scheduler") else None, + model_name=obj.model_name if hasattr(obj, "model_name") else None, + momentum=obj.momentum if hasattr(obj, "momentum") else None, + nesterov=obj.nesterov if hasattr(obj, "nesterov") else None, + number_of_epochs=obj.number_of_epochs if hasattr(obj, "number_of_epochs") else None, + number_of_workers=obj.number_of_workers if hasattr(obj, "number_of_workers") else None, + optimizer=obj.optimizer if hasattr(obj, "optimizer") else None, + random_seed=obj.random_seed if hasattr(obj, "random_seed") else None, + step_lr_gamma=obj.step_lr_gamma if hasattr(obj, "step_lr_gamma") else None, + step_lr_step_size=obj.step_lr_step_size if hasattr(obj, "step_lr_step_size") else None, + training_batch_size=obj.training_batch_size if hasattr(obj, "training_batch_size") else None, + validation_batch_size=obj.validation_batch_size if hasattr(obj, "validation_batch_size") else None, + warmup_cosine_lr_cycles=obj.warmup_cosine_lr_cycles if hasattr(obj, "warmup_cosine_lr_cycles") else None, + warmup_cosine_lr_warmup_epochs=( + obj.warmup_cosine_lr_warmup_epochs if hasattr(obj, "warmup_cosine_lr_warmup_epochs") else None + ), + weight_decay=obj.weight_decay if hasattr(obj, "weight_decay") else None, + box_detections_per_image=obj.box_detections_per_image if hasattr(obj, "box_detections_per_image") else None, + box_score_threshold=obj.box_score_threshold if hasattr(obj, "box_score_threshold") else None, + image_size=obj.image_size if hasattr(obj, "image_size") else None, + max_size=obj.max_size if hasattr(obj, "max_size") else None, + min_size=obj.min_size if hasattr(obj, "min_size") else None, + model_size=obj.model_size if hasattr(obj, "model_size") else None, + multi_scale=obj.multi_scale if hasattr(obj, "multi_scale") else None, + nms_iou_threshold=obj.nms_iou_threshold if hasattr(obj, "nms_iou_threshold") else None, + tile_grid_size=obj.tile_grid_size if hasattr(obj, "tile_grid_size") else None, + tile_overlap_ratio=obj.tile_overlap_ratio if hasattr(obj, "tile_overlap_ratio") else None, + tile_predictions_nms_threshold=( + obj.tile_predictions_nms_threshold if hasattr(obj, "tile_predictions_nms_threshold") else None + ), + validation_iou_threshold=obj.validation_iou_threshold if hasattr(obj, "validation_iou_threshold") else None, + validation_metric_type=obj.validation_metric_type if hasattr(obj, "validation_metric_type") else None, + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ImageObjectDetectionSearchSpace): + return NotImplemented + + return ( + self.ams_gradient == other.ams_gradient + and self.beta1 == other.beta1 + and self.beta2 == other.beta2 + and self.distributed == other.distributed + and self.early_stopping == other.early_stopping + and self.early_stopping_delay == other.early_stopping_delay + and self.early_stopping_patience == other.early_stopping_patience + and self.enable_onnx_normalization == other.enable_onnx_normalization + and self.evaluation_frequency == other.evaluation_frequency + and self.gradient_accumulation_step == other.gradient_accumulation_step + and self.layers_to_freeze == other.layers_to_freeze + and self.learning_rate == other.learning_rate + and self.learning_rate_scheduler == other.learning_rate_scheduler + and self.model_name == other.model_name + and self.momentum == other.momentum + and self.nesterov == other.nesterov + and self.number_of_epochs == other.number_of_epochs + and self.number_of_workers == other.number_of_workers + and self.optimizer == other.optimizer + and self.random_seed == other.random_seed + and self.step_lr_gamma == other.step_lr_gamma + and self.step_lr_step_size == other.step_lr_step_size + and self.training_batch_size == other.training_batch_size + and self.validation_batch_size == other.validation_batch_size + and self.warmup_cosine_lr_cycles == other.warmup_cosine_lr_cycles + and self.warmup_cosine_lr_warmup_epochs == other.warmup_cosine_lr_warmup_epochs + and self.weight_decay == other.weight_decay + and self.box_detections_per_image == other.box_detections_per_image + and self.box_score_threshold == other.box_score_threshold + and self.image_size == other.image_size + and self.max_size == other.max_size + and self.min_size == other.min_size + and self.model_size == other.model_size + and self.multi_scale == other.multi_scale + and self.nms_iou_threshold == other.nms_iou_threshold + and self.tile_grid_size == other.tile_grid_size + and self.tile_overlap_ratio == other.tile_overlap_ratio + and self.tile_predictions_nms_threshold == other.tile_predictions_nms_threshold + and self.validation_iou_threshold == other.validation_iou_threshold + and self.validation_metric_type == other.validation_metric_type + ) + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_sweep_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_sweep_settings.py new file mode 100644 index 00000000..b5e9ffaf --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_sweep_settings.py @@ -0,0 +1,86 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from typing import Optional, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import ImageSweepSettings as RestImageSweepSettings +from azure.ai.ml._restclient.v2023_04_01_preview.models import SamplingAlgorithmType +from azure.ai.ml.entities._job.sweep.early_termination_policy import ( + BanditPolicy, + EarlyTerminationPolicy, + MedianStoppingPolicy, + TruncationSelectionPolicy, +) +from azure.ai.ml.entities._mixins import RestTranslatableMixin + + +class ImageSweepSettings(RestTranslatableMixin): + """Sweep settings for all AutoML Image Verticals. + + :keyword sampling_algorithm: Required. Type of the hyperparameter sampling. + algorithms. Possible values include: "Grid", "Random", "Bayesian". + :paramtype sampling_algorithm: Union[ + str, + ~azure.mgmt.machinelearningservices.models.SamplingAlgorithmType.GRID, + ~azure.mgmt.machinelearningservices.models.SamplingAlgorithmType.BAYESIAN, + ~azure.mgmt.machinelearningservices.models.SamplingAlgorithmType.RANDOM + + ] + :keyword early_termination: Type of early termination policy. + :paramtype early_termination: Union[ + + ~azure.mgmt.machinelearningservices.models.BanditPolicy, + ~azure.mgmt.machinelearningservices.models.MedianStoppingPolicy, + ~azure.mgmt.machinelearningservices.models.TruncationSelectionPolicy + + ] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_automl_image.py + :start-after: [START automl.automl_image_job.image_sweep_settings] + :end-before: [END automl.automl_image_job.image_sweep_settings] + :language: python + :dedent: 8 + :caption: Defining the sweep settings for an automl image job. + """ + + def __init__( + self, + *, + sampling_algorithm: Union[ + str, SamplingAlgorithmType.GRID, SamplingAlgorithmType.BAYESIAN, SamplingAlgorithmType.RANDOM + ], + early_termination: Optional[ + Union[EarlyTerminationPolicy, BanditPolicy, MedianStoppingPolicy, TruncationSelectionPolicy] + ] = None, + ): + self.sampling_algorithm = sampling_algorithm + self.early_termination = early_termination + + def _to_rest_object(self) -> RestImageSweepSettings: + return RestImageSweepSettings( + sampling_algorithm=self.sampling_algorithm, + early_termination=self.early_termination._to_rest_object() if self.early_termination else None, + ) + + @classmethod + def _from_rest_object(cls, obj: RestImageSweepSettings) -> "ImageSweepSettings": + return cls( + sampling_algorithm=obj.sampling_algorithm, + early_termination=( + EarlyTerminationPolicy._from_rest_object(obj.early_termination) if obj.early_termination else None + ), + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ImageSweepSettings): + return NotImplemented + + return self.sampling_algorithm == other.sampling_algorithm and self.early_termination == other.early_termination + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/__init__.py new file mode 100644 index 00000000..9be7b483 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/__init__.py @@ -0,0 +1,25 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from .automl_nlp_job import AutoMLNLPJob +from .nlp_featurization_settings import NlpFeaturizationSettings +from .nlp_fixed_parameters import NlpFixedParameters +from .nlp_limit_settings import NlpLimitSettings +from .nlp_search_space import NlpSearchSpace +from .nlp_sweep_settings import NlpSweepSettings +from .text_classification_job import TextClassificationJob +from .text_classification_multilabel_job import TextClassificationMultilabelJob +from .text_ner_job import TextNerJob + +__all__ = [ + "AutoMLNLPJob", + "NlpFeaturizationSettings", + "NlpFixedParameters", + "NlpLimitSettings", + "NlpSearchSpace", + "NlpSweepSettings", + "TextClassificationJob", + "TextClassificationMultilabelJob", + "TextNerJob", +] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/automl_nlp_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/automl_nlp_job.py new file mode 100644 index 00000000..f0b3baa8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/automl_nlp_job.py @@ -0,0 +1,467 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from abc import ABC +from typing import Any, Dict, List, Optional, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + LogVerbosity, + NlpLearningRateScheduler, + SamplingAlgorithmType, +) +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.entities._inputs_outputs import Input +from azure.ai.ml.entities._job.automl.automl_vertical import AutoMLVertical +from azure.ai.ml.entities._job.automl.nlp.nlp_featurization_settings import NlpFeaturizationSettings +from azure.ai.ml.entities._job.automl.nlp.nlp_fixed_parameters import NlpFixedParameters +from azure.ai.ml.entities._job.automl.nlp.nlp_limit_settings import NlpLimitSettings +from azure.ai.ml.entities._job.automl.nlp.nlp_search_space import NlpSearchSpace +from azure.ai.ml.entities._job.automl.nlp.nlp_sweep_settings import NlpSweepSettings +from azure.ai.ml.entities._job.automl.search_space import SearchSpace +from azure.ai.ml.entities._job.automl.utils import cast_to_specific_search_space +from azure.ai.ml.entities._job.sweep.early_termination_policy import EarlyTerminationPolicy +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException + + +# pylint: disable=too-many-instance-attributes,protected-access +class AutoMLNLPJob(AutoMLVertical, ABC): + """Base class for AutoML NLP jobs. + + You should not instantiate this class directly. Instead you should + create classes for specific NLP Jobs. + + :param task_type: NLP task type, must be one of 'TextClassification', + 'TextClassificationMultilabel', or 'TextNER' + :type task_type: str + :param primary_metric: Primary metric to display from NLP job + :type primary_metric: str + :param training_data: Training data + :type training_data: Input + :param validation_data: Validation data + :type validation_data: Input + :param target_column_name: Column name of the target column, defaults to None + :type target_column_name: Optional[str] + :param log_verbosity: The degree of verbosity used in logging, defaults to None, + must be one of 'NotSet', 'Debug', 'Info', 'Warning', 'Error', 'Critical', or None + :type log_verbosity: Optional[str] + :param featurization: Featurization settings used for NLP job, defaults to None + :type featurization: Optional[~azure.ai.ml.automl.NlpFeaturizationSettings] + :param limits: Limit settings for NLP jobs, defaults to None + :type limits: Optional[~azure.ai.ml.automl.NlpLimitSettings] + :param sweep: Sweep settings used for NLP job, defaults to None + :type sweep: Optional[~azure.ai.ml.automl.NlpSweepSettings] + :param training_parameters: Fixed parameters for the training of all candidates. + , defaults to None + :type training_parameters: Optional[~azure.ai.ml.automl.NlpFixedParameters] + :param search_space: Search space(s) to sweep over for NLP sweep jobs, defaults to None + :type search_space: Optional[List[~azure.ai.ml.automl.NlpSearchSpace]] + """ + + def __init__( + self, + *, + task_type: str, + primary_metric: str, + training_data: Optional[Input], + validation_data: Optional[Input], + target_column_name: Optional[str] = None, + log_verbosity: Optional[str] = None, + featurization: Optional[NlpFeaturizationSettings] = None, + limits: Optional[NlpLimitSettings] = None, + sweep: Optional[NlpSweepSettings] = None, + training_parameters: Optional[NlpFixedParameters] = None, + search_space: Optional[List[NlpSearchSpace]] = None, + **kwargs: Any, + ): + self._training_parameters: Optional[NlpFixedParameters] = None + + super().__init__( + task_type, training_data=training_data, validation_data=validation_data, **kwargs # type: ignore + ) + self.log_verbosity = log_verbosity + self._primary_metric: str = "" + self.primary_metric = primary_metric + + self.target_column_name = target_column_name + + self._featurization = featurization + self._limits = limits or NlpLimitSettings() + self._sweep = sweep + self.training_parameters = training_parameters # via setter method. + self._search_space = search_space + + @property + def training_parameters(self) -> Optional[NlpFixedParameters]: + """Parameters that are used for all submitted jobs. + + :return: fixed training parameters for NLP jobs + :rtype: ~azure.ai.ml.automl.NlpFixedParameters + """ + return self._training_parameters + + @training_parameters.setter + def training_parameters(self, value: Union[Dict, NlpFixedParameters]) -> None: + if value is None: + self._training_parameters = None + elif isinstance(value, NlpFixedParameters): + self._training_parameters = value + # Convert parameters from snake case to enum. + self.set_training_parameters(learning_rate_scheduler=value.learning_rate_scheduler) + else: + if not isinstance(value, dict): + msg = "Expected a dictionary for nlp training parameters." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.AUTOML, + error_category=ErrorCategory.USER_ERROR, + ) + self.set_training_parameters(**value) + + @property + def search_space(self) -> Optional[List[NlpSearchSpace]]: + """Search space(s) to sweep over for NLP sweep jobs + + :return: list of search spaces to sweep over for NLP jobs + :rtype: List[~azure.ai.ml.automl.NlpSearchSpace] + """ + return self._search_space + + @search_space.setter + def search_space(self, value: Union[List[dict], List[SearchSpace]]) -> None: + if not isinstance(value, list): + msg = "Expected a list for search space." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.AUTOML, + error_category=ErrorCategory.USER_ERROR, + ) + + all_dict_type = all(isinstance(item, dict) for item in value) + all_search_space_type = all(isinstance(item, SearchSpace) for item in value) + + if not (all_search_space_type or all_dict_type): + msg = "Expected all items in the list to be either dictionaries or SearchSpace objects." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.AUTOML, + error_category=ErrorCategory.USER_ERROR, + ) + + self._search_space = [ + cast_to_specific_search_space(item, NlpSearchSpace, self.task_type) for item in value # type: ignore + ] + + @property + def primary_metric(self) -> str: + """Primary metric to display from NLP job + + :return: primary metric to display + :rtype: str + """ + return self._primary_metric + + @primary_metric.setter + def primary_metric(self, value: str) -> None: + self._primary_metric = value + + @property + def log_verbosity(self) -> LogVerbosity: + """Log verbosity configuration + + :return: the degree of verbosity used in logging + :rtype: ~azure.mgmt.machinelearningservices.models.LogVerbosity + """ + return self._log_verbosity + + @log_verbosity.setter + def log_verbosity(self, value: Union[str, LogVerbosity]) -> None: + self._log_verbosity = None if value is None else LogVerbosity[camel_to_snake(value).upper()] + + @property + def limits(self) -> NlpLimitSettings: + """Limit settings for NLP jobs + + :return: limit configuration for NLP job + :rtype: ~azure.ai.ml.automl.NlpLimitSettings + """ + return self._limits + + @limits.setter + def limits(self, value: Union[Dict, NlpLimitSettings]) -> None: + if isinstance(value, NlpLimitSettings): + self._limits = value + else: + if not isinstance(value, dict): + msg = "Expected a dictionary for limit settings." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.AUTOML, + error_category=ErrorCategory.USER_ERROR, + ) + self.set_limits(**value) + + @property + def sweep(self) -> Optional[NlpSweepSettings]: + """Sweep settings used for NLP job + + :return: sweep settings + :rtype: ~azure.ai.ml.automl.NlpSweepSettings + """ + return self._sweep + + @sweep.setter + def sweep(self, value: Union[Dict, NlpSweepSettings]) -> None: + if isinstance(value, NlpSweepSettings): + self._sweep = value + else: + if not isinstance(value, dict): + msg = "Expected a dictionary for sweep settings." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.AUTOML, + error_category=ErrorCategory.USER_ERROR, + ) + self.set_sweep(**value) + + @property + def featurization(self) -> Optional[NlpFeaturizationSettings]: + """Featurization settings used for NLP job + + :return: featurization settings + :rtype: ~azure.ai.ml.automl.NlpFeaturizationSettings + """ + return self._featurization + + @featurization.setter + def featurization(self, value: Union[Dict, NlpFeaturizationSettings]) -> None: + if isinstance(value, NlpFeaturizationSettings): + self._featurization = value + else: + if not isinstance(value, dict): + msg = "Expected a dictionary for featurization settings." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.AUTOML, + error_category=ErrorCategory.USER_ERROR, + ) + self.set_featurization(**value) + + def set_data(self, *, training_data: Input, target_column_name: str, validation_data: Input) -> None: + """Define data configuration for NLP job + + :keyword training_data: Training data + :type training_data: ~azure.ai.ml.Input + :keyword target_column_name: Column name of the target column. + :type target_column_name: str + :keyword validation_data: Validation data + :type validation_data: ~azure.ai.ml.Input + """ + # Properties for NlpVerticalDataSettings + self.target_column_name = target_column_name + self.training_data = training_data + self.validation_data = validation_data + + def set_limits( + self, + *, + max_trials: int = 1, + max_concurrent_trials: int = 1, + max_nodes: int = 1, + timeout_minutes: Optional[int] = None, + trial_timeout_minutes: Optional[int] = None, + ) -> None: + """Define limit configuration for AutoML NLP job + + :keyword max_trials: Maximum number of AutoML iterations, defaults to 1 + :type max_trials: int, optional + :keyword max_concurrent_trials: Maximum number of concurrent AutoML iterations, defaults to 1 + :type max_concurrent_trials: int, optional + :keyword max_nodes: Maximum number of nodes used for sweep, defaults to 1 + :type max_nodes: int, optional + :keyword timeout_minutes: Timeout for the AutoML job, defaults to None + :type timeout_minutes: Optional[int] + :keyword trial_timeout_minutes: Timeout for each AutoML trial, defaults to None + :type trial_timeout_minutes: Optional[int] + """ + self._limits = NlpLimitSettings( + max_trials=max_trials, + max_concurrent_trials=max_concurrent_trials, + max_nodes=max_nodes, + timeout_minutes=timeout_minutes, + trial_timeout_minutes=trial_timeout_minutes, + ) + + def set_sweep( + self, + *, + sampling_algorithm: Union[str, SamplingAlgorithmType], + early_termination: Optional[EarlyTerminationPolicy] = None, + ) -> None: + """Define sweep configuration for AutoML NLP job + + :keyword sampling_algorithm: Required. Specifies type of hyperparameter sampling algorithm. + Possible values include: "Grid", "Random", and "Bayesian". + :type sampling_algorithm: Union[str, ~azure.ai.ml.automl.SamplingAlgorithmType] + :keyword early_termination: Optional. early termination policy to end poorly performing training candidates, + defaults to None. + :type early_termination: Optional[~azure.mgmt.machinelearningservices.models.EarlyTerminationPolicy] + """ + if self._sweep: + self._sweep.sampling_algorithm = sampling_algorithm + else: + self._sweep = NlpSweepSettings(sampling_algorithm=sampling_algorithm) + + self._sweep.early_termination = early_termination or self._sweep.early_termination + + def set_training_parameters( + self, + *, + gradient_accumulation_steps: Optional[int] = None, + learning_rate: Optional[float] = None, + learning_rate_scheduler: Optional[Union[str, NlpLearningRateScheduler]] = None, + model_name: Optional[str] = None, + number_of_epochs: Optional[int] = None, + training_batch_size: Optional[int] = None, + validation_batch_size: Optional[int] = None, + warmup_ratio: Optional[float] = None, + weight_decay: Optional[float] = None, + ) -> None: + """Fix certain training parameters throughout the training procedure for all candidates. + + :keyword gradient_accumulation_steps: number of steps over which to accumulate gradients before a backward + pass. This must be a positive integer., defaults to None + :type gradient_accumulation_steps: Optional[int] + :keyword learning_rate: initial learning rate. Must be a float in (0, 1)., defaults to None + :type learning_rate: Optional[float] + :keyword learning_rate_scheduler: the type of learning rate scheduler. Must choose from 'linear', 'cosine', + 'cosine_with_restarts', 'polynomial', 'constant', and 'constant_with_warmup'., defaults to None + :type learning_rate_scheduler: Optional[Union[str, ~azure.ai.ml.automl.NlpLearningRateScheduler]] + :keyword model_name: the model name to use during training. Must choose from 'bert-base-cased', + 'bert-base-uncased', 'bert-base-multilingual-cased', 'bert-base-german-cased', 'bert-large-cased', + 'bert-large-uncased', 'distilbert-base-cased', 'distilbert-base-uncased', 'roberta-base', 'roberta-large', + 'distilroberta-base', 'xlm-roberta-base', 'xlm-roberta-large', xlnet-base-cased', and 'xlnet-large-cased'., + defaults to None + :type model_name: Optional[str] + :keyword number_of_epochs: the number of epochs to train with. Must be a positive integer., defaults to None + :type number_of_epochs: Optional[int] + :keyword training_batch_size: the batch size during training. Must be a positive integer., defaults to None + :type training_batch_size: Optional[int] + :keyword validation_batch_size: the batch size during validation. Must be a positive integer., defaults to None + :type validation_batch_size: Optional[int] + :keyword warmup_ratio: ratio of total training steps used for a linear warmup from 0 to learning_rate. + Must be a float in [0, 1]., defaults to None + :type warmup_ratio: Optional[float] + :keyword weight_decay: value of weight decay when optimizer is sgd, adam, or adamw. This must be a float in + the range [0, 1]., defaults to None + :type weight_decay: Optional[float] + """ + self._training_parameters = self._training_parameters or NlpFixedParameters() + + self._training_parameters.gradient_accumulation_steps = ( + gradient_accumulation_steps + if gradient_accumulation_steps is not None + else self._training_parameters.gradient_accumulation_steps + ) + + self._training_parameters.learning_rate = ( + learning_rate if learning_rate is not None else self._training_parameters.learning_rate + ) + + self._training_parameters.learning_rate_scheduler = ( + NlpLearningRateScheduler[camel_to_snake(learning_rate_scheduler).upper()] + if learning_rate_scheduler is not None + else self._training_parameters.learning_rate_scheduler + ) + + self._training_parameters.model_name = ( + model_name if model_name is not None else self._training_parameters.model_name + ) + + self._training_parameters.number_of_epochs = ( + number_of_epochs if number_of_epochs is not None else self._training_parameters.number_of_epochs + ) + + self._training_parameters.training_batch_size = ( + training_batch_size if training_batch_size is not None else self._training_parameters.training_batch_size + ) + + self._training_parameters.validation_batch_size = ( + validation_batch_size + if validation_batch_size is not None + else self._training_parameters.validation_batch_size + ) + + self._training_parameters.warmup_ratio = ( + warmup_ratio if warmup_ratio is not None else self._training_parameters.warmup_ratio + ) + + self._training_parameters.weight_decay = ( + weight_decay if weight_decay is not None else self._training_parameters.weight_decay + ) + + def set_featurization(self, *, dataset_language: Optional[str] = None) -> None: + """Define featurization configuration for AutoML NLP job. + + :keyword dataset_language: Language of the dataset, defaults to None + :type dataset_language: Optional[str] + """ + self._featurization = NlpFeaturizationSettings( + dataset_language=dataset_language, + ) + + def extend_search_space(self, value: Union[SearchSpace, List[SearchSpace]]) -> None: + """Add (a) search space(s) for an AutoML NLP job. + + :param value: either a SearchSpace object or a list of SearchSpace objects with nlp-specific parameters. + :type value: Union[~azure.ai.ml.automl.SearchSpace, List[~azure.ai.ml.automl.SearchSpace]] + """ + self._search_space = self._search_space or [] + if isinstance(value, list): + self._search_space.extend( + [cast_to_specific_search_space(item, NlpSearchSpace, self.task_type) for item in value] # type: ignore + ) + else: + self._search_space.append( + cast_to_specific_search_space(value, NlpSearchSpace, self.task_type) # type: ignore + ) + + @classmethod + def _get_search_space_from_str(cls, search_space_str: Optional[str]) -> Optional[List]: + if search_space_str is not None: + return [NlpSearchSpace._from_rest_object(entry) for entry in search_space_str if entry is not None] + return None + + def _restore_data_inputs(self) -> None: + """Restore MLTableJobInputs to Inputs within data_settings. + + self.training_data and self.validation_data should reflect what user passed in (Input) Once we get response back + from service (as MLTableJobInput), we should set responsible ones back to Input + """ + super()._restore_data_inputs() + self.training_data = self.training_data if self.training_data else None # type: ignore + self.validation_data = self.validation_data if self.validation_data else None # type: ignore + + def __eq__(self, other: object) -> bool: + if not isinstance(other, AutoMLNLPJob): + return NotImplemented + + return ( + self.primary_metric == other.primary_metric + and self.log_verbosity == other.log_verbosity + and self.training_data == other.training_data + and self.validation_data == other.validation_data + and self._featurization == other._featurization + and self._limits == other._limits + and self._sweep == other._sweep + and self._training_parameters == other._training_parameters + and self._search_space == other._search_space + ) + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/nlp_featurization_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/nlp_featurization_settings.py new file mode 100644 index 00000000..5649dea2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/nlp_featurization_settings.py @@ -0,0 +1,47 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + NlpVerticalFeaturizationSettings as RestNlpVerticalFeaturizationSettings, +) +from azure.ai.ml.entities._job.automl.featurization_settings import FeaturizationSettings, FeaturizationSettingsType + + +class NlpFeaturizationSettings(FeaturizationSettings): + """Featurization settings for all AutoML NLP Verticals. + + :ivar type: Specifies the type of FeaturizationSettings. Set automatically to "NLP" for this class. + :vartype type: str + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_automl_nlp.py + :start-after: [START automl.nlp_featurization_settings] + :end-before: [END automl.nlp_featurization_settings] + :language: python + :dedent: 8 + :caption: creating an nlp featurization settings + """ + + type = FeaturizationSettingsType.NLP + + def _to_rest_object(self) -> RestNlpVerticalFeaturizationSettings: + return RestNlpVerticalFeaturizationSettings( + dataset_language=self.dataset_language, + ) + + @classmethod + def _from_rest_object(cls, obj: RestNlpVerticalFeaturizationSettings) -> "NlpFeaturizationSettings": + return cls( + dataset_language=obj.dataset_language, + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, NlpFeaturizationSettings): + return NotImplemented + + return super().__eq__(other) + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/nlp_fixed_parameters.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/nlp_fixed_parameters.py new file mode 100644 index 00000000..13c594b6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/nlp_fixed_parameters.py @@ -0,0 +1,117 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from typing import Optional + +from azure.ai.ml._restclient.v2023_04_01_preview.models import NlpFixedParameters as RestNlpFixedParameters +from azure.ai.ml.entities._mixins import RestTranslatableMixin + + +class NlpFixedParameters(RestTranslatableMixin): + """Configuration of fixed parameters for all candidates of an AutoML NLP Job + + :param gradient_accumulation_steps: number of steps over which to accumulate gradients before a backward + pass. This must be a positive integer, defaults to None + :type gradient_accumulation_steps: Optional[int] + :param learning_rate: initial learning rate. Must be a float in (0, 1), defaults to None + :type learning_rate: Optional[float] + :param learning_rate_scheduler: the type of learning rate scheduler. Must choose from 'linear', 'cosine', + 'cosine_with_restarts', 'polynomial', 'constant', and 'constant_with_warmup', defaults to None + :type learning_rate_scheduler: Optional[str] + :param model_name: the model name to use during training. Must choose from 'bert-base-cased', + 'bert-base-uncased', 'bert-base-multilingual-cased', 'bert-base-german-cased', 'bert-large-cased', + 'bert-large-uncased', 'distilbert-base-cased', 'distilbert-base-uncased', 'roberta-base', 'roberta-large', + 'distilroberta-base', 'xlm-roberta-base', 'xlm-roberta-large', xlnet-base-cased', and 'xlnet-large-cased', + defaults to None + :type model_name: Optional[str] + :param number_of_epochs: the number of epochs to train with. Must be a positive integer, defaults to None + :type number_of_epochs: Optional[int] + :param training_batch_size: the batch size during training. Must be a positive integer, defaults to None + :type training_batch_size: Optional[int] + :param validation_batch_size: the batch size during validation. Must be a positive integer, defaults to None + :type validation_batch_size: Optional[int] + :param warmup_ratio: ratio of total training steps used for a linear warmup from 0 to learning_rate. + Must be a float in [0, 1], defaults to None + :type warmup_ratio: Optional[float] + :param weight_decay: value of weight decay when optimizer is sgd, adam, or adamw. This must be a float in + the range [0, 1] defaults to None + :type weight_decay: Optional[float] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_automl_nlp.py + :start-after: [START automl.nlp_fixed_parameters] + :end-before: [END automl.nlp_fixed_parameters] + :language: python + :dedent: 8 + :caption: creating an nlp fixed parameters + """ + + def __init__( + self, + *, + gradient_accumulation_steps: Optional[int] = None, + learning_rate: Optional[float] = None, + learning_rate_scheduler: Optional[str] = None, + model_name: Optional[str] = None, + number_of_epochs: Optional[int] = None, + training_batch_size: Optional[int] = None, + validation_batch_size: Optional[int] = None, + warmup_ratio: Optional[float] = None, + weight_decay: Optional[float] = None, + ): + self.gradient_accumulation_steps = gradient_accumulation_steps + self.learning_rate = learning_rate + self.learning_rate_scheduler = learning_rate_scheduler + self.model_name = model_name + self.number_of_epochs = number_of_epochs + self.training_batch_size = training_batch_size + self.validation_batch_size = validation_batch_size + self.warmup_ratio = warmup_ratio + self.weight_decay = weight_decay + + def _to_rest_object(self) -> RestNlpFixedParameters: + return RestNlpFixedParameters( + gradient_accumulation_steps=self.gradient_accumulation_steps, + learning_rate=self.learning_rate, + learning_rate_scheduler=self.learning_rate_scheduler, + model_name=self.model_name, + number_of_epochs=self.number_of_epochs, + training_batch_size=self.training_batch_size, + validation_batch_size=self.validation_batch_size, + warmup_ratio=self.warmup_ratio, + weight_decay=self.weight_decay, + ) + + @classmethod + def _from_rest_object(cls, obj: RestNlpFixedParameters) -> "NlpFixedParameters": + return cls( + gradient_accumulation_steps=obj.gradient_accumulation_steps, + learning_rate=obj.learning_rate, + learning_rate_scheduler=obj.learning_rate_scheduler, + model_name=obj.model_name, + number_of_epochs=obj.number_of_epochs, + training_batch_size=obj.training_batch_size, + validation_batch_size=obj.validation_batch_size, + warmup_ratio=obj.warmup_ratio, + weight_decay=obj.weight_decay, + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, NlpFixedParameters): + return NotImplemented + + return ( + self.gradient_accumulation_steps == other.gradient_accumulation_steps + and self.learning_rate == other.learning_rate + and self.learning_rate_scheduler == other.learning_rate_scheduler + and self.model_name == other.model_name + and self.number_of_epochs == other.number_of_epochs + and self.training_batch_size == other.training_batch_size + and self.validation_batch_size == other.validation_batch_size + and self.warmup_ratio == other.warmup_ratio + and self.weight_decay == other.weight_decay + ) + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/nlp_limit_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/nlp_limit_settings.py new file mode 100644 index 00000000..1e99f4f0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/nlp_limit_settings.py @@ -0,0 +1,79 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Optional + +from azure.ai.ml._restclient.v2023_04_01_preview.models import NlpVerticalLimitSettings as RestNlpLimitSettings +from azure.ai.ml._utils.utils import from_iso_duration_format_mins, to_iso_duration_format_mins +from azure.ai.ml.entities._mixins import RestTranslatableMixin + + +class NlpLimitSettings(RestTranslatableMixin): + """Limit settings for all AutoML NLP Verticals. + + :param max_concurrent_trials: Maximum number of concurrent AutoML iterations. + :type max_concurrent_trials: int + :param max_trials: Maximum number of AutoML iterations. + :type max_trials: int + :param timeout_minutes: AutoML job timeout. + :type timeout_minutes: int + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_automl_nlp.py + :start-after: [START automl.nlp_limit_settings] + :end-before: [END automl.nlp_limit_settings] + :language: python + :dedent: 8 + :caption: creating an nlp limit settings + """ + + def __init__( + self, + *, + max_concurrent_trials: Optional[int] = None, + max_trials: int = 1, + max_nodes: int = 1, + timeout_minutes: Optional[int] = None, + trial_timeout_minutes: Optional[int] = None, + ): + self.max_concurrent_trials = max_concurrent_trials + self.max_trials = max_trials + self.max_nodes = max_nodes + self.timeout_minutes = timeout_minutes + self.trial_timeout_minutes = trial_timeout_minutes + + def _to_rest_object(self) -> RestNlpLimitSettings: + return RestNlpLimitSettings( + max_concurrent_trials=self.max_concurrent_trials, + max_trials=self.max_trials, + max_nodes=self.max_nodes, + timeout=to_iso_duration_format_mins(self.timeout_minutes), + trial_timeout=to_iso_duration_format_mins(self.trial_timeout_minutes), + ) + + @classmethod + def _from_rest_object(cls, obj: RestNlpLimitSettings) -> "NlpLimitSettings": + return cls( + max_concurrent_trials=obj.max_concurrent_trials, + max_trials=obj.max_trials, + max_nodes=obj.max_nodes, + timeout_minutes=from_iso_duration_format_mins(obj.timeout), + trial_timeout_minutes=from_iso_duration_format_mins(obj.trial_timeout), + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, NlpLimitSettings): + return NotImplemented + + return ( + self.max_concurrent_trials == other.max_concurrent_trials + and self.max_trials == other.max_trials + and self.max_nodes == other.max_nodes + and self.timeout_minutes == other.timeout_minutes + and self.trial_timeout_minutes == other.trial_timeout_minutes + ) + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/nlp_search_space.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/nlp_search_space.py new file mode 100644 index 00000000..e4ad435f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/nlp_search_space.py @@ -0,0 +1,185 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Optional, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import NlpLearningRateScheduler, NlpParameterSubspace +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.constants import NlpModels +from azure.ai.ml.entities._job.automl.search_space import SearchSpace +from azure.ai.ml.entities._job.automl.search_space_utils import _convert_from_rest_object, _convert_to_rest_object +from azure.ai.ml.entities._job.sweep.search_space import Choice, SweepDistribution +from azure.ai.ml.entities._mixins import RestTranslatableMixin + + +class NlpSearchSpace(RestTranslatableMixin): + """Search space for AutoML NLP tasks. + + :param gradient_accumulation_steps: number of steps over which to accumulate gradients before a backward + pass. This must be a positive integer., defaults to None + :type gradient_accumulation_steps: Optional[Union[int, SweepDistribution]] + :param learning_rate: initial learning rate. Must be a float in (0, 1), defaults to None + :type learning_rate: Optional[Union[float, SweepDistribution]] + :param learning_rate_scheduler: the type of learning rate scheduler. Must choose from 'linear', 'cosine', + 'cosine_with_restarts', 'polynomial', 'constant', and 'constant_with_warmup', defaults to None + :type learning_rate_scheduler: Optional[Union[str, SweepDistribution]] + :param model_name: the model name to use during training. Must choose from 'bert-base-cased', + 'bert-base-uncased', 'bert-base-multilingual-cased', 'bert-base-german-cased', 'bert-large-cased', + 'bert-large-uncased', 'distilbert-base-cased', 'distilbert-base-uncased', 'roberta-base', 'roberta-large', + 'distilroberta-base', 'xlm-roberta-base', 'xlm-roberta-large', xlnet-base-cased', and 'xlnet-large-cased', + defaults to None + :type model_name: Optional[Union[str, SweepDistribution]] + :param number_of_epochs: the number of epochs to train with. Must be a positive integer, defaults to None + :type number_of_epochs: Optional[Union[int, SweepDistribution]] + :param training_batch_size: the batch size during training. Must be a positive integer, defaults to None + :type training_batch_size: Optional[Union[int, SweepDistribution]] + :param validation_batch_size: the batch size during validation. Must be a positive integer, defaults to None + :type validation_batch_size: Optional[Union[int, SweepDistribution]] + :param warmup_ratio: ratio of total training steps used for a linear warmup from 0 to learning_rate. + Must be a float in [0, 1], defaults to None + :type warmup_ratio: Optional[Union[float, SweepDistribution]] + :param weight_decay: value of weight decay when optimizer is sgd, adam, or adamw. This must be a float in + the range [0, 1], defaults to None + :type weight_decay: Optional[Union[float, SweepDistribution]] + + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_automl_nlp.py + :start-after: [START automl.nlp_search_space] + :end-before: [END automl.nlp_search_space] + :language: python + :dedent: 8 + :caption: creating an nlp search space + """ + + def __init__( + self, + *, + gradient_accumulation_steps: Optional[Union[int, SweepDistribution]] = None, + learning_rate: Optional[Union[float, SweepDistribution]] = None, + learning_rate_scheduler: Optional[Union[str, SweepDistribution]] = None, + model_name: Optional[Union[str, SweepDistribution]] = None, + number_of_epochs: Optional[Union[int, SweepDistribution]] = None, + training_batch_size: Optional[Union[int, SweepDistribution]] = None, + validation_batch_size: Optional[Union[int, SweepDistribution]] = None, + warmup_ratio: Optional[Union[float, SweepDistribution]] = None, + weight_decay: Optional[Union[float, SweepDistribution]] = None + ): + # Since we want customers to be able to specify enums as well rather than just strings, we need to access + # the enum values here before we serialize them ('NlpModels.BERT_BASE_CASED' vs. 'bert-base-cased'). + if isinstance(learning_rate_scheduler, NlpLearningRateScheduler): + learning_rate_scheduler = camel_to_snake(learning_rate_scheduler.value) + elif isinstance(learning_rate_scheduler, Choice): + if learning_rate_scheduler.values is not None: + learning_rate_scheduler.values = [ + camel_to_snake(item.value) if isinstance(item, NlpLearningRateScheduler) else item + for item in learning_rate_scheduler.values + ] + + if isinstance(model_name, NlpModels): + model_name = model_name.value + elif isinstance(model_name, Choice): + if model_name.values is not None: + model_name.values = [item.value if isinstance(item, NlpModels) else item for item in model_name.values] + + self.gradient_accumulation_steps = gradient_accumulation_steps + self.learning_rate = learning_rate + self.learning_rate_scheduler = learning_rate_scheduler + self.model_name = model_name + self.number_of_epochs = number_of_epochs + self.training_batch_size = training_batch_size + self.validation_batch_size = validation_batch_size + self.warmup_ratio = warmup_ratio + self.weight_decay = weight_decay + + def _to_rest_object(self) -> NlpParameterSubspace: + return NlpParameterSubspace( + gradient_accumulation_steps=( + _convert_to_rest_object(self.gradient_accumulation_steps) + if self.gradient_accumulation_steps is not None + else None + ), + learning_rate=_convert_to_rest_object(self.learning_rate) if self.learning_rate is not None else None, + learning_rate_scheduler=( + _convert_to_rest_object(self.learning_rate_scheduler) + if self.learning_rate_scheduler is not None + else None + ), + model_name=_convert_to_rest_object(self.model_name) if self.model_name is not None else None, + number_of_epochs=( + _convert_to_rest_object(self.number_of_epochs) if self.number_of_epochs is not None else None + ), + training_batch_size=( + _convert_to_rest_object(self.training_batch_size) if self.training_batch_size is not None else None + ), + validation_batch_size=( + _convert_to_rest_object(self.validation_batch_size) if self.validation_batch_size is not None else None + ), + warmup_ratio=_convert_to_rest_object(self.warmup_ratio) if self.warmup_ratio is not None else None, + weight_decay=_convert_to_rest_object(self.weight_decay) if self.weight_decay is not None else None, + ) + + @classmethod + def _from_rest_object(cls, obj: NlpParameterSubspace) -> "NlpSearchSpace": + return cls( + gradient_accumulation_steps=( + _convert_from_rest_object(obj.gradient_accumulation_steps) + if obj.gradient_accumulation_steps is not None + else None + ), + learning_rate=_convert_from_rest_object(obj.learning_rate) if obj.learning_rate is not None else None, + learning_rate_scheduler=( + _convert_from_rest_object(obj.learning_rate_scheduler) + if obj.learning_rate_scheduler is not None + else None + ), + model_name=_convert_from_rest_object(obj.model_name) if obj.model_name is not None else None, + number_of_epochs=( + _convert_from_rest_object(obj.number_of_epochs) if obj.number_of_epochs is not None else None + ), + training_batch_size=( + _convert_from_rest_object(obj.training_batch_size) if obj.training_batch_size is not None else None + ), + validation_batch_size=( + _convert_from_rest_object(obj.validation_batch_size) if obj.validation_batch_size is not None else None + ), + warmup_ratio=_convert_from_rest_object(obj.warmup_ratio) if obj.warmup_ratio is not None else None, + weight_decay=_convert_from_rest_object(obj.weight_decay) if obj.weight_decay is not None else None, + ) + + @classmethod + def _from_search_space_object(cls, obj: SearchSpace) -> "NlpSearchSpace": + return cls( + gradient_accumulation_steps=( + obj.gradient_accumulation_steps if hasattr(obj, "gradient_accumulation_steps") else None + ), + learning_rate=obj.learning_rate if hasattr(obj, "learning_rate") else None, + learning_rate_scheduler=obj.learning_rate_scheduler if hasattr(obj, "learning_rate_scheduler") else None, + model_name=obj.model_name if hasattr(obj, "model_name") else None, + number_of_epochs=obj.number_of_epochs if hasattr(obj, "number_of_epochs") else None, + training_batch_size=obj.training_batch_size if hasattr(obj, "training_batch_size") else None, + validation_batch_size=obj.validation_batch_size if hasattr(obj, "validation_batch_size") else None, + warmup_ratio=obj.warmup_ratio if hasattr(obj, "warmup_ratio") else None, + weight_decay=obj.weight_decay if hasattr(obj, "weight_decay") else None, + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, NlpSearchSpace): + return NotImplemented + + return ( + self.gradient_accumulation_steps == other.gradient_accumulation_steps + and self.learning_rate == other.learning_rate + and self.learning_rate_scheduler == other.learning_rate_scheduler + and self.model_name == other.model_name + and self.number_of_epochs == other.number_of_epochs + and self.training_batch_size == other.training_batch_size + and self.validation_batch_size == other.validation_batch_size + and self.warmup_ratio == other.warmup_ratio + and self.weight_decay == other.weight_decay + ) + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/nlp_sweep_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/nlp_sweep_settings.py new file mode 100644 index 00000000..e446a30c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/nlp_sweep_settings.py @@ -0,0 +1,65 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Optional, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import NlpSweepSettings as RestNlpSweepSettings +from azure.ai.ml._restclient.v2023_04_01_preview.models import SamplingAlgorithmType +from azure.ai.ml.entities._job.sweep.early_termination_policy import EarlyTerminationPolicy +from azure.ai.ml.entities._mixins import RestTranslatableMixin + + +# pylint: disable=protected-access +class NlpSweepSettings(RestTranslatableMixin): + """Sweep settings for all AutoML NLP tasks. + + :param sampling_algorithm: Required. Specifies type of hyperparameter sampling algorithm. + Possible values include: "Grid", "Random", and "Bayesian". + :type sampling_algorithm: Union[str, ~azure.ai.ml.automl.SamplingAlgorithmType] + :param early_termination: Early termination policy to end poorly performing training candidates, + defaults to None. + :type early_termination: Optional[~azure.mgmt.machinelearningservices.models.EarlyTerminationPolicy] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_automl_nlp.py + :start-after: [START automl.nlp_sweep_settings] + :end-before: [END automl.nlp_sweep_settings] + :language: python + :dedent: 8 + :caption: creating an nlp sweep settings + """ + + def __init__( + self, + *, + sampling_algorithm: Union[str, SamplingAlgorithmType], + early_termination: Optional[EarlyTerminationPolicy] = None, + ): + self.sampling_algorithm = sampling_algorithm + self.early_termination = early_termination + + def _to_rest_object(self) -> RestNlpSweepSettings: + return RestNlpSweepSettings( + sampling_algorithm=self.sampling_algorithm, + early_termination=self.early_termination._to_rest_object() if self.early_termination else None, + ) + + @classmethod + def _from_rest_object(cls, obj: RestNlpSweepSettings) -> "NlpSweepSettings": + return cls( + sampling_algorithm=obj.sampling_algorithm, + early_termination=( + EarlyTerminationPolicy._from_rest_object(obj.early_termination) if obj.early_termination else None + ), + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, NlpSweepSettings): + return NotImplemented + + return self.sampling_algorithm == other.sampling_algorithm and self.early_termination == other.early_termination + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/text_classification_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/text_classification_job.py new file mode 100644 index 00000000..290f4f70 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/text_classification_job.py @@ -0,0 +1,248 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from typing import TYPE_CHECKING, Any, Dict, Optional, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import AutoMLJob as RestAutoMLJob +from azure.ai.ml._restclient.v2023_04_01_preview.models import JobBase, TaskType +from azure.ai.ml._restclient.v2023_04_01_preview.models._azure_machine_learning_workspaces_enums import ( + ClassificationPrimaryMetrics, +) +from azure.ai.ml._restclient.v2024_01_01_preview.models import TextClassification as RestTextClassification +from azure.ai.ml._utils.utils import camel_to_snake, is_data_binding_expression +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY +from azure.ai.ml.constants._job.automl import AutoMLConstants +from azure.ai.ml.entities._credentials import _BaseJobIdentityConfiguration +from azure.ai.ml.entities._inputs_outputs import Input +from azure.ai.ml.entities._job._input_output_helpers import from_rest_data_outputs, to_rest_data_outputs +from azure.ai.ml.entities._job.automl.nlp.automl_nlp_job import AutoMLNLPJob +from azure.ai.ml.entities._job.automl.nlp.nlp_featurization_settings import NlpFeaturizationSettings +from azure.ai.ml.entities._job.automl.nlp.nlp_fixed_parameters import NlpFixedParameters +from azure.ai.ml.entities._job.automl.nlp.nlp_limit_settings import NlpLimitSettings +from azure.ai.ml.entities._job.automl.nlp.nlp_sweep_settings import NlpSweepSettings +from azure.ai.ml.entities._system_data import SystemData +from azure.ai.ml.entities._util import load_from_dict + +# avoid circular import error +if TYPE_CHECKING: + from azure.ai.ml.entities._component.component import Component + + +class TextClassificationJob(AutoMLNLPJob): + """Configuration for AutoML Text Classification Job. + + :param target_column_name: The name of the target column, defaults to None + :type target_column_name: Optional[str] + :param training_data: Training data to be used for training, defaults to None + :type training_data: Optional[~azure.ai.ml.Input] + :param validation_data: Validation data to be used for evaluating the trained model, defaults to None + :type validation_data: Optional[~azure.ai.ml.Input] + :param primary_metric: The primary metric to be displayed, defaults to None + :type primary_metric: Optional[~azure.ai.ml.automl.ClassificationPrimaryMetrics] + :param log_verbosity: Log verbosity level, defaults to None + :type log_verbosity: Optional[str] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_automl_nlp.py + :start-after: [START automl.automl_nlp_job.text_classification_job] + :end-before: [END automl.automl_nlp_job.text_classification_job] + :language: python + :dedent: 8 + :caption: creating an automl text classification job + """ + + _DEFAULT_PRIMARY_METRIC = ClassificationPrimaryMetrics.ACCURACY + + def __init__( + self, + *, + target_column_name: Optional[str] = None, + training_data: Optional[Input] = None, + validation_data: Optional[Input] = None, + primary_metric: Optional[ClassificationPrimaryMetrics] = None, + log_verbosity: Optional[str] = None, + **kwargs: Any + ): + super().__init__( + task_type=TaskType.TEXT_CLASSIFICATION, + primary_metric=primary_metric or TextClassificationJob._DEFAULT_PRIMARY_METRIC, + target_column_name=target_column_name, + training_data=training_data, + validation_data=validation_data, + log_verbosity=log_verbosity, + **kwargs, + ) + + @property + def primary_metric(self) -> Union[str, ClassificationPrimaryMetrics]: + return self._primary_metric + + @primary_metric.setter + def primary_metric(self, value: Union[str, ClassificationPrimaryMetrics]) -> None: + """setter for primary metric + + :param value: _description_ + :type value: Union[str, ClassificationPrimaryMetrics] + """ + if is_data_binding_expression(str(value), ["parent"]): + self._primary_metric = value + return + + self._primary_metric = ( + TextClassificationJob._DEFAULT_PRIMARY_METRIC + if value is None + else ClassificationPrimaryMetrics[camel_to_snake(value).upper()] + ) + + def _to_rest_object(self) -> JobBase: + text_classification = RestTextClassification( + target_column_name=self.target_column_name, + training_data=self.training_data, + validation_data=self.validation_data, + limit_settings=self._limits._to_rest_object() if self._limits else None, + sweep_settings=self._sweep._to_rest_object() if self._sweep else None, + fixed_parameters=self._training_parameters._to_rest_object() if self._training_parameters else None, + search_space=( + [entry._to_rest_object() for entry in self._search_space if entry is not None] + if self._search_space is not None + else None + ), + featurization_settings=self._featurization._to_rest_object() if self._featurization else None, + primary_metric=self.primary_metric, + log_verbosity=self.log_verbosity, + ) + # resolve data inputs in rest object + self._resolve_data_inputs(text_classification) + + properties = RestAutoMLJob( + display_name=self.display_name, + description=self.description, + experiment_name=self.experiment_name, + tags=self.tags, + compute_id=self.compute, + properties=self.properties, + environment_id=self.environment_id, + environment_variables=self.environment_variables, + services=self.services, + outputs=to_rest_data_outputs(self.outputs), + resources=self.resources, + task_details=text_classification, + identity=self.identity._to_job_rest_object() if self.identity else None, + queue_settings=self.queue_settings, + ) + + result = JobBase(properties=properties) + result.name = self.name + return result + + @classmethod + def _from_rest_object(cls, obj: JobBase) -> "TextClassificationJob": + properties: RestAutoMLJob = obj.properties + task_details: RestTextClassification = properties.task_details + assert isinstance(task_details, RestTextClassification) + limits = ( + NlpLimitSettings._from_rest_object(task_details.limit_settings) if task_details.limit_settings else None + ) + featurization = ( + NlpFeaturizationSettings._from_rest_object(task_details.featurization_settings) + if task_details.featurization_settings + else None + ) + sweep = NlpSweepSettings._from_rest_object(task_details.sweep_settings) if task_details.sweep_settings else None + training_parameters = ( + NlpFixedParameters._from_rest_object(task_details.fixed_parameters) + if task_details.fixed_parameters + else None + ) + + text_classification_job = cls( + # ----- job specific params + id=obj.id, + name=obj.name, + description=properties.description, + tags=properties.tags, + properties=properties.properties, + experiment_name=properties.experiment_name, + services=properties.services, + status=properties.status, + creation_context=SystemData._from_rest_object(obj.system_data) if obj.system_data else None, + display_name=properties.display_name, + compute=properties.compute_id, + outputs=from_rest_data_outputs(properties.outputs), + resources=properties.resources, + # ----- task specific params + primary_metric=task_details.primary_metric, + log_verbosity=task_details.log_verbosity, + target_column_name=task_details.target_column_name, + training_data=task_details.training_data, + validation_data=task_details.validation_data, + limits=limits, + sweep=sweep, + training_parameters=training_parameters, + search_space=cls._get_search_space_from_str(task_details.search_space), + featurization=featurization, + identity=( + _BaseJobIdentityConfiguration._from_rest_object(properties.identity) if properties.identity else None + ), + queue_settings=properties.queue_settings, + ) + + text_classification_job._restore_data_inputs() + + return text_classification_job + + def _to_component(self, context: Optional[Dict] = None, **kwargs: Any) -> "Component": + raise NotImplementedError() + + @classmethod + def _load_from_dict( + cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any + ) -> "TextClassificationJob": + from azure.ai.ml._schema.automl.nlp_vertical.text_classification import TextClassificationSchema + + if kwargs.pop("inside_pipeline", False): + from azure.ai.ml._schema.pipeline.automl_node import AutoMLTextClassificationNode + + loaded_data = load_from_dict( + AutoMLTextClassificationNode, + data, + context, + additional_message, + **kwargs, + ) + else: + loaded_data = load_from_dict(TextClassificationSchema, data, context, additional_message, **kwargs) + job_instance = cls._create_instance_from_schema_dict(loaded_data) + return job_instance + + @classmethod + def _create_instance_from_schema_dict(cls, loaded_data: Dict) -> "TextClassificationJob": + loaded_data.pop(AutoMLConstants.TASK_TYPE_YAML, None) + return TextClassificationJob(**loaded_data) + + def _to_dict(self, inside_pipeline: bool = False) -> Dict: + from azure.ai.ml._schema.automl.nlp_vertical.text_classification import TextClassificationSchema + from azure.ai.ml._schema.pipeline.automl_node import AutoMLTextClassificationNode + + if inside_pipeline: + res_autoML: dict = AutoMLTextClassificationNode(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res_autoML + + res: dict = TextClassificationSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TextClassificationJob): + return NotImplemented + + if not super(TextClassificationJob, self).__eq__(other): + return False + + return self.primary_metric == other.primary_metric + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/text_classification_multilabel_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/text_classification_multilabel_job.py new file mode 100644 index 00000000..ac19b451 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/text_classification_multilabel_job.py @@ -0,0 +1,252 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from typing import TYPE_CHECKING, Any, Dict, Optional, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import AutoMLJob as RestAutoMLJob +from azure.ai.ml._restclient.v2023_04_01_preview.models import ClassificationMultilabelPrimaryMetrics, JobBase, TaskType +from azure.ai.ml._restclient.v2024_01_01_preview.models import ( + TextClassificationMultilabel as RestTextClassificationMultilabel, +) +from azure.ai.ml._utils.utils import camel_to_snake, is_data_binding_expression +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY +from azure.ai.ml.constants._job.automl import AutoMLConstants +from azure.ai.ml.entities._credentials import _BaseJobIdentityConfiguration +from azure.ai.ml.entities._inputs_outputs import Input +from azure.ai.ml.entities._job._input_output_helpers import from_rest_data_outputs, to_rest_data_outputs +from azure.ai.ml.entities._job.automl.nlp.automl_nlp_job import AutoMLNLPJob +from azure.ai.ml.entities._job.automl.nlp.nlp_featurization_settings import NlpFeaturizationSettings +from azure.ai.ml.entities._job.automl.nlp.nlp_fixed_parameters import NlpFixedParameters +from azure.ai.ml.entities._job.automl.nlp.nlp_limit_settings import NlpLimitSettings +from azure.ai.ml.entities._job.automl.nlp.nlp_sweep_settings import NlpSweepSettings +from azure.ai.ml.entities._system_data import SystemData +from azure.ai.ml.entities._util import load_from_dict + +# avoid circular import error +if TYPE_CHECKING: + from azure.ai.ml.entities._component.component import Component + + +class TextClassificationMultilabelJob(AutoMLNLPJob): + """Configuration for AutoML Text Classification Multilabel Job. + + :param target_column_name: The name of the target column, defaults to None + :type target_column_name: Optional[str] + :param training_data: Training data to be used for training, defaults to None + :type training_data: Optional[~azure.ai.ml.Input] + :param validation_data: Validation data to be used for evaluating the trained model, defaults to None + :type validation_data: Optional[~azure.ai.ml.Input] + :param primary_metric: The primary metric to be displayed., defaults to None + :type primary_metric: Optional[str] + :param log_verbosity: Log verbosity level, defaults to None + :type log_verbosity: Optional[str] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_automl_nlp.py + :start-after: [START automl.text_classification_multilabel_job] + :end-before: [END automl.text_classification_multilabel_job] + :language: python + :dedent: 8 + :caption: creating an automl text classification multilabel job + """ + + _DEFAULT_PRIMARY_METRIC = ClassificationMultilabelPrimaryMetrics.ACCURACY + + def __init__( + self, + *, + target_column_name: Optional[str] = None, + training_data: Optional[Input] = None, + validation_data: Optional[Input] = None, + primary_metric: Optional[str] = None, + log_verbosity: Optional[str] = None, + **kwargs: Any + ): + super().__init__( + task_type=TaskType.TEXT_CLASSIFICATION_MULTILABEL, + primary_metric=primary_metric or TextClassificationMultilabelJob._DEFAULT_PRIMARY_METRIC, + target_column_name=target_column_name, + training_data=training_data, + validation_data=validation_data, + log_verbosity=log_verbosity, + **kwargs, + ) + + @property + def primary_metric(self) -> Union[str, ClassificationMultilabelPrimaryMetrics]: + return self._primary_metric + + @primary_metric.setter + def primary_metric(self, value: Union[str, ClassificationMultilabelPrimaryMetrics]) -> None: + if is_data_binding_expression(str(value), ["parent"]): + self._primary_metric = value + return + + self._primary_metric = ( + TextClassificationMultilabelJob._DEFAULT_PRIMARY_METRIC + if value is None + else ClassificationMultilabelPrimaryMetrics[camel_to_snake(value).upper()] + ) + + def _to_rest_object(self) -> JobBase: + text_classification_multilabel = RestTextClassificationMultilabel( + target_column_name=self.target_column_name, + training_data=self.training_data, + validation_data=self.validation_data, + limit_settings=self._limits._to_rest_object() if self._limits else None, + sweep_settings=self._sweep._to_rest_object() if self._sweep else None, + fixed_parameters=self._training_parameters._to_rest_object() if self._training_parameters else None, + search_space=( + [entry._to_rest_object() for entry in self._search_space if entry is not None] + if self._search_space is not None + else None + ), + featurization_settings=self._featurization._to_rest_object() if self._featurization else None, + primary_metric=self.primary_metric, + log_verbosity=self.log_verbosity, + ) + # resolve data inputs in rest object + self._resolve_data_inputs(text_classification_multilabel) + + properties = RestAutoMLJob( + display_name=self.display_name, + description=self.description, + experiment_name=self.experiment_name, + tags=self.tags, + compute_id=self.compute, + properties=self.properties, + environment_id=self.environment_id, + environment_variables=self.environment_variables, + services=self.services, + outputs=to_rest_data_outputs(self.outputs), + resources=self.resources, + task_details=text_classification_multilabel, + identity=self.identity._to_job_rest_object() if self.identity else None, + queue_settings=self.queue_settings, + ) + + result = JobBase(properties=properties) + result.name = self.name + return result + + @classmethod + def _from_rest_object(cls, obj: JobBase) -> "TextClassificationMultilabelJob": + properties: RestAutoMLJob = obj.properties + task_details: RestTextClassificationMultilabel = properties.task_details + assert isinstance(task_details, RestTextClassificationMultilabel) + limits = ( + NlpLimitSettings._from_rest_object(task_details.limit_settings) if task_details.limit_settings else None + ) + featurization = ( + NlpFeaturizationSettings._from_rest_object(task_details.featurization_settings) + if task_details.featurization_settings + else None + ) + sweep = NlpSweepSettings._from_rest_object(task_details.sweep_settings) if task_details.sweep_settings else None + training_parameters = ( + NlpFixedParameters._from_rest_object(task_details.fixed_parameters) + if task_details.fixed_parameters + else None + ) + + text_classification_multilabel_job = cls( + # ----- job specific params + id=obj.id, + name=obj.name, + description=properties.description, + tags=properties.tags, + properties=properties.properties, + experiment_name=properties.experiment_name, + services=properties.services, + status=properties.status, + creation_context=SystemData._from_rest_object(obj.system_data) if obj.system_data else None, + display_name=properties.display_name, + compute=properties.compute_id, + outputs=from_rest_data_outputs(properties.outputs), + resources=properties.resources, + # ----- task specific params + primary_metric=task_details.primary_metric, + log_verbosity=task_details.log_verbosity, + target_column_name=task_details.target_column_name, + training_data=task_details.training_data, + validation_data=task_details.validation_data, + limits=limits, + sweep=sweep, + training_parameters=training_parameters, + search_space=cls._get_search_space_from_str(task_details.search_space), + featurization=featurization, + identity=( + _BaseJobIdentityConfiguration._from_rest_object(properties.identity) if properties.identity else None + ), + queue_settings=properties.queue_settings, + ) + + text_classification_multilabel_job._restore_data_inputs() + + return text_classification_multilabel_job + + def _to_component(self, context: Optional[Dict] = None, **kwargs: Any) -> "Component": + raise NotImplementedError() + + @classmethod + def _load_from_dict( + cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any + ) -> "TextClassificationMultilabelJob": + from azure.ai.ml._schema.automl.nlp_vertical.text_classification_multilabel import ( + TextClassificationMultilabelSchema, + ) + + if kwargs.pop("inside_pipeline", False): + from azure.ai.ml._schema.pipeline.automl_node import AutoMLTextClassificationMultilabelNode + + loaded_data = load_from_dict( + AutoMLTextClassificationMultilabelNode, + data, + context, + additional_message, + **kwargs, + ) + else: + loaded_data = load_from_dict( + TextClassificationMultilabelSchema, + data, + context, + additional_message, + **kwargs, + ) + job_instance = cls._create_instance_from_schema_dict(loaded_data) + return job_instance + + @classmethod + def _create_instance_from_schema_dict(cls, loaded_data: Dict) -> "TextClassificationMultilabelJob": + loaded_data.pop(AutoMLConstants.TASK_TYPE_YAML, None) + return TextClassificationMultilabelJob(**loaded_data) + + def _to_dict(self, inside_pipeline: bool = False) -> Dict: + from azure.ai.ml._schema.automl.nlp_vertical.text_classification_multilabel import ( + TextClassificationMultilabelSchema, + ) + from azure.ai.ml._schema.pipeline.automl_node import AutoMLTextClassificationMultilabelNode + + if inside_pipeline: + res_autoML: dict = AutoMLTextClassificationMultilabelNode(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res_autoML + + res: dict = TextClassificationMultilabelSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TextClassificationMultilabelJob): + return NotImplemented + + if not super(TextClassificationMultilabelJob, self).__eq__(other): + return False + + return self.primary_metric == other.primary_metric + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/text_ner_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/text_ner_job.py new file mode 100644 index 00000000..a87965f1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/text_ner_job.py @@ -0,0 +1,231 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from typing import TYPE_CHECKING, Any, Dict, Optional, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import AutoMLJob as RestAutoMLJob +from azure.ai.ml._restclient.v2023_04_01_preview.models import JobBase, TaskType +from azure.ai.ml._restclient.v2023_04_01_preview.models._azure_machine_learning_workspaces_enums import ( + ClassificationPrimaryMetrics, +) +from azure.ai.ml._restclient.v2024_01_01_preview.models import TextNer as RestTextNER +from azure.ai.ml._utils.utils import camel_to_snake, is_data_binding_expression +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY +from azure.ai.ml.constants._job.automl import AutoMLConstants +from azure.ai.ml.entities._credentials import _BaseJobIdentityConfiguration +from azure.ai.ml.entities._inputs_outputs import Input +from azure.ai.ml.entities._job._input_output_helpers import from_rest_data_outputs, to_rest_data_outputs +from azure.ai.ml.entities._job.automl.nlp.automl_nlp_job import AutoMLNLPJob +from azure.ai.ml.entities._job.automl.nlp.nlp_featurization_settings import NlpFeaturizationSettings +from azure.ai.ml.entities._job.automl.nlp.nlp_fixed_parameters import NlpFixedParameters +from azure.ai.ml.entities._job.automl.nlp.nlp_limit_settings import NlpLimitSettings +from azure.ai.ml.entities._job.automl.nlp.nlp_sweep_settings import NlpSweepSettings +from azure.ai.ml.entities._system_data import SystemData +from azure.ai.ml.entities._util import load_from_dict + +# avoid circular import error +if TYPE_CHECKING: + from azure.ai.ml.entities._component.component import Component + + +class TextNerJob(AutoMLNLPJob): + """Configuration for AutoML Text NER Job. + + :param training_data: Training data to be used for training, defaults to None + :type training_data: Optional[~azure.ai.ml.Input] + :param validation_data: Validation data to be used for evaluating the trained model, + defaults to None + :type validation_data: Optional[~azure.ai.ml.Input] + :param primary_metric: The primary metric to be displayed, defaults to None + :type primary_metric: Optional[str] + :param log_verbosity: Log verbosity level, defaults to None + :type log_verbosity: Optional[str] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_automl_nlp.py + :start-after: [START automl.text_ner_job] + :end-before: [END automl.text_ner_job] + :language: python + :dedent: 8 + :caption: creating an automl text ner job + """ + + _DEFAULT_PRIMARY_METRIC = ClassificationPrimaryMetrics.ACCURACY + + def __init__( + self, + *, + training_data: Optional[Input] = None, + validation_data: Optional[Input] = None, + primary_metric: Optional[str] = None, + log_verbosity: Optional[str] = None, + **kwargs: Any + ): + super(TextNerJob, self).__init__( + task_type=TaskType.TEXT_NER, + primary_metric=primary_metric or TextNerJob._DEFAULT_PRIMARY_METRIC, + training_data=training_data, + validation_data=validation_data, + log_verbosity=log_verbosity, + **kwargs, + ) + + @property + def primary_metric(self) -> Union[str, ClassificationPrimaryMetrics]: + return self._primary_metric + + @primary_metric.setter + def primary_metric(self, value: Union[str, ClassificationPrimaryMetrics]) -> None: + if is_data_binding_expression(str(value), ["parent"]): + self._primary_metric = value + return + + self._primary_metric = ( + TextNerJob._DEFAULT_PRIMARY_METRIC + if value is None + else ClassificationPrimaryMetrics[camel_to_snake(value).upper()] + ) + + def _to_rest_object(self) -> JobBase: + text_ner = RestTextNER( + training_data=self.training_data, + validation_data=self.validation_data, + limit_settings=self._limits._to_rest_object() if self._limits else None, + sweep_settings=self._sweep._to_rest_object() if self._sweep else None, + fixed_parameters=self._training_parameters._to_rest_object() if self._training_parameters else None, + search_space=( + [entry._to_rest_object() for entry in self._search_space if entry is not None] + if self._search_space is not None + else None + ), + featurization_settings=self._featurization._to_rest_object() if self._featurization else None, + primary_metric=self.primary_metric, + log_verbosity=self.log_verbosity, + ) + # resolve data inputs in rest object + self._resolve_data_inputs(text_ner) + + properties = RestAutoMLJob( + display_name=self.display_name, + description=self.description, + experiment_name=self.experiment_name, + tags=self.tags, + compute_id=self.compute, + properties=self.properties, + environment_id=self.environment_id, + environment_variables=self.environment_variables, + services=self.services, + outputs=to_rest_data_outputs(self.outputs), + resources=self.resources, + task_details=text_ner, + identity=self.identity._to_job_rest_object() if self.identity else None, + queue_settings=self.queue_settings, + ) + + result = JobBase(properties=properties) + result.name = self.name + return result + + @classmethod + def _from_rest_object(cls, obj: JobBase) -> "TextNerJob": + properties: RestAutoMLJob = obj.properties + task_details: RestTextNER = properties.task_details + assert isinstance(task_details, RestTextNER) + limits = ( + NlpLimitSettings._from_rest_object(task_details.limit_settings) if task_details.limit_settings else None + ) + featurization = ( + NlpFeaturizationSettings._from_rest_object(task_details.featurization_settings) + if task_details.featurization_settings + else None + ) + sweep = NlpSweepSettings._from_rest_object(task_details.sweep_settings) if task_details.sweep_settings else None + training_parameters = ( + NlpFixedParameters._from_rest_object(task_details.fixed_parameters) + if task_details.fixed_parameters + else None + ) + + text_ner_job = cls( + # ----- job specific params + id=obj.id, + name=obj.name, + description=properties.description, + tags=properties.tags, + properties=properties.properties, + experiment_name=properties.experiment_name, + services=properties.services, + status=properties.status, + creation_context=SystemData._from_rest_object(obj.system_data) if obj.system_data else None, + display_name=properties.display_name, + compute=properties.compute_id, + outputs=from_rest_data_outputs(properties.outputs), + resources=properties.resources, + # ----- task specific params + primary_metric=task_details.primary_metric, + log_verbosity=task_details.log_verbosity, + target_column_name=task_details.target_column_name, + training_data=task_details.training_data, + validation_data=task_details.validation_data, + limits=limits, + sweep=sweep, + training_parameters=training_parameters, + search_space=cls._get_search_space_from_str(task_details.search_space), + featurization=featurization, + identity=( + _BaseJobIdentityConfiguration._from_rest_object(properties.identity) if properties.identity else None + ), + queue_settings=properties.queue_settings, + ) + + text_ner_job._restore_data_inputs() + + return text_ner_job + + def _to_component(self, context: Optional[Dict] = None, **kwargs: Any) -> "Component": + raise NotImplementedError() + + @classmethod + def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "TextNerJob": + from azure.ai.ml._schema.automl.nlp_vertical.text_ner import TextNerSchema + + if kwargs.pop("inside_pipeline", False): + from azure.ai.ml._schema.pipeline.automl_node import AutoMLTextNerNode + + loaded_data = load_from_dict(AutoMLTextNerNode, data, context, additional_message, **kwargs) + else: + loaded_data = load_from_dict(TextNerSchema, data, context, additional_message, **kwargs) + job_instance = cls._create_instance_from_schema_dict(loaded_data) + return job_instance + + @classmethod + def _create_instance_from_schema_dict(cls, loaded_data: Dict) -> "TextNerJob": + loaded_data.pop(AutoMLConstants.TASK_TYPE_YAML, None) + return TextNerJob(**loaded_data) + + def _to_dict(self, inside_pipeline: bool = False) -> Dict: + from azure.ai.ml._schema.automl.nlp_vertical.text_ner import TextNerSchema + from azure.ai.ml._schema.pipeline.automl_node import AutoMLTextNerNode + + if inside_pipeline: + res_autoML: dict = AutoMLTextNerNode(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res_autoML + + res: dict = TextNerSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TextNerJob): + return NotImplemented + + if not super(TextNerJob, self).__eq__(other): + return False + + return self.primary_metric == other.primary_metric + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/search_space.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/search_space.py new file mode 100644 index 00000000..a958de56 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/search_space.py @@ -0,0 +1,14 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +from typing import Any + + +class SearchSpace: + """SearchSpace class for AutoML verticals.""" + + def __init__(self, **kwargs: Any) -> None: + for k, v in kwargs.items(): + self.__setattr__(k, v) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/search_space_utils.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/search_space_utils.py new file mode 100644 index 00000000..732030d4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/search_space_utils.py @@ -0,0 +1,276 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import re +from typing import Any, List, Union + +from marshmallow import fields + +from azure.ai.ml._schema._sweep.search_space import ( + ChoiceSchema, + NormalSchema, + QNormalSchema, + QUniformSchema, + RandintSchema, + UniformSchema, +) +from azure.ai.ml._schema.core.fields import DumpableIntegerField, DumpableStringField, NestedField, UnionField +from azure.ai.ml._utils.utils import float_to_str +from azure.ai.ml.constants._job.sweep import SearchSpace +from azure.ai.ml.entities._job.sweep.search_space import ( + Choice, + LogNormal, + LogUniform, + Normal, + QLogNormal, + QLogUniform, + QNormal, + QUniform, + Randint, + SweepDistribution, + Uniform, +) +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException + + +def _convert_to_rest_object(sweep_distribution: Union[bool, int, float, str, SweepDistribution]) -> str: + if isinstance(sweep_distribution, float): + # Float requires some special handling for small values that get auto-represented with scientific notation. + res: str = float_to_str(sweep_distribution) + return res + if not isinstance(sweep_distribution, SweepDistribution): + # Convert [bool, float, str] types to str + return str(sweep_distribution) + + rest_object = sweep_distribution._to_rest_object() + if not isinstance(rest_object, list): + msg = "Rest Object for sweep distribution should be a list." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.AUTOML, + error_category=ErrorCategory.USER_ERROR, + ) + + if len(rest_object) <= 1: + msg = "Rest object for sweep distribution should contain at least two elements." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.AUTOML, + error_category=ErrorCategory.USER_ERROR, + ) + + sweep_distribution_type = rest_object[0] + sweep_distribution_args = [] + + if not isinstance(rest_object[1], list): + msg = "The second element of Rest object for sweep distribution should be a list." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.AUTOML, + error_category=ErrorCategory.USER_ERROR, + ) + + if sweep_distribution_type == SearchSpace.CHOICE: + # Rest objects for choice distribution are of format ["choice", [[0, 1, 2]]] + if not isinstance(rest_object[1][0], list): + msg = "The second element of Rest object for choice distribution should be a list of list." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.AUTOML, + error_category=ErrorCategory.USER_ERROR, + ) + for value in rest_object[1][0]: + if isinstance(value, str): + sweep_distribution_args.append("'" + value + "'") + elif isinstance(value, float): + sweep_distribution_args.append(float_to_str(value)) + else: + sweep_distribution_args.append(str(value)) + else: + for value in rest_object[1]: + if isinstance(value, float): + sweep_distribution_args.append(float_to_str(value)) + else: + sweep_distribution_args.append(str(value)) + + sweep_distribution_str: str = sweep_distribution_type + "(" + sweep_distribution_str += ",".join(sweep_distribution_args) + sweep_distribution_str += ")" + return sweep_distribution_str + + +def _is_int(value: str) -> bool: + try: + int(value) + return True + except ValueError: + return False + + +def _is_float(value: str) -> bool: + try: + float(value) + return True + except ValueError: + return False + + +def _get_type_inferred_value(value: str) -> Union[bool, int, float, str]: + value = value.strip() + if _is_int(value): + # Int + return int(value) + if _is_float(value): + # Float + return float(value) + if value in ["True", "False"]: + # Convert "True", "False" to python boolean literals + return value == "True" + # string value. Remove quotes before returning. + return value.strip("'\"") + + +def _convert_from_rest_object( + sweep_distribution_str: str, +) -> Any: + # sweep_distribution_str can be a distribution like "choice('vitb16r224', 'vits16r224')" or + # a single value like "True", "1", "1.0567", "vitb16r224" + + sweep_distribution_str = sweep_distribution_str.strip() + # Filter by the delimiters and remove splits that are empty strings + sweep_distribution_separated = list(filter(None, re.split("[ ,()]+", sweep_distribution_str))) + + if len(sweep_distribution_separated) == 1: + # Single value. + return _get_type_inferred_value(sweep_distribution_separated[0]) + + # Distribution string + sweep_distribution_type = sweep_distribution_separated[0].strip().lower() + sweep_distribution_args: List = [] + for value in sweep_distribution_separated[1:]: + sweep_distribution_args.append(_get_type_inferred_value(value)) + + if sweep_distribution_type == SearchSpace.CHOICE: + sweep_distribution_args = [sweep_distribution_args] # Choice values are list of lists + + sweep_distribution = SweepDistribution._from_rest_object([sweep_distribution_type, sweep_distribution_args]) + return sweep_distribution + + +def _convert_sweep_dist_dict_to_str_dict(sweep_distribution: dict) -> dict: + for k, sweep_dist_dict in sweep_distribution.items(): + if sweep_dist_dict is not None: + sweep_distribution[k] = _convert_sweep_dist_dict_item_to_str(sweep_dist_dict) + return sweep_distribution + + +class ChoicePlusSchema(ChoiceSchema): + """Choice schema that allows boolean values also""" + + values = fields.List( + UnionField( + [ + DumpableIntegerField(strict=True), + DumpableStringField(), + fields.Float(), + fields.Dict( + keys=fields.Str(), + values=UnionField( + [ + NestedField("ChoicePlusSchema"), + NestedField(NormalSchema()), + NestedField(QNormalSchema()), + NestedField(RandintSchema()), + NestedField(UniformSchema()), + NestedField(QUniformSchema()), + DumpableIntegerField(strict=True), + fields.Float(), + fields.Str(), + fields.Boolean(), + ] + ), + ), + fields.Boolean(), + ] + ) + ) + + +def _convert_sweep_dist_dict_item_to_str(sweep_distribution: Union[bool, int, float, str, dict]) -> str: + # Convert a Sweep Distribution dict to Sweep Distribution string + # Eg. {type: 'choice', values: ['vitb16r224','vits16r224']} => "Choice('vitb16r224','vits16r224')" + if isinstance(sweep_distribution, dict): + sweep_dist_type = sweep_distribution["type"] + if sweep_dist_type == SearchSpace.CHOICE: + sweep_dist_obj = ChoicePlusSchema().load(sweep_distribution) # pylint: disable=no-member + elif sweep_dist_type in SearchSpace.UNIFORM_LOGUNIFORM: + sweep_dist_obj = UniformSchema().load(sweep_distribution) # pylint: disable=no-member + elif sweep_dist_type in SearchSpace.NORMAL_LOGNORMAL: + sweep_dist_obj = NormalSchema().load(sweep_distribution) # pylint: disable=no-member + elif sweep_dist_type in SearchSpace.QUNIFORM_QLOGUNIFORM: + sweep_dist_obj = QUniformSchema().load(sweep_distribution) # pylint: disable=no-member + elif sweep_dist_type in SearchSpace.QNORMAL_QLOGNORMAL: + sweep_dist_obj = QNormalSchema().load(sweep_distribution) # pylint: disable=no-member + elif sweep_dist_type in SearchSpace.RANDINT: + sweep_dist_obj = RandintSchema().load(sweep_distribution) # pylint: disable=no-member + else: + msg = f"Unsupported sweep distribution type {sweep_dist_type}" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.AUTOML, + error_category=ErrorCategory.USER_ERROR, + ) + else: # Case for other primitive types + sweep_dist_obj = sweep_distribution + + sweep_dist_str = _convert_to_rest_object(sweep_dist_obj) + return sweep_dist_str + + +def _convert_sweep_dist_str_to_dict(sweep_dist_str_list: dict) -> dict: + for k, val in sweep_dist_str_list.items(): + if isinstance(val, str): + sweep_dist_str_list[k] = _convert_sweep_dist_str_item_to_dict(val) + return sweep_dist_str_list + + +def _convert_sweep_dist_str_item_to_dict( + sweep_distribution_str: str, +) -> Union[bool, int, float, str, dict]: + # sweep_distribution_str can be a distribution like "choice('vitb16r224', 'vits16r224')" + # return type is {type: 'choice', values: ['vitb16r224', 'vits16r224']} + sweep_dist_obj = _convert_from_rest_object(sweep_distribution_str) + sweep_dist: Union[bool, int, float, str, dict] = "" + if isinstance(sweep_dist_obj, SweepDistribution): + if isinstance(sweep_dist_obj, Choice): + sweep_dist = ChoicePlusSchema().dump(sweep_dist_obj) # pylint: disable=no-member + elif isinstance(sweep_dist_obj, (QNormal, QLogNormal)): + sweep_dist = QNormalSchema().dump(sweep_dist_obj) # pylint: disable=no-member + elif isinstance(sweep_dist_obj, (QUniform, QLogUniform)): + sweep_dist = QUniformSchema().dump(sweep_dist_obj) # pylint: disable=no-member + elif isinstance(sweep_dist_obj, (Uniform, LogUniform)): + sweep_dist = UniformSchema().dump(sweep_dist_obj) # pylint: disable=no-member + elif isinstance(sweep_dist_obj, (Normal, LogNormal)): + sweep_dist = NormalSchema().dump(sweep_dist_obj) # pylint: disable=no-member + elif isinstance(sweep_dist_obj, Randint): + sweep_dist = RandintSchema().dump(sweep_dist_obj) # pylint: disable=no-member + else: + msg = "Invalid sweep distribution {}" + raise ValidationException( + message=msg.format(sweep_distribution_str), + no_personal_data_message=msg, + target=ErrorTarget.AUTOML, + error_category=ErrorCategory.USER_ERROR, + ) + else: # Case for other primitive types + sweep_dist = sweep_dist_obj + + return sweep_dist diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/stack_ensemble_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/stack_ensemble_settings.py new file mode 100644 index 00000000..c17fa7e3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/stack_ensemble_settings.py @@ -0,0 +1,70 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Any, Optional + +from azure.ai.ml._restclient.v2023_04_01_preview.models import StackEnsembleSettings as RestStackEnsembleSettings +from azure.ai.ml._restclient.v2023_04_01_preview.models import StackMetaLearnerType +from azure.ai.ml.entities._mixins import RestTranslatableMixin + + +class StackEnsembleSettings(RestTranslatableMixin): + """Advance setting to customize StackEnsemble run.""" + + def __init__( + self, + *, + stack_meta_learner_k_wargs: Optional[Any] = None, + stack_meta_learner_train_percentage: float = 0.2, + stack_meta_learner_type: Optional[StackMetaLearnerType] = None, + **kwargs: Any + ): + """ + :param stack_meta_learner_k_wargs: Optional parameters to pass to the initializer of the + meta-learner. + :type stack_meta_learner_k_wargs: any + :param stack_meta_learner_train_percentage: Specifies the proportion of the training set + (when choosing train and validation type of training) to be reserved for training the + meta-learner. Default value is 0.2. + :type stack_meta_learner_train_percentage: float + :param stack_meta_learner_type: The meta-learner is a model trained on the output of the + individual heterogeneous models. Possible values include: "None", "LogisticRegression", + "LogisticRegressionCV", "LightGBMClassifier", "ElasticNet", "ElasticNetCV", + "LightGBMRegressor", "LinearRegression". + :type stack_meta_learner_type: str or + ~azure.mgmt.machinelearningservices.models.StackMetaLearnerType + """ + super(StackEnsembleSettings, self).__init__(**kwargs) + self.stack_meta_learner_k_wargs = stack_meta_learner_k_wargs + self.stack_meta_learner_train_percentage = stack_meta_learner_train_percentage + self.stack_meta_learner_type = stack_meta_learner_type + + def _to_rest_object(self) -> RestStackEnsembleSettings: + return RestStackEnsembleSettings( + stack_meta_learner_k_wargs=self.stack_meta_learner_k_wargs, + stack_meta_learner_train_percentage=self.stack_meta_learner_train_percentage, + stack_meta_learner_type=self.stack_meta_learner_type, + ) + + @classmethod + def _from_rest_object(cls, obj: RestStackEnsembleSettings) -> "StackEnsembleSettings": + return cls( + stack_meta_learner_k_wargs=obj.stack_meta_learner_k_wargs, + stack_meta_learner_train_percentage=obj.stack_meta_learner_train_percentage, + stack_meta_learner_type=obj.stack_meta_learner_type, + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, StackEnsembleSettings): + return NotImplemented + + return ( + super().__eq__(other) + and self.stack_meta_learner_k_wargs == other.stack_meta_learner_k_wargs + and self.stack_meta_learner_train_percentage == other.stack_meta_learner_train_percentage + and self.stack_meta_learner_type == other.stack_meta_learner_type + ) + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/__init__.py new file mode 100644 index 00000000..c0373010 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/__init__.py @@ -0,0 +1,22 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from .automl_tabular import AutoMLTabular +from .classification_job import ClassificationJob +from .featurization_settings import ColumnTransformer, TabularFeaturizationSettings +from .forecasting_job import ForecastingJob +from .forecasting_settings import ForecastingSettings +from .limit_settings import TabularLimitSettings +from .regression_job import RegressionJob + +__all__ = [ + "AutoMLTabular", + "ClassificationJob", + "ColumnTransformer", + "ForecastingJob", + "ForecastingSettings", + "RegressionJob", + "TabularFeaturizationSettings", + "TabularLimitSettings", +] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/automl_tabular.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/automl_tabular.py new file mode 100644 index 00000000..5f4ed22b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/automl_tabular.py @@ -0,0 +1,607 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=too-many-instance-attributes + +from abc import ABC +from typing import Any, Dict, List, Optional, Union + +from azure.ai.ml._restclient.v2024_01_01_preview.models import ( + AutoNCrossValidations, + BlockedTransformers, + CustomNCrossValidations, + LogVerbosity, +) +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.constants import TabularTrainingMode +from azure.ai.ml.constants._job.automl import AutoMLConstants +from azure.ai.ml.entities._inputs_outputs import Input +from azure.ai.ml.entities._job.automl.automl_vertical import AutoMLVertical +from azure.ai.ml.entities._job.automl.stack_ensemble_settings import StackEnsembleSettings +from azure.ai.ml.entities._job.automl.tabular.featurization_settings import ( + ColumnTransformer, + TabularFeaturizationSettings, +) +from azure.ai.ml.entities._job.automl.tabular.limit_settings import TabularLimitSettings +from azure.ai.ml.entities._job.automl.training_settings import TrainingSettings +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException + + +class AutoMLTabular(AutoMLVertical, ABC): + """Initialize an AutoML job entity for tabular data. + + Constructor for AutoMLTabular. + + :keyword task_type: The type of task to run. Possible values include: "classification", "regression" + , "forecasting". + :paramtype task_type: str + :keyword featurization: featurization settings. Defaults to None. + :paramtype featurization: typing.Optional[TabularFeaturizationSettings] + :keyword limits: limits settings. Defaults to None. + :paramtype limits: typing.Optional[TabularLimitSettings] + :keyword training: training settings. Defaults to None. + :paramtype training: typing.Optional[TrainingSettings] + :keyword log_verbosity: Verbosity of logging. Possible values include: "debug", "info", "warning", "error", + "critical". Defaults to "info". + :paramtype log_verbosity: str + :keyword target_column_name: The name of the target column. Defaults to None. + :paramtype target_column_name: typing.Optional[str] + :keyword weight_column_name: The name of the weight column. Defaults to None. + :paramtype weight_column_name: typing.Optional[str] + :keyword validation_data_size: The size of the validation data. Defaults to None. + :paramtype validation_data_size: typing.Optional[float] + :keyword cv_split_column_names: The names of the columns to use for cross validation. Defaults to None. + :paramtype cv_split_column_names: typing.Optional[List[str]] + :keyword n_cross_validations: The number of cross validations to run. Defaults to None. + :paramtype n_cross_validations: typing.Optional[int] + :keyword test_data_size: The size of the test data. Defaults to None. + :paramtype test_data_size: typing.Optional[float] + :keyword training_data: The training data. Defaults to None. + :paramtype training_data: typing.Optional[azure.ai.ml.entities._inputs_outputs.Input] + :keyword validation_data: The validation data. Defaults to None. + :paramtype validation_data: typing.Optional[azure.ai.ml.entities._inputs_outputs.Input] + :keyword test_data: The test data. Defaults to None. + :paramtype test_data: typing.Optional[azure.ai.ml.entities._inputs_outputs.Input] + """ + + def __init__( + self, + *, + task_type: str, + featurization: Optional[TabularFeaturizationSettings] = None, + limits: Optional[TabularLimitSettings] = None, + training: Optional[Any] = None, + **kwargs: Any, + ) -> None: + """Initialize an AutoML job entity for tabular data. + + Constructor for AutoMLTabular. + + :keyword task_type: The type of task to run. Possible values include: "classification", "regression" + , "forecasting". + :paramtype task_type: str + :keyword featurization: featurization settings. Defaults to None. + :paramtype featurization: typing.Optional[TabularFeaturizationSettings] + :keyword limits: limits settings. Defaults to None. + :paramtype limits: typing.Optional[TabularLimitSettings] + :keyword training: training settings. Defaults to None. + :paramtype training: typing.Optional[TrainingSettings] + :keyword log_verbosity: Verbosity of logging. Possible values include: "debug", "info", "warning", "error", + "critical". Defaults to "info". + :paramtype log_verbosity: str + :keyword target_column_name: The name of the target column. Defaults to None. + :paramtype target_column_name: typing.Optional[str] + :keyword weight_column_name: The name of the weight column. Defaults to None. + :paramtype weight_column_name: typing.Optional[str] + :keyword validation_data_size: The size of the validation data. Defaults to None. + :paramtype validation_data_size: typing.Optional[float] + :keyword cv_split_column_names: The names of the columns to use for cross validation. Defaults to None. + :paramtype cv_split_column_names: typing.Optional[List[str]] + :keyword n_cross_validations: The number of cross validations to run. Defaults to None. + :paramtype n_cross_validations: typing.Optional[int] + :keyword test_data_size: The size of the test data. Defaults to None. + :paramtype test_data_size: typing.Optional[float] + :keyword training_data: The training data. Defaults to None. + :paramtype training_data: typing.Optional[azure.ai.ml.entities._inputs_outputs.Input] + :keyword validation_data: The validation data. Defaults to None. + :paramtype validation_data: typing.Optional[azure.ai.ml.entities._inputs_outputs.Input] + :keyword test_data: The test data. Defaults to None. + :paramtype test_data: typing.Optional[azure.ai.ml.entities._inputs_outputs.Input] + :raises: :class:`azure.ai.ml.exceptions.ValidationException` + """ + self.log_verbosity = kwargs.pop("log_verbosity", LogVerbosity.INFO) + + self.target_column_name = kwargs.pop("target_column_name", None) + self.weight_column_name = kwargs.pop("weight_column_name", None) + self.validation_data_size = kwargs.pop("validation_data_size", None) + self.cv_split_column_names = kwargs.pop("cv_split_column_names", None) + self.n_cross_validations = kwargs.pop("n_cross_validations", None) + self.test_data_size = kwargs.pop("test_data_size", None) + + super().__init__( + task_type=task_type, + training_data=kwargs.pop("training_data", None), + validation_data=kwargs.pop("validation_data", None), + test_data=kwargs.pop("test_data", None), + **kwargs, + ) + + self._featurization = featurization + self._limits = limits + self._training = training + + @property + def log_verbosity(self) -> LogVerbosity: + """Get the log verbosity for the AutoML job. + + :return: log verbosity for the AutoML job + :rtype: LogVerbosity + """ + return self._log_verbosity + + @log_verbosity.setter + def log_verbosity(self, value: Union[str, LogVerbosity]) -> None: + """Set the log verbosity for the AutoML job. + + :param value: str or LogVerbosity + :type value: typing.Union[str, LogVerbosity] + """ + self._log_verbosity = None if value is None else LogVerbosity[camel_to_snake(value).upper()] + + @property + def limits(self) -> Optional[TabularLimitSettings]: + """Get the tabular limits for the AutoML job. + + :return: Tabular limits for the AutoML job + :rtype: TabularLimitSettings + """ + return self._limits + + @limits.setter + def limits(self, value: Union[Dict, TabularLimitSettings]) -> None: + """Set the limits for the AutoML job. + + :param value: typing.Dict or TabularLimitSettings + :type value: typing.Union[typing.Dict, TabularLimitSettings] + :raises ValidationException: Expected a dictionary for limit settings. + """ + if isinstance(value, TabularLimitSettings): + self._limits = value + else: + if not isinstance(value, dict): + msg = "Expected a dictionary for limit settings." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.AUTOML, + error_category=ErrorCategory.USER_ERROR, + ) + self.set_limits(**value) + + @property + def training(self) -> Any: + """Get the training settings for the AutoML job. + + :return: Training settings for the AutoML job. + :rtype: TrainingSettings + """ + return self._training + + @training.setter + def training(self, value: Union[Dict, TrainingSettings]) -> None: + """Set the training settings for the AutoML job. + + :param value: typing.Dict or TrainingSettings + :type value: typing.Union[typing.Dict, TrainingSettings] + :raises ValidationException: Expected a dictionary for training settings. + """ + if isinstance(value, TrainingSettings): + self._training = value + else: + if not isinstance(value, dict): + msg = "Expected a dictionary for training settings." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.AUTOML, + error_category=ErrorCategory.USER_ERROR, + ) + self.set_training(**value) + + @property + def featurization(self) -> Optional[TabularFeaturizationSettings]: + """Get the tabular featurization settings for the AutoML job. + + :return: Tabular featurization settings for the AutoML job + :rtype: TabularFeaturizationSettings + """ + return self._featurization + + @featurization.setter + def featurization(self, value: Union[Dict, TabularFeaturizationSettings]) -> None: + """Set the featurization settings for the AutoML job. + + :param value: typing.Dict or TabularFeaturizationSettings + :type value: typing.Union[typing.Dict, TabularFeaturizationSettings] + :raises ValidationException: Expected a dictionary for featurization settings + """ + if isinstance(value, TabularFeaturizationSettings): + self._featurization = value + else: + if not isinstance(value, dict): + msg = "Expected a dictionary for featurization settings." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.AUTOML, + error_category=ErrorCategory.USER_ERROR, + ) + self.set_featurization(**value) + + def set_limits( + self, + *, + enable_early_termination: Optional[bool] = None, + exit_score: Optional[float] = None, + max_concurrent_trials: Optional[int] = None, + max_cores_per_trial: Optional[int] = None, + max_nodes: Optional[int] = None, + max_trials: Optional[int] = None, + timeout_minutes: Optional[int] = None, + trial_timeout_minutes: Optional[int] = None, + ) -> None: + """Set limits for the job. + + :keyword enable_early_termination: Whether to enable early termination if the score is not improving in the + short term, defaults to None. + + Early stopping logic: + + * No early stopping for first 20 iterations (landmarks). + * Early stopping window starts on the 21st iteration and looks for early_stopping_n_iters iterations + (currently set to 10). This means that the first iteration where stopping can occur is the 31st. + * AutoML still schedules 2 ensemble iterations AFTER early stopping, which might result in higher scores. + * Early stopping is triggered if the absolute value of best score calculated is the same for past + early_stopping_n_iters iterations, that is, if there is no improvement in score for + early_stopping_n_iters iterations. + + :paramtype enable_early_termination: typing.Optional[bool] + :keyword exit_score: Target score for experiment. The experiment terminates after this score is reached. + If not specified (no criteria), the experiment runs until no further progress is made + on the primary metric. For for more information on exit criteria, see this `article + <https://learn.microsoft.com/azure/machine-learning/how-to-configure-auto-train#exit-criteria>`_ + , defaults to None + :paramtype exit_score: typing.Optional[float] + :keyword max_concurrent_trials: This is the maximum number of iterations that would be executed in parallel. + The default value is 1. + + * AmlCompute clusters support one iteration running per node. For multiple AutoML experiment parent runs + executed in parallel on a single AmlCompute cluster, the sum of the ``max_concurrent_trials`` values + for all experiments should be less than or equal to the maximum number of nodes. Otherwise, runs + will be queued until nodes are available. + + * DSVM supports multiple iterations per node. ``max_concurrent_trials`` should + be less than or equal to the number of cores on the DSVM. For multiple experiments + run in parallel on a single DSVM, the sum of the ``max_concurrent_trials`` values for all + experiments should be less than or equal to the maximum number of nodes. + + * Databricks - ``max_concurrent_trials`` should be less than or equal to the number of + worker nodes on Databricks. + + ``max_concurrent_trials`` does not apply to local runs. Formerly, this parameter + was named ``concurrent_iterations``. + :paramtype max_concurrent_trials: typing.Optional[int] + :keyword max_cores_per_trial: The maximum number of threads to use for a given training iteration. + Acceptable values: + + * Greater than 1 and less than or equal to the maximum number of cores on the compute target. + + * Equal to -1, which means to use all the possible cores per iteration per child-run. + + * Equal to 1, the default. + + :paramtype max_cores_per_trial: typing.Optional[int] + :keyword max_nodes: [Experimental] The maximum number of nodes to use for distributed training. + + * For forecasting, each model is trained using max(2, int(max_nodes / max_concurrent_trials)) nodes. + + * For classification/regression, each model is trained using max_nodes nodes. + + Note- This parameter is in public preview and might change in future. + :paramtype max_nodes: typing.Optional[int] + :keyword max_trials: The total number of different algorithm and parameter combinations to test during an + automated ML experiment. If not specified, the default is 1000 iterations. + :paramtype max_trials: typing.Optional[int] + :keyword timeout_minutes: Maximum amount of time in minutes that all iterations combined can take before the + experiment terminates. If not specified, the default experiment timeout is 6 days. To specify a timeout + less than or equal to 1 hour, make sure your dataset's size is not greater than + 10,000,000 (rows times column) or an error results, defaults to None + :paramtype timeout_minutes: typing.Optional[int] + :keyword trial_timeout_minutes: Maximum time in minutes that each iteration can run for before it terminates. + If not specified, a value of 1 month or 43200 minutes is used, defaults to None + :paramtype trial_timeout_minutes: typing.Optional[int] + """ + self._limits = self._limits or TabularLimitSettings() + self._limits.enable_early_termination = ( + enable_early_termination if enable_early_termination is not None else self._limits.enable_early_termination + ) + self._limits.exit_score = exit_score if exit_score is not None else self._limits.exit_score + self._limits.max_concurrent_trials = ( + max_concurrent_trials if max_concurrent_trials is not None else self._limits.max_concurrent_trials + ) + self._limits.max_cores_per_trial = ( + max_cores_per_trial if max_cores_per_trial is not None else self._limits.max_cores_per_trial + ) + self._limits.max_nodes = max_nodes if max_nodes is not None else self._limits.max_nodes + self._limits.max_trials = max_trials if max_trials is not None else self._limits.max_trials + self._limits.timeout_minutes = timeout_minutes if timeout_minutes is not None else self._limits.timeout_minutes + self._limits.trial_timeout_minutes = ( + trial_timeout_minutes if trial_timeout_minutes is not None else self._limits.trial_timeout_minutes + ) + + def set_training( + self, + *, + enable_onnx_compatible_models: Optional[bool] = None, + enable_dnn_training: Optional[bool] = None, + enable_model_explainability: Optional[bool] = None, + enable_stack_ensemble: Optional[bool] = None, + enable_vote_ensemble: Optional[bool] = None, + stack_ensemble_settings: Optional[StackEnsembleSettings] = None, + ensemble_model_download_timeout: Optional[int] = None, + allowed_training_algorithms: Optional[List[str]] = None, + blocked_training_algorithms: Optional[List[str]] = None, + training_mode: Optional[Union[str, TabularTrainingMode]] = None, + ) -> None: + """The method to configure training related settings. + + :keyword enable_onnx_compatible_models: Whether to enable or disable enforcing the ONNX-compatible models. + The default is False. For more information about Open Neural Network Exchange (ONNX) and Azure Machine + Learning,see this `article <https://learn.microsoft.com/azure/machine-learning/concept-onnx>`__. + :paramtype enable_onnx_compatible_models: typing.Optional[bool] + :keyword enable_dnn_training: Whether to include DNN based models during model selection. + However, the default is True for DNN NLP tasks, and it's False for all other AutoML tasks. + :paramtype enable_dnn_training: typing.Optional[bool] + :keyword enable_model_explainability: Whether to enable explaining the best AutoML model at the end of all + AutoML training iterations. For more information, see + `Interpretability: model explanations in automated machine learning + <https://learn.microsoft.com/azure/machine-learning/how-to-machine-learning-interpretability-automl>`__. + , defaults to None + :paramtype enable_model_explainability: typing.Optional[bool] + :keyword enable_stack_ensemble: Whether to enable/disable StackEnsemble iteration. + If `enable_onnx_compatible_models` flag is being set, then StackEnsemble iteration will be disabled. + Similarly, for Timeseries tasks, StackEnsemble iteration will be disabled by default, to avoid risks of + overfitting due to small training set used in fitting the meta learner. + For more information about ensembles, see `Ensemble configuration + <https://learn.microsoft.com/azure/machine-learning/how-to-configure-auto-train#ensemble>`__ + , defaults to None + :paramtype enable_stack_ensemble: typing.Optional[bool] + :keyword enable_vote_ensemble: Whether to enable/disable VotingEnsemble iteration. + For more information about ensembles, see `Ensemble configuration + <https://learn.microsoft.com/azure/machine-learning/how-to-configure-auto-train#ensemble>`__ + , defaults to None + :paramtype enable_vote_ensemble: typing.Optional[bool] + :keyword stack_ensemble_settings: Settings for StackEnsemble iteration, defaults to None + :paramtype stack_ensemble_settings: typing.Optional[StackEnsembleSettings] + :keyword ensemble_model_download_timeout: During VotingEnsemble and StackEnsemble model generation, + multiple fitted models from the previous child runs are downloaded. Configure this parameter with a + higher value than 300 secs, if more time is needed, defaults to None + :paramtype ensemble_model_download_timeout: typing.Optional[int] + :keyword allowed_training_algorithms: A list of model names to search for an experiment. If not specified, + then all models supported for the task are used minus any specified in ``blocked_training_algorithms`` + or deprecated TensorFlow models, defaults to None + :paramtype allowed_training_algorithms: typing.Optional[List[str]] + :keyword blocked_training_algorithms: A list of algorithms to ignore for an experiment, defaults to None + :paramtype blocked_training_algorithms: typing.Optional[List[str]] + :keyword training_mode: [Experimental] The training mode to use. + The possible values are- + + * distributed- enables distributed training for supported algorithms. + + * non_distributed- disables distributed training. + + * auto- Currently, it is same as non_distributed. In future, this might change. + + Note: This parameter is in public preview and may change in future. + :paramtype training_mode: typing.Optional[typing.Union[str, azure.ai.ml.constants.TabularTrainingMode]] + """ + # get training object by calling training getter of respective tabular task + self._training = self.training + if self._training is not None: + self._training.enable_onnx_compatible_models = ( + enable_onnx_compatible_models + if enable_onnx_compatible_models is not None + else self._training.enable_onnx_compatible_models + ) + self._training.enable_dnn_training = ( + enable_dnn_training if enable_dnn_training is not None else self._training.enable_dnn_training + ) + self._training.enable_model_explainability = ( + enable_model_explainability + if enable_model_explainability is not None + else self._training.enable_model_explainability + ) + self._training.enable_stack_ensemble = ( + enable_stack_ensemble if enable_stack_ensemble is not None else self._training.enable_stack_ensemble + ) + self._training.enable_vote_ensemble = ( + enable_vote_ensemble if enable_vote_ensemble is not None else self._training.enable_vote_ensemble + ) + self._training.stack_ensemble_settings = ( + stack_ensemble_settings + if stack_ensemble_settings is not None + else self._training.stack_ensemble_settings + ) + self._training.ensemble_model_download_timeout = ( + ensemble_model_download_timeout + if ensemble_model_download_timeout is not None + else self._training.ensemble_model_download_timeout + ) + + self._training.allowed_training_algorithms = allowed_training_algorithms + self._training.blocked_training_algorithms = blocked_training_algorithms + self._training.training_mode = training_mode if training_mode is not None else self._training.training_mode + + def set_featurization( + self, + *, + blocked_transformers: Optional[List[Union[BlockedTransformers, str]]] = None, + column_name_and_types: Optional[Dict[str, str]] = None, + dataset_language: Optional[str] = None, + transformer_params: Optional[Dict[str, List[ColumnTransformer]]] = None, + mode: Optional[str] = None, + enable_dnn_featurization: Optional[bool] = None, + ) -> None: + """Define feature engineering configuration. + + :keyword blocked_transformers: A list of transformer names to be blocked during featurization, defaults to None + :paramtype blocked_transformers: Optional[List[Union[BlockedTransformers, str]]] + :keyword column_name_and_types: A dictionary of column names and feature types used to update column purpose + , defaults to None + :paramtype column_name_and_types: Optional[Dict[str, str]] + :keyword dataset_language: Three character ISO 639-3 code for the language(s) contained in the dataset. + Languages other than English are only supported if you use GPU-enabled compute. The language_code + 'mul' should be used if the dataset contains multiple languages. To find ISO 639-3 codes for different + languages, please refer to https://en.wikipedia.org/wiki/List_of_ISO_639-3_codes, defaults to None + :paramtype dataset_language: Optional[str] + :keyword transformer_params: A dictionary of transformer and corresponding customization parameters + , defaults to None + :paramtype transformer_params: Optional[Dict[str, List[ColumnTransformer]]] + :keyword mode: "off", "auto", defaults to "auto", defaults to None + :paramtype mode: Optional[str] + :keyword enable_dnn_featurization: Whether to include DNN based feature engineering methods, defaults to None + :paramtype enable_dnn_featurization: Optional[bool] + """ + self._featurization = self._featurization or TabularFeaturizationSettings() + self._featurization.blocked_transformers = ( + blocked_transformers if blocked_transformers is not None else self._featurization.blocked_transformers + ) + self._featurization.column_name_and_types = ( + column_name_and_types if column_name_and_types is not None else self._featurization.column_name_and_types + ) + self._featurization.dataset_language = ( + dataset_language if dataset_language is not None else self._featurization.dataset_language + ) + self._featurization.transformer_params = ( + transformer_params if transformer_params is not None else self._featurization.transformer_params + ) + self._featurization.mode = mode or self._featurization.mode + self._featurization.enable_dnn_featurization = ( + enable_dnn_featurization + if enable_dnn_featurization is not None + else self._featurization.enable_dnn_featurization + ) + + def set_data( + self, + *, + training_data: Input, + target_column_name: str, + weight_column_name: Optional[str] = None, + validation_data: Optional[Input] = None, + validation_data_size: Optional[float] = None, + n_cross_validations: Optional[Union[str, int]] = None, + cv_split_column_names: Optional[List[str]] = None, + test_data: Optional[Input] = None, + test_data_size: Optional[float] = None, + ) -> None: + """Define data configuration. + + :keyword training_data: Training data. + :paramtype training_data: Input + :keyword target_column_name: Column name of the target column. + :paramtype target_column_name: str + :keyword weight_column_name: Weight column name, defaults to None + :paramtype weight_column_name: typing.Optional[str] + :keyword validation_data: Validation data, defaults to None + :paramtype validation_data: typing.Optional[Input] + :keyword validation_data_size: Validation data size, defaults to None + :paramtype validation_data_size: typing.Optional[float] + :keyword n_cross_validations: n_cross_validations, defaults to None + :paramtype n_cross_validations: typing.Optional[typing.Union[str, int]] + :keyword cv_split_column_names: cv_split_column_names, defaults to None + :paramtype cv_split_column_names: typing.Optional[typing.List[str]] + :keyword test_data: Test data, defaults to None + :paramtype test_data: typing.Optional[Input] + :keyword test_data_size: Test data size, defaults to None + :paramtype test_data_size: typing.Optional[float] + """ + self.target_column_name = target_column_name if target_column_name is not None else self.target_column_name + self.weight_column_name = weight_column_name if weight_column_name is not None else self.weight_column_name + self.training_data = training_data if training_data is not None else self.training_data + self.validation_data = validation_data if validation_data is not None else self.validation_data + self.validation_data_size = ( + validation_data_size if validation_data_size is not None else self.validation_data_size + ) + self.cv_split_column_names = ( + cv_split_column_names if cv_split_column_names is not None else self.cv_split_column_names + ) + self.n_cross_validations = n_cross_validations if n_cross_validations is not None else self.n_cross_validations + self.test_data = test_data if test_data is not None else self.test_data + self.test_data_size = test_data_size if test_data_size is not None else self.test_data_size + + def _validation_data_to_rest(self, rest_obj: "AutoMLTabular") -> None: + """Validation data serialization. + + :param rest_obj: Serialized object + :type rest_obj: AutoMLTabular + """ + if rest_obj.n_cross_validations: + n_cross_val = rest_obj.n_cross_validations + # Convert n_cross_validations int value to CustomNCrossValidations + if isinstance(n_cross_val, int) and n_cross_val > 1: + rest_obj.n_cross_validations = CustomNCrossValidations(value=n_cross_val) + # Convert n_cross_validations str value to AutoNCrossValidations + elif isinstance(n_cross_val, str): + rest_obj.n_cross_validations = AutoNCrossValidations() + + def _validation_data_from_rest(self) -> None: + """Validation data deserialization.""" + if self.n_cross_validations: + n_cross_val = self.n_cross_validations + # Convert n_cross_validations CustomNCrossValidations back into int value + if isinstance(n_cross_val, CustomNCrossValidations): + self.n_cross_validations = n_cross_val.value + # Convert n_cross_validations AutoNCrossValidations to str value + elif isinstance(n_cross_val, AutoNCrossValidations): + self.n_cross_validations = AutoMLConstants.AUTO + + def __eq__(self, other: object) -> bool: + """Return True if both instances have the same values. + + This method check instances equality and returns True if both of + the instances have the same attributes with the same values. + + :param other: Any object + :type other: object + :return: True or False + :rtype: bool + """ + if not isinstance(other, AutoMLTabular): + return NotImplemented + + return ( + self.target_column_name == other.target_column_name + and self.weight_column_name == other.weight_column_name + and self.training_data == other.training_data + and self.validation_data == other.validation_data + and self.validation_data_size == other.validation_data_size + and self.cv_split_column_names == other.cv_split_column_names + and self.n_cross_validations == other.n_cross_validations + and self.test_data == other.test_data + and self.test_data_size == other.test_data_size + and self._featurization == other._featurization + and self._limits == other._limits + and self._training == other._training + ) + + def __ne__(self, other: object) -> bool: + """Check inequality between two AutoMLTabular objects. + + :param other: Any object + :type other: object + :return: True or False + :rtype: bool + """ + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/classification_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/classification_job.py new file mode 100644 index 00000000..6f5ab271 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/classification_job.py @@ -0,0 +1,352 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access +from typing import Any, Dict, Optional, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import AutoMLJob as RestAutoMLJob +from azure.ai.ml._restclient.v2023_04_01_preview.models import Classification as RestClassification +from azure.ai.ml._restclient.v2023_04_01_preview.models import ClassificationPrimaryMetrics, JobBase, TaskType +from azure.ai.ml._utils.utils import camel_to_snake, is_data_binding_expression +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY +from azure.ai.ml.constants._job.automl import AutoMLConstants +from azure.ai.ml.entities._credentials import _BaseJobIdentityConfiguration +from azure.ai.ml.entities._job._input_output_helpers import from_rest_data_outputs, to_rest_data_outputs +from azure.ai.ml.entities._job.automl.tabular.automl_tabular import AutoMLTabular +from azure.ai.ml.entities._job.automl.tabular.featurization_settings import TabularFeaturizationSettings +from azure.ai.ml.entities._job.automl.tabular.limit_settings import TabularLimitSettings +from azure.ai.ml.entities._job.automl.training_settings import ( # noqa: F401 # pylint: disable=unused-import + ClassificationTrainingSettings, + TrainingSettings, +) +from azure.ai.ml.entities._util import load_from_dict + + +class ClassificationJob(AutoMLTabular): + """Configuration for AutoML Classification Job. + + :keyword primary_metric: The primary metric to use for optimization, defaults to None + :paramtype primary_metric: typing.Optional[str] + :keyword positive_label: Positive label for binary metrics calculation, defaults to None + :paramtype positive_label: typing.Optional[str] + :keyword featurization: Featurization settings. Defaults to None. + :paramtype featurization: typing.Optional[TabularFeaturizationSettings] + :keyword limits: Limits settings. Defaults to None. + :paramtype limits: typing.Optional[TabularLimitSettings] + :keyword training: Training settings. Defaults to None. + :paramtype training: typing.Optional[TrainingSettings] + :return: An instance of ClassificationJob object. + :rtype: ~azure.ai.ml.entities.automl.ClassificationJob + :raises ValueError: If primary_metric is not a valid primary metric + :raises ValueError: If positive_label is not a valid positive label + :raises ValueError: If featurization is not a valid featurization settings + :raises ValueError: If limits is not a valid limits settings + :raises ValueError: If training is not a valid training settings + """ + + _DEFAULT_PRIMARY_METRIC = ClassificationPrimaryMetrics.ACCURACY + + def __init__( + self, + *, + primary_metric: Optional[str] = None, + positive_label: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Initialize a new AutoML Classification task. + + :keyword primary_metric: The primary metric to use for optimization, defaults to None + :paramtype primary_metric: typing.Optional[str] + :keyword positive_label: Positive label for binary metrics calculation, defaults to None + :paramtype positive_label: typing.Optional[str] + :keyword featurization: featurization settings. Defaults to None. + :paramtype featurization: typing.Optional[TabularFeaturizationSettings] + :keyword limits: limits settings. Defaults to None. + :paramtype limits: typing.Optional[TabularLimitSettings] + :keyword training: training settings. Defaults to None. + :paramtype training: typing.Optional[TrainingSettings] + :raises ValueError: If primary_metric is not a valid primary metric + :raises ValueError: If positive_label is not a valid positive label + :raises ValueError: If featurization is not a valid featurization settings + :raises ValueError: If limits is not a valid limits settings + :raises ValueError: If training is not a valid training settings + """ + # Extract any task specific settings + featurization = kwargs.pop("featurization", None) + limits = kwargs.pop("limits", None) + training = kwargs.pop("training", None) + + super().__init__( + task_type=TaskType.CLASSIFICATION, + featurization=featurization, + limits=limits, + training=training, + **kwargs, + ) + + self.primary_metric = primary_metric or ClassificationJob._DEFAULT_PRIMARY_METRIC + self.positive_label = positive_label + + @property + def primary_metric(self) -> Union[str, ClassificationPrimaryMetrics]: + """The primary metric to use for optimization. + + :return: The primary metric to use for optimization. + :rtype: typing.Union[str, ClassificationPrimaryMetrics] + """ + return self._primary_metric + + @primary_metric.setter + def primary_metric(self, value: Union[str, ClassificationPrimaryMetrics]) -> None: + """The primary metric to use for optimization setter. + + :param value: Primary metric to use for optimization. + :type value: typing.Union[str, ClassificationPrimaryMetrics] + """ + # TODO: better way to do this + if is_data_binding_expression(str(value), ["parent"]): + self._primary_metric = value + return + self._primary_metric = ( + ClassificationJob._DEFAULT_PRIMARY_METRIC + if value is None + else ClassificationPrimaryMetrics[camel_to_snake(value).upper()] + ) + + @property # type: ignore + def training(self) -> ClassificationTrainingSettings: + """Training Settings for AutoML Classification Job. + + :return: Training settings used for AutoML Classification Job. + :rtype: ClassificationTrainingSettings + """ + return self._training or ClassificationTrainingSettings() + + @training.setter + def training(self, value: Union[Dict, ClassificationTrainingSettings]) -> None: # pylint: disable=unused-argument + ... + + def _to_rest_object(self) -> JobBase: + """Convert ClassificationJob object to a REST object. + + :return: REST object representation of this object. + :rtype: JobBase + """ + classification_task = RestClassification( + target_column_name=self.target_column_name, + training_data=self.training_data, + validation_data=self.validation_data, + validation_data_size=self.validation_data_size, + weight_column_name=self.weight_column_name, + cv_split_column_names=self.cv_split_column_names, + n_cross_validations=self.n_cross_validations, + test_data=self.test_data, + test_data_size=self.test_data_size, + featurization_settings=self._featurization._to_rest_object() if self._featurization else None, + limit_settings=self._limits._to_rest_object() if self._limits else None, + training_settings=self._training._to_rest_object() if self._training else None, + primary_metric=self.primary_metric, + positive_label=self.positive_label, + log_verbosity=self.log_verbosity, + ) + self._resolve_data_inputs(classification_task) + self._validation_data_to_rest(classification_task) + + properties = RestAutoMLJob( + display_name=self.display_name, + description=self.description, + experiment_name=self.experiment_name, + tags=self.tags, + compute_id=self.compute, + properties=self.properties, + environment_id=self.environment_id, + environment_variables=self.environment_variables, + services=self.services, + outputs=to_rest_data_outputs(self.outputs), + resources=self.resources, + task_details=classification_task, + identity=self.identity._to_job_rest_object() if self.identity else None, + queue_settings=self.queue_settings, + ) + + result = JobBase(properties=properties) + result.name = self.name + return result + + @classmethod + def _from_rest_object(cls, obj: JobBase) -> "ClassificationJob": + """Convert a REST object to ClassificationJob object. + + :param obj: ClassificationJob in Rest format. + :type obj: JobBase + :return: ClassificationJob objects. + :rtype: ClassificationJob + """ + + properties: RestAutoMLJob = obj.properties + task_details: RestClassification = properties.task_details + + job_args_dict = { + "id": obj.id, + "name": obj.name, + "description": properties.description, + "tags": properties.tags, + "properties": properties.properties, + "experiment_name": properties.experiment_name, + "services": properties.services, + "status": properties.status, + "creation_context": obj.system_data, + "display_name": properties.display_name, + "compute": properties.compute_id, + "outputs": from_rest_data_outputs(properties.outputs), + "resources": properties.resources, + "identity": ( + _BaseJobIdentityConfiguration._from_rest_object(properties.identity) if properties.identity else None + ), + "queue_settings": properties.queue_settings, + } + + classification_job = cls( + target_column_name=task_details.target_column_name, + training_data=task_details.training_data, + validation_data=task_details.validation_data, + validation_data_size=task_details.validation_data_size, + weight_column_name=task_details.weight_column_name, + cv_split_column_names=task_details.cv_split_column_names, + n_cross_validations=task_details.n_cross_validations, + test_data=task_details.test_data, + test_data_size=task_details.test_data_size, + featurization=( + TabularFeaturizationSettings._from_rest_object(task_details.featurization_settings) + if task_details.featurization_settings + else None + ), + limits=( + TabularLimitSettings._from_rest_object(task_details.limit_settings) + if task_details.limit_settings + else None + ), + training=( + ClassificationTrainingSettings._from_rest_object(task_details.training_settings) + if task_details.training_settings + else None + ), + primary_metric=task_details.primary_metric, + positive_label=task_details.positive_label, + log_verbosity=task_details.log_verbosity, + **job_args_dict, + ) + + classification_job._restore_data_inputs() + classification_job._validation_data_from_rest() + + return classification_job + + @classmethod + def _load_from_dict( + cls, + data: Dict, + context: Dict, + additional_message: str, + **kwargs: Any, + ) -> "ClassificationJob": + """Load from a dictionary. + + :param data: dictionary representation of the object. + :type data: typing.Dict + :param context: dictionary containing the context. + :type context: typing.Dict + :param additional_message: additional message to be added to the error message. + :type additional_message: str + :return: ClassificationJob object. + :rtype: ClassificationJob + """ + from azure.ai.ml._schema.automl.table_vertical.classification import AutoMLClassificationSchema + from azure.ai.ml._schema.pipeline.automl_node import AutoMLClassificationNodeSchema + + if kwargs.pop("inside_pipeline", False): + loaded_data = load_from_dict( + AutoMLClassificationNodeSchema, + data, + context, + additional_message, + **kwargs, + ) + else: + loaded_data = load_from_dict(AutoMLClassificationSchema, data, context, additional_message, **kwargs) + job_instance = cls._create_instance_from_schema_dict(loaded_data) + return job_instance + + @classmethod + def _create_instance_from_schema_dict(cls, loaded_data: Dict) -> "ClassificationJob": + """Create an instance from a schema dictionary. + + :param loaded_data: dictionary containing the data. + :type loaded_data: typing.Dict + :return: ClassificationJob object. + :rtype: ClassificationJob + """ + loaded_data.pop(AutoMLConstants.TASK_TYPE_YAML, None) + data_settings = { + "training_data": loaded_data.pop("training_data"), + "target_column_name": loaded_data.pop("target_column_name"), + "weight_column_name": loaded_data.pop("weight_column_name", None), + "validation_data": loaded_data.pop("validation_data", None), + "validation_data_size": loaded_data.pop("validation_data_size", None), + "cv_split_column_names": loaded_data.pop("cv_split_column_names", None), + "n_cross_validations": loaded_data.pop("n_cross_validations", None), + "test_data": loaded_data.pop("test_data", None), + "test_data_size": loaded_data.pop("test_data_size", None), + } + job = ClassificationJob(**loaded_data) + job.set_data(**data_settings) + return job + + def _to_dict(self, inside_pipeline: bool = False) -> Dict: + """Convert the object to a dictionary. + + :param inside_pipeline: whether the job is inside a pipeline or not, defaults to False + :type inside_pipeline: bool + :return: dictionary representation of the object. + :rtype: typing.Dict + """ + from azure.ai.ml._schema.automl.table_vertical.classification import AutoMLClassificationSchema + from azure.ai.ml._schema.pipeline.automl_node import AutoMLClassificationNodeSchema + + schema_dict: dict = {} + if inside_pipeline: + schema_dict = AutoMLClassificationNodeSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + else: + schema_dict = AutoMLClassificationSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + + return schema_dict + + def __eq__(self, other: object) -> bool: + """Returns True if both instances have the same values. + + This method check instances equality and returns True if both of + the instances have the same attributes with the same values. + + :param other: Any object + :type other: object + :return: True or False + :rtype: bool + """ + if not isinstance(other, ClassificationJob): + return NotImplemented + + if not super().__eq__(other): + return False + + return self.primary_metric == other.primary_metric + + def __ne__(self, other: object) -> bool: + """Check inequality between two ImageLimitSettings objects. + + :param other: Any object + :type other: object + :return: True or False + :rtype: bool + """ + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/featurization_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/featurization_settings.py new file mode 100644 index 00000000..6ef2332e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/featurization_settings.py @@ -0,0 +1,170 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import logging +from typing import Dict, List, Optional, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import BlockedTransformers +from azure.ai.ml._restclient.v2023_04_01_preview.models import ColumnTransformer as RestColumnTransformer +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + TableVerticalFeaturizationSettings as RestTabularFeaturizationSettings, +) +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.constants._job.automl import AutoMLTransformerParameterKeys +from azure.ai.ml.entities._job.automl.featurization_settings import FeaturizationSettings, FeaturizationSettingsType +from azure.ai.ml.entities._mixins import RestTranslatableMixin + +module_logger = logging.getLogger(__name__) + + +class ColumnTransformer(RestTranslatableMixin): + """Column transformer settings. + + :param fields: The fields on which to perform custom featurization + :type field: List[str] + :param parameters: parameters used for custom featurization + :type parameters: Dict[str, Optional[str, float]] + """ + + def __init__( + self, + *, + fields: Optional[List[str]] = None, + parameters: Optional[Dict[str, Union[str, float]]] = None, + ): + self.fields = fields + self.parameters = parameters + + def _to_rest_object(self) -> RestColumnTransformer: + return RestColumnTransformer(fields=self.fields, parameters=self.parameters) + + @classmethod + def _from_rest_object(cls, obj: RestColumnTransformer) -> Optional["ColumnTransformer"]: + if obj: + fields = obj.fields + parameters = obj.parameters + return ColumnTransformer(fields=fields, parameters=parameters) + return None + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ColumnTransformer): + return NotImplemented + return self.fields == other.fields and self.parameters == other.parameters + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + +class TabularFeaturizationSettings(FeaturizationSettings): + """Featurization settings for an AutoML Job.""" + + def __init__( + self, + *, + blocked_transformers: Optional[List[Union[BlockedTransformers, str]]] = None, + column_name_and_types: Optional[Dict[str, str]] = None, + dataset_language: Optional[str] = None, + transformer_params: Optional[Dict[str, List[ColumnTransformer]]] = None, + mode: Optional[str] = None, + enable_dnn_featurization: Optional[bool] = None, + ): + """ + :param blocked_transformers: A list of transformers to ignore when featurizing. + :type blocked_transformers: List[Union[BlockedTransformers, str]] + :param column_name_and_types: A dictionary of column names and feature types used to update column purpose. + :type column_name_and_types: Dict[str, str] + :param dataset_language: The language of the dataset. + :type dataset_language: str + :param transformer_params: A dictionary of transformers and their parameters. + :type transformer_params: Dict[str, List[ColumnTransformer]] + :param mode: The mode of the featurization. + :type mode: str + :param enable_dnn_featurization: Whether to enable DNN featurization. + :type enable_dnn_featurization: bool + :ivar type: Specifies the type of FeaturizationSettings. Set automatically to "Tabular" for this class. + :vartype type: str + """ + super().__init__(dataset_language=dataset_language) + self.blocked_transformers = blocked_transformers + self.column_name_and_types = column_name_and_types + self.transformer_params = transformer_params + self.mode = mode + self.enable_dnn_featurization = enable_dnn_featurization + self.type = FeaturizationSettingsType.TABULAR + + @property + def transformer_params(self) -> Optional[Dict[str, List[ColumnTransformer]]]: + """A dictionary of transformers and their parameters.""" + return self._transformer_params + + @transformer_params.setter + def transformer_params(self, value: Dict[str, List[ColumnTransformer]]) -> None: + self._transformer_params = ( + None + if not value + else {(AutoMLTransformerParameterKeys[camel_to_snake(k).upper()].value): v for k, v in value.items()} + ) + + @property + def blocked_transformers(self) -> Optional[List[Union[BlockedTransformers, str]]]: + """A list of transformers to ignore when featurizing.""" + return self._blocked_transformers + + @blocked_transformers.setter + def blocked_transformers(self, blocked_transformers_list: List[Union[BlockedTransformers, str]]) -> None: + self._blocked_transformers = ( + None + if blocked_transformers_list is None + else [BlockedTransformers[camel_to_snake(o)] for o in blocked_transformers_list] + ) + + def _to_rest_object(self) -> RestTabularFeaturizationSettings: + transformer_dict = {} + if self.transformer_params: + for key, settings in self.transformer_params.items(): + transformer_dict[key] = [o._to_rest_object() for o in settings] + return RestTabularFeaturizationSettings( + blocked_transformers=self.blocked_transformers, + column_name_and_types=self.column_name_and_types, + dataset_language=self.dataset_language, + mode=self.mode, + transformer_params=transformer_dict, + enable_dnn_featurization=self.enable_dnn_featurization, + ) + + @classmethod + def _from_rest_object(cls, obj: RestTabularFeaturizationSettings) -> "TabularFeaturizationSettings": + rest_transformers_params = obj.transformer_params + transformer_dict: Optional[Dict] = None + if rest_transformers_params: + transformer_dict = {} + for key, settings in rest_transformers_params.items(): + transformer_dict[key] = [ColumnTransformer._from_rest_object(o) for o in settings] + transformer_params = transformer_dict + + return TabularFeaturizationSettings( + blocked_transformers=obj.blocked_transformers, + column_name_and_types=obj.column_name_and_types, + dataset_language=obj.dataset_language, + transformer_params=transformer_params, + mode=obj.mode, + enable_dnn_featurization=obj.enable_dnn_featurization, + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TabularFeaturizationSettings): + return NotImplemented + return ( + super().__eq__(other) + and self.blocked_transformers == other.blocked_transformers + and self.column_name_and_types == other.column_name_and_types + and self.transformer_params == other.transformer_params + and self.mode == other.mode + and self.enable_dnn_featurization == other.enable_dnn_featurization + ) + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/forecasting_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/forecasting_job.py new file mode 100644 index 00000000..9bd10b19 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/forecasting_job.py @@ -0,0 +1,686 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from typing import Any, Dict, List, Optional, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import AutoMLJob as RestAutoMLJob +from azure.ai.ml._restclient.v2023_04_01_preview.models import Forecasting as RestForecasting +from azure.ai.ml._restclient.v2023_04_01_preview.models import ForecastingPrimaryMetrics, JobBase, TaskType +from azure.ai.ml._utils.utils import camel_to_snake, is_data_binding_expression +from azure.ai.ml.constants import TabularTrainingMode +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY +from azure.ai.ml.constants._job.automl import AutoMLConstants +from azure.ai.ml.entities._credentials import _BaseJobIdentityConfiguration +from azure.ai.ml.entities._job._input_output_helpers import from_rest_data_outputs, to_rest_data_outputs +from azure.ai.ml.entities._job.automl.stack_ensemble_settings import StackEnsembleSettings +from azure.ai.ml.entities._job.automl.tabular.automl_tabular import AutoMLTabular +from azure.ai.ml.entities._job.automl.tabular.featurization_settings import TabularFeaturizationSettings +from azure.ai.ml.entities._job.automl.tabular.forecasting_settings import ForecastingSettings +from azure.ai.ml.entities._job.automl.tabular.limit_settings import TabularLimitSettings +from azure.ai.ml.entities._job.automl.training_settings import ForecastingTrainingSettings +from azure.ai.ml.entities._util import load_from_dict + + +class ForecastingJob(AutoMLTabular): + """ + Configuration for AutoML Forecasting Task. + + :param primary_metric: The primary metric to use for model selection. + :type primary_metric: Optional[str] + :param forecasting_settings: The settings for the forecasting task. + :type forecasting_settings: + Optional[~azure.ai.ml.automl.ForecastingSettings] + :param kwargs: Job-specific arguments + :type kwargs: Dict[str, Any] + """ + + _DEFAULT_PRIMARY_METRIC = ForecastingPrimaryMetrics.NORMALIZED_ROOT_MEAN_SQUARED_ERROR + + def __init__( + self, + *, + primary_metric: Optional[str] = None, + forecasting_settings: Optional[ForecastingSettings] = None, + **kwargs: Any, + ) -> None: + """Initialize a new AutoML Forecasting task.""" + # Extract any task specific settings + featurization = kwargs.pop("featurization", None) + limits = kwargs.pop("limits", None) + training = kwargs.pop("training", None) + + super().__init__( + task_type=TaskType.FORECASTING, + featurization=featurization, + limits=limits, + training=training, + **kwargs, + ) + + self.primary_metric = primary_metric or ForecastingJob._DEFAULT_PRIMARY_METRIC + self._forecasting_settings = forecasting_settings + + @property + def primary_metric(self) -> Optional[str]: + """ + Return the primary metric to use for model selection. + + :return: The primary metric for model selection. + :rtype: Optional[str] + """ + return self._primary_metric + + @primary_metric.setter + def primary_metric(self, value: Union[str, ForecastingPrimaryMetrics]) -> None: + """ + Set the primary metric to use for model selection. + + :param value: The primary metric for model selection. + :type: Union[str, ~azure.ai.ml.automl.ForecastingPrimaryMetrics] + """ + if is_data_binding_expression(str(value), ["parent"]): + self._primary_metric = value + return + self._primary_metric = ( + ForecastingJob._DEFAULT_PRIMARY_METRIC + if value is None + else ForecastingPrimaryMetrics[camel_to_snake(value).upper()] + ) + + @property # type: ignore + def training(self) -> ForecastingTrainingSettings: + """ + Return the forecast training settings. + + :return: training settings. + :rtype: ~azure.ai.ml.automl.ForecastingTrainingSettings + """ + return self._training or ForecastingTrainingSettings() + + @training.setter + def training(self, value: Union[Dict, ForecastingTrainingSettings]) -> None: # pylint: disable=unused-argument + ... + + @property + def forecasting_settings(self) -> Optional[ForecastingSettings]: + """ + Return the forecast settings. + + :return: forecast settings. + :rtype: ~azure.ai.ml.automl.ForecastingSettings + """ + return self._forecasting_settings + + def set_forecast_settings( + self, + *, + time_column_name: Optional[str] = None, + forecast_horizon: Optional[Union[str, int]] = None, + time_series_id_column_names: Optional[Union[str, List[str]]] = None, + target_lags: Optional[Union[str, int, List[int]]] = None, + feature_lags: Optional[str] = None, + target_rolling_window_size: Optional[Union[str, int]] = None, + country_or_region_for_holidays: Optional[str] = None, + use_stl: Optional[str] = None, + seasonality: Optional[Union[str, int]] = None, + short_series_handling_config: Optional[str] = None, + frequency: Optional[str] = None, + target_aggregate_function: Optional[str] = None, + cv_step_size: Optional[int] = None, + features_unknown_at_forecast_time: Optional[Union[str, List[str]]] = None, + ) -> None: + """Manage parameters used by forecasting tasks. + + :keyword time_column_name: + The name of the time column. This parameter is required when forecasting to specify the datetime + column in the input data used for building the time series and inferring its frequency. + :paramtype time_column_name: Optional[str] + :keyword forecast_horizon: + The desired maximum forecast horizon in units of time-series frequency. The default value is 1. + + Units are based on the time interval of your training data, e.g., monthly, weekly that the forecaster + should predict out. When task type is forecasting, this parameter is required. For more information on + setting forecasting parameters, see `Auto-train a time-series forecast model <https://learn.microsoft.com/ + azure/machine-learning/how-to-auto-train-forecast>`_. + :type forecast_horizon: Optional[Union[int, str]] + :keyword time_series_id_column_names: + The names of columns used to group a time series. + It can be used to create multiple series. If time series id column names is not defined or + the identifier columns specified do not identify all the series in the dataset, the time series identifiers + will be automatically created for your data set. + :paramtype time_series_id_column_names: Optional[Union[str, List[str]]] + :keyword target_lags: The number of past periods to lag from the target column. By default the lags are turned + off. + + When forecasting, this parameter represents the number of rows to lag the target values based + on the frequency of the data. This is represented as a list or single integer. Lag should be used + when the relationship between the independent variables and dependent variable do not match up or + correlate by default. For example, when trying to forecast demand for a product, the demand in any + month may depend on the price of specific commodities 3 months prior. In this example, you may want + to lag the target (demand) negatively by 3 months so that the model is training on the correct + relationship. For more information, see `Auto-train a time-series forecast model + <https://learn.microsoft.com/azure/machine-learning/how-to-auto-train-forecast>`_. + + **Note on auto detection of target lags and rolling window size. + Please see the corresponding comments in the rolling window section.** + We use the next algorithm to detect the optimal target lag and rolling window size. + + #. Estimate the maximum lag order for the look back feature selection. In our case it is the number of + periods till the next date frequency granularity i.e. if frequency is daily, it will be a week (7), + if it is a week, it will be month (4). That values multiplied by two is the largest + possible values of lags/rolling windows. In our examples, we will consider the maximum lag + order of 14 and 8 respectively). + #. Create a de-seasonalized series by adding trend and residual components. This will be used + in the next step. + #. Estimate the PACF - Partial Auto Correlation Function on the on the data from (2) + and search for points, where the auto correlation is significant i.e. its absolute + value is more then 1.96/square_root(maximal lag value), which correspond to significance of 95%. + #. If all points are significant, we consider it being strong seasonality + and do not create look back features. + #. We scan the PACF values from the beginning and the value before the first insignificant + auto correlation will designate the lag. If first significant element (value correlate with + itself) is followed by insignificant, the lag will be 0 and we will not use look back features. + + :type target_lags: Optional[Union[str, int, List[int]]] + :keyword feature_lags: Flag for generating lags for the numeric features with 'auto' or None. + :paramtype feature_lags: Optional[str] + :keyword target_rolling_window_size: The number of past periods used to create a rolling window average of the + target column. + + When forecasting, this parameter represents `n` historical periods to use to generate forecasted values, + <= training set size. If omitted, `n` is the full training set size. Specify this parameter + when you only want to consider a certain amount of history when training the model. + If set to 'auto', rolling window will be estimated as the last + value where the PACF is more then the significance threshold. Please see target_lags section for details. + :paramtype target_rolling_window_size: Optional[Union[str, int]] + :keyword country_or_region_for_holidays: The country/region used to generate holiday features. + These should be ISO 3166 two-letter country/region codes, for example 'US' or 'GB'. + :paramtype country_or_region_for_holidays: Optional[str] + :keyword use_stl: Configure STL Decomposition of the time-series target column. + use_stl can take three values: None (default) - no stl decomposition, 'season' - only generate + season component and season_trend - generate both season and trend components. + :type use_stl: Optional[str] + :keyword seasonality: Set time series seasonality as an integer multiple of the series frequency. + If seasonality is set to 'auto', it will be inferred. + If set to None, the time series is assumed non-seasonal which is equivalent to seasonality=1. + :paramtype seasonality: Optional[Union[int, str] + :keyword short_series_handling_config: + The parameter defining how if AutoML should handle short time series. + + Possible values: 'auto' (default), 'pad', 'drop' and None. + + * **auto** short series will be padded if there are no long series, + otherwise short series will be dropped. + * **pad** all the short series will be padded. + * **drop** all the short series will be dropped". + * **None** the short series will not be modified. + + If set to 'pad', the table will be padded with the zeroes and + empty values for the regressors and random values for target with the mean + equal to target value median for given time series id. If median is more or equal + to zero, the minimal padded value will be clipped by zero: + Input: + + +------------+---------------+----------+--------+ + | Date | numeric_value | string | target | + +============+===============+==========+========+ + | 2020-01-01 | 23 | green | 55 | + +------------+---------------+----------+--------+ + + Output assuming minimal number of values is four: + + +------------+---------------+----------+--------+ + | Date | numeric_value | string | target | + +============+===============+==========+========+ + | 2019-12-29 | 0 | NA | 55.1 | + +------------+---------------+----------+--------+ + | 2019-12-30 | 0 | NA | 55.6 | + +------------+---------------+----------+--------+ + | 2019-12-31 | 0 | NA | 54.5 | + +------------+---------------+----------+--------+ + | 2020-01-01 | 23 | green | 55 | + +------------+---------------+----------+--------+ + + **Note:** We have two parameters short_series_handling_configuration and + legacy short_series_handling. When both parameters are set we are + synchronize them as shown in the table below (short_series_handling_configuration and + short_series_handling for brevity are marked as handling_configuration and handling + respectively). + + +------------+--------------------------+----------------------+-----------------------------+ + | | handling | | handling | | resulting | | resulting | + | | | configuration | | handling | | handling | + | | | | | configuration | + +============+==========================+======================+=============================+ + | True | auto | True | auto | + +------------+--------------------------+----------------------+-----------------------------+ + | True | pad | True | auto | + +------------+--------------------------+----------------------+-----------------------------+ + | True | drop | True | auto | + +------------+--------------------------+----------------------+-----------------------------+ + | True | None | False | None | + +------------+--------------------------+----------------------+-----------------------------+ + | False | auto | False | None | + +------------+--------------------------+----------------------+-----------------------------+ + | False | pad | False | None | + +------------+--------------------------+----------------------+-----------------------------+ + | False | drop | False | None | + +------------+--------------------------+----------------------+-----------------------------+ + | False | None | False | None | + +------------+--------------------------+----------------------+-----------------------------+ + + :type short_series_handling_config: Optional[str] + :keyword frequency: Forecast frequency. + + When forecasting, this parameter represents the period with which the forecast is desired, + for example daily, weekly, yearly, etc. The forecast frequency is dataset frequency by default. + You can optionally set it to greater (but not lesser) than dataset frequency. + We'll aggregate the data and generate the results at forecast frequency. For example, + for daily data, you can set the frequency to be daily, weekly or monthly, but not hourly. + The frequency needs to be a pandas offset alias. + Please refer to pandas documentation for more information: + https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#dateoffset-objects + :type frequency: Optional[str] + :keyword target_aggregate_function: The function to be used to aggregate the time series target + column to conform to a user specified frequency. If the target_aggregation_function is set, + but the freq parameter is not set, the error is raised. The possible target aggregation + functions are: "sum", "max", "min" and "mean". + + * The target column values are aggregated based on the specified operation. + Typically, sum is appropriate for most scenarios. + * Numerical predictor columns in your data are aggregated by sum, mean, minimum value, + and maximum value. As a result, automated ML generates new columns suffixed with the + aggregation function name and applies the selected aggregate operation. + * For categorical predictor columns, the data is aggregated by mode, + the most prominent category in the window. + * Date predictor columns are aggregated by minimum value, maximum value and mode. + + +----------------+-------------------------------+--------------------------------------+ + | | freq | | target_aggregation_function | | Data regularity | + | | | | fixing mechanism | + +================+===============================+======================================+ + | None (Default) | None (Default) | | The aggregation | + | | | | is not applied. | + | | | | If the valid | + | | | | frequency can | + | | | | not be | + | | | | determined | + | | | | the error | + | | | | will be raised. | + +----------------+-------------------------------+--------------------------------------+ + | Some Value | None (Default) | | The aggregation | + | | | | is not applied. | + | | | | If the number | + | | | | of data points | + | | | | compliant to | + | | | | given frequency | + | | | | grid is | + | | | | less then 90% | + | | | | these points | + | | | | will be | + | | | | removed, | + | | | | otherwise | + | | | | the error will | + | | | | be raised. | + +----------------+-------------------------------+--------------------------------------+ + | None (Default) | Aggregation function | | The error about | + | | | | missing | + | | | | frequency | + | | | | parameter is | + | | | | raised. | + +----------------+-------------------------------+--------------------------------------+ + | Some Value | Aggregation function | | Aggregate to | + | | | | frequency using | + | | | | provided | + | | | | aggregation | + | | | | function. | + +----------------+-------------------------------+--------------------------------------+ + + :type target_aggregate_function: Optional[str] + :keyword cv_step_size: Number of periods between the origin_time of one CV fold and the next fold. + For example, if `n_step` = 3 for daily data, the origin time for each fold will be three days apart. + :paramtype cv_step_size: Optional[int] + :keyword features_unknown_at_forecast_time: The feature columns that are available for training but + unknown at the time of forecast/inference. If features_unknown_at_forecast_time is set to an empty + list, it is assumed that all the feature columns in the dataset are known at inference time. If this + parameter is not set the support for future features is not enabled. + :paramtype features_unknown_at_forecast_time: Optional[Union[str, List[str]]] + """ + self._forecasting_settings = self._forecasting_settings or ForecastingSettings() + + self._forecasting_settings.country_or_region_for_holidays = ( + country_or_region_for_holidays + if country_or_region_for_holidays is not None + else self._forecasting_settings.country_or_region_for_holidays + ) + self._forecasting_settings.cv_step_size = ( + cv_step_size if cv_step_size is not None else self._forecasting_settings.cv_step_size + ) + self._forecasting_settings.forecast_horizon = ( + forecast_horizon if forecast_horizon is not None else self._forecasting_settings.forecast_horizon + ) + self._forecasting_settings.target_lags = ( + target_lags if target_lags is not None else self._forecasting_settings.target_lags + ) + self._forecasting_settings.target_rolling_window_size = ( + target_rolling_window_size + if target_rolling_window_size is not None + else self._forecasting_settings.target_rolling_window_size + ) + self._forecasting_settings.frequency = ( + frequency if frequency is not None else self._forecasting_settings.frequency + ) + self._forecasting_settings.feature_lags = ( + feature_lags if feature_lags is not None else self._forecasting_settings.feature_lags + ) + self._forecasting_settings.seasonality = ( + seasonality if seasonality is not None else self._forecasting_settings.seasonality + ) + self._forecasting_settings.use_stl = use_stl if use_stl is not None else self._forecasting_settings.use_stl + self._forecasting_settings.short_series_handling_config = ( + short_series_handling_config + if short_series_handling_config is not None + else self._forecasting_settings.short_series_handling_config + ) + self._forecasting_settings.target_aggregate_function = ( + target_aggregate_function + if target_aggregate_function is not None + else self._forecasting_settings.target_aggregate_function + ) + self._forecasting_settings.time_column_name = ( + time_column_name if time_column_name is not None else self._forecasting_settings.time_column_name + ) + self._forecasting_settings.time_series_id_column_names = ( + time_series_id_column_names + if time_series_id_column_names is not None + else self._forecasting_settings.time_series_id_column_names + ) + self._forecasting_settings.features_unknown_at_forecast_time = ( + features_unknown_at_forecast_time + if features_unknown_at_forecast_time is not None + else self._forecasting_settings.features_unknown_at_forecast_time + ) + + # override + def set_training( + self, + *, + enable_onnx_compatible_models: Optional[bool] = None, + enable_dnn_training: Optional[bool] = None, + enable_model_explainability: Optional[bool] = None, + enable_stack_ensemble: Optional[bool] = None, + enable_vote_ensemble: Optional[bool] = None, + stack_ensemble_settings: Optional[StackEnsembleSettings] = None, + ensemble_model_download_timeout: Optional[int] = None, + allowed_training_algorithms: Optional[List[str]] = None, + blocked_training_algorithms: Optional[List[str]] = None, + training_mode: Optional[Union[str, TabularTrainingMode]] = None, + ) -> None: + """ + The method to configure forecast training related settings. + + :keyword enable_onnx_compatible_models: + Whether to enable or disable enforcing the ONNX-compatible models. + The default is False. For more information about Open Neural Network Exchange (ONNX) and Azure Machine + Learning, see this `article <https://learn.microsoft.com/azure/machine-learning/concept-onnx>`__. + :type enable_onnx_compatible: Optional[bool] + :keyword enable_dnn_training: + Whether to include DNN based models during model selection. + However, the default is True for DNN NLP tasks, and it's False for all other AutoML tasks. + :paramtype enable_dnn_training: Optional[bool] + :keyword enable_model_explainability: + Whether to enable explaining the best AutoML model at the end of all AutoML training iterations. + For more information, see `Interpretability: model explanations in automated machine learning + <https://learn.microsoft.com/azure/machine-learning/how-to-machine-learning-interpretability-automl>`__. + , defaults to None + :type enable_model_explainability: Optional[bool] + :keyword enable_stack_ensemble: + Whether to enable/disable StackEnsemble iteration. + If `enable_onnx_compatible_models` flag is being set, then StackEnsemble iteration will be disabled. + Similarly, for Timeseries tasks, StackEnsemble iteration will be disabled by default, to avoid risks of + overfitting due to small training set used in fitting the meta learner. + For more information about ensembles, see `Ensemble configuration + <https://learn.microsoft.com/azure/machine-learning/how-to-configure-auto-train#ensemble>`__ + , defaults to None + :type enable_stack_ensemble: Optional[bool] + :keyword enable_vote_ensemble: + Whether to enable/disable VotingEnsemble iteration. + For more information about ensembles, see `Ensemble configuration + <https://learn.microsoft.com/azure/machine-learning/how-to-configure-auto-train#ensemble>`__ + , defaults to None + :type enable_vote_ensemble: Optional[bool] + :keyword stack_ensemble_settings: + Settings for StackEnsemble iteration, defaults to None + :paramtype stack_ensemble_settings: Optional[StackEnsembleSettings] + :keyword ensemble_model_download_timeout: + During VotingEnsemble and StackEnsemble model generation, + multiple fitted models from the previous child runs are downloaded. Configure this parameter with a + higher value than 300 secs, if more time is needed, defaults to None + :paramtype ensemble_model_download_timeout: Optional[int] + :keyword allowed_training_algorithms: + A list of model names to search for an experiment. If not specified, + then all models supported for the task are used minus any specified in ``blocked_training_algorithms`` + or deprecated TensorFlow models, defaults to None + :paramtype allowed_training_algorithms: Optional[List[str]] + :keyword blocked_training_algorithms: + A list of algorithms to ignore for an experiment, defaults to None + :paramtype blocked_training_algorithms: Optional[List[str]] + :keyword training_mode: + [Experimental] The training mode to use. + The possible values are- + + * distributed- enables distributed training for supported algorithms. + + * non_distributed- disables distributed training. + + * auto- Currently, it is same as non_distributed. In future, this might change. + + Note: This parameter is in public preview and may change in future. + :type training_mode: Optional[Union[~azure.ai.ml.constants.TabularTrainingMode, str]] + """ + super().set_training( + enable_onnx_compatible_models=enable_onnx_compatible_models, + enable_dnn_training=enable_dnn_training, + enable_model_explainability=enable_model_explainability, + enable_stack_ensemble=enable_stack_ensemble, + enable_vote_ensemble=enable_vote_ensemble, + stack_ensemble_settings=stack_ensemble_settings, + ensemble_model_download_timeout=ensemble_model_download_timeout, + allowed_training_algorithms=allowed_training_algorithms, + blocked_training_algorithms=blocked_training_algorithms, + training_mode=training_mode, + ) + + # Disable stack ensemble by default, since it is currently not supported for forecasting tasks + if enable_stack_ensemble is None: + if self._training is not None: + self._training.enable_stack_ensemble = False + + def _to_rest_object(self) -> JobBase: + if self._forecasting_settings is not None: + forecasting_task = RestForecasting( + target_column_name=self.target_column_name, + training_data=self.training_data, + validation_data=self.validation_data, + validation_data_size=self.validation_data_size, + weight_column_name=self.weight_column_name, + cv_split_column_names=self.cv_split_column_names, + n_cross_validations=self.n_cross_validations, + test_data=self.test_data, + test_data_size=self.test_data_size, + featurization_settings=self._featurization._to_rest_object() if self._featurization else None, + limit_settings=self._limits._to_rest_object() if self._limits else None, + training_settings=self._training._to_rest_object() if self._training else None, + primary_metric=self.primary_metric, + log_verbosity=self.log_verbosity, + forecasting_settings=self._forecasting_settings._to_rest_object(), + ) + else: + forecasting_task = RestForecasting( + target_column_name=self.target_column_name, + training_data=self.training_data, + validation_data=self.validation_data, + validation_data_size=self.validation_data_size, + weight_column_name=self.weight_column_name, + cv_split_column_names=self.cv_split_column_names, + n_cross_validations=self.n_cross_validations, + test_data=self.test_data, + test_data_size=self.test_data_size, + featurization_settings=self._featurization._to_rest_object() if self._featurization else None, + limit_settings=self._limits._to_rest_object() if self._limits else None, + training_settings=self._training._to_rest_object() if self._training else None, + primary_metric=self.primary_metric, + log_verbosity=self.log_verbosity, + forecasting_settings=None, + ) + + self._resolve_data_inputs(forecasting_task) + self._validation_data_to_rest(forecasting_task) + + properties = RestAutoMLJob( + display_name=self.display_name, + description=self.description, + experiment_name=self.experiment_name, + tags=self.tags, + compute_id=self.compute, + properties=self.properties, + environment_id=self.environment_id, + environment_variables=self.environment_variables, + services=self.services, + outputs=to_rest_data_outputs(self.outputs), + resources=self.resources, + task_details=forecasting_task, + identity=self.identity._to_job_rest_object() if self.identity else None, + queue_settings=self.queue_settings, + ) + + result = JobBase(properties=properties) + result.name = self.name + return result + + @classmethod + def _from_rest_object(cls, obj: JobBase) -> "ForecastingJob": + properties: RestAutoMLJob = obj.properties + task_details: RestForecasting = properties.task_details + + job_args_dict = { + "id": obj.id, + "name": obj.name, + "description": properties.description, + "tags": properties.tags, + "properties": properties.properties, + "experiment_name": properties.experiment_name, + "services": properties.services, + "status": properties.status, + "creation_context": obj.system_data, + "display_name": properties.display_name, + "compute": properties.compute_id, + "outputs": from_rest_data_outputs(properties.outputs), + "resources": properties.resources, + "identity": ( + _BaseJobIdentityConfiguration._from_rest_object(properties.identity) if properties.identity else None + ), + "queue_settings": properties.queue_settings, + } + + forecasting_job = cls( + target_column_name=task_details.target_column_name, + training_data=task_details.training_data, + validation_data=task_details.validation_data, + validation_data_size=task_details.validation_data_size, + weight_column_name=task_details.weight_column_name, + cv_split_column_names=task_details.cv_split_column_names, + n_cross_validations=task_details.n_cross_validations, + test_data=task_details.test_data, + test_data_size=task_details.test_data_size, + featurization=( + TabularFeaturizationSettings._from_rest_object(task_details.featurization_settings) + if task_details.featurization_settings + else None + ), + limits=( + TabularLimitSettings._from_rest_object(task_details.limit_settings) + if task_details.limit_settings + else None + ), + training=( + ForecastingTrainingSettings._from_rest_object(task_details.training_settings) + if task_details.training_settings + else None + ), + primary_metric=task_details.primary_metric, + forecasting_settings=( + ForecastingSettings._from_rest_object(task_details.forecasting_settings) + if task_details.forecasting_settings + else None + ), + log_verbosity=task_details.log_verbosity, + **job_args_dict, + ) + + forecasting_job._restore_data_inputs() + forecasting_job._validation_data_from_rest() + + return forecasting_job + + @classmethod + def _load_from_dict( + cls, + data: Dict, + context: Dict, + additional_message: str, + **kwargs: Any, + ) -> "ForecastingJob": + from azure.ai.ml._schema.automl.table_vertical.forecasting import AutoMLForecastingSchema + from azure.ai.ml._schema.pipeline.automl_node import AutoMLForecastingNodeSchema + + if kwargs.pop("inside_pipeline", False): + loaded_data = load_from_dict(AutoMLForecastingNodeSchema, data, context, additional_message, **kwargs) + else: + loaded_data = load_from_dict(AutoMLForecastingSchema, data, context, additional_message, **kwargs) + job_instance = cls._create_instance_from_schema_dict(loaded_data) + return job_instance + + @classmethod + def _create_instance_from_schema_dict(cls, loaded_data: Dict) -> "ForecastingJob": + loaded_data.pop(AutoMLConstants.TASK_TYPE_YAML, None) + data_settings = { + "training_data": loaded_data.pop("training_data"), + "target_column_name": loaded_data.pop("target_column_name"), + "weight_column_name": loaded_data.pop("weight_column_name", None), + "validation_data": loaded_data.pop("validation_data", None), + "validation_data_size": loaded_data.pop("validation_data_size", None), + "cv_split_column_names": loaded_data.pop("cv_split_column_names", None), + "n_cross_validations": loaded_data.pop("n_cross_validations", None), + "test_data": loaded_data.pop("test_data", None), + "test_data_size": loaded_data.pop("test_data_size", None), + } + job = ForecastingJob(**loaded_data) + job.set_data(**data_settings) + return job + + def _to_dict(self, inside_pipeline: bool = False) -> Dict: + from azure.ai.ml._schema.automl.table_vertical.forecasting import AutoMLForecastingSchema + from azure.ai.ml._schema.pipeline.automl_node import AutoMLForecastingNodeSchema + + schema_dict: dict = {} + if inside_pipeline: + schema_dict = AutoMLForecastingNodeSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + else: + schema_dict = AutoMLForecastingSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return schema_dict + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ForecastingJob): + return NotImplemented + + if not super(ForecastingJob, self).__eq__(other): + return False + + return self.primary_metric == other.primary_metric and self._forecasting_settings == other._forecasting_settings + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/forecasting_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/forecasting_settings.py new file mode 100644 index 00000000..09439483 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/forecasting_settings.py @@ -0,0 +1,383 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=too-many-instance-attributes + +from typing import List, Optional, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + AutoForecastHorizon, + AutoSeasonality, + AutoTargetLags, + AutoTargetRollingWindowSize, + CustomForecastHorizon, + CustomSeasonality, + CustomTargetLags, + CustomTargetRollingWindowSize, + ForecastHorizonMode, +) +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + ForecastingSettings as RestForecastingSettings, +) +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + SeasonalityMode, + TargetLagsMode, + TargetRollingWindowSizeMode, +) +from azure.ai.ml.entities._mixins import RestTranslatableMixin + + +class ForecastingSettings(RestTranslatableMixin): + """Forecasting settings for an AutoML Job. + + :param country_or_region_for_holidays: The country/region used to generate holiday features. These should be ISO + 3166 two-letter country/region code, for example 'US' or 'GB'. + :type country_or_region_for_holidays: Optional[str] + :param cv_step_size: + Number of periods between the origin_time of one CV fold and the next fold. For + example, if `n_step` = 3 for daily data, the origin time for each fold will be + three days apart. + :type cv_step_size: Optional[int] + :param forecast_horizon: + The desired maximum forecast horizon in units of time-series frequency. The default value is 1. + + Units are based on the time interval of your training data, e.g., monthly, weekly that the forecaster + should predict out. When task type is forecasting, this parameter is required. For more information on + setting forecasting parameters, see `Auto-train a time-series forecast model <https://learn.microsoft.com/ + azure/machine-learning/how-to-auto-train-forecast>`_. + :type forecast_horizon: Optional[Union[int, str]] + :param target_lags: + The number of past periods to lag from the target column. By default the lags are turned off. + + When forecasting, this parameter represents the number of rows to lag the target values based + on the frequency of the data. This is represented as a list or single integer. Lag should be used + when the relationship between the independent variables and dependent variable do not match up or + correlate by default. For example, when trying to forecast demand for a product, the demand in any + month may depend on the price of specific commodities 3 months prior. In this example, you may want + to lag the target (demand) negatively by 3 months so that the model is training on the correct + relationship. For more information, see `Auto-train a time-series forecast model + <https://learn.microsoft.com/azure/machine-learning/how-to-auto-train-forecast>`_. + + **Note on auto detection of target lags and rolling window size. + Please see the corresponding comments in the rolling window section.** + We use the next algorithm to detect the optimal target lag and rolling window size. + + #. Estimate the maximum lag order for the look back feature selection. In our case it is the number of + periods till the next date frequency granularity i.e. if frequency is daily, it will be a week (7), + if it is a week, it will be month (4). That values multiplied by two is the largest + possible values of lags/rolling windows. In our examples, we will consider the maximum lag + order of 14 and 8 respectively). + #. Create a de-seasonalized series by adding trend and residual components. This will be used + in the next step. + #. Estimate the PACF - Partial Auto Correlation Function on the on the data from (2) + and search for points, where the auto correlation is significant i.e. its absolute + value is more then 1.96/square_root(maximal lag value), which correspond to significance of 95%. + #. If all points are significant, we consider it being strong seasonality + and do not create look back features. + #. We scan the PACF values from the beginning and the value before the first insignificant + auto correlation will designate the lag. If first significant element (value correlate with + itself) is followed by insignificant, the lag will be 0 and we will not use look back features. + :type target_lags: Union[str, int, List[int]] + :param target_rolling_window_size: + The number of past periods used to create a rolling window average of the target column. + + When forecasting, this parameter represents `n` historical periods to use to generate forecasted values, + <= training set size. If omitted, `n` is the full training set size. Specify this parameter + when you only want to consider a certain amount of history when training the model. + If set to 'auto', rolling window will be estimated as the last + value where the PACF is more then the significance threshold. Please see target_lags section for details. + :type target_rolling_window_size: Optional[Union[str, int]] + :param frequency: Forecast frequency. + + When forecasting, this parameter represents the period with which the forecast is desired, + for example daily, weekly, yearly, etc. The forecast frequency is dataset frequency by default. + You can optionally set it to greater (but not lesser) than dataset frequency. + We'll aggregate the data and generate the results at forecast frequency. For example, + for daily data, you can set the frequency to be daily, weekly or monthly, but not hourly. + The frequency needs to be a pandas offset alias. + Please refer to pandas documentation for more information: + https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#dateoffset-objects + :type frequency: Optional[str] + :param feature_lags: Flag for generating lags for the numeric features with 'auto' or None. + :type feature_lags: Optional[str] + :param seasonality: Set time series seasonality as an integer multiple of the series frequency. + If seasonality is set to 'auto', it will be inferred. + If set to None, the time series is assumed non-seasonal which is equivalent to seasonality=1. + :type seasonality: Optional[Union[int, str]] + :param use_stl: Configure STL Decomposition of the time-series target column. + use_stl can take three values: None (default) - no stl decomposition, 'season' - only generate + season component and season_trend - generate both season and trend components. + :type use_stl: Optional[str] + :param short_series_handling_config: + The parameter defining how if AutoML should handle short time series. + + Possible values: 'auto' (default), 'pad', 'drop' and None. + * **auto** short series will be padded if there are no long series, + otherwise short series will be dropped. + * **pad** all the short series will be padded. + * **drop** all the short series will be dropped". + * **None** the short series will not be modified. + If set to 'pad', the table will be padded with the zeroes and + empty values for the regressors and random values for target with the mean + equal to target value median for given time series id. If median is more or equal + to zero, the minimal padded value will be clipped by zero. + Input: + + +------------+---------------+----------+--------+ + | Date | numeric_value | string | target | + +============+===============+==========+========+ + | 2020-01-01 | 23 | green | 55 | + +------------+---------------+----------+--------+ + + Output assuming minimal number of values is four: + + +------------+---------------+----------+--------+ + | Date | numeric_value | string | target | + +============+===============+==========+========+ + | 2019-12-29 | 0 | NA | 55.1 | + +------------+---------------+----------+--------+ + | 2019-12-30 | 0 | NA | 55.6 | + +------------+---------------+----------+--------+ + | 2019-12-31 | 0 | NA | 54.5 | + +------------+---------------+----------+--------+ + | 2020-01-01 | 23 | green | 55 | + +------------+---------------+----------+--------+ + + **Note:** We have two parameters short_series_handling_configuration and + legacy short_series_handling. When both parameters are set we are + synchronize them as shown in the table below (short_series_handling_configuration and + short_series_handling for brevity are marked as handling_configuration and handling + respectively). + + +------------+--------------------------+----------------------+-----------------------------+ + | | handling | | handling configuration | | resulting handling | | resulting handling | + | | | | | configuration | + +============+==========================+======================+=============================+ + | True | auto | True | auto | + +------------+--------------------------+----------------------+-----------------------------+ + | True | pad | True | auto | + +------------+--------------------------+----------------------+-----------------------------+ + | True | drop | True | auto | + +------------+--------------------------+----------------------+-----------------------------+ + | True | None | False | None | + +------------+--------------------------+----------------------+-----------------------------+ + | False | auto | False | None | + +------------+--------------------------+----------------------+-----------------------------+ + | False | pad | False | None | + +------------+--------------------------+----------------------+-----------------------------+ + | False | drop | False | None | + +------------+--------------------------+----------------------+-----------------------------+ + | False | None | False | None | + +------------+--------------------------+----------------------+-----------------------------+ + + :type short_series_handling_config: Optional[str] + :param target_aggregate_function: The function to be used to aggregate the time series target + column to conform to a user specified frequency. If the + target_aggregation_function is set, but the freq parameter + is not set, the error is raised. The possible target + aggregation functions are: "sum", "max", "min" and "mean". + + * The target column values are aggregated based on the specified operation. + Typically, sum is appropriate for most scenarios. + * Numerical predictor columns in your data are aggregated by sum, mean, minimum value, + and maximum value. As a result, automated ML generates new columns suffixed with the + aggregation function name and applies the selected aggregate operation. + * For categorical predictor columns, the data is aggregated by mode, + the most prominent category in the window. + * Date predictor columns are aggregated by minimum value, maximum value and mode. + + +----------------+-------------------------------+--------------------------------------+ + | | freq | | target_aggregation_function | | Data regularity | + | | | | fixing mechanism | + +================+===============================+======================================+ + | None (Default) | None (Default) | | The aggregation is not | + | | | | applied. If the valid | + | | | | frequency can not be | + | | | | determined the error will | + | | | | be raised. | + +----------------+-------------------------------+--------------------------------------+ + | Some Value | None (Default) | | The aggregation is not | + | | | | applied. If the number | + | | | | of data points compliant | + | | | | to given frequency grid | + | | | | is less then 90% these points | + | | | | will be removed, otherwise | + | | | | the error will be raised. | + +----------------+-------------------------------+--------------------------------------+ + | None (Default) | Aggregation function | | The error about missing | + | | | | frequency parameter | + | | | | is raised. | + +----------------+-------------------------------+--------------------------------------+ + | Some Value | Aggregation function | | Aggregate to frequency using | + | | | | provided aggregation function. | + +----------------+-------------------------------+--------------------------------------+ + :type target_aggregate_function: str + :param time_column_name: + The name of the time column. This parameter is required when forecasting to specify the datetime + column in the input data used for building the time series and inferring its frequency. + :type time_column_name: Optional[str] + :param time_series_id_column_names: + The names of columns used to group a timeseries. + It can be used to create multiple series. If time series id column names is not defined or + the identifier columns specified do not identify all the series in the dataset, the time series identifiers + will be automatically created for your dataset. + :type time_series_id_column_names: Union[str, List[str]] + :param features_unknown_at_forecast_time: + The feature columns that are available for training but unknown at the time of forecast/inference. + If features_unknown_at_forecast_time is set to an empty list, it is assumed that + all the feature columns in the dataset are known at inference time. If this parameter is not set + the support for future features is not enabled. + :type features_unknown_at_forecast_time: Optional[Union[str, List[str]]] + """ + + def __init__( + self, + *, + country_or_region_for_holidays: Optional[str] = None, + cv_step_size: Optional[int] = None, + forecast_horizon: Optional[Union[str, int]] = None, + target_lags: Optional[Union[str, int, List[int]]] = None, + target_rolling_window_size: Optional[Union[str, int]] = None, + frequency: Optional[str] = None, + feature_lags: Optional[str] = None, + seasonality: Optional[Union[str, int]] = None, + use_stl: Optional[str] = None, + short_series_handling_config: Optional[str] = None, + target_aggregate_function: Optional[str] = None, + time_column_name: Optional[str] = None, + time_series_id_column_names: Optional[Union[str, List[str]]] = None, + features_unknown_at_forecast_time: Optional[Union[str, List[str]]] = None, + ): + self.country_or_region_for_holidays = country_or_region_for_holidays + self.cv_step_size = cv_step_size + self.forecast_horizon = forecast_horizon + self.target_lags = target_lags + self.target_rolling_window_size = target_rolling_window_size + self.frequency = frequency + self.feature_lags = feature_lags + self.seasonality = seasonality + self.use_stl = use_stl + self.short_series_handling_config = short_series_handling_config + self.target_aggregate_function = target_aggregate_function + self.time_column_name = time_column_name + self.time_series_id_column_names = time_series_id_column_names + self.features_unknown_at_forecast_time = features_unknown_at_forecast_time + + def _to_rest_object(self) -> RestForecastingSettings: + forecast_horizon = None + if isinstance(self.forecast_horizon, str): + forecast_horizon = AutoForecastHorizon() + elif self.forecast_horizon: + forecast_horizon = CustomForecastHorizon(value=self.forecast_horizon) + + target_lags = None + if isinstance(self.target_lags, str): + target_lags = AutoTargetLags() + elif self.target_lags: + lags = [self.target_lags] if not isinstance(self.target_lags, list) else self.target_lags + target_lags = CustomTargetLags(values=lags) + + target_rolling_window_size = None + if isinstance(self.target_rolling_window_size, str): + target_rolling_window_size = AutoTargetRollingWindowSize() + elif self.target_rolling_window_size: + target_rolling_window_size = CustomTargetRollingWindowSize(value=self.target_rolling_window_size) + + seasonality = None + if isinstance(self.seasonality, str): + seasonality = AutoSeasonality() + elif self.seasonality: + seasonality = CustomSeasonality(value=self.seasonality) + + time_series_id_column_names = self.time_series_id_column_names + if isinstance(self.time_series_id_column_names, str) and self.time_series_id_column_names: + time_series_id_column_names = [self.time_series_id_column_names] + + features_unknown_at_forecast_time = self.features_unknown_at_forecast_time + if isinstance(self.features_unknown_at_forecast_time, str) and self.features_unknown_at_forecast_time: + features_unknown_at_forecast_time = [self.features_unknown_at_forecast_time] + + return RestForecastingSettings( + country_or_region_for_holidays=self.country_or_region_for_holidays, + cv_step_size=self.cv_step_size, + forecast_horizon=forecast_horizon, + time_column_name=self.time_column_name, + target_lags=target_lags, + target_rolling_window_size=target_rolling_window_size, + seasonality=seasonality, + frequency=self.frequency, + feature_lags=self.feature_lags, + use_stl=self.use_stl, + short_series_handling_config=self.short_series_handling_config, + target_aggregate_function=self.target_aggregate_function, + time_series_id_column_names=time_series_id_column_names, + features_unknown_at_forecast_time=features_unknown_at_forecast_time, + ) + + @classmethod + def _from_rest_object(cls, obj: RestForecastingSettings) -> "ForecastingSettings": + forecast_horizon = None + if obj.forecast_horizon and obj.forecast_horizon.mode == ForecastHorizonMode.AUTO: + forecast_horizon = obj.forecast_horizon.mode.lower() + elif obj.forecast_horizon: + forecast_horizon = obj.forecast_horizon.value + + rest_target_lags = obj.target_lags + target_lags = None + if rest_target_lags and rest_target_lags.mode == TargetLagsMode.AUTO: + target_lags = rest_target_lags.mode.lower() + elif rest_target_lags: + target_lags = rest_target_lags.values + + target_rolling_window_size = None + if obj.target_rolling_window_size and obj.target_rolling_window_size.mode == TargetRollingWindowSizeMode.AUTO: + target_rolling_window_size = obj.target_rolling_window_size.mode.lower() + elif obj.target_rolling_window_size: + target_rolling_window_size = obj.target_rolling_window_size.value + + seasonality = None + if obj.seasonality and obj.seasonality.mode == SeasonalityMode.AUTO: + seasonality = obj.seasonality.mode.lower() + elif obj.seasonality: + seasonality = obj.seasonality.value + + return cls( + country_or_region_for_holidays=obj.country_or_region_for_holidays, + cv_step_size=obj.cv_step_size, + forecast_horizon=forecast_horizon, + target_lags=target_lags, + target_rolling_window_size=target_rolling_window_size, + frequency=obj.frequency, + feature_lags=obj.feature_lags, + seasonality=seasonality, + use_stl=obj.use_stl, + short_series_handling_config=obj.short_series_handling_config, + target_aggregate_function=obj.target_aggregate_function, + time_column_name=obj.time_column_name, + time_series_id_column_names=obj.time_series_id_column_names, + features_unknown_at_forecast_time=obj.features_unknown_at_forecast_time, + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ForecastingSettings): + return NotImplemented + return ( + self.country_or_region_for_holidays == other.country_or_region_for_holidays + and self.cv_step_size == other.cv_step_size + and self.forecast_horizon == other.forecast_horizon + and self.target_lags == other.target_lags + and self.target_rolling_window_size == other.target_rolling_window_size + and self.frequency == other.frequency + and self.feature_lags == other.feature_lags + and self.seasonality == other.seasonality + and self.use_stl == other.use_stl + and self.short_series_handling_config == other.short_series_handling_config + and self.target_aggregate_function == other.target_aggregate_function + and self.time_column_name == other.time_column_name + and self.time_series_id_column_names == other.time_series_id_column_names + and self.features_unknown_at_forecast_time == other.features_unknown_at_forecast_time + ) + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/limit_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/limit_settings.py new file mode 100644 index 00000000..1024f504 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/limit_settings.py @@ -0,0 +1,101 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Optional + +from azure.ai.ml._restclient.v2023_04_01_preview.models import TableVerticalLimitSettings as RestTabularLimitSettings +from azure.ai.ml._utils.utils import from_iso_duration_format_mins, to_iso_duration_format_mins +from azure.ai.ml.entities._mixins import RestTranslatableMixin + + +class TabularLimitSettings(RestTranslatableMixin): + """Limit settings for a AutoML Table Verticals. + + :param enable_early_termination: Whether to enable early termination if the score is not improving in + the short term. The default is True. + :type enable_early_termination: bool + :param exit_score: Target score for experiment. The experiment terminates after this score is reached. + :type exit_score: float + :param max_concurrent_trials: Maximum number of concurrent AutoML iterations. + :type max_concurrent_trials: int + :param max_cores_per_trial: The maximum number of threads to use for a given training iteration. + :type max_cores_per_trial: int + :param max_nodes: [Experimental] The maximum number of nodes to use for distributed training. + + * For forecasting, each model is trained using max(2, int(max_nodes / max_concurrent_trials)) nodes. + + * For classification/regression, each model is trained using max_nodes nodes. + + Note- This parameter is in public preview and might change in future. + :type max_nodes: int + :param max_trials: Maximum number of AutoML iterations. + :type max_trials: int + :param timeout_minutes: AutoML job timeout. + :type timeout_minutes: int + :param trial_timeout_minutes: AutoML job timeout. + :type trial_timeout_minutes: int + """ + + def __init__( + self, + *, + enable_early_termination: Optional[bool] = None, + exit_score: Optional[float] = None, + max_concurrent_trials: Optional[int] = None, + max_cores_per_trial: Optional[int] = None, + max_nodes: Optional[int] = None, + max_trials: Optional[int] = None, + timeout_minutes: Optional[int] = None, + trial_timeout_minutes: Optional[int] = None, + ): + self.enable_early_termination = enable_early_termination + self.exit_score = exit_score + self.max_concurrent_trials = max_concurrent_trials + self.max_cores_per_trial = max_cores_per_trial + self.max_nodes = max_nodes + self.max_trials = max_trials + self.timeout_minutes = timeout_minutes + self.trial_timeout_minutes = trial_timeout_minutes + + def _to_rest_object(self) -> RestTabularLimitSettings: + return RestTabularLimitSettings( + enable_early_termination=self.enable_early_termination, + exit_score=self.exit_score, + max_concurrent_trials=self.max_concurrent_trials, + max_cores_per_trial=self.max_cores_per_trial, + max_nodes=self.max_nodes, + max_trials=self.max_trials, + timeout=to_iso_duration_format_mins(self.timeout_minutes), + trial_timeout=to_iso_duration_format_mins(self.trial_timeout_minutes), + ) + + @classmethod + def _from_rest_object(cls, obj: RestTabularLimitSettings) -> "TabularLimitSettings": + return cls( + enable_early_termination=obj.enable_early_termination, + exit_score=obj.exit_score, + max_concurrent_trials=obj.max_concurrent_trials, + max_cores_per_trial=obj.max_cores_per_trial, + max_nodes=obj.max_nodes, + max_trials=obj.max_trials, + timeout_minutes=from_iso_duration_format_mins(obj.timeout), + trial_timeout_minutes=from_iso_duration_format_mins(obj.trial_timeout), + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TabularLimitSettings): + return NotImplemented + return ( + self.enable_early_termination == other.enable_early_termination + and self.exit_score == other.exit_score + and self.max_concurrent_trials == other.max_concurrent_trials + and self.max_cores_per_trial == other.max_cores_per_trial + and self.max_nodes == other.max_nodes + and self.max_trials == other.max_trials + and self.timeout_minutes == other.timeout_minutes + and self.trial_timeout_minutes == other.trial_timeout_minutes + ) + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/regression_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/regression_job.py new file mode 100644 index 00000000..3531e52c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/regression_job.py @@ -0,0 +1,239 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from typing import Any, Dict, Optional, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import AutoMLJob as RestAutoMLJob +from azure.ai.ml._restclient.v2023_04_01_preview.models import JobBase +from azure.ai.ml._restclient.v2023_04_01_preview.models import Regression as RestRegression +from azure.ai.ml._restclient.v2023_04_01_preview.models import RegressionPrimaryMetrics, TaskType +from azure.ai.ml._utils.utils import camel_to_snake, is_data_binding_expression +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY +from azure.ai.ml.constants._job.automl import AutoMLConstants +from azure.ai.ml.entities._credentials import _BaseJobIdentityConfiguration +from azure.ai.ml.entities._job._input_output_helpers import from_rest_data_outputs, to_rest_data_outputs +from azure.ai.ml.entities._job.automl.tabular import AutoMLTabular, TabularFeaturizationSettings, TabularLimitSettings +from azure.ai.ml.entities._job.automl.training_settings import RegressionTrainingSettings +from azure.ai.ml.entities._util import load_from_dict + + +class RegressionJob(AutoMLTabular): + """Configuration for AutoML Regression Job.""" + + _DEFAULT_PRIMARY_METRIC = RegressionPrimaryMetrics.NORMALIZED_ROOT_MEAN_SQUARED_ERROR + + def __init__( + self, + *, + primary_metric: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Initialize a new AutoML Regression task. + + :param primary_metric: The primary metric to use for optimization + :type primary_metric: str + :param kwargs: Job-specific arguments + :type kwargs: dict + """ + # Extract any task specific settings + featurization = kwargs.pop("featurization", None) + limits = kwargs.pop("limits", None) + training = kwargs.pop("training", None) + + super().__init__( + task_type=TaskType.REGRESSION, + featurization=featurization, + limits=limits, + training=training, + **kwargs, + ) + + self.primary_metric = primary_metric or RegressionJob._DEFAULT_PRIMARY_METRIC + + @property + def primary_metric(self) -> Union[str, RegressionPrimaryMetrics]: + return self._primary_metric + + @primary_metric.setter + def primary_metric(self, value: Union[str, RegressionPrimaryMetrics]) -> None: + # TODO: better way to do this + if is_data_binding_expression(str(value), ["parent"]): + self._primary_metric = value + return + self._primary_metric = ( + RegressionJob._DEFAULT_PRIMARY_METRIC + if value is None + else RegressionPrimaryMetrics[camel_to_snake(value).upper()] + ) + + @property + def training(self) -> RegressionTrainingSettings: + return self._training or RegressionTrainingSettings() + + @training.setter + def training(self, value: Union[Dict, RegressionTrainingSettings]) -> None: # pylint: disable=unused-argument + ... + + def _to_rest_object(self) -> JobBase: + regression_task = RestRegression( + target_column_name=self.target_column_name, + training_data=self.training_data, + validation_data=self.validation_data, + validation_data_size=self.validation_data_size, + weight_column_name=self.weight_column_name, + cv_split_column_names=self.cv_split_column_names, + n_cross_validations=self.n_cross_validations, + test_data=self.test_data, + test_data_size=self.test_data_size, + featurization_settings=self._featurization._to_rest_object() if self._featurization else None, + limit_settings=self._limits._to_rest_object() if self._limits else None, + training_settings=self._training._to_rest_object() if self._training else None, + primary_metric=self.primary_metric, + log_verbosity=self.log_verbosity, + ) + self._resolve_data_inputs(regression_task) + self._validation_data_to_rest(regression_task) + + properties = RestAutoMLJob( + display_name=self.display_name, + description=self.description, + experiment_name=self.experiment_name, + tags=self.tags, + compute_id=self.compute, + properties=self.properties, + environment_id=self.environment_id, + environment_variables=self.environment_variables, + services=self.services, + outputs=to_rest_data_outputs(self.outputs), + resources=self.resources, + task_details=regression_task, + identity=self.identity._to_job_rest_object() if self.identity else None, + queue_settings=self.queue_settings, + ) + + result = JobBase(properties=properties) + result.name = self.name + return result + + @classmethod + def _from_rest_object(cls, obj: JobBase) -> "RegressionJob": + properties: RestAutoMLJob = obj.properties + task_details: RestRegression = properties.task_details + + job_args_dict = { + "id": obj.id, + "name": obj.name, + "description": properties.description, + "tags": properties.tags, + "properties": properties.properties, + "experiment_name": properties.experiment_name, + "services": properties.services, + "status": properties.status, + "creation_context": obj.system_data, + "display_name": properties.display_name, + "compute": properties.compute_id, + "outputs": from_rest_data_outputs(properties.outputs), + "resources": properties.resources, + "identity": ( + _BaseJobIdentityConfiguration._from_rest_object(properties.identity) if properties.identity else None + ), + "queue_settings": properties.queue_settings, + } + + regression_job = cls( + target_column_name=task_details.target_column_name, + training_data=task_details.training_data, + validation_data=task_details.validation_data, + validation_data_size=task_details.validation_data_size, + weight_column_name=task_details.weight_column_name, + cv_split_column_names=task_details.cv_split_column_names, + n_cross_validations=task_details.n_cross_validations, + test_data=task_details.test_data, + test_data_size=task_details.test_data_size, + featurization=( + TabularFeaturizationSettings._from_rest_object(task_details.featurization_settings) + if task_details.featurization_settings + else None + ), + limits=( + TabularLimitSettings._from_rest_object(task_details.limit_settings) + if task_details.limit_settings + else None + ), + training=( + RegressionTrainingSettings._from_rest_object(task_details.training_settings) + if task_details.training_settings + else None + ), + primary_metric=task_details.primary_metric, + log_verbosity=task_details.log_verbosity, + **job_args_dict, + ) + + regression_job._restore_data_inputs() + regression_job._validation_data_from_rest() + + return regression_job + + @classmethod + def _load_from_dict( + cls, + data: Dict, + context: Dict, + additional_message: str, + **kwargs: Any, + ) -> "RegressionJob": + from azure.ai.ml._schema.automl.table_vertical.regression import AutoMLRegressionSchema + from azure.ai.ml._schema.pipeline.automl_node import AutoMLRegressionNodeSchema + + if kwargs.pop("inside_pipeline", False): + loaded_data = load_from_dict(AutoMLRegressionNodeSchema, data, context, additional_message, **kwargs) + else: + loaded_data = load_from_dict(AutoMLRegressionSchema, data, context, additional_message, **kwargs) + job_instance = cls._create_instance_from_schema_dict(loaded_data) + return job_instance + + @classmethod + def _create_instance_from_schema_dict(cls, loaded_data: Dict) -> "RegressionJob": + loaded_data.pop(AutoMLConstants.TASK_TYPE_YAML, None) + data_settings = { + "training_data": loaded_data.pop("training_data"), + "target_column_name": loaded_data.pop("target_column_name"), + "weight_column_name": loaded_data.pop("weight_column_name", None), + "validation_data": loaded_data.pop("validation_data", None), + "validation_data_size": loaded_data.pop("validation_data_size", None), + "cv_split_column_names": loaded_data.pop("cv_split_column_names", None), + "n_cross_validations": loaded_data.pop("n_cross_validations", None), + "test_data": loaded_data.pop("test_data", None), + "test_data_size": loaded_data.pop("test_data_size", None), + } + job = RegressionJob(**loaded_data) + job.set_data(**data_settings) + return job + + def _to_dict(self, inside_pipeline: bool = False) -> Dict: + from azure.ai.ml._schema.automl.table_vertical.regression import AutoMLRegressionSchema + from azure.ai.ml._schema.pipeline.automl_node import AutoMLRegressionNodeSchema + + schema_dict: dict = {} + if inside_pipeline: + schema_dict = AutoMLRegressionNodeSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + else: + schema_dict = AutoMLRegressionSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + + return schema_dict + + def __eq__(self, other: object) -> bool: + if not isinstance(other, RegressionJob): + return NotImplemented + + if not super(RegressionJob, self).__eq__(other): + return False + + return self.primary_metric == other.primary_metric + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/training_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/training_settings.py new file mode 100644 index 00000000..97bc7e17 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/training_settings.py @@ -0,0 +1,357 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=R0902,protected-access + +from typing import Any, List, Optional, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import ClassificationModels +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + ClassificationTrainingSettings as RestClassificationTrainingSettings, +) +from azure.ai.ml._restclient.v2023_04_01_preview.models import ForecastingModels +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + ForecastingTrainingSettings as RestForecastingTrainingSettings, +) +from azure.ai.ml._restclient.v2023_04_01_preview.models import RegressionModels +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + RegressionTrainingSettings as RestRegressionTrainingSettings, +) +from azure.ai.ml._restclient.v2023_04_01_preview.models import TrainingSettings as RestTrainingSettings +from azure.ai.ml._utils.utils import camel_to_snake, from_iso_duration_format_mins, to_iso_duration_format_mins +from azure.ai.ml.constants import TabularTrainingMode +from azure.ai.ml.entities._job.automl.stack_ensemble_settings import StackEnsembleSettings +from azure.ai.ml.entities._mixins import RestTranslatableMixin +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException + + +class TrainingSettings(RestTranslatableMixin): + """TrainingSettings class for Azure Machine Learning.""" + + def __init__( + self, + *, + enable_onnx_compatible_models: Optional[bool] = None, + enable_dnn_training: Optional[bool] = None, + enable_model_explainability: Optional[bool] = None, + enable_stack_ensemble: Optional[bool] = None, + enable_vote_ensemble: Optional[bool] = None, + stack_ensemble_settings: Optional[StackEnsembleSettings] = None, + ensemble_model_download_timeout: Optional[int] = None, + allowed_training_algorithms: Optional[List[str]] = None, + blocked_training_algorithms: Optional[List[str]] = None, + training_mode: Optional[Union[str, TabularTrainingMode]] = None, + ): + """TrainingSettings class for Azure Machine Learning. + + :param enable_onnx_compatible_models: If set to True, the model will be trained to be compatible with ONNX + :type enable_onnx_compatible_models: typing.Optional[bool] + :param enable_dnn_training: If set to True,the model will use DNN training + :type enable_dnn_training: typing.Optional[bool] + :param enable_model_explainability: If set to True, the model will be trained to be explainable + :type enable_model_explainability: typing.Optional[bool] + :param enable_stack_ensemble: If set to True, a final ensemble model will be created using a stack of models + :type enable_stack_ensemble: typing.Optional[bool] + :param enable_vote_ensemble: If set to True, a final ensemble model will be created using a voting ensemble + :type enable_vote_ensemble: typing.Optional[bool] + :param stack_ensemble_settings: Settings for stack ensemble + :type stack_ensemble_settings: typing.Optional[azure.ai.ml.automl.StackEnsembleSettings] + :param ensemble_model_download_timeout: Timeout for downloading ensemble models + :type ensemble_model_download_timeout: typing.Optional[typing.List[int]] + :param allowed_training_algorithms: Models to train + :type allowed_training_algorithms: typing.Optional[typing.List[str]] + :param blocked_training_algorithms: Models that will not be considered for training + :type blocked_training_algorithms: typing.Optional[typing.List[str]] + :param training_mode: [Experimental] The training mode to use. + The possible values are- + + * distributed- enables distributed training for supported algorithms. + + * non_distributed- disables distributed training. + + * auto- Currently, it is same as non_distributed. In future, this might change. + + Note: This parameter is in public preview and may change in future. + :type training_mode: typing.Optional[typing.Union[str, azure.ai.ml.constants.TabularTrainingMode]] + """ + self.enable_onnx_compatible_models = enable_onnx_compatible_models + self.enable_dnn_training = enable_dnn_training + self.enable_model_explainability = enable_model_explainability + self.enable_stack_ensemble = enable_stack_ensemble + self.enable_vote_ensemble = enable_vote_ensemble + self.stack_ensemble_settings = stack_ensemble_settings + self.ensemble_model_download_timeout = ensemble_model_download_timeout + self.allowed_training_algorithms = allowed_training_algorithms + self.blocked_training_algorithms = blocked_training_algorithms + self.training_mode = training_mode + + @property + def training_mode(self) -> Optional[TabularTrainingMode]: + return self._training_mode + + @training_mode.setter + def training_mode(self, value: Optional[Union[str, TabularTrainingMode]]) -> None: + if value is None or value is TabularTrainingMode: + self._training_mode = value + elif hasattr(TabularTrainingMode, camel_to_snake(value).upper()): + self._training_mode = TabularTrainingMode[camel_to_snake(value).upper()] + else: + supported_values = ", ".join([f'"{camel_to_snake(mode.value)}"' for mode in TabularTrainingMode]) + msg = ( + f"Unsupported training mode: {value}. Supported values are- {supported_values}. " + "Or you can use azure.ai.ml.constants.TabularTrainingMode enum." + ) + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.AUTOML, + error_category=ErrorCategory.USER_ERROR, + ) + + @property + def allowed_training_algorithms(self) -> Optional[List[str]]: + return self._allowed_training_algorithms + + @allowed_training_algorithms.setter + def allowed_training_algorithms(self, value: Optional[List[str]]) -> None: + self._allowed_training_algorithms = value + + @property + def blocked_training_algorithms(self) -> Optional[List[str]]: + return self._blocked_training_algorithms + + @blocked_training_algorithms.setter + def blocked_training_algorithms(self, value: Optional[List[str]]) -> None: + self._blocked_training_algorithms = value + + def _to_rest_object(self) -> RestTrainingSettings: + return RestTrainingSettings( + enable_dnn_training=self.enable_dnn_training, + enable_onnx_compatible_models=self.enable_onnx_compatible_models, + enable_model_explainability=self.enable_model_explainability, + enable_stack_ensemble=self.enable_stack_ensemble, + enable_vote_ensemble=self.enable_vote_ensemble, + stack_ensemble_settings=( + self.stack_ensemble_settings._to_rest_object() if self.stack_ensemble_settings else None + ), + ensemble_model_download_timeout=to_iso_duration_format_mins(self.ensemble_model_download_timeout), + training_mode=self.training_mode, + ) + + @classmethod + def _from_rest_object(cls, obj: RestTrainingSettings) -> "TrainingSettings": + return cls( + enable_dnn_training=obj.enable_dnn_training, + enable_onnx_compatible_models=obj.enable_onnx_compatible_models, + enable_model_explainability=obj.enable_model_explainability, + enable_stack_ensemble=obj.enable_stack_ensemble, + enable_vote_ensemble=obj.enable_vote_ensemble, + ensemble_model_download_timeout=from_iso_duration_format_mins(obj.ensemble_model_download_timeout), + stack_ensemble_settings=( + StackEnsembleSettings._from_rest_object(obj.stack_ensemble_settings) + if obj.stack_ensemble_settings + else None + ), + training_mode=obj.training_mode, + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TrainingSettings): + return NotImplemented + return ( + self.enable_dnn_training == other.enable_dnn_training + and self.enable_onnx_compatible_models == other.enable_onnx_compatible_models + and self.enable_model_explainability == other.enable_model_explainability + and self.enable_stack_ensemble == other.enable_stack_ensemble + and self.enable_vote_ensemble == other.enable_vote_ensemble + and self.ensemble_model_download_timeout == other.ensemble_model_download_timeout + and self.stack_ensemble_settings == other.stack_ensemble_settings + and self.allowed_training_algorithms == other.allowed_training_algorithms + and self.blocked_training_algorithms == other.blocked_training_algorithms + and self.training_mode == other.training_mode + ) + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + +class ClassificationTrainingSettings(TrainingSettings): + """Classification TrainingSettings class for Azure Machine Learning.""" + + def __init__( + self, + **kwargs: Any, + ): + super().__init__(**kwargs) + + @property + def allowed_training_algorithms(self) -> Optional[List]: + return self._allowed_training_algorithms + + @allowed_training_algorithms.setter + def allowed_training_algorithms(self, allowed_model_list: Union[List[str], List[ClassificationModels]]) -> None: + self._allowed_training_algorithms = ( + None + if allowed_model_list is None + else [ClassificationModels[camel_to_snake(o)] for o in allowed_model_list] + ) + + @property + def blocked_training_algorithms(self) -> Optional[List]: + return self._blocked_training_algorithms + + @blocked_training_algorithms.setter + def blocked_training_algorithms(self, blocked_model_list: Union[List[str], List[ClassificationModels]]) -> None: + self._blocked_training_algorithms = ( + None + if blocked_model_list is None + else [ClassificationModels[camel_to_snake(o)] for o in blocked_model_list] + ) + + def _to_rest_object(self) -> RestClassificationTrainingSettings: + return RestClassificationTrainingSettings( + enable_dnn_training=self.enable_dnn_training, + enable_onnx_compatible_models=self.enable_onnx_compatible_models, + enable_model_explainability=self.enable_model_explainability, + enable_stack_ensemble=self.enable_stack_ensemble, + enable_vote_ensemble=self.enable_vote_ensemble, + stack_ensemble_settings=self.stack_ensemble_settings, + ensemble_model_download_timeout=to_iso_duration_format_mins(self.ensemble_model_download_timeout), + allowed_training_algorithms=self.allowed_training_algorithms, + blocked_training_algorithms=self.blocked_training_algorithms, + training_mode=self.training_mode, + ) + + @classmethod + def _from_rest_object(cls, obj: RestClassificationTrainingSettings) -> "ClassificationTrainingSettings": + return cls( + enable_dnn_training=obj.enable_dnn_training, + enable_onnx_compatible_models=obj.enable_onnx_compatible_models, + enable_model_explainability=obj.enable_model_explainability, + enable_stack_ensemble=obj.enable_stack_ensemble, + enable_vote_ensemble=obj.enable_vote_ensemble, + ensemble_model_download_timeout=from_iso_duration_format_mins(obj.ensemble_model_download_timeout), + stack_ensemble_settings=obj.stack_ensemble_settings, + allowed_training_algorithms=obj.allowed_training_algorithms, + blocked_training_algorithms=obj.blocked_training_algorithms, + training_mode=obj.training_mode, + ) + + +class ForecastingTrainingSettings(TrainingSettings): + """Forecasting TrainingSettings class for Azure Machine Learning.""" + + def __init__( + self, + **kwargs: Any, + ): + super().__init__(**kwargs) + + @property + def allowed_training_algorithms(self) -> Optional[List]: + return self._allowed_training_algorithms + + @allowed_training_algorithms.setter + def allowed_training_algorithms(self, allowed_model_list: Union[List[str], List[ForecastingModels]]) -> None: + self._allowed_training_algorithms = ( + None if allowed_model_list is None else [ForecastingModels[camel_to_snake(o)] for o in allowed_model_list] + ) + + @property + def blocked_training_algorithms(self) -> Optional[List]: + return self._blocked_training_algorithms + + @blocked_training_algorithms.setter + def blocked_training_algorithms(self, blocked_model_list: Union[List[str], List[ForecastingModels]]) -> None: + self._blocked_training_algorithms = ( + None if blocked_model_list is None else [ForecastingModels[camel_to_snake(o)] for o in blocked_model_list] + ) + + def _to_rest_object(self) -> RestForecastingTrainingSettings: + return RestForecastingTrainingSettings( + enable_dnn_training=self.enable_dnn_training, + enable_onnx_compatible_models=self.enable_onnx_compatible_models, + enable_model_explainability=self.enable_model_explainability, + enable_stack_ensemble=self.enable_stack_ensemble, + enable_vote_ensemble=self.enable_vote_ensemble, + stack_ensemble_settings=self.stack_ensemble_settings, + ensemble_model_download_timeout=to_iso_duration_format_mins(self.ensemble_model_download_timeout), + allowed_training_algorithms=self.allowed_training_algorithms, + blocked_training_algorithms=self.blocked_training_algorithms, + training_mode=self.training_mode, + ) + + @classmethod + def _from_rest_object(cls, obj: RestForecastingTrainingSettings) -> "ForecastingTrainingSettings": + return cls( + enable_dnn_training=obj.enable_dnn_training, + enable_onnx_compatible_models=obj.enable_onnx_compatible_models, + enable_model_explainability=obj.enable_model_explainability, + enable_stack_ensemble=obj.enable_stack_ensemble, + enable_vote_ensemble=obj.enable_vote_ensemble, + ensemble_model_download_timeout=from_iso_duration_format_mins(obj.ensemble_model_download_timeout), + stack_ensemble_settings=obj.stack_ensemble_settings, + allowed_training_algorithms=obj.allowed_training_algorithms, + blocked_training_algorithms=obj.blocked_training_algorithms, + training_mode=obj.training_mode, + ) + + +class RegressionTrainingSettings(TrainingSettings): + """Regression TrainingSettings class for Azure Machine Learning.""" + + def __init__( + self, + **kwargs: Any, + ): + super().__init__(**kwargs) + + @property + def allowed_training_algorithms(self) -> Optional[List]: + return self._allowed_training_algorithms + + @allowed_training_algorithms.setter + def allowed_training_algorithms(self, allowed_model_list: Union[List[str], List[ForecastingModels]]) -> None: + self._allowed_training_algorithms = ( + None if allowed_model_list is None else [RegressionModels[camel_to_snake(o)] for o in allowed_model_list] + ) + + @property + def blocked_training_algorithms(self) -> Optional[List]: + return self._blocked_training_algorithms + + @blocked_training_algorithms.setter + def blocked_training_algorithms(self, blocked_model_list: Union[List[str], List[ForecastingModels]]) -> None: + self._blocked_training_algorithms = ( + None if blocked_model_list is None else [RegressionModels[camel_to_snake(o)] for o in blocked_model_list] + ) + + def _to_rest_object(self) -> RestRegressionTrainingSettings: + return RestRegressionTrainingSettings( + enable_dnn_training=self.enable_dnn_training, + enable_onnx_compatible_models=self.enable_onnx_compatible_models, + enable_model_explainability=self.enable_model_explainability, + enable_stack_ensemble=self.enable_stack_ensemble, + enable_vote_ensemble=self.enable_vote_ensemble, + stack_ensemble_settings=self.stack_ensemble_settings, + ensemble_model_download_timeout=to_iso_duration_format_mins(self.ensemble_model_download_timeout), + allowed_training_algorithms=self.allowed_training_algorithms, + blocked_training_algorithms=self.blocked_training_algorithms, + training_mode=self.training_mode, + ) + + @classmethod + def _from_rest_object(cls, obj: RestRegressionTrainingSettings) -> "RegressionTrainingSettings": + return cls( + enable_dnn_training=obj.enable_dnn_training, + enable_onnx_compatible_models=obj.enable_onnx_compatible_models, + enable_model_explainability=obj.enable_model_explainability, + enable_stack_ensemble=obj.enable_stack_ensemble, + enable_vote_ensemble=obj.enable_vote_ensemble, + ensemble_model_download_timeout=from_iso_duration_format_mins(obj.ensemble_model_download_timeout), + stack_ensemble_settings=obj.stack_ensemble_settings, + allowed_training_algorithms=obj.allowed_training_algorithms, + blocked_training_algorithms=obj.blocked_training_algorithms, + training_mode=obj.training_mode, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/utils.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/utils.py new file mode 100644 index 00000000..08521d7e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/utils.py @@ -0,0 +1,47 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import TYPE_CHECKING, Dict, Type, Union + +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException + +if TYPE_CHECKING: + from azure.ai.ml.entities._job.automl.image.image_classification_search_space import ImageClassificationSearchSpace + from azure.ai.ml.entities._job.automl.image.image_object_detection_search_space import ( + ImageObjectDetectionSearchSpace, + ) + from azure.ai.ml.entities._job.automl.nlp.nlp_search_space import NlpSearchSpace + from azure.ai.ml.entities._job.automl.search_space import SearchSpace + + +def cast_to_specific_search_space( + input: Union[Dict, "SearchSpace"], # pylint: disable=redefined-builtin + class_name: Union[ + Type["ImageClassificationSearchSpace"], Type["ImageObjectDetectionSearchSpace"], Type["NlpSearchSpace"] + ], + task_type: str, +) -> Union["ImageClassificationSearchSpace", "ImageObjectDetectionSearchSpace", "NlpSearchSpace"]: + def validate_searchspace_args(input_dict: dict) -> None: + searchspace = class_name() + for key in input_dict: + if not hasattr(searchspace, key): + msg = f"Received unsupported search space parameter for {task_type} Job." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.AUTOML, + error_category=ErrorCategory.USER_ERROR, + ) + + if isinstance(input, dict): + validate_searchspace_args(input) + specific_search_space = class_name(**input) + else: + validate_searchspace_args(input.__dict__) + specific_search_space = class_name._from_search_space_object(input) # pylint: disable=protected-access + + res: Union["ImageClassificationSearchSpace", "ImageObjectDetectionSearchSpace", "NlpSearchSpace"] = ( + specific_search_space + ) + return res diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/base_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/base_job.py new file mode 100644 index 00000000..72b464e5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/base_job.py @@ -0,0 +1,85 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import logging +from typing import Any, Dict + +from azure.ai.ml._restclient.runhistory.models import Run +from azure.ai.ml._schema.job import BaseJobSchema +from azure.ai.ml.constants import JobType +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE +from azure.ai.ml.entities._system_data import SystemData +from azure.ai.ml.entities._util import load_from_dict + +from .job import Job + +module_logger = logging.getLogger(__name__) + +""" +TODO[Joe]: This class is temporarily created to handle "Base" job type from the service. + We will be working on a more granular job type for pipeline child jobs in the future. + Spec Ref: https://github.com/Azure/azureml_run_specification/pull/340 + MFE PR: https://msdata.visualstudio.com/DefaultCollection/Vienna/_workitems/edit/1167303/ +""" + + +class _BaseJob(Job): + """Base Job, only used in pipeline child jobs. + + :param name: Name of the resource. + :type name: str + :param description: Description of the resource. + :type description: str + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: dict[str, str] + :param properties: The asset property dictionary. + :type properties: dict[str, str] + :param experiment_name: Name of the experiment the job will be created under, + if None is provided, default will be set to current directory name. + :type experiment_name: str + :param services: Information on services associated with the job, readonly. + :type services: dict[str, JobService] + :param compute: The compute target the job runs on. + :type compute: str + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + """ + + def __init__(self, **kwargs: Any): + kwargs[TYPE] = JobType.BASE + + super().__init__(**kwargs) + + def _to_dict(self) -> Dict: + res: dict = BaseJobSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + @classmethod + def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "_BaseJob": + loaded_data = load_from_dict(BaseJobSchema, data, context, additional_message, **kwargs) + return _BaseJob(**loaded_data) + + @classmethod + def _load_from_rest(cls, obj: Run) -> "_BaseJob": + creation_context = SystemData( + created_by=obj.created_by, + created_by_type=obj.created_from, + created_at=obj.created_utc, + last_modified_by=obj.last_modified_by, + last_modified_at=obj.last_modified_utc, + ) + base_job = _BaseJob( + name=obj.run_id, + display_name=obj.display_name, + description=obj.description, + tags=obj.tags, + properties=obj.properties, + experiment_name=obj.experiment_id, + services=obj.services, + status=obj.status, + creation_context=creation_context, + compute=f"{obj.compute.target}" if obj.compute else None, + ) + + return base_job diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/command_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/command_job.py new file mode 100644 index 00000000..0a0c7e82 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/command_job.py @@ -0,0 +1,314 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import copy +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, Optional, Union + +from azure.ai.ml._restclient.v2025_01_01_preview.models import CommandJob as RestCommandJob +from azure.ai.ml._restclient.v2025_01_01_preview.models import JobBase +from azure.ai.ml._schema.job.command_job import CommandJobSchema +from azure.ai.ml._utils.utils import map_single_brackets_and_warn +from azure.ai.ml.constants import JobType +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, LOCAL_COMPUTE_PROPERTY, LOCAL_COMPUTE_TARGET, TYPE +from azure.ai.ml.entities import Environment +from azure.ai.ml.entities._credentials import ( + AmlTokenConfiguration, + ManagedIdentityConfiguration, + UserIdentityConfiguration, + _BaseJobIdentityConfiguration, +) +from azure.ai.ml.entities._inputs_outputs import Input, Output +from azure.ai.ml.entities._job._input_output_helpers import ( + from_rest_data_outputs, + from_rest_inputs_to_dataset_literal, + to_rest_data_outputs, + to_rest_dataset_literal_inputs, + validate_inputs_for_command, +) +from azure.ai.ml.entities._job.distribution import DistributionConfiguration +from azure.ai.ml.entities._job.job_service import ( + JobService, + JobServiceBase, + JupyterLabJobService, + SshJobService, + TensorBoardJobService, + VsCodeJobService, +) +from azure.ai.ml.entities._system_data import SystemData +from azure.ai.ml.entities._util import load_from_dict +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + +from .job import Job +from .job_io_mixin import JobIOMixin +from .job_limits import CommandJobLimits +from .job_resource_configuration import JobResourceConfiguration +from .parameterized_command import ParameterizedCommand +from .queue_settings import QueueSettings + +# avoid circular import error +if TYPE_CHECKING: + from azure.ai.ml.entities import CommandComponent + from azure.ai.ml.entities._builders import Command + +module_logger = logging.getLogger(__name__) + + +class CommandJob(Job, ParameterizedCommand, JobIOMixin): + """Command job. + + .. note:: + For sweep jobs, inputs, outputs, and parameters are accessible as environment variables using the prefix + ``AZUREML_PARAMETER_``. For example, if you have a parameter named "input_data", you can access it as + ``AZUREML_PARAMETER_input_data``. + + :keyword services: Read-only information on services associated with the job. + :paramtype services: Optional[dict[str, ~azure.ai.ml.entities.JobService]] + :keyword inputs: Mapping of output data bindings used in the command. + :paramtype inputs: Optional[dict[str, Union[~azure.ai.ml.Input, str, bool, int, float]]] + :keyword outputs: Mapping of output data bindings used in the job. + :paramtype outputs: Optional[dict[str, ~azure.ai.ml.Output]] + :keyword identity: The identity that the job will use while running on compute. + :paramtype identity: Optional[Union[~azure.ai.ml.ManagedIdentityConfiguration, ~azure.ai.ml.AmlTokenConfiguration, + ~azure.ai.ml.UserIdentityConfiguration]] + :keyword limits: The limits for the job. + :paramtype limits: Optional[~azure.ai.ml.entities.CommandJobLimits] + :keyword parent_job_name: parent job id for command job + :paramtype parent_job_name: Optional[str] + :keyword kwargs: A dictionary of additional configuration parameters. + :paramtype kwargs: dict + + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_command_configurations.py + :start-after: [START command_job_definition] + :end-before: [END command_job_definition] + :language: python + :dedent: 8 + :caption: Configuring a CommandJob. + """ + + def __init__( + self, + *, + inputs: Optional[Dict[str, Union[Input, str, bool, int, float]]] = None, + outputs: Optional[Dict[str, Output]] = None, + limits: Optional[CommandJobLimits] = None, + identity: Optional[ + Union[Dict, ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration] + ] = None, + services: Optional[ + Dict[str, Union[JobService, JupyterLabJobService, SshJobService, TensorBoardJobService, VsCodeJobService]] + ] = None, + parent_job_name: Optional[str] = None, + **kwargs: Any, + ) -> None: + kwargs[TYPE] = JobType.COMMAND + self._parameters: dict = kwargs.pop("parameters", {}) + self.parent_job_name = parent_job_name + + super().__init__(**kwargs) + + self.outputs = outputs # type: ignore[assignment] + self.inputs = inputs # type: ignore[assignment] + self.limits = limits + self.identity = identity + self.services = services + + @property + def parameters(self) -> Dict[str, str]: + """MLFlow parameters. + + :return: MLFlow parameters logged in job. + :rtype: dict[str, str] + """ + return self._parameters + + def _to_dict(self) -> Dict: + res: dict = CommandJobSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + def _to_rest_object(self) -> JobBase: + self._validate() + self.command = map_single_brackets_and_warn(self.command) + modified_properties = copy.deepcopy(self.properties) + # Remove any properties set on the service as read-only + modified_properties.pop("_azureml.ComputeTargetType", None) + # Handle local compute case + compute = self.compute + resources = self.resources + if self.compute == LOCAL_COMPUTE_TARGET: + compute = None + if resources is None: + resources = JobResourceConfiguration() + if not isinstance(resources, Dict): + if resources.properties is None: + resources.properties = {} + # This is the format of the October Api response. We need to match it exactly + resources.properties[LOCAL_COMPUTE_PROPERTY] = {LOCAL_COMPUTE_PROPERTY: True} + + properties = RestCommandJob( + display_name=self.display_name, + description=self.description, + command=self.command, + code_id=self.code, + compute_id=compute, + properties=modified_properties, + experiment_name=self.experiment_name, + inputs=to_rest_dataset_literal_inputs(self.inputs, job_type=self.type), + outputs=to_rest_data_outputs(self.outputs), + environment_id=self.environment, + distribution=( + self.distribution._to_rest_object() + if self.distribution and not isinstance(self.distribution, Dict) + else None + ), + tags=self.tags, + identity=( + self.identity._to_job_rest_object() if self.identity and not isinstance(self.identity, Dict) else None + ), + environment_variables=self.environment_variables, + resources=resources._to_rest_object() if resources and not isinstance(resources, Dict) else None, + limits=self.limits._to_rest_object() if self.limits else None, + services=JobServiceBase._to_rest_job_services(self.services), + queue_settings=self.queue_settings._to_rest_object() if self.queue_settings else None, + parent_job_name=self.parent_job_name, + ) + result = JobBase(properties=properties) + result.name = self.name + return result + + @classmethod + def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "CommandJob": + loaded_data = load_from_dict(CommandJobSchema, data, context, additional_message, **kwargs) + return CommandJob(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_data) + + @classmethod + def _load_from_rest(cls, obj: JobBase) -> "CommandJob": + rest_command_job: RestCommandJob = obj.properties + command_job = CommandJob( + name=obj.name, + id=obj.id, + display_name=rest_command_job.display_name, + description=rest_command_job.description, + tags=rest_command_job.tags, + properties=rest_command_job.properties, + command=rest_command_job.command, + experiment_name=rest_command_job.experiment_name, + services=JobServiceBase._from_rest_job_services(rest_command_job.services), + status=rest_command_job.status, + creation_context=SystemData._from_rest_object(obj.system_data) if obj.system_data else None, + code=rest_command_job.code_id, + compute=rest_command_job.compute_id, + environment=rest_command_job.environment_id, + distribution=DistributionConfiguration._from_rest_object(rest_command_job.distribution), + parameters=rest_command_job.parameters, + # pylint: disable=protected-access + identity=( + _BaseJobIdentityConfiguration._from_rest_object(rest_command_job.identity) + if rest_command_job.identity + else None + ), + environment_variables=rest_command_job.environment_variables, + resources=JobResourceConfiguration._from_rest_object(rest_command_job.resources), + limits=CommandJobLimits._from_rest_object(rest_command_job.limits), + inputs=from_rest_inputs_to_dataset_literal(rest_command_job.inputs), + outputs=from_rest_data_outputs(rest_command_job.outputs), + queue_settings=QueueSettings._from_rest_object(rest_command_job.queue_settings), + parent_job_name=rest_command_job.parent_job_name, + ) + # Handle special case of local job + if ( + command_job.resources is not None + and not isinstance(command_job.resources, Dict) + and command_job.resources.properties is not None + and command_job.resources.properties.get(LOCAL_COMPUTE_PROPERTY, None) + ): + command_job.compute = LOCAL_COMPUTE_TARGET + command_job.resources.properties.pop(LOCAL_COMPUTE_PROPERTY) + return command_job + + def _to_component(self, context: Optional[Dict] = None, **kwargs: Any) -> "CommandComponent": + """Translate a command job to component. + + :param context: Context of command job YAML file. + :type context: dict + :return: Translated command component. + :rtype: CommandComponent + """ + from azure.ai.ml.entities import CommandComponent + + pipeline_job_dict = kwargs.get("pipeline_job_dict", {}) + context = context or {BASE_PATH_CONTEXT_KEY: Path("./")} + + # Create anonymous command component with default version as 1 + return CommandComponent( + tags=self.tags, + is_anonymous=True, + base_path=context[BASE_PATH_CONTEXT_KEY], + code=self.code, + command=self.command, + environment=self.environment, + description=self.description, + inputs=self._to_inputs(inputs=self.inputs, pipeline_job_dict=pipeline_job_dict), + outputs=self._to_outputs(outputs=self.outputs, pipeline_job_dict=pipeline_job_dict), + resources=self.resources if self.resources else None, + distribution=self.distribution if self.distribution else None, + ) + + def _to_node(self, context: Optional[Dict] = None, **kwargs: Any) -> "Command": + """Translate a command job to a pipeline node. + + :param context: Context of command job YAML file. + :type context: dict + :return: Translated command component. + :rtype: Command + """ + from azure.ai.ml.entities._builders import Command + + component = self._to_component(context, **kwargs) + + return Command( + component=component, + compute=self.compute, + # Need to supply the inputs with double curly. + inputs=self.inputs, # type: ignore[arg-type] + outputs=self.outputs, # type: ignore[arg-type] + environment_variables=self.environment_variables, + description=self.description, + tags=self.tags, + display_name=self.display_name, + limits=self.limits, + services=self.services, + properties=self.properties, + identity=self.identity, + queue_settings=self.queue_settings, + ) + + def _validate(self) -> None: + if self.command is None: + msg = "command is required" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.JOB, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.MISSING_FIELD, + ) + if self.environment is None: + msg = "environment is required for non-local runs" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.JOB, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.MISSING_FIELD, + ) + if isinstance(self.environment, Environment): + self.environment.validate() + validate_inputs_for_command(self.command, self.inputs) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/compute_configuration.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/compute_configuration.py new file mode 100644 index 00000000..dcc00825 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/compute_configuration.py @@ -0,0 +1,110 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import json +import logging +from typing import Any, Dict, Optional + +from azure.ai.ml._restclient.v2020_09_01_dataplanepreview.models import ComputeConfiguration as RestComputeConfiguration +from azure.ai.ml.constants._common import LOCAL_COMPUTE_TARGET +from azure.ai.ml.constants._job.job import JobComputePropertyFields +from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin + +module_logger = logging.getLogger(__name__) + + +class ComputeConfiguration(RestTranslatableMixin, DictMixin): + """Compute resource configuration + + :param target: The compute target. + :type target: Optional[str] + :param instance_count: The number of instances. + :type instance_count: Optional[int] + :param is_local: Specifies if the compute will be on the local machine. + :type is_local: Optional[bool] + :param location: The location of the compute resource. + :type location: Optional[str] + :param properties: The resource properties + :type properties: Optional[Dict[str, Any]] + :param deserialize_properties: Specifies if property bag should be deserialized. Defaults to False. + :type deserialize_properties: bool + """ + + def __init__( + self, + *, + target: Optional[str] = None, + instance_count: Optional[int] = None, + is_local: Optional[bool] = None, + instance_type: Optional[str] = None, + location: Optional[str] = None, + properties: Optional[Dict[str, Any]] = None, + deserialize_properties: bool = False, + ) -> None: + self.instance_count = instance_count + self.target = target or LOCAL_COMPUTE_TARGET + self.is_local = is_local or self.target == LOCAL_COMPUTE_TARGET + self.instance_type = instance_type + self.location = location + self.properties = properties + if deserialize_properties and properties and self.properties is not None: + for key, value in self.properties.items(): + try: + self.properties[key] = json.loads(value) + except Exception: # pylint: disable=W0718 + # keep serialized string if load fails + pass + + def _to_rest_object(self) -> RestComputeConfiguration: + if self.properties: + serialized_properties = {} + for key, value in self.properties.items(): + try: + if key.lower() == JobComputePropertyFields.SINGULARITY.lower(): + # Map Singularity -> AISupercomputer in SDK until MFE does mapping + key = JobComputePropertyFields.AISUPERCOMPUTER + # Ensure keymatch is case invariant + elif key.lower() == JobComputePropertyFields.AISUPERCOMPUTER.lower(): + key = JobComputePropertyFields.AISUPERCOMPUTER + serialized_properties[key] = json.dumps(value) + except Exception: # pylint: disable=W0718 + pass + else: + serialized_properties = None + return RestComputeConfiguration( + target=self.target if not self.is_local else None, + is_local=self.is_local, + instance_count=self.instance_count, + instance_type=self.instance_type, + location=self.location, + properties=serialized_properties, + ) + + @classmethod + def _from_rest_object(cls, obj: RestComputeConfiguration) -> "ComputeConfiguration": + return ComputeConfiguration( + target=obj.target, + is_local=obj.is_local, + instance_count=obj.instance_count, + location=obj.location, + instance_type=obj.instance_type, + properties=obj.properties, + deserialize_properties=True, + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ComputeConfiguration): + return NotImplemented + return ( + self.instance_count == other.instance_count + and self.target == other.target + and self.is_local == other.is_local + and self.location == other.location + and self.instance_type == other.instance_type + ) + + def __ne__(self, other: object) -> bool: + if not isinstance(other, ComputeConfiguration): + return NotImplemented + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/data_transfer/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/data_transfer/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/data_transfer/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/data_transfer/data_transfer_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/data_transfer/data_transfer_job.py new file mode 100644 index 00000000..b510da80 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/data_transfer/data_transfer_job.py @@ -0,0 +1,358 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, Optional, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import JobBase +from azure.ai.ml._schema.job.data_transfer_job import ( + DataTransferCopyJobSchema, + DataTransferExportJobSchema, + DataTransferImportJobSchema, +) +from azure.ai.ml.constants import JobType +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE +from azure.ai.ml.constants._component import DataTransferBuiltinComponentUri, DataTransferTaskType, ExternalDataType +from azure.ai.ml.entities._inputs_outputs import Input, Output +from azure.ai.ml.entities._inputs_outputs.external_data import Database, FileSystem +from azure.ai.ml.entities._util import load_from_dict +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + +from ..job import Job +from ..job_io_mixin import JobIOMixin + +# avoid circular import error +if TYPE_CHECKING: + from azure.ai.ml.entities._builders import DataTransferCopy, DataTransferExport, DataTransferImport + from azure.ai.ml.entities._component.datatransfer_component import DataTransferCopyComponent + +module_logger = logging.getLogger(__name__) + + +class DataTransferJob(Job, JobIOMixin): + """DataTransfer job. + + :param name: Name of the job. + :type name: str + :param description: Description of the job. + :type description: str + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: dict[str, str] + :param display_name: Display name of the job. + :type display_name: str + :param properties: The asset property dictionary. + :type properties: dict[str, str] + :param experiment_name: Name of the experiment the job will be created under. + If None is provided, default will be set to current directory name. + :type experiment_name: str + :param services: Information on services associated with the job, readonly. + :type services: dict[str, JobService] + :param inputs: Inputs to the command. + :type inputs: dict[str, Union[azure.ai.ml.Input, str, bool, int, float]] + :param outputs: Mapping of output data bindings used in the job. + :type outputs: dict[str, azure.ai.ml.Output] + :param compute: The compute target the job runs on. + :type compute: str + :param task: task type in data transfer component, possible value is "copy_data". + :type task: str + :param data_copy_mode: data copy mode in copy task, possible value is "merge_with_overwrite", "fail_if_conflict". + :type data_copy_mode: str + :keyword kwargs: A dictionary of additional configuration parameters. + :paramtype kwargs: dict + """ + + def __init__( + self, + task: str, + **kwargs: Any, + ): + kwargs[TYPE] = JobType.DATA_TRANSFER + self._parameters: Dict = kwargs.pop("parameters", {}) + super().__init__(**kwargs) + self.task = task + + @property + def parameters(self) -> Dict: + """MLFlow parameters. + + :return: MLFlow parameters logged in job. + :rtype: Dict[str, str] + """ + return self._parameters + + def _validate(self) -> None: + if self.compute is None: + msg = "compute is required" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.JOB, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.MISSING_FIELD, + ) + + @classmethod + def _load_from_rest(cls, obj: JobBase) -> "DataTransferJob": + # Todo: need update rest api + raise NotImplementedError("Not support submit standalone job for now") + + def _to_rest_object(self) -> JobBase: + # Todo: need update rest api + raise NotImplementedError("Not support submit standalone job for now") + + @classmethod + def _build_source_sink( + cls, io_dict: Optional[Union[Dict, Database, FileSystem]] + ) -> Optional[Union[(Database, FileSystem)]]: + if io_dict is None: + return io_dict + if isinstance(io_dict, (Database, FileSystem)): + component_io = io_dict + else: + if isinstance(io_dict, dict): + data_type = io_dict.pop("type", None) + if data_type == ExternalDataType.DATABASE: + component_io = Database(**io_dict) + elif data_type == ExternalDataType.FILE_SYSTEM: + component_io = FileSystem(**io_dict) + else: + msg = "Type in source or sink only support {} and {}, currently got {}." + raise ValidationException( + message=msg.format( + ExternalDataType.DATABASE, + ExternalDataType.FILE_SYSTEM, + data_type, + ), + no_personal_data_message=msg.format( + ExternalDataType.DATABASE, + ExternalDataType.FILE_SYSTEM, + "data_type", + ), + target=ErrorTarget.DATA_TRANSFER_JOB, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + else: + msg = "Source or sink only support dict, Database and FileSystem" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.DATA_TRANSFER_JOB, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + return component_io + + +class DataTransferCopyJob(DataTransferJob): + def __init__( + self, + *, + inputs: Optional[Dict[str, Union[Input, str]]] = None, + outputs: Optional[Dict[str, Union[Output]]] = None, + data_copy_mode: Optional[str] = None, + **kwargs: Any, + ): + kwargs["task"] = DataTransferTaskType.COPY_DATA + super().__init__(**kwargs) + + self.outputs = outputs # type: ignore[assignment] + self.inputs = inputs # type: ignore[assignment] + self.data_copy_mode = data_copy_mode + + def _to_dict(self) -> Dict: + res: dict = DataTransferCopyJobSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + @classmethod + def _load_from_dict( + cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any + ) -> "DataTransferCopyJob": + loaded_data = load_from_dict(DataTransferCopyJobSchema, data, context, additional_message, **kwargs) + return DataTransferCopyJob(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_data) + + def _to_component(self, context: Optional[Dict] = None, **kwargs: Any) -> "DataTransferCopyComponent": + """Translate a data transfer copy job to component. + + :param context: Context of data transfer job YAML file. + :type context: dict + :return: Translated data transfer copy component. + :rtype: DataTransferCopyComponent + """ + from azure.ai.ml.entities._component.datatransfer_component import DataTransferCopyComponent + + pipeline_job_dict = kwargs.get("pipeline_job_dict", {}) + context = context or {BASE_PATH_CONTEXT_KEY: Path("./")} + + # Create anonymous command component with default version as 1 + return DataTransferCopyComponent( + tags=self.tags, + is_anonymous=True, + base_path=context[BASE_PATH_CONTEXT_KEY], + description=self.description, + inputs=self._to_inputs(inputs=self.inputs, pipeline_job_dict=pipeline_job_dict), + outputs=self._to_outputs(outputs=self.outputs, pipeline_job_dict=pipeline_job_dict), + data_copy_mode=self.data_copy_mode, + ) + + def _to_node(self, context: Optional[Dict] = None, **kwargs: Any) -> "DataTransferCopy": + """Translate a data transfer copy job to a pipeline node. + + :param context: Context of data transfer job YAML file. + :type context: dict + :return: Translated data transfer component. + :rtype: DataTransferCopy + """ + from azure.ai.ml.entities._builders import DataTransferCopy + + component = self._to_component(context, **kwargs) + + return DataTransferCopy( + component=component, + compute=self.compute, + # Need to supply the inputs with double curly. + inputs=self.inputs, # type: ignore[arg-type] + outputs=self.outputs, # type: ignore[arg-type] + description=self.description, + tags=self.tags, + display_name=self.display_name, + ) + + +class DataTransferImportJob(DataTransferJob): + def __init__( + self, + *, + outputs: Optional[Dict[str, Union[Output]]] = None, + source: Optional[Union[Dict, Database, FileSystem]] = None, + **kwargs: Any, + ): + kwargs["task"] = DataTransferTaskType.IMPORT_DATA + super().__init__(**kwargs) + + self.outputs = outputs # type: ignore[assignment] + self.source = self._build_source_sink(source) + + def _to_dict(self) -> Dict: + res: dict = DataTransferImportJobSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + @classmethod + def _load_from_dict( + cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any + ) -> "DataTransferImportJob": + loaded_data = load_from_dict(DataTransferImportJobSchema, data, context, additional_message, **kwargs) + return DataTransferImportJob(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_data) + + def _to_component(self, context: Optional[Dict] = None, **kwargs: Any) -> str: + """Translate a data transfer import job to component. + + :param context: Context of data transfer job YAML file. + :type context: dict + :return: Translated data transfer import component. + :rtype: str + """ + + component: str = "" + if self.source is not None and self.source.type == ExternalDataType.DATABASE: + component = DataTransferBuiltinComponentUri.IMPORT_DATABASE + else: + component = DataTransferBuiltinComponentUri.IMPORT_FILE_SYSTEM + + return component + + def _to_node(self, context: Optional[Dict] = None, **kwargs: Any) -> "DataTransferImport": + """Translate a data transfer import job to a pipeline node. + + :param context: Context of data transfer job YAML file. + :type context: dict + :return: Translated data transfer import node. + :rtype: DataTransferImport + """ + from azure.ai.ml.entities._builders import DataTransferImport + + component = self._to_component(context, **kwargs) + + return DataTransferImport( + component=component, + compute=self.compute, + source=self.source, + outputs=self.outputs, # type: ignore[arg-type] + description=self.description, + tags=self.tags, + display_name=self.display_name, + properties=self.properties, + ) + + +class DataTransferExportJob(DataTransferJob): + def __init__( + self, + *, + inputs: Optional[Dict[str, Union[Input]]] = None, + sink: Optional[Union[Dict, Database, FileSystem]] = None, + **kwargs: Any, + ): + kwargs["task"] = DataTransferTaskType.EXPORT_DATA + super().__init__(**kwargs) + + self.inputs = inputs # type: ignore[assignment] + self.sink = self._build_source_sink(sink) + + def _to_dict(self) -> Dict: + res: dict = DataTransferExportJobSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + @classmethod + def _load_from_dict( + cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any + ) -> "DataTransferExportJob": + loaded_data = load_from_dict(DataTransferExportJobSchema, data, context, additional_message, **kwargs) + return DataTransferExportJob(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_data) + + def _to_component(self, context: Optional[Dict] = None, **kwargs: Any) -> str: + """Translate a data transfer export job to component. + + :param context: Context of data transfer job YAML file. + :type context: dict + :return: Translated data transfer export component. + :rtype: str + """ + component: str = "" + if self.sink is not None and self.sink.type == ExternalDataType.DATABASE: + component = DataTransferBuiltinComponentUri.EXPORT_DATABASE + else: + msg = "Sink is a required field for export data task and we don't support exporting file system for now." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.DATA_TRANSFER_JOB, + error_type=ValidationErrorType.INVALID_VALUE, + ) + return component + + def _to_node(self, context: Optional[Dict] = None, **kwargs: Any) -> "DataTransferExport": + """Translate a data transfer export job to a pipeline node. + + :param context: Context of data transfer job YAML file. + :type context: dict + :return: Translated data transfer export node. + :rtype: DataTransferExport + """ + from azure.ai.ml.entities._builders import DataTransferExport + + component = self._to_component(context, **kwargs) + + return DataTransferExport( + component=component, + compute=self.compute, + sink=self.sink, + inputs=self.inputs, # type: ignore[arg-type] + description=self.description, + tags=self.tags, + display_name=self.display_name, + properties=self.properties, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/constants.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/constants.py new file mode 100644 index 00000000..5084ffbd --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/constants.py @@ -0,0 +1,20 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +class AzureMLDistillationProperties: + ENABLE_DISTILLATION = "azureml.enable_distillation" + DATA_GENERATION_TYPE = "azureml.data_generation_type" + DATA_GENERATION_TASK_TYPE = "azureml.data_generation_task_type" + TEACHER_MODEL = "azureml.teacher_model" + INSTANCE_TYPE = "azureml.instance_type" + CONNECTION_INFORMATION = "azureml.connection_information" + + +class EndpointSettings: + VALID_SETTINGS = {"request_batch_size", "min_endpoint_success_ratio"} + + +class PromptSettingKeys: + VALID_SETTINGS = {"enable_chain_of_thought", "enable_chain_of_density", "max_len_summary"} diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/distillation_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/distillation_job.py new file mode 100644 index 00000000..469fde98 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/distillation_job.py @@ -0,0 +1,542 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import json +from typing import Any, Dict, Optional + +from azure.ai.ml._restclient.v2024_01_01_preview.models import ( + CustomModelFineTuning as RestCustomModelFineTuningVertical, +) +from azure.ai.ml._restclient.v2024_01_01_preview.models import FineTuningJob as RestFineTuningJob +from azure.ai.ml._restclient.v2024_01_01_preview.models import JobBase as RestJobBase +from azure.ai.ml._restclient.v2024_01_01_preview.models import MLFlowModelJobInput, UriFileJobInput +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.constants import DataGenerationType, JobType +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE, AssetTypes +from azure.ai.ml.entities._inputs_outputs import Input +from azure.ai.ml.entities._job._input_output_helpers import from_rest_data_outputs, to_rest_data_outputs +from azure.ai.ml.entities._job.distillation.constants import ( + AzureMLDistillationProperties, + EndpointSettings, + PromptSettingKeys, +) +from azure.ai.ml.entities._job.distillation.endpoint_request_settings import EndpointRequestSettings +from azure.ai.ml.entities._job.distillation.prompt_settings import PromptSettings +from azure.ai.ml.entities._job.distillation.teacher_model_settings import TeacherModelSettings +from azure.ai.ml.entities._job.job import Job +from azure.ai.ml.entities._job.job_io_mixin import JobIOMixin +from azure.ai.ml.entities._job.resource_configuration import ResourceConfiguration +from azure.ai.ml.entities._util import load_from_dict +from azure.ai.ml.entities._workspace.connections.workspace_connection import WorkspaceConnection + + +# pylint: disable=too-many-instance-attributes +@experimental +class DistillationJob(Job, JobIOMixin): + def __init__( + self, + *, + data_generation_type: str, + data_generation_task_type: str, + teacher_model_endpoint_connection: WorkspaceConnection, + student_model: Input, + training_data: Optional[Input] = None, + validation_data: Optional[Input] = None, + teacher_model_settings: Optional[TeacherModelSettings] = None, + prompt_settings: Optional[PromptSettings] = None, + hyperparameters: Optional[Dict] = None, + resources: Optional[ResourceConfiguration] = None, + **kwargs: Any, + ) -> None: + self._data_generation_type = data_generation_type + self._data_generation_task_type = data_generation_task_type + self._teacher_model_endpoint_connection = teacher_model_endpoint_connection + self._student_model = student_model + self._training_data = training_data + self._validation_data = validation_data + self._teacher_model_settings = teacher_model_settings + self._prompt_settings = prompt_settings + self._hyperparameters = hyperparameters + self._resources = resources + + if self._training_data is None and self._data_generation_type == DataGenerationType.LABEL_GENERATION: + raise ValueError( + f"Training data can not be None when data generation type is set to " + f"{DataGenerationType.LABEL_GENERATION}." + ) + + if self._validation_data is None and self._data_generation_type == DataGenerationType.LABEL_GENERATION: + raise ValueError( + f"Validation data can not be None when data generation type is set to " + f"{DataGenerationType.LABEL_GENERATION}." + ) + + kwargs[TYPE] = JobType.DISTILLATION + self._outputs = kwargs.pop("outputs", None) + super().__init__(**kwargs) + + @property + def data_generation_type(self) -> str: + """Get the type of synthetic data generation to perform. + + :return: str representing the type of synthetic data generation to perform. + :rtype: str + """ + return self._data_generation_type + + @data_generation_type.setter + def data_generation_type(self, task: str) -> None: + """Set the data generation task. + + :param task: The data generation task. Possible values include 'Label_Generation' and 'Data_Generation'. + :type task: str + """ + self._data_generation_type = task + + @property + def data_generation_task_type(self) -> str: + """Get the type of synthetic data to generate. + + :return: str representing the type of synthetic data to generate. + :rtype: str + """ + return self._data_generation_task_type + + @data_generation_task_type.setter + def data_generation_task_type(self, task: str) -> None: + """Set the data generation type. + + :param task: The data generation type. Possible values include 'nli', 'nlu_qa', 'conversational', + 'math', and 'summarization'. + :type task: str + """ + self._data_generation_task_type = task + + @property + def teacher_model_endpoint_connection(self) -> WorkspaceConnection: + """Get the endpoint connection of the teacher model to use for data generation. + + :return: Endpoint connection + :rtype: WorkspaceConnection + """ + return self._teacher_model_endpoint_connection + + @teacher_model_endpoint_connection.setter + def teacher_model_endpoint_connection(self, connection: WorkspaceConnection) -> None: + """Set the endpoint information of the teacher model. + + :param connection: Workspace connection + :type connection: WorkspaceConnection + """ + self._teacher_model_endpoint_connection = connection + + @property + def student_model(self) -> Input: + """Get the student model to be trained with synthetic data + + :return: The student model to be finetuned + :rtype: Input + """ + return self._student_model + + @student_model.setter + def student_model(self, model: Input) -> None: + """Set the student model to be trained. + + :param model: The model to use for finetuning + :type model: Input + """ + self._student_model = model + + @property + def training_data(self) -> Optional[Input]: + """Get the training data. + + :return: Training data input + :rtype: typing.Optional[Input] + """ + return self._training_data + + @training_data.setter + def training_data(self, training_data: Optional[Input]) -> None: + """Set the training data. + + :param training_data: Training data input + :type training_data: typing.Optional[Input] + """ + self._training_data = training_data + + @property + def validation_data(self) -> Optional[Input]: + """Get the validation data. + + :return: Validation data input + :rtype: typing.Optional[Input] + """ + return self._validation_data + + @validation_data.setter + def validation_data(self, validation_data: Optional[Input]) -> None: + """Set the validation data. + + :param validation_data: Validation data input + :type validation_data: typing.Optional[Input] + """ + self._validation_data = validation_data + + @property + def teacher_model_settings(self) -> Optional[TeacherModelSettings]: + """Get the teacher model settings. + + :return: The settings for the teacher model to use. + :rtype: typing.Optional[TeacherModelSettings] + """ + return self._teacher_model_settings + + @property + def prompt_settings(self) -> Optional[PromptSettings]: + """Get the settings for the prompt. + + :return: The settings for the prompt. + :rtype: typing.Optional[PromptSettings] + """ + return self._prompt_settings + + @property + def hyperparameters(self) -> Optional[Dict]: + """Get the finetuning hyperparameters. + + :return: The finetuning hyperparameters. + :rtype: typing.Optional[typing.Dict] + """ + return self._hyperparameters + + @property + def resources(self) -> Optional[ResourceConfiguration]: + """Get the resources for data generation. + + :return: The resources for data generation. + :rtype: typing.Optional[ResourceConfiguration] + """ + return self._resources + + @resources.setter + def resources(self, resource: Optional[ResourceConfiguration]) -> None: + """Set the resources for data generation. + + :param resource: The resources for data generation. + :type resource: typing.Optional[ResourceConfiguration] + """ + self._resources = resource + + def set_teacher_model_settings( + self, + inference_parameters: Optional[Dict] = None, + endpoint_request_settings: Optional[EndpointRequestSettings] = None, + ): + """Set settings related to the teacher model. + + :param inference_parameters: Settings the teacher model uses during inferencing. + :type inference_parameters: typing.Optional[typing.Dict] + :param endpoint_request_settings: Settings for inference requests to the endpoint + :type endpoint_request_settings: typing.Optional[EndpointRequestSettings] + """ + self._teacher_model_settings = TeacherModelSettings( + inference_parameters=inference_parameters, endpoint_request_settings=endpoint_request_settings + ) + + def set_prompt_settings(self, prompt_settings: Optional[PromptSettings]): + """Set settings related to the system prompt used for generating data. + + :param prompt_settings: Settings related to the system prompt used for generating data. + :type prompt_settings: typing.Optional[PromptSettings] + """ + self._prompt_settings = prompt_settings if prompt_settings is not None else self._prompt_settings + + def set_finetuning_settings(self, hyperparameters: Optional[Dict]): + """Set the hyperparamters for finetuning. + + :param hyperparameters: The hyperparameters for finetuning. + :type hyperparameters: typing.Optional[typing.Dict] + """ + self._hyperparameters = hyperparameters if hyperparameters is not None else self._hyperparameters + + def _to_dict(self) -> Dict: + """Convert the object to a dictionary. + + :return: dictionary representation of the object. + :rtype: typing.Dict + """ + from azure.ai.ml._schema._distillation.distillation_job import DistillationJobSchema + + schema_dict: dict = {} + schema_dict = DistillationJobSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + + return schema_dict + + @classmethod + def _load_from_dict( + cls, + data: Dict, + context: Dict, + additional_message: str, + **kwargs: Any, + ) -> "DistillationJob": + """Load from a dictionary. + + :param data: dictionary representation of the object. + :type data: typing.Dict + :param context: dictionary containing the context. + :type context: typing.Dict + :param additional_message: additional message to be added to the error message. + :type additional_message: str + :return: DistillationJob object. + :rtype: DistillationJob + """ + from azure.ai.ml._schema._distillation.distillation_job import DistillationJobSchema + + loaded_data = load_from_dict(DistillationJobSchema, data, context, additional_message, **kwargs) + + training_data = loaded_data.get("training_data", None) + if isinstance(training_data, str): + loaded_data["training_data"] = Input(type="uri_file", path=training_data) + + validation_data = loaded_data.get("validation_data", None) + if isinstance(validation_data, str): + loaded_data["validation_data"] = Input(type="uri_file", path=validation_data) + + student_model = loaded_data.get("student_model", None) + if isinstance(student_model, str): + loaded_data["student_model"] = Input(type=AssetTypes.URI_FILE, path=student_model) + + job_instance = DistillationJob(**loaded_data) + return job_instance + + @classmethod + def _from_rest_object(cls, obj: RestJobBase) -> "DistillationJob": + """Convert a REST object to DistillationJob object. + + :param obj: CustomModelFineTuningJob in Rest format. + :type obj: JobBase + :return: DistillationJob objects. + :rtype: DistillationJob + """ + properties: RestFineTuningJob = obj.properties + finetuning_details: RestCustomModelFineTuningVertical = properties.fine_tuning_details + + job_kwargs_dict = DistillationJob._filter_properties(properties=properties.properties) + + job_args_dict = { + "id": obj.id, + "name": obj.name, + "description": properties.description, + "tags": properties.tags, + "properties": properties.properties, + "experiment_name": properties.experiment_name, + "services": properties.services, + "status": properties.status, + "creation_context": obj.system_data, + "display_name": properties.display_name, + "outputs": from_rest_data_outputs(properties.outputs), + } + + distillation_job = cls( + student_model=finetuning_details.model, + training_data=finetuning_details.training_data, + validation_data=finetuning_details.validation_data, + hyperparameters=finetuning_details.hyper_parameters, + **job_kwargs_dict, + **job_args_dict, + ) + + distillation_job._restore_inputs() + + return distillation_job + + def _to_rest_object(self) -> "RestFineTuningJob": + """Convert DistillationJob object to a RestFineTuningJob object. + + :return: REST object representation of this object. + :rtype: JobBase + """ + distillation = RestCustomModelFineTuningVertical( + task_type="ChatCompletion", + model=self.student_model, + model_provider="Custom", + training_data=self.training_data, + validation_data=self.validation_data, + hyper_parameters=self._hyperparameters, + ) + + if isinstance(distillation.training_data, Input): + distillation.training_data = UriFileJobInput(uri=distillation.training_data.path) + if isinstance(distillation.validation_data, Input): + distillation.validation_data = UriFileJobInput(uri=distillation.validation_data.path) + if isinstance(distillation.model, Input): + distillation.model = MLFlowModelJobInput(uri=distillation.model.path) + + self._add_distillation_properties(self.properties) + + finetuning_job = RestFineTuningJob( + display_name=self.display_name, + description=self.description, + experiment_name=self.experiment_name, + services=self.services, + tags=self.tags, + properties=self.properties, + fine_tuning_details=distillation, + outputs=to_rest_data_outputs(self.outputs), + ) + + result = RestJobBase(properties=finetuning_job) + result.name = self.name + + return result + + @classmethod + def _load_from_rest(cls, obj: RestJobBase) -> "DistillationJob": + """Loads the rest object to a dict containing items to init the AutoMLJob objects. + + :param obj: Azure Resource Manager resource envelope. + :type obj: JobBase + :raises ValidationException: task type validation error + :return: A DistillationJob + :rtype: DistillationJob + """ + return DistillationJob._from_rest_object(obj) + + # TODO: Remove once Distillation is added to MFE + def _add_distillation_properties(self, properties: Dict) -> None: + """Adds DistillationJob attributes to properties to pass into the FT Overloaded API property bag + + :param properties: Current distillation properties + :type properties: typing.Dict + """ + properties[AzureMLDistillationProperties.ENABLE_DISTILLATION] = True + properties[AzureMLDistillationProperties.DATA_GENERATION_TASK_TYPE] = self._data_generation_task_type.upper() + properties[f"{AzureMLDistillationProperties.TEACHER_MODEL}.endpoint_name"] = ( + self._teacher_model_endpoint_connection.name + ) + + # Not needed for FT Overload API but additional info needed to convert from REST object to Distillation object + properties[AzureMLDistillationProperties.DATA_GENERATION_TYPE] = self._data_generation_type + properties[AzureMLDistillationProperties.CONNECTION_INFORMATION] = json.dumps( + self._teacher_model_endpoint_connection._to_dict() # pylint: disable=protected-access + ) + + if self._prompt_settings: + for setting, value in self._prompt_settings.items(): + if value is not None: + properties[f"azureml.{setting.strip('_')}"] = value + + if self._teacher_model_settings: + inference_settings = self._teacher_model_settings.inference_parameters + endpoint_settings = self._teacher_model_settings.endpoint_request_settings + + if inference_settings: + for inference_key, value in inference_settings.items(): + if value is not None: + properties[f"{AzureMLDistillationProperties.TEACHER_MODEL}.{inference_key}"] = value + + if endpoint_settings: + for setting, value in endpoint_settings.items(): + if value is not None: + properties[f"azureml.{setting.strip('_')}"] = value + + if self._resources and self._resources.instance_type: + properties[f"{AzureMLDistillationProperties.INSTANCE_TYPE}.data_generation"] = self._resources.instance_type + + # TODO: Remove once Distillation is added to MFE + @classmethod + def _filter_properties(cls, properties: Dict) -> Dict: + """Convert properties from REST object back to their original states. + + :param properties: Properties from a REST object + :type properties: typing.Dict + :return: A dict that can be used to create a DistillationJob + :rtype: typing.Dict + """ + inference_parameters = {} + endpoint_settings = {} + prompt_settings = {} + resources = {} + teacher_settings = {} + teacher_model_info = "" + for key, val in properties.items(): + param = key.split(".")[-1] + if AzureMLDistillationProperties.TEACHER_MODEL in key and param != "endpoint_name": + inference_parameters[param] = val + elif AzureMLDistillationProperties.INSTANCE_TYPE in key: + resources[key.split(".")[1]] = val + elif AzureMLDistillationProperties.CONNECTION_INFORMATION in key: + teacher_model_info = val + else: + if param in EndpointSettings.VALID_SETTINGS: + endpoint_settings[param] = val + elif param in PromptSettingKeys.VALID_SETTINGS: + prompt_settings[param] = val + + if inference_parameters: + teacher_settings["inference_parameters"] = inference_parameters + if endpoint_settings: + teacher_settings["endpoint_request_settings"] = EndpointRequestSettings(**endpoint_settings) # type: ignore + + return { + "data_generation_task_type": properties.get(AzureMLDistillationProperties.DATA_GENERATION_TASK_TYPE), + "data_generation_type": properties.get(AzureMLDistillationProperties.DATA_GENERATION_TYPE), + "teacher_model_endpoint_connection": WorkspaceConnection._load( # pylint: disable=protected-access + data=json.loads(teacher_model_info) + ), + "teacher_model_settings": ( + TeacherModelSettings(**teacher_settings) if teacher_settings else None # type: ignore + ), + "prompt_settings": PromptSettings(**prompt_settings) if prompt_settings else None, + "resources": ResourceConfiguration(**resources) if resources else None, + } + + def _restore_inputs(self) -> None: + """Restore UriFileJobInputs to JobInputs within data_settings.""" + if isinstance(self.training_data, UriFileJobInput): + self.training_data = Input(type=AssetTypes.URI_FILE, path=self.training_data.uri) + if isinstance(self.validation_data, UriFileJobInput): + self.validation_data = Input(type=AssetTypes.URI_FILE, path=self.validation_data.uri) + if isinstance(self.student_model, MLFlowModelJobInput): + self.student_model = Input(type=AssetTypes.MLFLOW_MODEL, path=self.student_model.uri) + + def __eq__(self, other: object) -> bool: + """Returns True if both instances have the same values. + + This method check instances equality and returns True if both of + the instances have the same attributes with the same values. + + :param other: Any object + :type other: object + :return: True or False + :rtype: bool + """ + if not isinstance(other, DistillationJob): + return False + return ( + super().__eq__(other) + and self.data_generation_type == other.data_generation_type + and self.data_generation_task_type == other.data_generation_task_type + and self.teacher_model_endpoint_connection.name == other.teacher_model_endpoint_connection.name + and self.student_model == other.student_model + and self.training_data == other.training_data + and self.validation_data == other.validation_data + and self.teacher_model_settings == other.teacher_model_settings + and self.prompt_settings == other.prompt_settings + and self.hyperparameters == other.hyperparameters + and self.resources == other.resources + ) + + def __ne__(self, other: object) -> bool: + """Check inequality between two DistillationJob objects. + + :param other: Any object + :type other: object + :return: True or False + :rtype: bool + """ + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/endpoint_request_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/endpoint_request_settings.py new file mode 100644 index 00000000..89fb8015 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/endpoint_request_settings.py @@ -0,0 +1,90 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from typing import Optional + +from azure.ai.ml._utils._experimental import experimental + + +@experimental +class EndpointRequestSettings: + def __init__(self, *, request_batch_size: Optional[int] = None, min_endpoint_success_ratio: Optional[float] = None): + """Initialize EndpointRequestSettings. + + :param request_batch_size: The number of requests to send to the teacher model endpoint as a batch, + defaults to None + :type request_batch_size: typing.Optional[int], optional + :param min_endpoint_success_ratio: The ratio of (successful requests / total requests) needed for the + data generation step to be considered successful. Must be a value between 0 and 1 inclusive, + defaults to None + :type min_endpoint_success_ratio: typing.Optional[float], optional + """ + self._request_batch_size = request_batch_size + self._min_endpoint_success_ratio = min_endpoint_success_ratio + + @property + def request_batch_size(self) -> Optional[int]: + """Get the number of inference requests to send to the teacher model as a batch. + + :return: The number of inference requests to send to the teacher model as a batch. + :rtype: typing.Optional[int] + """ + return self._request_batch_size + + @request_batch_size.setter + def request_batch_size(self, value: Optional[int]) -> None: + """Set the number of inference requests to send to the teacher model as a batch. + + :param value: The number of inference requests to send to the teacher model as a batch. + :type value: typing.Optional[int] + """ + self._request_batch_size = value + + @property + def min_endpoint_success_ratio(self) -> Optional[float]: + """Get the minimum ratio of successful inferencing requests. + + :return: The minimum ratio of successful inferencing requests. + :rtype: typing.Optional[float] + """ + return self._min_endpoint_success_ratio + + @min_endpoint_success_ratio.setter + def min_endpoint_success_ratio(self, ratio: Optional[float]) -> None: + """Set the minimum ratio of successful inferencing requests. + + :param ratio: The minimum ratio of successful inferencing requests. + :type ratio: typing.Optional[float] + """ + self._min_endpoint_success_ratio = ratio + + def items(self): + return self.__dict__.items() + + def __eq__(self, other: object) -> bool: + """Returns True if both instances have the same values. + + This method check instances equality and returns True if both of + the instances have the same attributes with the same values. + + :param other: Any object + :type other: object + :return: True or False + :rtype: bool + """ + if not isinstance(other, EndpointRequestSettings): + return False + return ( + self.request_batch_size == other.request_batch_size + and self.min_endpoint_success_ratio == other.min_endpoint_success_ratio + ) + + def __ne__(self, other: object) -> bool: + """Check inequality between two EndpointRequestSettings objects. + + :param other: Any object + :type other: object + :return: True or False + :rtype: bool + """ + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/prompt_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/prompt_settings.py new file mode 100644 index 00000000..d74af748 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/prompt_settings.py @@ -0,0 +1,138 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from typing import Optional + +from azure.ai.ml._utils._experimental import experimental + + +@experimental +class PromptSettings: + def __init__( + self, + *, + enable_chain_of_thought: bool = False, + enable_chain_of_density: bool = False, + max_len_summary: Optional[int] = None, + # custom_prompt: Optional[str] = None + ): + """Initialize PromptSettings. + + :param enable_chain_of_thought: Whether or not to enable chain of thought which modifies the system prompt + used. Can be used for all `data_generation_task_type` values except `SUMMARIZATION`, defaults to False + :type enable_chain_of_thought: bool, optional + :param enable_chain_of_density: Whether or not to enable chain of density which modifies the system prompt + used. Can only be used for `data_generation_task_type` of `SUMMARIZATION`, defaults to False + :type enable_chain_of_density: bool, optional + :param max_len_summary: The maximum length of the summary generated for data_generation_task_type` of + `SUMMARIZATION`, defaults to None + :type max_len_summary: typing.Optional[int] + """ + self._enable_chain_of_thought = enable_chain_of_thought + self._enable_chain_of_density = enable_chain_of_density + self._max_len_summary = max_len_summary + # self._custom_prompt = custom_prompt + + @property + def enable_chain_of_thought(self) -> bool: + """Get whether or not chain of thought is enabled. + + :return: Whether or not chain of thought is enabled. + :rtype: bool + """ + return self._enable_chain_of_thought + + @enable_chain_of_thought.setter + def enable_chain_of_thought(self, value: bool) -> None: + """Set chain of thought. + + :param value: Whether or not chain of thought is enabled. + :type value: bool + """ + self._enable_chain_of_thought = value + + @property + def enable_chain_of_density(self) -> bool: + """Get whether or not chain of density is enabled. + + :return: Whether or not chain of thought is enabled + :rtype: bool + """ + return self._enable_chain_of_density + + @enable_chain_of_density.setter + def enable_chain_of_density(self, value: bool) -> None: + """Set whether or not chain of thought is enabled. + + :param value: Whether or not chain of thought is enabled + :type value: bool + """ + self._enable_chain_of_density = value + + @property + def max_len_summary(self) -> Optional[int]: + """The number of tokens to use for summarization. + + :return: The number of tokens to use for summarization + :rtype: typing.Optional[int] + """ + return self._max_len_summary + + @max_len_summary.setter + def max_len_summary(self, length: Optional[int]) -> None: + """Set the number of tokens to use for summarization. + + :param length: The number of tokens to use for summarization. + :type length: typing.Optional[int] + """ + self._max_len_summary = length + + # @property + # def custom_prompt(self) -> Optional[str]: + # """Get the custom system prompt to use for inferencing. + # :return: The custom prompt to use for inferencing. + # :rtype: Optional[str] + # """ + # return self._custom_prompt + + # @custom_prompt.setter + # def custom_prompt(self, prompt: Optional[str]) -> None: + # """Set the custom prompt to use for inferencing. + + # :param prompt: The custom prompt to use for inferencing. + # :type prompt: Optional[str] + # """ + # self._custom_prompt = prompt + + def items(self): + return self.__dict__.items() + + def __eq__(self, other: object) -> bool: + """Returns True if both instances have the same values. + + This method check instances equality and returns True if both of + the instances have the same attributes with the same values. + + :param other: Any object + :type other: object + :return: True or False + :rtype: bool + """ + if not isinstance(other, PromptSettings): + return False + return ( + self.enable_chain_of_thought == other.enable_chain_of_thought + and self.enable_chain_of_density == other.enable_chain_of_density + and self.max_len_summary == other.max_len_summary + # self.custom_prompt == other.custom_prompt + ) + + def __ne__(self, other: object) -> bool: + """Check inequality between two PromptSettings objects. + + :param other: Any object + :type other: object + :return: True or False + :rtype: bool + """ + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/teacher_model_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/teacher_model_settings.py new file mode 100644 index 00000000..481800de --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/teacher_model_settings.py @@ -0,0 +1,93 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from typing import Dict, Optional + +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.entities._job.distillation.endpoint_request_settings import EndpointRequestSettings + + +@experimental +class TeacherModelSettings: + def __init__( + self, + *, + inference_parameters: Optional[Dict] = None, + endpoint_request_settings: Optional[EndpointRequestSettings] = None, + ): + """Initialize TeacherModelSettings + + :param inference_parameters: The inference parameters inferencing requests will use, defaults to None + :type inference_parameters: typing.Optional[typing.Dict], optional + :param endpoint_request_settings: The settings to use for the endpoint, defaults to None + :type endpoint_request_settings: typing.Optional[EndpointRequestSettings], optional + """ + self._inference_parameters = inference_parameters + self._endpoint_request_settings = endpoint_request_settings + + @property + def inference_parameters(self) -> Optional[Dict]: + """Get the inference parameters. + + :return: The inference parameters. + :rtype: typing.Optional[typing.Dict] + """ + return self._inference_parameters + + @inference_parameters.setter + def inference_parameters(self, params: Optional[Dict]) -> None: + """Set the inference parameters. + + :param params: Inference parameters. + :type params: typing.Optional[typing.Dict] + """ + self._inference_parameters = params + + @property + def endpoint_request_settings(self) -> Optional[EndpointRequestSettings]: + """Get the endpoint request settings. + + :return: The endpoint request settings. + :rtype: typing.Optional[EndpointRequestSettings] + """ + return self._endpoint_request_settings + + @endpoint_request_settings.setter + def endpoint_request_settings(self, endpoint_settings: Optional[EndpointRequestSettings]) -> None: + """Set the endpoint request settings. + + :param endpoint_settings: Endpoint request settings + :type endpoint_settings: typing.Optional[EndpointRequestSettings] + """ + self._endpoint_request_settings = endpoint_settings + + def items(self): + return self.__dict__.items() + + def __eq__(self, other: object) -> bool: + """Returns True if both instances have the same values. + + This method check instances equality and returns True if both of + the instances have the same attributes with the same values. + + :param other: Any object + :type other: object + :return: True or False + :rtype: bool + """ + if not isinstance(other, TeacherModelSettings): + return False + return ( + self.inference_parameters == other.inference_parameters + and self.endpoint_request_settings == other.endpoint_request_settings + ) + + def __ne__(self, other: object) -> bool: + """Check inequality between two TeacherModelSettings objects. + + :param other: Any object + :type other: object + :return: True or False + :rtype: bool + """ + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distribution.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distribution.py new file mode 100644 index 00000000..ec7277c6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distribution.py @@ -0,0 +1,229 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from typing import Any, Dict, Optional, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + DistributionConfiguration as RestDistributionConfiguration, +) +from azure.ai.ml._restclient.v2023_04_01_preview.models import DistributionType as RestDistributionType +from azure.ai.ml._restclient.v2023_04_01_preview.models import Mpi as RestMpi +from azure.ai.ml._restclient.v2023_04_01_preview.models import PyTorch as RestPyTorch +from azure.ai.ml._restclient.v2023_04_01_preview.models import Ray as RestRay +from azure.ai.ml._restclient.v2023_04_01_preview.models import TensorFlow as RestTensorFlow +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.constants import DistributionType +from azure.ai.ml.entities._mixins import RestTranslatableMixin + +SDK_TO_REST = { + DistributionType.MPI: RestDistributionType.MPI, + DistributionType.TENSORFLOW: RestDistributionType.TENSOR_FLOW, + DistributionType.PYTORCH: RestDistributionType.PY_TORCH, + DistributionType.RAY: RestDistributionType.RAY, +} + + +class DistributionConfiguration(RestTranslatableMixin): + """Distribution configuration for a component or job. + + This class is not meant to be instantiated directly. Instead, use one of its subclasses. + """ + + def __init__(self, **kwargs: Any) -> None: + self.type: Any = None + + @classmethod + def _from_rest_object( + cls, obj: Optional[Union[RestDistributionConfiguration, Dict]] + ) -> Optional["DistributionConfiguration"]: + """Constructs a DistributionConfiguration object from a REST object + + This function works for distribution property of a Job object and of a Component object() + + Distribution of Job when returned by MFE, is a RestDistributionConfiguration + + Distribution of Component when returned by MFE, is a Dict. + e.g. {'type': 'Mpi', 'process_count_per_instance': '1'} + + So in the job distribution case, we need to call as_dict() first and get type from "distribution_type" property. + In the componenet case, we need to extract type from key "type" + + + :param obj: The object to translate + :type obj: Optional[Union[RestDistributionConfiguration, Dict]] + :return: The distribution configuration + :rtype: DistributionConfiguration + """ + if obj is None: + return None + + if isinstance(obj, dict): + data = obj + else: + data = obj.as_dict() + + type_str = data.pop("distribution_type", None) or data.pop("type", None) + klass = DISTRIBUTION_TYPE_MAP[type_str.lower()] + res: DistributionConfiguration = klass(**data) + return res + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, DistributionConfiguration): + return NotImplemented + res: bool = self._to_rest_object() == other._to_rest_object() + return res + + +class MpiDistribution(DistributionConfiguration): + """MPI distribution configuration. + + :keyword process_count_per_instance: The number of processes per node. + :paramtype process_count_per_instance: Optional[int] + :ivar type: Specifies the type of distribution. Set automatically to "mpi" for this class. + :vartype type: str + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START mpi_distribution_configuration] + :end-before: [END mpi_distribution_configuration] + :language: python + :dedent: 8 + :caption: Configuring a CommandComponent with an MpiDistribution. + """ + + def __init__(self, *, process_count_per_instance: Optional[int] = None, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.type = DistributionType.MPI + self.process_count_per_instance = process_count_per_instance + + def _to_rest_object(self) -> RestMpi: + return RestMpi(process_count_per_instance=self.process_count_per_instance) + + +class PyTorchDistribution(DistributionConfiguration): + """PyTorch distribution configuration. + + :keyword process_count_per_instance: The number of processes per node. + :paramtype process_count_per_instance: Optional[int] + :ivar type: Specifies the type of distribution. Set automatically to "pytorch" for this class. + :vartype type: str + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START pytorch_distribution_configuration] + :end-before: [END pytorch_distribution_configuration] + :language: python + :dedent: 8 + :caption: Configuring a CommandComponent with a PyTorchDistribution. + """ + + def __init__(self, *, process_count_per_instance: Optional[int] = None, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.type = DistributionType.PYTORCH + self.process_count_per_instance = process_count_per_instance + + def _to_rest_object(self) -> RestPyTorch: + return RestPyTorch(process_count_per_instance=self.process_count_per_instance) + + +class TensorFlowDistribution(DistributionConfiguration): + """TensorFlow distribution configuration. + + :vartype distribution_type: str or ~azure.mgmt.machinelearningservices.models.DistributionType + :keyword parameter_server_count: The number of parameter server tasks. Defaults to 0. + :paramtype parameter_server_count: Optional[int] + :keyword worker_count: The number of workers. Defaults to the instance count. + :paramtype worker_count: Optional[int] + :ivar parameter_server_count: Number of parameter server tasks. + :vartype parameter_server_count: int + :ivar worker_count: Number of workers. If not specified, will default to the instance count. + :vartype worker_count: int + :ivar type: Specifies the type of distribution. Set automatically to "tensorflow" for this class. + :vartype type: str + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START tensorflow_distribution_configuration] + :end-before: [END tensorflow_distribution_configuration] + :language: python + :dedent: 8 + :caption: Configuring a CommandComponent with a TensorFlowDistribution. + """ + + def __init__( + self, *, parameter_server_count: Optional[int] = 0, worker_count: Optional[int] = None, **kwargs: Any + ) -> None: + super().__init__(**kwargs) + self.type = DistributionType.TENSORFLOW + self.parameter_server_count = parameter_server_count + self.worker_count = worker_count + + def _to_rest_object(self) -> RestTensorFlow: + return RestTensorFlow(parameter_server_count=self.parameter_server_count, worker_count=self.worker_count) + + +@experimental +class RayDistribution(DistributionConfiguration): + """Ray distribution configuration. + + :vartype distribution_type: str or ~azure.mgmt.machinelearningservices.models.DistributionType + :ivar port: The port of the head ray process. + :vartype port: int + :ivar address: The address of Ray head node. + :vartype address: str + :ivar include_dashboard: Provide this argument to start the Ray dashboard GUI. + :vartype include_dashboard: bool + :ivar dashboard_port: The port to bind the dashboard server to. + :vartype dashboard_port: int + :ivar head_node_additional_args: Additional arguments passed to ray start in head node. + :vartype head_node_additional_args: str + :ivar worker_node_additional_args: Additional arguments passed to ray start in worker node. + :vartype worker_node_additional_args: str + :ivar type: Specifies the type of distribution. Set automatically to "Ray" for this class. + :vartype type: str + """ + + def __init__( + self, + *, + port: Optional[int] = None, + address: Optional[str] = None, + include_dashboard: Optional[bool] = None, + dashboard_port: Optional[int] = None, + head_node_additional_args: Optional[str] = None, + worker_node_additional_args: Optional[str] = None, + **kwargs: Any + ): + super().__init__(**kwargs) + self.type = DistributionType.RAY + + self.port = port + self.address = address + self.include_dashboard = include_dashboard + self.dashboard_port = dashboard_port + self.head_node_additional_args = head_node_additional_args + self.worker_node_additional_args = worker_node_additional_args + + def _to_rest_object(self) -> RestRay: + return RestRay( + port=self.port, + address=self.address, + include_dashboard=self.include_dashboard, + dashboard_port=self.dashboard_port, + head_node_additional_args=self.head_node_additional_args, + worker_node_additional_args=self.worker_node_additional_args, + ) + + +DISTRIBUTION_TYPE_MAP = { + DistributionType.MPI: MpiDistribution, + DistributionType.TENSORFLOW: TensorFlowDistribution, + DistributionType.PYTORCH: PyTorchDistribution, + DistributionType.RAY: RayDistribution, +} diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/azure_openai_finetuning_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/azure_openai_finetuning_job.py new file mode 100644 index 00000000..e659c634 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/azure_openai_finetuning_job.py @@ -0,0 +1,242 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from typing import Any, Dict + +from azure.ai.ml._restclient.v2024_01_01_preview.models import ( + ModelProvider as RestModelProvider, + AzureOpenAiFineTuning as RestAzureOpenAIFineTuning, + FineTuningJob as RestFineTuningJob, + JobBase as RestJobBase, +) +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY +from azure.ai.ml.entities._job._input_output_helpers import from_rest_data_outputs, to_rest_data_outputs + +from azure.ai.ml.entities._job.finetuning.finetuning_vertical import FineTuningVertical +from azure.ai.ml.entities._job.finetuning.azure_openai_hyperparameters import AzureOpenAIHyperparameters +from azure.ai.ml.entities._util import load_from_dict +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException +from azure.ai.ml._utils._experimental import experimental + + +@experimental +class AzureOpenAIFineTuningJob(FineTuningVertical): + def __init__( + self, + **kwargs: Any, + ) -> None: + # Extract any task specific settings + model = kwargs.pop("model", None) + task = kwargs.pop("task", None) + # Convert task to lowercase first letter, this is when we create + # object from the schema, using dict object from the REST api response. + # TextCompletion => textCompletion + if task: + task = task[0].lower() + task[1:] + training_data = kwargs.pop("training_data", None) + validation_data = kwargs.pop("validation_data", None) + hyperparameters = kwargs.pop("hyperparameters", None) + if hyperparameters and not isinstance(hyperparameters, AzureOpenAIHyperparameters): + raise ValidationException( + category=ErrorCategory.USER_ERROR, + target=ErrorTarget.JOB, + message="Hyperparameters if provided should of type AzureOpenAIHyperparameters", + no_personal_data_message="Hyperparameters if provided should of type AzureOpenAIHyperparameters", + ) + + self._hyperparameters = hyperparameters + + super().__init__( + task=task, + model=model, + model_provider=RestModelProvider.AZURE_OPEN_AI, + training_data=training_data, + validation_data=validation_data, + **kwargs, + ) + + @property + def hyperparameters(self) -> AzureOpenAIHyperparameters: + """Get hyperparameters. + + :return: Hyperparameters for finetuning the model. + :rtype: AzureOpenAIHyperparameters + """ + return self._hyperparameters + + @hyperparameters.setter + def hyperparameters(self, hyperparameters: AzureOpenAIHyperparameters) -> None: + """Set hyperparameters. + + :param hyperparameters: Hyperparameters for finetuning the model. + :type hyperparameters: AzureOpenAiHyperParameters + """ + self._hyperparameters = hyperparameters + + def _to_rest_object(self) -> "RestFineTuningJob": + """Convert CustomFineTuningVertical object to a RestFineTuningJob object. + + :return: REST object representation of this object. + :rtype: JobBase + """ + aoai_finetuning_vertical = RestAzureOpenAIFineTuning( + task_type=self._task, + model=self._model, + model_provider=self._model_provider, + training_data=self._training_data, + validation_data=self._validation_data, + hyper_parameters=self.hyperparameters._to_rest_object() if self.hyperparameters else None, + ) + + self._resolve_inputs(aoai_finetuning_vertical) + + finetuning_job = RestFineTuningJob( + display_name=self.display_name, + description=self.description, + experiment_name=self.experiment_name, + tags=self.tags, + properties=self.properties, + fine_tuning_details=aoai_finetuning_vertical, + outputs=to_rest_data_outputs(self.outputs), + ) + + result = RestJobBase(properties=finetuning_job) + result.name = self.name + + return result + + def _to_dict(self) -> Dict: + """Convert the object to a dictionary. + + :return: dictionary representation of the object. + :rtype: typing.Dict + """ + from azure.ai.ml._schema._finetuning.azure_openai_finetuning import AzureOpenAIFineTuningSchema + + schema_dict: dict = {} + # TODO: Combeback to this later for FineTuningJob in Pipelines + # if inside_pipeline: + # schema_dict = AutoMLClassificationNodeSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + # else: + schema_dict = AzureOpenAIFineTuningSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + + return schema_dict + + def __eq__(self, other: object) -> bool: + """Returns True if both instances have the same values. + + This method check instances equality and returns True if both of + the instances have the same attributes with the same values. + + :param other: Any object + :type other: object + :return: True or False + :rtype: bool + """ + if not isinstance(other, AzureOpenAIFineTuningJob): + return NotImplemented + + return super().__eq__(other) and self.hyperparameters == other.hyperparameters + + def __ne__(self, other: object) -> bool: + """Check inequality between two AzureOpenAIFineTuningJob objects. + + :param other: Any object + :type other: object + :return: True or False + :rtype: bool + """ + return not self.__eq__(other) + + @classmethod + def _from_rest_object(cls, obj: RestJobBase) -> "AzureOpenAIFineTuningJob": + """Convert a REST object to AzureOpenAIFineTuningJob object. + + :param obj: AzureOpenAIFineTuningJob in Rest format. + :type obj: JobBase + :return: AzureOpenAIFineTuningJob objects. + :rtype: AzureOpenAIFineTuningJob + """ + + properties: RestFineTuningJob = obj.properties + finetuning_details: RestAzureOpenAIFineTuning = properties.fine_tuning_details + + job_args_dict = { + "id": obj.id, + "name": obj.name, + "description": properties.description, + "tags": properties.tags, + "properties": properties.properties, + "experiment_name": properties.experiment_name, + "status": properties.status, + "creation_context": obj.system_data, + "display_name": properties.display_name, + "outputs": from_rest_data_outputs(properties.outputs), + } + + aoai_finetuning_job = cls( + task=finetuning_details.task_type, + model=finetuning_details.model, + training_data=finetuning_details.training_data, + validation_data=finetuning_details.validation_data, + hyperparameters=AzureOpenAIHyperparameters._from_rest_object(finetuning_details.hyper_parameters), + **job_args_dict, + ) + + aoai_finetuning_job._restore_inputs() + + return aoai_finetuning_job + + @classmethod + def _load_from_dict( + cls, + data: Dict, + context: Dict, + additional_message: str, + **kwargs: Any, + ) -> "AzureOpenAIFineTuningJob": + """Load from a dictionary. + + :param data: dictionary representation of the object. + :type data: typing.Dict + :param context: dictionary containing the context. + :type context: typing.Dict + :param additional_message: additional message to be added to the error message. + :type additional_message: str + :return: AzureOpenAIFineTuningJob object. + :rtype: AzureOpenAIFineTuningJob + """ + from azure.ai.ml._schema._finetuning.azure_openai_finetuning import AzureOpenAIFineTuningSchema + + # TODO: Combeback to this later - Pipeline part. + # from azure.ai.ml._schema.pipeline.automl_node import AutoMLClassificationNodeSchema + + # if kwargs.pop("inside_pipeline", False): + # loaded_data = load_from_dict( + # AutoMLClassificationNodeSchema, + # data, + # context, + # additional_message, + # **kwargs, + # ) + # else: + loaded_data = load_from_dict(AzureOpenAIFineTuningSchema, data, context, additional_message, **kwargs) + + job_instance = cls._create_instance_from_schema_dict(loaded_data) + return job_instance + + @classmethod + def _create_instance_from_schema_dict(cls, loaded_data: Dict) -> "AzureOpenAIFineTuningJob": + """Create an instance from a schema dictionary. + + :param loaded_data: dictionary containing the data. + :type loaded_data: typing.Dict + :return: AzureOpenAIFineTuningJob object. + :rtype: AzureOpenAIFineTuningJob + """ + + job = AzureOpenAIFineTuningJob(**loaded_data) + return job diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/azure_openai_hyperparameters.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/azure_openai_hyperparameters.py new file mode 100644 index 00000000..2b420a46 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/azure_openai_hyperparameters.py @@ -0,0 +1,125 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Optional +from azure.ai.ml.entities._mixins import RestTranslatableMixin +from azure.ai.ml._restclient.v2024_01_01_preview.models import ( + AzureOpenAiHyperParameters as RestAzureOpenAiHyperParameters, +) +from azure.ai.ml._utils._experimental import experimental + + +@experimental +class AzureOpenAIHyperparameters(RestTranslatableMixin): + """Hyperparameters for Azure OpenAI model finetuning.""" + + def __init__( + self, + *, + batch_size: Optional[int] = None, + learning_rate_multiplier: Optional[float] = None, + n_epochs: Optional[int] = None, + ): + """Initialize AzureOpenAIHyperparameters. + + param batch_size: Number of examples in each batch. + A larger batch size means that model parameters are updated less + frequently, but with lower variance. Defaults to None. + type batch_size: int + param learning_rate_multiplier: Scaling factor for the learning rate. + A smaller learning rate may be useful to avoid overfitting. + type learning_rate_multiplier: float + param n_epochs: The number of epochs to train the model for. + An epoch refers to one full cycle through the training dataset. + type n_epochs: int + """ + self._batch_size = batch_size + self._learning_rate_multiplier = learning_rate_multiplier + self._n_epochs = n_epochs + # Not exposed in the public API, so need to check how to handle this + # self._additional_properties = kwargs + + @property + def batch_size(self) -> Optional[int]: + """Get the batch size for training.""" + return self._batch_size + + @batch_size.setter + def batch_size(self, value: Optional[int]) -> None: + """Set the batch size for training. + :param value: The batch size for training. + :type value: int + """ + self._batch_size = value + + @property + def learning_rate_multiplier(self) -> Optional[float]: + """Get the learning rate multiplier. + :return: The learning rate multiplier. + :rtype: float + """ + return self._learning_rate_multiplier + + @learning_rate_multiplier.setter + def learning_rate_multiplier(self, value: Optional[float]) -> None: + """Set the learning rate multiplier. + :param value: The learning rate multiplier. + :type value: float + """ + self._learning_rate_multiplier = value + + @property + def n_epochs(self) -> Optional[int]: + """Get the number of epochs. + :return: The number of epochs. + :rtype: int + """ + return self._n_epochs + + @n_epochs.setter + def n_epochs(self, value: Optional[int]) -> None: + """Set the number of epochs. + :param value: The number of epochs. + :type value: int + """ + self._n_epochs = value + + # Not exposed in the public API, so need to check how to handle this + # @property + # def additional_properties(self) -> dict: + # """Get additional properties.""" + # return self._additional_properties + + # @additional_properties.setter + # def additional_properties(self, value: dict) -> None: + # """Set additional properties.""" + # self._additional_properties = value + + def _to_rest_object(self) -> RestAzureOpenAiHyperParameters: + return RestAzureOpenAiHyperParameters( + batch_size=self._batch_size, + learning_rate_multiplier=self._learning_rate_multiplier, + n_epochs=self._n_epochs, + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, AzureOpenAIHyperparameters): + return NotImplemented + return ( + self._batch_size == other._batch_size + and self._learning_rate_multiplier == other._learning_rate_multiplier + and self._n_epochs == other._n_epochs + ) + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + @classmethod + def _from_rest_object(cls, obj: RestAzureOpenAiHyperParameters) -> "AzureOpenAIHyperparameters": + aoai_hyperparameters = cls( + batch_size=obj.batch_size, + learning_rate_multiplier=obj.learning_rate_multiplier, + n_epochs=obj.n_epochs, + ) + return aoai_hyperparameters diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/custom_model_finetuning_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/custom_model_finetuning_job.py new file mode 100644 index 00000000..e6ddd86d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/custom_model_finetuning_job.py @@ -0,0 +1,258 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from typing import Any, Dict + +from azure.ai.ml._restclient.v2024_10_01_preview.models import ( + ModelProvider as RestModelProvider, + CustomModelFineTuning as RestCustomModelFineTuningVertical, + FineTuningJob as RestFineTuningJob, + JobBase as RestJobBase, +) +from azure.ai.ml.entities._job._input_output_helpers import ( + from_rest_data_outputs, + to_rest_data_outputs, +) +from azure.ai.ml.entities._job.job_resources import JobResources +from azure.ai.ml.entities._job.queue_settings import QueueSettings +from azure.ai.ml.entities._inputs_outputs import Input +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY +from azure.ai.ml.entities._job.finetuning.finetuning_vertical import FineTuningVertical +from azure.ai.ml.entities._util import load_from_dict +from azure.ai.ml._utils._experimental import experimental + + +@experimental +class CustomModelFineTuningJob(FineTuningVertical): + def __init__( + self, + **kwargs: Any, + ) -> None: + # Extract any task specific settings + model = kwargs.pop("model", None) + task = kwargs.pop("task", None) + # Convert task to lowercase first letter, this is when we create + # object from the schema, using dict object from the REST api response. + # TextCompletion => textCompletion + if task: + task = task[0].lower() + task[1:] + training_data = kwargs.pop("training_data", None) + validation_data = kwargs.pop("validation_data", None) + self._hyperparameters = kwargs.pop("hyperparameters", None) + super().__init__( + task=task, + model=model, + model_provider=RestModelProvider.CUSTOM, + training_data=training_data, + validation_data=validation_data, + **kwargs, + ) + + @property + def hyperparameters(self) -> Dict[str, str]: + """Get hyperparameters. + + :return: + :rtype: hyperparameters: Dict[str,str] + """ + return self._hyperparameters + + @hyperparameters.setter + def hyperparameters(self, hyperparameters: Dict[str, str]) -> None: + """Set hyperparameters. + + :param hyperparameters: Hyperparameters for finetuning the model + :type hyperparameters: Dict[str,str] + """ + self._hyperparameters = hyperparameters + + def _to_rest_object(self) -> "RestFineTuningJob": + """Convert CustomFineTuningVertical object to a RestFineTuningJob object. + + :return: REST object representation of this object. + :rtype: JobBase + """ + custom_finetuning_vertical = RestCustomModelFineTuningVertical( + task_type=self._task, + model=self._model, + model_provider=self._model_provider, + training_data=self._training_data, + validation_data=self._validation_data, + hyper_parameters=self._hyperparameters, + ) + self._resolve_inputs(custom_finetuning_vertical) + + finetuning_job = RestFineTuningJob( + display_name=self.display_name, + description=self.description, + experiment_name=self.experiment_name, + services=self.services, + tags=self.tags, + properties=self.properties, + compute_id=self.compute, + fine_tuning_details=custom_finetuning_vertical, + outputs=to_rest_data_outputs(self.outputs), + ) + if self.resources: + finetuning_job.resources = self.resources._to_rest_object() + if self.queue_settings: + finetuning_job.queue_settings = self.queue_settings._to_rest_object() + + result = RestJobBase(properties=finetuning_job) + result.name = self.name + + return result + + def _to_dict(self) -> Dict: + """Convert the object to a dictionary. + + :return: dictionary representation of the object. + :rtype: typing.Dict + """ + from azure.ai.ml._schema._finetuning.custom_model_finetuning import ( + CustomModelFineTuningSchema, + ) + + schema_dict: dict = {} + # TODO: Combeback to this later for FineTuningJob in pipeline + # if inside_pipeline: + # schema_dict = AutoMLClassificationNodeSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + # else: + schema_dict = CustomModelFineTuningSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + + return schema_dict + + def __eq__(self, other: object) -> bool: + """Returns True if both instances have the same values. + + This method check instances equality and returns True if both of + the instances have the same attributes with the same values. + + :param other: Any object + :type other: object + :return: True or False + :rtype: bool + """ + if not isinstance(other, CustomModelFineTuningJob): + return NotImplemented + + return super().__eq__(other) and self.hyperparameters == other.hyperparameters + + def __ne__(self, other: object) -> bool: + """Check inequality between two CustomModelFineTuningJob objects. + + :param other: Any object + :type other: object + :return: True or False + :rtype: bool + """ + return not self.__eq__(other) + + @classmethod + def _from_rest_object(cls, obj: RestJobBase) -> "CustomModelFineTuningJob": + """Convert a REST object to CustomModelFineTuningJob object. + + :param obj: CustomModelFineTuningJob in Rest format. + :type obj: JobBase + :return: CustomModelFineTuningJob objects. + :rtype: CustomModelFineTuningJob + """ + + properties: RestFineTuningJob = obj.properties + finetuning_details: RestCustomModelFineTuningVertical = properties.fine_tuning_details + + job_args_dict = { + "id": obj.id, + "name": obj.name, + "description": properties.description, + "tags": properties.tags, + "properties": properties.properties, + "services": properties.services, + "experiment_name": properties.experiment_name, + "status": properties.status, + "creation_context": obj.system_data, + "display_name": properties.display_name, + "compute": properties.compute_id, + "outputs": from_rest_data_outputs(properties.outputs), + } + + if properties.resources: + job_args_dict["resources"] = JobResources._from_rest_object(properties.resources) + if properties.queue_settings: + job_args_dict["queue_settings"] = QueueSettings._from_rest_object(properties.queue_settings) + + custom_model_finetuning_job = cls( + task=finetuning_details.task_type, + model=finetuning_details.model, + training_data=finetuning_details.training_data, + validation_data=finetuning_details.validation_data, + hyperparameters=finetuning_details.hyper_parameters, + **job_args_dict, + ) + + custom_model_finetuning_job._restore_inputs() + + return custom_model_finetuning_job + + @classmethod + def _load_from_dict( + cls, + data: Dict, + context: Dict, + additional_message: str, + **kwargs: Any, + ) -> "CustomModelFineTuningJob": + """Load from a dictionary. + + :param data: dictionary representation of the object. + :type data: typing.Dict + :param context: dictionary containing the context. + :type context: typing.Dict + :param additional_message: additional message to be added to the error message. + :type additional_message: str + :return: CustomModelFineTuningJob object. + :rtype: CustomModelFineTuningJob + """ + from azure.ai.ml._schema._finetuning.custom_model_finetuning import ( + CustomModelFineTuningSchema, + ) + + # TODO: Combeback to this later - Pipeline part. + # from azure.ai.ml._schema.pipeline.automl_node import AutoMLClassificationNodeSchema + + # if kwargs.pop("inside_pipeline", False): + # loaded_data = load_from_dict( + # AutoMLClassificationNodeSchema, + # data, + # context, + # additional_message, + # **kwargs, + # ) + # else: + loaded_data = load_from_dict(CustomModelFineTuningSchema, data, context, additional_message, **kwargs) + + training_data = loaded_data.get("training_data", None) + if isinstance(training_data, str): + loaded_data["training_data"] = Input(type="uri_file", path=training_data) + + validation_data = loaded_data.get("validation_data", None) + if isinstance(validation_data, str): + loaded_data["validation_data"] = Input(type="uri_file", path=validation_data) + + job_instance = cls._create_instance_from_schema_dict(loaded_data) + return job_instance + + @classmethod + def _create_instance_from_schema_dict(cls, loaded_data: Dict) -> "CustomModelFineTuningJob": + """Create an instance from a schema dictionary. + + :param loaded_data: dictionary containing the data. + :type loaded_data: typing.Dict + :return: CustomModelFineTuningJob object. + :rtype: CustomModelFineTuningJob + """ + job = CustomModelFineTuningJob(**loaded_data) + return job diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/finetuning_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/finetuning_job.py new file mode 100644 index 00000000..ec8d9d5d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/finetuning_job.py @@ -0,0 +1,224 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from typing import Any, Dict, Optional + +from azure.ai.ml.entities._job.job import Job +from azure.ai.ml.entities._job.job_io_mixin import JobIOMixin +from azure.ai.ml._restclient.v2024_10_01_preview.models import ( + ModelProvider as RestModelProvider, + JobBase as RestJobBase, +) +from azure.ai.ml.constants import JobType +from azure.ai.ml.constants._common import TYPE +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.entities._job.job_resources import JobResources +from azure.ai.ml.entities._job.queue_settings import QueueSettings +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException +from azure.ai.ml.constants._job.finetuning import FineTuningConstants +from azure.ai.ml._utils._experimental import experimental + + +@experimental +class FineTuningJob(Job, JobIOMixin): + def __init__( + self, + **kwargs: Any, + ) -> None: + kwargs[TYPE] = JobType.FINE_TUNING + self.resources = kwargs.pop("resources", None) + self.queue_settings = kwargs.pop("queue_settings", None) + self.outputs = kwargs.pop("outputs", None) + super().__init__(**kwargs) + + @property + def resources(self) -> Optional[JobResources]: + """Job resources to use during job execution. + :return: Job Resources object. + :rtype: JobResources + """ + return self._resources if hasattr(self, "_resources") else None + + @resources.setter + def resources(self, value: JobResources) -> None: + """Set JobResources. + + :param value: JobResources object. + :type value: JobResources + :raises ValidationException: Expected a JobResources object. + """ + if isinstance(value, JobResources): + self._resources = value + elif value: + msg = "Expected an instance of JobResources." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.FINETUNING, + error_category=ErrorCategory.USER_ERROR, + ) + + @property + def queue_settings(self) -> Optional[QueueSettings]: + """Queue settings for job execution. + :return: QueueSettings object. + :rtype: QueueSettings + """ + return self._queue_settings if hasattr(self, "_queue_settings") else None + + @queue_settings.setter + def queue_settings(self, value: QueueSettings) -> None: + """Set queue settings for job execution. + + :param value: QueueSettings object. + :type value: QueueSettings + :raises ValidationException: Expected a QueueSettings object. + """ + if isinstance(value, QueueSettings): + self._queue_settings = value + elif value: + msg = "Expected an instance of QueueSettings." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.FINETUNING, + error_category=ErrorCategory.USER_ERROR, + ) + + def __eq__(self, other: object) -> bool: + """Returns True if both instances have the same values. + + This method check instances equality and returns True if both of + the instances have the same attributes with the same values. + + :param other: Any object + :type other: object + :return: True or False + :rtype: bool + """ + if not isinstance(other, FineTuningJob): + return NotImplemented + + queue_settings_match = (not self.queue_settings and not other.queue_settings) or ( + self.queue_settings is not None + and other.queue_settings is not None + and self.queue_settings.job_tier is not None + and other.queue_settings.job_tier is not None + and self.queue_settings.job_tier.lower() == other.queue_settings.job_tier.lower() + ) + + outputs_match = not self.outputs and not other.outputs + if self.outputs and other.outputs: + outputs_match = ( + self.outputs["registered_model"].name == other.outputs["registered_model"].name + and self.outputs["registered_model"].type == other.outputs["registered_model"].type + ) + + return ( + outputs_match + and self.resources == other.resources + and queue_settings_match + # add properties from base class + and self.name == other.name + and self.description == other.description + and self.tags == other.tags + and self.properties == other.properties + and self.compute == other.compute + and self.id == other.id + and self.experiment_name == other.experiment_name + and self.status == other.status + ) + + def __ne__(self, other: object) -> bool: + """Check inequality between two FineTuningJob objects. + + :param other: Any object + :type other: object + :return: True or False + :rtype: bool + """ + return not self.__eq__(other) + + @classmethod + def _get_model_provider_mapping(cls) -> Dict: + """Create a mapping of task type to job class. + + :return: An FineTuningVertical object containing the model provider type to job class mapping. + :rtype: FineTuningJob + """ + from .custom_model_finetuning_job import CustomModelFineTuningJob + from .azure_openai_finetuning_job import AzureOpenAIFineTuningJob + + return { + camel_to_snake(RestModelProvider.CUSTOM): CustomModelFineTuningJob, + camel_to_snake(RestModelProvider.AZURE_OPEN_AI): AzureOpenAIFineTuningJob, + } + + @classmethod + def _load_from_rest(cls, obj: RestJobBase) -> "FineTuningJob": + """Loads the rest object to a dict containing items to init the AutoMLJob objects. + + :param obj: Azure Resource Manager resource envelope. + :type obj: JobBase + :raises ValidationException: task type validation error + :return: A FineTuningJob + :rtype: FineTuningJob + """ + model_provider = ( + camel_to_snake(obj.properties.fine_tuning_details.model_provider) + if obj.properties.fine_tuning_details.model_provider + else None + ) + class_type = cls._get_model_provider_mapping().get(model_provider, None) + if class_type: + res: FineTuningJob = class_type._from_rest_object(obj) + return res + msg = f"Unsupported model provider type: {obj.properties.fine_tuning_details.model_provider}" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.FINETUNING, + error_category=ErrorCategory.SYSTEM_ERROR, + ) + + @classmethod + def _load_from_dict( + cls, + data: Dict, + context: Dict, + additional_message: str, + **kwargs: Any, + ) -> "FineTuningJob": + """Loads the dictionary objects to an FineTuningJob object. + + :param data: A data dictionary. + :type data: typing.Dict + :param context: A context dictionary. + :type context: typing.Dict + :param additional_message: An additional message to be logged in the ValidationException. + :type additional_message: str + + :raises ValidationException: task type validation error + :return: An FineTuningJob + :rtype: FineTuningJob + """ + model_provider = data.get(FineTuningConstants.ModelProvider) + class_type = cls._get_model_provider_mapping().get(model_provider, None) + if class_type: + res: FineTuningJob = class_type._load_from_dict( + data, + context, + additional_message, + **kwargs, + ) + return res + msg = f"Unsupported model provider type: {model_provider}" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.AUTOML, + error_category=ErrorCategory.USER_ERROR, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/finetuning_vertical.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/finetuning_vertical.py new file mode 100644 index 00000000..c9a5fe41 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/finetuning_vertical.py @@ -0,0 +1,202 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Any, Optional, cast + +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException +from azure.ai.ml._restclient.v2024_10_01_preview.models import ( + ModelProvider as RestModelProvider, + FineTuningVertical as RestFineTuningVertical, + UriFileJobInput, + MLFlowModelJobInput, +) +from azure.ai.ml.constants._common import AssetTypes +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.entities._inputs_outputs import Input +from azure.ai.ml.entities._job.finetuning.finetuning_job import FineTuningJob + +from azure.ai.ml._utils._experimental import experimental + + +@experimental +class FineTuningVertical(FineTuningJob): + def __init__( + self, + *, + task: str, + model: Input, + model_provider: Optional[str], + training_data: Input, + validation_data: Optional[Input] = None, + **kwargs: Any, + ) -> None: + self._task = task + self._model = model + self._model_provider = model_provider + self._training_data = training_data + self._validation_data = validation_data + super().__init__(**kwargs) + + @property + def task(self) -> str: + """Get finetuning task. + + :return: The type of task to run. Possible values include: "ChatCompletion" + "TextCompletion", "TextClassification", "QuestionAnswering","TextSummarization", + "TokenClassification", "TextTranslation", "ImageClassification", "ImageInstanceSegmentation", + "ImageObjectDetection","VideoMultiObjectTracking". + + :rtype: str + """ + return self._task + + @task.setter + def task(self, task: str) -> None: + """Set finetuning task. + + :param task: The type of task to run. Possible values include: "ChatCompletion" + "TextCompletion", "TextClassification", "QuestionAnswering","TextSummarization", + "TokenClassification", "TextTranslation", "ImageClassification", "ImageInstanceSegmentation", + "ImageObjectDetection","VideoMultiObjectTracking",. + :type task: str + + :return: None + """ + self._task = task + + @property + def model(self) -> Optional[Input]: + """The model to be fine-tuned. + :return: Input object representing the mlflow model to be fine-tuned. + :rtype: Input + """ + return self._model + + @model.setter + def model(self, value: Input) -> None: + """Set the model to be fine-tuned. + + :param value: Input object representing the mlflow model to be fine-tuned. + :type value: Input + :raises ValidationException: Expected a mlflow model input. + """ + if isinstance(value, Input) and (cast(Input, value).type in ("mlflow_model", "custom_model")): + self._model = value + else: + msg = "Expected a mlflow model input or custom model input." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.FINETUNING, + error_category=ErrorCategory.USER_ERROR, + ) + + @property + def model_provider(self) -> Optional[str]: + """The model provider. + :return: The model provider. + :rtype: str + """ + return self._model_provider + + @model_provider.setter + def model_provider(self, value: str) -> None: + """Set the model provider. + + :param value: The model provider. + :type value: str + """ + self._model_provider = RestModelProvider[camel_to_snake(value).upper()] if value else None + + @property + def training_data(self) -> Input: + """Get training data. + + :return: Training data input + :rtype: Input + """ + return self._training_data + + @training_data.setter + def training_data(self, training_data: Input) -> None: + """Set training data. + + :param training_data: Training data input + :type training_data: Input + """ + self._training_data = training_data + + @property + def validation_data(self) -> Optional[Input]: + """Get validation data. + + :return: Validation data input + :rtype: Input + """ + return self._validation_data + + @validation_data.setter + def validation_data(self, validation_data: Input) -> None: + """Set validation data. + + :param validation_data: Validation data input + :type validation_data: Input + """ + self._validation_data = validation_data + + def _resolve_inputs(self, rest_job: RestFineTuningVertical) -> None: + """Resolve JobInputs to UriFileJobInput within data_settings. + + :param rest_job: The rest job object. + :type rest_job: RestFineTuningVertical + """ + if isinstance(rest_job.training_data, Input): + rest_job.training_data = UriFileJobInput(uri=rest_job.training_data.path) + if isinstance(rest_job.validation_data, Input): + rest_job.validation_data = UriFileJobInput(uri=rest_job.validation_data.path) + if isinstance(rest_job.model, Input): + rest_job.model = MLFlowModelJobInput(uri=rest_job.model.path) + + def _restore_inputs(self) -> None: + """Restore UriFileJobInputs to JobInputs within data_settings.""" + if isinstance(self.training_data, UriFileJobInput): + self.training_data = Input(type=AssetTypes.URI_FILE, path=self.training_data.uri) + if isinstance(self.validation_data, UriFileJobInput): + self.validation_data = Input(type=AssetTypes.URI_FILE, path=self.validation_data.uri) + if isinstance(self.model, MLFlowModelJobInput): + self.model = Input(type=AssetTypes.MLFLOW_MODEL, path=self.model.uri) + + def __eq__(self, other: object) -> bool: + """Returns True if both instances have the same values. + + This method check instances equality and returns True if both of + the instances have the same attributes with the same values. + + :param other: Any object + :type other: object + :return: True or False + :rtype: bool + """ + if not isinstance(other, FineTuningVertical): + return NotImplemented + + return ( + # TODO: Equality from base class does not work, no current precedence for this + super().__eq__(other) + and self.task == other.task + and self.model == other.model + and self.model_provider == other.model_provider + and self.training_data == other.training_data + and self.validation_data == other.validation_data + ) + + def __ne__(self, other: object) -> bool: + """Check inequality between two FineTuningJob objects. + + :param other: Any object + :type other: object + :return: True or False + :rtype: bool + """ + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/import_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/import_job.py new file mode 100644 index 00000000..24d4ec90 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/import_job.py @@ -0,0 +1,285 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import logging +from abc import ABC, abstractmethod +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, Optional + +from azure.ai.ml._restclient.v2022_02_01_preview.models import CommandJob as RestCommandJob +from azure.ai.ml._restclient.v2022_02_01_preview.models import JobBaseData +from azure.ai.ml._schema.job.import_job import ImportJobSchema +from azure.ai.ml._utils.utils import is_private_preview_enabled +from azure.ai.ml.constants import JobType +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE +from azure.ai.ml.entities._inputs_outputs import Output +from azure.ai.ml.entities._job._input_output_helpers import ( + from_rest_data_outputs, + from_rest_inputs_to_dataset_literal, + to_rest_data_outputs, + to_rest_dataset_literal_inputs, +) +from azure.ai.ml.entities._job.job_io_mixin import JobIOMixin +from azure.ai.ml.entities._util import load_from_dict +from azure.ai.ml.exceptions import MlException + +from .job import Job + +# avoid circular import error +if TYPE_CHECKING: + from azure.ai.ml.entities._builders import Import + from azure.ai.ml.entities._component.import_component import ImportComponent + +module_logger = logging.getLogger(__name__) + + +class ImportSource(ABC): + def __init__( + self, + *, + type: Optional[str] = None, # pylint: disable=redefined-builtin + connection: Optional[str] = None, + ): + self.type = type + self.connection = connection + + @abstractmethod + def _to_job_inputs(self) -> Dict[str, Optional[str]]: + pass + + @classmethod + def _from_job_inputs(cls, job_inputs: Dict[str, str]) -> "ImportSource": + """Translate job inputs to import source. + + :param job_inputs: The job inputs + :type job_inputs: Dict[str, str] + :return: The import source + :rtype: ImportSource + """ + type = job_inputs.get("type") # pylint: disable=redefined-builtin + connection = job_inputs.get("connection") + query = job_inputs.get("query") + path = job_inputs.get("path") + + import_source = ( + DatabaseImportSource(type=type, connection=connection, query=query) + if query is not None + else FileImportSource(type=type, connection=connection, path=path) + ) + return import_source + + +class DatabaseImportSource(ImportSource): + def __init__( + self, + *, + type: Optional[str] = None, # pylint: disable=redefined-builtin + connection: Optional[str] = None, + query: Optional[str] = None, + ): + ImportSource.__init__( + self, + type=type, + connection=connection, + ) + self.query = query + + def _to_job_inputs(self) -> Dict[str, Optional[str]]: + """Translate source to command Inputs. + + :return: The job inputs dict + :rtype: Dict[str, str] + """ + inputs = { + "type": self.type, + "connection": self.connection, + "query": self.query, + } + return inputs + + +class FileImportSource(ImportSource): + def __init__( + self, + *, + type: Optional[str] = None, # pylint: disable=redefined-builtin + connection: Optional[str] = None, + path: Optional[str] = None, + ): + ImportSource.__init__( + self, + type=type, + connection=connection, + ) + self.path = path + + def _to_job_inputs(self) -> Dict[str, Optional[str]]: + """Translate source to command Inputs. + + :return: The job inputs dict + :rtype: Dict[str, str] + """ + inputs = { + "type": self.type, + "connection": self.connection, + "path": self.path, + } + return inputs + + +class ImportJob(Job, JobIOMixin): + """Import job. + + :param name: Name of the job. + :type name: str + :param description: Description of the job. + :type description: str + :param display_name: Display name of the job. + :type display_name: str + :param experiment_name: Name of the experiment the job will be created under. + If None is provided, default will be set to current directory name. + :type experiment_name: str + :param source: Input source parameters to the import job. + :type source: azure.ai.ml.entities.DatabaseImportSource or FileImportSource + :param output: output data binding used in the job. + :type output: azure.ai.ml.Output + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + """ + + def __init__( + self, + *, + name: Optional[str] = None, + description: Optional[str] = None, + display_name: Optional[str] = None, + experiment_name: Optional[str] = None, + source: Optional[ImportSource] = None, + output: Optional[Output] = None, + **kwargs: Any, + ): + kwargs[TYPE] = JobType.IMPORT + + Job.__init__( + self, + name=name, + display_name=display_name, + description=description, + experiment_name=experiment_name, + **kwargs, + ) + + self.source = source + self.output = output + + def _to_dict(self) -> Dict: + res: dict = ImportJobSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + def _to_rest_object(self) -> JobBaseData: + # TODO: Remove in PuP + if not is_private_preview_enabled(): + msg = JobType.IMPORT + " job not supported." + raise MlException(message=msg, no_personal_data_message=msg) + + _inputs = self.source._to_job_inputs() if self.source is not None else None # pylint: disable=protected-access + if self.compute is None: + msg = "compute cannot be None." + raise MlException(message=msg, no_personal_data_message=msg) + + properties = RestCommandJob( + display_name=self.display_name, + description=self.description, + compute_id=self.compute, + experiment_name=self.experiment_name, + inputs=to_rest_dataset_literal_inputs(_inputs, job_type=self.type), + outputs=to_rest_data_outputs({"output": self.output}), + # TODO: Remove in PuP with native import job/component type support in MFE/Designer + # No longer applicable once new import job type is ready on MFE in PuP + # command and environment are required as we use command type for import + # command can be random string and the particular environment name here is defined as default in MFE + # public const string DefaultEnvironmentName = "AzureML-sklearn-0.24-ubuntu18.04-py37-cpu"; + # which is considered valid environment in MFE unless MFE changes current default logic + # but chance should be very low in PrP + command="import", + environment_id=self.compute.replace( + "/computes/DataFactory", "/environments/AzureML-sklearn-0.24-ubuntu18.04-py37-cpu" + ), + ) + result = JobBaseData(properties=properties) + result.name = self.name + return result + + @classmethod + def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "ImportJob": + loaded_data = load_from_dict(ImportJobSchema, data, context, additional_message, **kwargs) + return ImportJob(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_data) + + @classmethod + def _load_from_rest(cls, obj: JobBaseData) -> "ImportJob": + rest_command_job: RestCommandJob = obj.properties + outputs = from_rest_data_outputs(rest_command_job.outputs) + inputs = from_rest_inputs_to_dataset_literal(rest_command_job.inputs) + + import_job = ImportJob( + name=obj.name, + id=obj.id, + display_name=rest_command_job.display_name, + description=rest_command_job.description, + experiment_name=rest_command_job.experiment_name, + status=rest_command_job.status, + creation_context=obj.system_data, + source=ImportSource._from_job_inputs(inputs), # pylint: disable=protected-access + output=outputs["output"] if "output" in outputs else None, + ) + return import_job + + def _to_component(self, context: Optional[Dict] = None, **kwargs: Any) -> "ImportComponent": + """Translate a import job to component. + + :param context: Context of import job YAML file. + :type context: dict + :return: Translated import component. + :rtype: ImportComponent + """ + from azure.ai.ml.entities._component.import_component import ImportComponent + + pipeline_job_dict = kwargs.get("pipeline_job_dict", {}) + context = context or {BASE_PATH_CONTEXT_KEY: Path("import/")} + + _inputs = self.source._to_job_inputs() if self.source is not None else None # pylint: disable=protected-access + + # Create anonymous command component with default version as 1 + return ImportComponent( + is_anonymous=True, + base_path=context[BASE_PATH_CONTEXT_KEY], + description=self.description, + source=self._to_inputs( + inputs=_inputs, + pipeline_job_dict=pipeline_job_dict, + ), + output=self._to_outputs(outputs={"output": self.output}, pipeline_job_dict=pipeline_job_dict)["output"], + ) + + def _to_node(self, context: Optional[Dict] = None, **kwargs: Any) -> "Import": + """Translate a import job to a pipeline node. + + :param context: Context of import job YAML file. + :type context: dict + :return: Translated import node. + :rtype: Import + """ + from azure.ai.ml.entities._builders import Import + + component = self._to_component(context, **kwargs) + _inputs = self.source._to_job_inputs() if self.source is not None else None # pylint: disable=protected-access + return Import( + component=component, + compute=self.compute, + inputs=_inputs, + outputs={"output": self.output}, + description=self.description, + display_name=self.display_name, + properties=self.properties, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/input_output_entry.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/input_output_entry.py new file mode 100644 index 00000000..aa0e73b1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/input_output_entry.py @@ -0,0 +1,27 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import collections.abc +import logging +from typing import Any, Optional, Union + +from azure.ai.ml.constants import InputOutputModes +from azure.ai.ml.entities._assets import Data +from azure.ai.ml.entities._mixins import DictMixin + +module_logger = logging.getLogger(__name__) + + +class InputOutputEntry(DictMixin): + def __init__( + self, # pylint: disable=unused-argument + data: Optional[Union[str, "Data"]] = None, + mode: Optional[str] = InputOutputModes.MOUNT, + **kwargs: Any, + ): + # Data will be either a dataset id, inline dataset definition + self.data = data + self.mode = mode + if isinstance(self.data, collections.abc.Mapping) and not isinstance(self.data, Data): + self.data = Data(**self.data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/input_port.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/input_port.py new file mode 100644 index 00000000..7953bbde --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/input_port.py @@ -0,0 +1,18 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import logging +from typing import Optional, Union + +module_logger = logging.getLogger(__name__) + + +class InputPort: + def __init__(self, *, type_string: str, default: Optional[str] = None, optional: Optional[bool] = False): + self.type_string = type_string + self.optional = optional + if self.type_string == "number" and default is not None: + self.default: Union[float, Optional[str]] = float(default) + else: + self.default = default diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job.py new file mode 100644 index 00000000..b181636e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job.py @@ -0,0 +1,363 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import json +import logging +import traceback +from abc import abstractmethod +from collections import OrderedDict +from os import PathLike +from pathlib import Path +from typing import IO, Any, AnyStr, Dict, List, Optional, Tuple, Type, Union + +from azure.ai.ml._restclient.runhistory.models import Run +from azure.ai.ml._restclient.v2023_04_01_preview.models import JobBase, JobService +from azure.ai.ml._restclient.v2023_04_01_preview.models import JobType as RestJobType +from azure.ai.ml._restclient.v2024_01_01_preview.models import JobBase as JobBase_2401 +from azure.ai.ml._restclient.v2024_01_01_preview.models import JobType as RestJobType_20240101Preview +from azure.ai.ml._utils._html_utils import make_link, to_html +from azure.ai.ml._utils.utils import dump_yaml_to_file +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY, CommonYamlFields +from azure.ai.ml.constants._compute import ComputeType +from azure.ai.ml.constants._job.job import JobServices, JobType +from azure.ai.ml.entities._mixins import TelemetryMixin +from azure.ai.ml.entities._resource import Resource +from azure.ai.ml.entities._util import find_type_in_override +from azure.ai.ml.exceptions import ( + ErrorCategory, + ErrorTarget, + JobException, + JobParsingError, + PipelineChildJobError, + ValidationErrorType, + ValidationException, +) + +from ._studio_url_from_job_id import studio_url_from_job_id +from .pipeline._component_translatable import ComponentTranslatableMixin + +module_logger = logging.getLogger(__name__) + + +def _is_pipeline_child_job(job: JobBase) -> bool: + # pipeline child job has no properties, so we can check through testing job.properties + # if backend has spec changes, this method need to be updated + return job.properties is None + + +class Job(Resource, ComponentTranslatableMixin, TelemetryMixin): + """Base class for jobs. + + This class should not be instantiated directly. Instead, use one of its subclasses. + + :param name: The name of the job. + :type name: Optional[str] + :param display_name: The display name of the job. + :type display_name: Optional[str] + :param description: The description of the job. + :type description: Optional[str] + :param tags: Tag dictionary. Tags can be added, removed, and updated. + :type tags: Optional[dict[str, str]] + :param properties: The job property dictionary. + :type properties: Optional[dict[str, str]] + :param experiment_name: The name of the experiment the job will be created under. Defaults to the name of the + current directory. + :type experiment_name: Optional[str] + :param services: Information on services associated with the job. + :type services: Optional[dict[str, ~azure.ai.ml.entities.JobService]] + :param compute: Information about the compute resources associated with the job. + :type compute: Optional[str] + """ + + def __init__( + self, + name: Optional[str] = None, + display_name: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[Dict] = None, + properties: Optional[Dict] = None, + experiment_name: Optional[str] = None, + compute: Optional[str] = None, + services: Optional[Dict[str, JobService]] = None, + **kwargs: Any, + ) -> None: + self._type: Optional[str] = kwargs.pop("type", JobType.COMMAND) + self._status: Optional[str] = kwargs.pop("status", None) + self._log_files: Optional[Dict] = kwargs.pop("log_files", None) + + super().__init__( + name=name, + description=description, + tags=tags, + properties=properties, + **kwargs, + ) + + self.display_name = display_name + self.experiment_name = experiment_name + self.compute: Any = compute + self.services = services + + @property + def type(self) -> Optional[str]: + """The type of the job. + + :return: The type of the job. + :rtype: Optional[str] + """ + return self._type + + @property + def status(self) -> Optional[str]: + """The status of the job. + + Common values returned include "Running", "Completed", and "Failed". All possible values are: + + * NotStarted - This is a temporary state that client-side Run objects are in before cloud submission. + * Starting - The Run has started being processed in the cloud. The caller has a run ID at this point. + * Provisioning - On-demand compute is being created for a given job submission. + * Preparing - The run environment is being prepared and is in one of two stages: + * Docker image build + * conda environment setup + * Queued - The job is queued on the compute target. For example, in BatchAI, the job is in a queued state + while waiting for all the requested nodes to be ready. + * Running - The job has started to run on the compute target. + * Finalizing - User code execution has completed, and the run is in post-processing stages. + * CancelRequested - Cancellation has been requested for the job. + * Completed - The run has completed successfully. This includes both the user code execution and run + post-processing stages. + * Failed - The run failed. Usually the Error property on a run will provide details as to why. + * Canceled - Follows a cancellation request and indicates that the run is now successfully cancelled. + * NotResponding - For runs that have Heartbeats enabled, no heartbeat has been recently sent. + + :return: Status of the job. + :rtype: Optional[str] + """ + return self._status + + @property + def log_files(self) -> Optional[Dict[str, str]]: + """Job output files. + + :return: The dictionary of log names and URLs. + :rtype: Optional[Dict[str, str]] + """ + return self._log_files + + @property + def studio_url(self) -> Optional[str]: + """Azure ML studio endpoint. + + :return: The URL to the job details page. + :rtype: Optional[str] + """ + if self.services and (JobServices.STUDIO in self.services.keys()): + res: Optional[str] = self.services[JobServices.STUDIO].endpoint + return res + + return studio_url_from_job_id(self.id) if self.id else None + + def dump(self, dest: Union[str, PathLike, IO[AnyStr]], **kwargs: Any) -> None: + """Dumps the job content into a file in YAML format. + + :param dest: The local path or file stream to write the YAML content to. + If dest is a file path, a new file will be created. + If dest is an open file, the file will be written to directly. + :type dest: Union[PathLike, str, IO[AnyStr]] + :raises FileExistsError: Raised if dest is a file path and the file already exists. + :raises IOError: Raised if dest is an open file and the file is not writable. + """ + path = kwargs.pop("path", None) + yaml_serialized = self._to_dict() + dump_yaml_to_file(dest, yaml_serialized, default_flow_style=False, path=path, **kwargs) + + def _get_base_info_dict(self) -> OrderedDict: + return OrderedDict( + [ + ("Experiment", self.experiment_name), + ("Name", self.name), + ("Type", self._type), + ("Status", self._status), + ] + ) + + def _repr_html_(self) -> str: + info = self._get_base_info_dict() + if self.studio_url: + info.update( + [ + ( + "Details Page", + make_link(self.studio_url, "Link to Azure Machine Learning studio"), + ), + ] + ) + res: str = to_html(info) + return res + + @abstractmethod + def _to_dict(self) -> Dict: + pass + + @classmethod + def _resolve_cls_and_type(cls, data: Dict, params_override: Optional[List[Dict]] = None) -> Tuple: + from azure.ai.ml.entities._builders.command import Command + from azure.ai.ml.entities._builders.spark import Spark + from azure.ai.ml.entities._job.automl.automl_job import AutoMLJob + from azure.ai.ml.entities._job.distillation.distillation_job import DistillationJob + from azure.ai.ml.entities._job.finetuning.finetuning_job import FineTuningJob + from azure.ai.ml.entities._job.import_job import ImportJob + from azure.ai.ml.entities._job.pipeline.pipeline_job import PipelineJob + from azure.ai.ml.entities._job.sweep.sweep_job import SweepJob + + job_type: Optional[Type["Job"]] = None + type_in_override = find_type_in_override(params_override) + type_str = type_in_override or data.get(CommonYamlFields.TYPE, JobType.COMMAND) # override takes the priority + if type_str == JobType.COMMAND: + job_type = Command + elif type_str == JobType.SPARK: + job_type = Spark + elif type_str == JobType.IMPORT: + job_type = ImportJob + elif type_str == JobType.SWEEP: + job_type = SweepJob + elif type_str == JobType.AUTOML: + job_type = AutoMLJob + elif type_str == JobType.PIPELINE: + job_type = PipelineJob + elif type_str == JobType.FINE_TUNING: + job_type = FineTuningJob + elif type_str == JobType.DISTILLATION: + job_type = DistillationJob + else: + msg = f"Unsupported job type: {type_str}." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.JOB, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + return job_type, type_str + + @classmethod + def _load( + cls, + data: Optional[Dict] = None, + yaml_path: Optional[Union[PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "Job": + """Load a job object from a yaml file. + + :param cls: Indicates that this is a class method. + :type cls: class + :param data: Data Dictionary, defaults to None + :type data: Dict + :param yaml_path: YAML Path, defaults to None + :type yaml_path: Union[PathLike, str] + :param params_override: Fields to overwrite on top of the yaml file. + Format is [{"field1": "value1"}, {"field2": "value2"}], defaults to None + :type params_override: List[Dict] + :raises Exception: An exception + :return: Loaded job object. + :rtype: Job + """ + data = data or {} + params_override = params_override or [] + context = { + BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"), + PARAMS_OVERRIDE_KEY: params_override, + } + job_type, type_str = cls._resolve_cls_and_type(data, params_override) + job: Job = job_type._load_from_dict( + data=data, + context=context, + additional_message=f"If you are trying to configure a job that is not of type {type_str}, please specify " + f"the correct job type in the 'type' property.", + **kwargs, + ) + if yaml_path: + job._source_path = yaml_path + return job + + @classmethod + def _from_rest_object( # pylint: disable=too-many-return-statements + cls, obj: Union[JobBase, JobBase_2401, Run] + ) -> "Job": + from azure.ai.ml.entities import PipelineJob + from azure.ai.ml.entities._builders.command import Command + from azure.ai.ml.entities._builders.spark import Spark + from azure.ai.ml.entities._job.automl.automl_job import AutoMLJob + from azure.ai.ml.entities._job.base_job import _BaseJob + from azure.ai.ml.entities._job.distillation.distillation_job import DistillationJob + from azure.ai.ml.entities._job.finetuning.finetuning_job import FineTuningJob + from azure.ai.ml.entities._job.import_job import ImportJob + from azure.ai.ml.entities._job.sweep.sweep_job import SweepJob + + try: + if isinstance(obj, Run): + # special handling for child jobs + return _BaseJob._load_from_rest(obj) + if _is_pipeline_child_job(obj): + raise PipelineChildJobError(job_id=obj.id) + if obj.properties.job_type == RestJobType.COMMAND: + # PrP only until new import job type is ready on MFE in PuP + # compute type 'DataFactory' is reserved compute name for 'clusterless' ADF jobs + if obj.properties.compute_id and obj.properties.compute_id.endswith("/" + ComputeType.ADF): + return ImportJob._load_from_rest(obj) + + res_command: Job = Command._load_from_rest_job(obj) + if hasattr(obj, "name"): + res_command._name = obj.name # type: ignore[attr-defined] + return res_command + if obj.properties.job_type == RestJobType.SPARK: + res_spark: Job = Spark._load_from_rest_job(obj) + if hasattr(obj, "name"): + res_spark._name = obj.name # type: ignore[attr-defined] + return res_spark + if obj.properties.job_type == RestJobType.SWEEP: + return SweepJob._load_from_rest(obj) + if obj.properties.job_type == RestJobType.AUTO_ML: + return AutoMLJob._load_from_rest(obj) + if obj.properties.job_type == RestJobType_20240101Preview.FINE_TUNING: + if obj.properties.properties.get("azureml.enable_distillation", False): + return DistillationJob._load_from_rest(obj) + return FineTuningJob._load_from_rest(obj) + if obj.properties.job_type == RestJobType.PIPELINE: + res_pipeline: Job = PipelineJob._load_from_rest(obj) + return res_pipeline + except PipelineChildJobError as ex: + raise ex + except Exception as ex: + error_message = json.dumps(obj.as_dict(), indent=2) if obj else None + module_logger.info( + "Exception: %s.\n%s\nUnable to parse the job resource: %s.\n", + ex, + traceback.format_exc(), + error_message, + ) + raise JobParsingError( + message=str(ex), + no_personal_data_message=f"Unable to parse a job resource of type:{type(obj).__name__}", + error_category=ErrorCategory.SYSTEM_ERROR, + ) from ex + msg = f"Unsupported job type {obj.properties.job_type}" + raise JobException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.JOB, + error_category=ErrorCategory.SYSTEM_ERROR, + ) + + def _get_telemetry_values(self) -> Dict: # pylint: disable=arguments-differ + telemetry_values = {"type": self.type} + return telemetry_values + + @classmethod + @abstractmethod + def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "Job": + pass diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_io_mixin.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_io_mixin.py new file mode 100644 index 00000000..21db73ba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_io_mixin.py @@ -0,0 +1,37 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +from typing import Dict, Union + +from azure.ai.ml.entities._inputs_outputs import Input, Output +from azure.ai.ml.entities._job._input_output_helpers import build_input_output + + +class JobIOMixin: + @property + def inputs(self) -> Dict[str, Union[Input, str, bool, int, float]]: + return self._inputs + + @inputs.setter + def inputs(self, value: Dict[str, Union[Input, str, bool, int, float]]) -> None: + self._inputs: Dict = {} + if not value: + return + + for input_name, input_value in value.items(): + self._inputs[input_name] = build_input_output(input_value) + + @property + def outputs(self) -> Dict[str, Output]: + return self._outputs + + @outputs.setter + def outputs(self, value: Dict[str, Output]) -> None: + self._outputs: Dict = {} + if not value: + return + + for output_name, output_value in value.items(): + self._outputs[output_name] = build_input_output(output_value, inputs=False) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_limits.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_limits.py new file mode 100644 index 00000000..7aae9263 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_limits.py @@ -0,0 +1,201 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import logging +from abc import ABC +from typing import Any, Optional, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import CommandJobLimits as RestCommandJobLimits +from azure.ai.ml._restclient.v2023_08_01_preview.models import SweepJobLimits as RestSweepJobLimits +from azure.ai.ml._utils.utils import from_iso_duration_format, is_data_binding_expression, to_iso_duration_format +from azure.ai.ml.constants import JobType +from azure.ai.ml.entities._mixins import RestTranslatableMixin + +module_logger = logging.getLogger(__name__) + + +class JobLimits(RestTranslatableMixin, ABC): + """Base class for Job limits. + + This class should not be instantiated directly. Instead, one of its child classes should be used. + """ + + def __init__( + self, + ) -> None: + self.type: Any = None + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, JobLimits): + return NotImplemented + res: bool = self._to_rest_object() == other._to_rest_object() + return res + + +class CommandJobLimits(JobLimits): + """Limits for Command Jobs. + + :keyword timeout: The maximum run duration, in seconds, after which the job will be cancelled. + :paramtype timeout: Optional[Union[int, str]] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_command_configurations.py + :start-after: [START command_job_definition] + :end-before: [END command_job_definition] + :language: python + :dedent: 8 + :caption: Configuring a CommandJob with CommandJobLimits. + """ + + def __init__(self, *, timeout: Optional[Union[int, str]] = None) -> None: + super().__init__() + self.type = JobType.COMMAND + self.timeout = timeout + + def _to_rest_object(self) -> RestCommandJobLimits: + if is_data_binding_expression(self.timeout): + return RestCommandJobLimits(timeout=self.timeout) + return RestCommandJobLimits(timeout=to_iso_duration_format(self.timeout)) + + @classmethod + def _from_rest_object(cls, obj: Union[RestCommandJobLimits, dict]) -> Optional["CommandJobLimits"]: + if not obj: + return None + if isinstance(obj, dict): + timeout_value = obj.get("timeout", None) + # if timeout value is a binding string + if is_data_binding_expression(timeout_value): + return cls(timeout=timeout_value) + # if response timeout is a normal iso date string + obj = RestCommandJobLimits.from_dict(obj) + return cls(timeout=from_iso_duration_format(obj.timeout)) + + +class SweepJobLimits(JobLimits): + """Limits for Sweep Jobs. + + :keyword max_concurrent_trials: The maximum number of concurrent trials for the Sweep Job. + :paramtype max_concurrent_trials: Optional[int] + :keyword max_total_trials: The maximum number of total trials for the Sweep Job. + :paramtype max_total_trials: Optional[int] + :keyword timeout: The maximum run duration, in seconds, after which the job will be cancelled. + :paramtype timeout: Optional[int] + :keyword trial_timeout: The timeout value, in seconds, for each Sweep Job trial. + :paramtype trial_timeout: Optional[int] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_sweep_configurations.py + :start-after: [START configure_sweep_job_bayesian_sampling_algorithm] + :end-before: [END configure_sweep_job_bayesian_sampling_algorithm] + :language: python + :dedent: 8 + :caption: Assigning limits to a SweepJob + """ + + def __init__( + self, + *, + max_concurrent_trials: Optional[int] = None, + max_total_trials: Optional[int] = None, + timeout: Optional[int] = None, + trial_timeout: Optional[Union[int, str]] = None, + ) -> None: + super().__init__() + self.type = JobType.SWEEP + self.max_concurrent_trials = max_concurrent_trials + self.max_total_trials = max_total_trials + self._timeout = _get_floored_timeout(timeout) + self._trial_timeout = _get_floored_timeout(trial_timeout) + + @property + def timeout(self) -> Optional[Union[int, str]]: + """The maximum run duration, in seconds, after which the job will be cancelled. + + :return: The maximum run duration, in seconds, after which the job will be cancelled. + :rtype: int + """ + return self._timeout + + @timeout.setter + def timeout(self, value: int) -> None: + """Sets the maximum run duration. + + :param value: The maximum run duration, in seconds, after which the job will be cancelled. + :type value: int + """ + self._timeout = _get_floored_timeout(value) + + @property + def trial_timeout(self) -> Optional[Union[int, str]]: + """The timeout value, in seconds, for each Sweep Job trial. + + :return: The timeout value, in seconds, for each Sweep Job trial. + :rtype: int + """ + return self._trial_timeout + + @trial_timeout.setter + def trial_timeout(self, value: int) -> None: + """Sets the timeout value for each Sweep Job trial. + + :param value: The timeout value, in seconds, for each Sweep Job trial. + :type value: int + """ + self._trial_timeout = _get_floored_timeout(value) + + def _to_rest_object(self) -> RestSweepJobLimits: + return RestSweepJobLimits( + max_concurrent_trials=self.max_concurrent_trials, + max_total_trials=self.max_total_trials, + timeout=to_iso_duration_format(self.timeout), + trial_timeout=to_iso_duration_format(self.trial_timeout), + ) + + @classmethod + def _from_rest_object(cls, obj: RestSweepJobLimits) -> Optional["SweepJobLimits"]: + if not obj: + return None + + return cls( + max_concurrent_trials=obj.max_concurrent_trials, + max_total_trials=obj.max_total_trials, + timeout=from_iso_duration_format(obj.timeout), + trial_timeout=from_iso_duration_format(obj.trial_timeout), + ) + + +def _get_floored_timeout(value: Optional[Union[int, str]]) -> Optional[Union[int, str]]: + # Bug 1335978: Service rounds durations less than 60 seconds to 60 days. + # If duration is non-0 and less than 60, set to 60. + if isinstance(value, int): + return value if not value or value > 60 else 60 + + return None + + +class DoWhileJobLimits(JobLimits): + """DoWhile Job limit class. + + :keyword max_iteration_count: The maximum number of iterations for the DoWhile Job. + :paramtype max_iteration_count: Optional[int] + """ + + def __init__( + self, # pylint: disable=unused-argument + *, + max_iteration_count: Optional[int] = None, + **kwargs: Any, + ) -> None: + super().__init__() + self._max_iteration_count = max_iteration_count + + @property + def max_iteration_count(self) -> Optional[int]: + """The maximum number of iterations for the DoWhile Job. + + :rtype: int + """ + return self._max_iteration_count diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_name_generator.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_name_generator.py new file mode 100644 index 00000000..e4f62d3d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_name_generator.py @@ -0,0 +1,487 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import random + +SUFFIX_LENGTH = 10 +ALLOWED_CHARS = "bcdfghjklmnpqrstvwxyz0123456789" + +ALLOWED_ADJECTIVES = [ + "affable", + "amiable", + "amusing", + "ashy", + "blue", + "bold", + "boring", + "brave", + "bright", + "bubbly", + "busy", + "calm", + "careful", + "clever", + "cool", + "coral", + "crimson", + "cyan", + "dreamy", + "dynamic", + "eager", + "elated", + "epic", + "frank", + "frosty", + "funny", + "gentle", + "gifted", + "good", + "goofy", + "gray", + "great", + "green", + "happy", + "helpful", + "heroic", + "honest", + "hungry", + "icy", + "ivory", + "jolly", + "jovial", + "joyful", + "keen", + "khaki", + "kind", + "lemon", + "lime", + "loving", + "loyal", + "lucid", + "magenta", + "mango", + "maroon", + "mighty", + "modest", + "musing", + "neat", + "nice", + "nifty", + "olden", + "olive", + "orange", + "patient", + "placid", + "plucky", + "plum", + "polite", + "purple", + "quiet", + "quirky", + "red", + "sad", + "salmon", + "serene", + "sharp", + "shy", + "silly", + "silver", + "sincere", + "sleepy", + "stoic", + "strong", + "sweet", + "teal", + "tender", + "tidy", + "tough", + "upbeat", + "wheat", + "willing", + "witty", + "yellow", + "zen", +] + +ALLOWED_NOUNS = [ + "actor", + "airport", + "angle", + "animal", + "answer", + "ant", + "apple", + "apricot", + "arch", + "arm", + "atemoya", + "avocado", + "bag", + "ball", + "balloon", + "band", + "basil", + "basin", + "basket", + "battery", + "beach", + "bean", + "bear", + "beard", + "bee", + "beet", + "bell", + "berry", + "bird", + "board", + "boat", + "bone", + "boniato", + "book", + "boot", + "bottle", + "box", + "brain", + "brake", + "branch", + "bread", + "brick", + "bridge", + "brush", + "bucket", + "bulb", + "button", + "cabbage", + "cake", + "calypso", + "camel", + "camera", + "candle", + "car", + "caravan", + "card", + "carnival", + "carpet", + "carrot", + "cart", + "cartoon", + "cassava", + "cat", + "celery", + "chaconia", + "chain", + "chayote", + "cheese", + "cheetah", + "cherry", + "chicken", + "chin", + "circle", + "clock", + "cloud", + "coat", + "coconut", + "collar", + "comb", + "cord", + "corn", + "cow", + "crayon", + "crowd", + "cumin", + "cup", + "curtain", + "cushion", + "date", + "deer", + "diamond", + "dinner", + "dog", + "dolphin", + "door", + "double", + "drain", + "drawer", + "dream", + "dress", + "drop", + "duck", + "eagle", + "ear", + "egg", + "endive", + "energy", + "engine", + "evening", + "eye", + "farm", + "feast", + "feather", + "feijoa", + "fennel", + "fig", + "fish", + "flag", + "floor", + "flower", + "fly", + "foot", + "forest", + "fork", + "fowl", + "fox", + "frame", + "frog", + "garage", + "garden", + "garlic", + "gas", + "ghost", + "giraffe", + "glass", + "glove", + "goat", + "gold", + "grape", + "grass", + "guava", + "guitar", + "gyro", + "hair", + "hamster", + "hand", + "hat", + "head", + "heart", + "helmet", + "holiday", + "hominy", + "honey", + "hook", + "horse", + "house", + "ice", + "insect", + "iron", + "island", + "jackal", + "jelly", + "jewel", + "jicama", + "juice", + "kale", + "kettle", + "key", + "king", + "kitchen", + "kite", + "kitten", + "kiwi", + "knee", + "knot", + "kumquat", + "lamp", + "leaf", + "leather", + "leek", + "leg", + "lemon", + "lettuce", + "library", + "lime", + "line", + "lion", + "lizard", + "lobster", + "lock", + "longan", + "loquat", + "lunch", + "lychee", + "machine", + "malanga", + "mango", + "mangos", + "map", + "market", + "match", + "melon", + "milk", + "monkey", + "moon", + "morning", + "muscle", + "music", + "nail", + "napa", + "napkin", + "neck", + "needle", + "nerve", + "nest", + "net", + "night", + "nose", + "nut", + "nutmeg", + "ocean", + "octopus", + "office", + "oil", + "okra", + "onion", + "orange", + "oregano", + "oven", + "owl", + "oxygen", + "oyster", + "panda", + "papaya", + "parang", + "parcel", + "parrot", + "parsnip", + "pasta", + "pea", + "peach", + "pear", + "pen", + "pencil", + "pepper", + "piano", + "picture", + "pig", + "pillow", + "pin", + "pipe", + "pizza", + "plane", + "planet", + "plastic", + "plate", + "plow", + "plum", + "pocket", + "pot", + "potato", + "prune", + "pummelo", + "pump", + "pumpkin", + "puppy", + "queen", + "quill", + "quince", + "rabbit", + "rail", + "rain", + "rainbow", + "raisin", + "rat", + "receipt", + "reggae", + "rhubarb", + "rhythm", + "rice", + "ring", + "river", + "rocket", + "rod", + "roof", + "room", + "root", + "rose", + "roti", + "sail", + "salt", + "sand", + "school", + "scooter", + "screw", + "seal", + "seed", + "shampoo", + "shark", + "sheep", + "shelf", + "ship", + "shirt", + "shoe", + "skin", + "snail", + "snake", + "soca", + "soccer", + "sock", + "soursop", + "spade", + "spider", + "spinach", + "sponge", + "spoon", + "spring", + "sprout", + "square", + "squash", + "stamp", + "star", + "station", + "steelpan", + "stem", + "stick", + "stomach", + "stone", + "store", + "street", + "sugar", + "sun", + "table", + "tail", + "tangelo", + "tent", + "thread", + "ticket", + "tiger", + "toe", + "tomato", + "tongue", + "tooth", + "town", + "train", + "tray", + "tree", + "truck", + "turnip", + "turtle", + "van", + "vase", + "vinegar", + "vulture", + "wall", + "watch", + "whale", + "wheel", + "whistle", + "window", + "wing", + "wire", + "wolf", + "worm", + "yacht", + "yak", + "yam", + "yogurt", + "yuca", + "zebra", + "zoo", +] + + +def generate_job_name() -> str: + adj = random.choice(ALLOWED_ADJECTIVES) + noun = random.choice(ALLOWED_NOUNS) + suffix = "".join(random.choices(ALLOWED_CHARS, k=SUFFIX_LENGTH)) + + return "_".join([adj, noun, suffix]) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_resource_configuration.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_resource_configuration.py new file mode 100644 index 00000000..a27b5ba1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_resource_configuration.py @@ -0,0 +1,239 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import json +import logging +from typing import Any, Dict, List, Optional, Union, cast + +from azure.ai.ml._restclient.v2023_04_01_preview.models import JobResourceConfiguration as RestJobResourceConfiguration +from azure.ai.ml._restclient.v2025_01_01_preview.models import ( + JobResourceConfiguration as RestJobResourceConfiguration202501, +) +from azure.ai.ml.constants._job.job import JobComputePropertyFields +from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin +from azure.ai.ml.entities._util import convert_ordered_dict_to_dict + +module_logger = logging.getLogger(__name__) + + +class BaseProperty(dict): + """Base class for entity classes to be used as value of JobResourceConfiguration.properties.""" + + def __init__(self, **kwargs: Any) -> None: + super().__init__() + for key, value in kwargs.items(): + setattr(self, key, value) + + def __setattr__(self, key: str, value: Any) -> None: + if key.startswith("_"): + super().__setattr__(key, value) + else: + self[key] = value + + def __getattr__(self, key: str) -> Any: + if key.startswith("_"): + super().__getattribute__(key) + return None + + return self[key] + + def __repr__(self) -> str: + return json.dumps(self.as_dict()) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, dict): + return self.as_dict() == other + if isinstance(other, BaseProperty): + return self.as_dict() == other.as_dict() + return False + + def as_dict(self) -> Dict[str, Any]: + res: dict = self._to_dict(self) + return res + + @classmethod + def _to_dict(cls, obj: Any) -> Any: + if isinstance(obj, dict): + result = {} + for key, value in obj.items(): + if value is None: + continue + if isinstance(value, dict): + result[key] = cls._to_dict(value) + else: + result[key] = value + return result + return obj + + +class Properties(BaseProperty): + # pre-defined properties are case-insensitive + # Map Singularity -> AISupercomputer in SDK until MFE does mapping + _KEY_MAPPING = { + JobComputePropertyFields.AISUPERCOMPUTER.lower(): JobComputePropertyFields.AISUPERCOMPUTER, + JobComputePropertyFields.SINGULARITY.lower(): JobComputePropertyFields.AISUPERCOMPUTER, + JobComputePropertyFields.ITP.lower(): JobComputePropertyFields.ITP, + JobComputePropertyFields.TARGET_SELECTOR.lower(): JobComputePropertyFields.TARGET_SELECTOR, + } + + def as_dict(self) -> Dict[str, Any]: + result = {} + for key, value in super().as_dict().items(): + if key.lower() in self._KEY_MAPPING: + key = self._KEY_MAPPING[key.lower()] + result[key] = value + # recursively convert Ordered Dict to dictionary + return cast(dict, convert_ordered_dict_to_dict(result)) + + +class JobResourceConfiguration(RestTranslatableMixin, DictMixin): + """Job resource configuration class, inherited and extended functionalities from ResourceConfiguration. + + :keyword locations: A list of locations where the job can run. + :paramtype locations: Optional[List[str]] + :keyword instance_count: The number of instances or nodes used by the compute target. + :paramtype instance_count: Optional[int] + :keyword instance_type: The type of VM to be used, as supported by the compute target. + :paramtype instance_type: Optional[str] + :keyword properties: A dictionary of properties for the job. + :paramtype properties: Optional[dict[str, Any]] + :keyword docker_args: Extra arguments to pass to the Docker run command. This would override any + parameters that have already been set by the system, or in this section. This parameter is only + supported for Azure ML compute types. + :paramtype docker_args: Optional[Union[str, List[str]]] + :keyword shm_size: The size of the docker container's shared memory block. This should be in the + format of (number)(unit) where the number has to be greater than 0 and the unit can be one of + b(bytes), k(kilobytes), m(megabytes), or g(gigabytes). + :paramtype shm_size: Optional[str] + :keyword max_instance_count: The maximum number of instances or nodes used by the compute target. + :paramtype max_instance_count: Optional[int] + :keyword kwargs: A dictionary of additional configuration parameters. + :paramtype kwargs: dict + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_command_configurations.py + :start-after: [START command_job_resource_configuration] + :end-before: [END command_job_resource_configuration] + :language: python + :dedent: 8 + :caption: Configuring a CommandJob with a JobResourceConfiguration. + """ + + def __init__( + self, # pylint: disable=unused-argument + *, + locations: Optional[List[str]] = None, + instance_count: Optional[int] = None, + instance_type: Optional[Union[str, List]] = None, + properties: Optional[Union[Properties, Dict]] = None, + docker_args: Optional[Union[str, List[str]]] = None, + shm_size: Optional[str] = None, + max_instance_count: Optional[int] = None, + **kwargs: Any + ) -> None: + self.locations = locations + self.instance_count = instance_count + self.instance_type = instance_type + self.shm_size = shm_size + self.max_instance_count = max_instance_count + self.docker_args = docker_args + self._properties = None + self.properties = properties + + @property + def properties(self) -> Optional[Union[Properties, Dict]]: + """The properties of the job. + + :rtype: ~azure.ai.ml.entities._job.job_resource_configuration.Properties + """ + return self._properties + + @properties.setter + def properties(self, properties: Dict[str, Any]) -> None: + """Sets the properties of the job. + + :param properties: A dictionary of properties for the job. + :type properties: Dict[str, Any] + :raises TypeError: Raised if properties is not a dictionary type. + """ + if properties is None: + self._properties = Properties() + elif isinstance(properties, dict): + self._properties = Properties(**properties) + else: + raise TypeError("properties must be a dict.") + + def _to_rest_object(self) -> Union[RestJobResourceConfiguration, RestJobResourceConfiguration202501]: + if self.docker_args and isinstance(self.docker_args, list): + return RestJobResourceConfiguration202501( + instance_count=self.instance_count, + instance_type=self.instance_type, + max_instance_count=self.max_instance_count, + properties=self.properties.as_dict() if isinstance(self.properties, Properties) else None, + docker_args_list=self.docker_args, + shm_size=self.shm_size, + ) + return RestJobResourceConfiguration( + locations=self.locations, + instance_count=self.instance_count, + instance_type=self.instance_type, + max_instance_count=self.max_instance_count, + properties=self.properties.as_dict() if isinstance(self.properties, Properties) else None, + docker_args=self.docker_args, + shm_size=self.shm_size, + ) + + @classmethod + def _from_rest_object( + cls, obj: Optional[Union[RestJobResourceConfiguration, RestJobResourceConfiguration202501]] + ) -> Optional["JobResourceConfiguration"]: + if obj is None: + return None + if isinstance(obj, dict): + return cls(**obj) + return JobResourceConfiguration( + locations=obj.locations if hasattr(obj, "locations") else None, + instance_count=obj.instance_count, + instance_type=obj.instance_type, + max_instance_count=obj.max_instance_count if hasattr(obj, "max_instance_count") else None, + properties=obj.properties, + docker_args=obj.docker_args_list if hasattr(obj, "docker_args_list") else obj.docker_args, + shm_size=obj.shm_size, + deserialize_properties=True, + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, JobResourceConfiguration): + return NotImplemented + return ( + self.locations == other.locations + and self.instance_count == other.instance_count + and self.instance_type == other.instance_type + and self.max_instance_count == other.max_instance_count + and self.docker_args == other.docker_args + and self.shm_size == other.shm_size + ) + + def __ne__(self, other: object) -> bool: + if not isinstance(other, JobResourceConfiguration): + return NotImplemented + return not self.__eq__(other) + + def _merge_with(self, other: "JobResourceConfiguration") -> None: + if other: + if other.locations: + self.locations = other.locations + if other.instance_count: + self.instance_count = other.instance_count + if other.instance_type: + self.instance_type = other.instance_type + if other.max_instance_count: + self.max_instance_count = other.max_instance_count + if other.properties: + self.properties = other.properties + if other.docker_args: + self.docker_args = other.docker_args + if other.shm_size: + self.shm_size = other.shm_size diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_resources.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_resources.py new file mode 100644 index 00000000..bd1cdad5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_resources.py @@ -0,0 +1,33 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Any, List +from azure.ai.ml.entities._mixins import RestTranslatableMixin +from azure.ai.ml._restclient.v2024_10_01_preview.models import JobResources as RestJobResources + + +class JobResources(RestTranslatableMixin): + """Resource configuration for a job. + + This class should not be instantiated directly. Instead, use its subclasses. + """ + + def __init__(self, *, instance_types: List[str]) -> None: + self.instance_types = instance_types + + def _to_rest_object(self) -> Any: + return RestJobResources(instance_types=self.instance_types) + + @classmethod + def _from_rest_object(cls, obj: RestJobResources) -> "JobResources": + job_resources = cls(instance_types=obj.instance_types) + return job_resources + + def __eq__(self, other: object) -> bool: + if not isinstance(other, JobResources): + return NotImplemented + return self.instance_types == other.instance_types + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_service.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_service.py new file mode 100644 index 00000000..a97048fc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_service.py @@ -0,0 +1,424 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import logging +from typing import Any, Dict, Optional, cast + +from typing_extensions import Literal + +from azure.ai.ml._restclient.v2023_04_01_preview.models import AllNodes +from azure.ai.ml._restclient.v2023_04_01_preview.models import JobService as RestJobService +from azure.ai.ml.constants._job.job import JobServiceTypeNames +from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + +module_logger = logging.getLogger(__name__) + + +class JobServiceBase(RestTranslatableMixin, DictMixin): + """Base class for job service configuration. + + This class should not be instantiated directly. Instead, use one of its subclasses. + + :keyword endpoint: The endpoint URL. + :paramtype endpoint: Optional[str] + :keyword type: The endpoint type. Accepted values are "jupyter_lab", "ssh", "tensor_board", and "vs_code". + :paramtype type: Optional[Literal["jupyter_lab", "ssh", "tensor_board", "vs_code"]] + :keyword port: The port for the endpoint. + :paramtype port: Optional[int] + :keyword nodes: Indicates whether the service has to run in all nodes. + :paramtype nodes: Optional[Literal["all"]] + :keyword properties: Additional properties to set on the endpoint. + :paramtype properties: Optional[dict[str, str]] + :keyword status: The status of the endpoint. + :paramtype status: Optional[str] + :keyword kwargs: A dictionary of additional configuration parameters. + :paramtype kwargs: dict + """ + + def __init__( # pylint: disable=unused-argument + self, + *, + endpoint: Optional[str] = None, + type: Optional[ # pylint: disable=redefined-builtin + Literal["jupyter_lab", "ssh", "tensor_board", "vs_code"] + ] = None, + nodes: Optional[Literal["all"]] = None, + status: Optional[str] = None, + port: Optional[int] = None, + properties: Optional[Dict[str, str]] = None, + **kwargs: Dict, + ) -> None: + self.endpoint = endpoint + self.type: Any = type + self.nodes = nodes + self.status = status + self.port = port + self.properties = properties + self._validate_nodes() + self._validate_type_name() + + def _validate_nodes(self) -> None: + if not self.nodes in ["all", None]: + msg = f"nodes should be either 'all' or None, but received '{self.nodes}'." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.JOB, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + def _validate_type_name(self) -> None: + if self.type and not self.type in JobServiceTypeNames.ENTITY_TO_REST: + msg = ( + f"type should be one of " f"{JobServiceTypeNames.NAMES_ALLOWED_FOR_PUBLIC}, but received '{self.type}'." + ) + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.JOB, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + def _to_rest_job_service(self, updated_properties: Optional[Dict[str, str]] = None) -> RestJobService: + return RestJobService( + endpoint=self.endpoint, + job_service_type=JobServiceTypeNames.ENTITY_TO_REST.get(self.type, None) if self.type else None, + nodes=AllNodes() if self.nodes else None, + status=self.status, + port=self.port, + properties=updated_properties if updated_properties else self.properties, + ) + + @classmethod + def _to_rest_job_services( + cls, + services: Optional[Dict], + ) -> Optional[Dict[str, RestJobService]]: + if services is None: + return None + + return {name: service._to_rest_object() for name, service in services.items()} + + @classmethod + def _from_rest_job_service_object(cls, obj: RestJobService) -> "JobServiceBase": + return cls( + endpoint=obj.endpoint, + type=( + JobServiceTypeNames.REST_TO_ENTITY.get(obj.job_service_type, None) # type: ignore[arg-type] + if obj.job_service_type + else None + ), + nodes="all" if obj.nodes else None, + status=obj.status, + port=obj.port, + # ssh_public_keys=_get_property(obj.properties, "sshPublicKeys"), + properties=obj.properties, + ) + + @classmethod + def _from_rest_job_services(cls, services: Dict[str, RestJobService]) -> Dict: + # """Resolve Dict[str, RestJobService] to Dict[str, Specific JobService]""" + if services is None: + return None + + result: dict = {} + for name, service in services.items(): + if service.job_service_type == JobServiceTypeNames.RestNames.JUPYTER_LAB: + result[name] = JupyterLabJobService._from_rest_object(service) + elif service.job_service_type == JobServiceTypeNames.RestNames.SSH: + result[name] = SshJobService._from_rest_object(service) + elif service.job_service_type == JobServiceTypeNames.RestNames.TENSOR_BOARD: + result[name] = TensorBoardJobService._from_rest_object(service) + elif service.job_service_type == JobServiceTypeNames.RestNames.VS_CODE: + result[name] = VsCodeJobService._from_rest_object(service) + else: + result[name] = JobService._from_rest_object(service) + return result + + +class JobService(JobServiceBase): + """Basic job service configuration for backward compatibility. + + This class is not intended to be used directly. Instead, use one of its subclasses specific to your job type. + + :keyword endpoint: The endpoint URL. + :paramtype endpoint: Optional[str] + :keyword type: The endpoint type. Accepted values are "jupyter_lab", "ssh", "tensor_board", and "vs_code". + :paramtype type: Optional[Literal["jupyter_lab", "ssh", "tensor_board", "vs_code"]] + :keyword port: The port for the endpoint. + :paramtype port: Optional[int] + :keyword nodes: Indicates whether the service has to run in all nodes. + :paramtype nodes: Optional[Literal["all"]] + :keyword properties: Additional properties to set on the endpoint. + :paramtype properties: Optional[dict[str, str]] + :keyword status: The status of the endpoint. + :paramtype status: Optional[str] + :keyword kwargs: A dictionary of additional configuration parameters. + :paramtype kwargs: dict + """ + + @classmethod + def _from_rest_object(cls, obj: RestJobService) -> "JobService": + return cast(JobService, cls._from_rest_job_service_object(obj)) + + def _to_rest_object(self) -> RestJobService: + return self._to_rest_job_service() + + +class SshJobService(JobServiceBase): + """SSH job service configuration. + + :ivar type: Specifies the type of job service. Set automatically to "ssh" for this class. + :vartype type: str + :keyword endpoint: The endpoint URL. + :paramtype endpoint: Optional[str] + :keyword port: The port for the endpoint. + :paramtype port: Optional[int] + :keyword nodes: Indicates whether the service has to run in all nodes. + :paramtype nodes: Optional[Literal["all"]] + :keyword properties: Additional properties to set on the endpoint. + :paramtype properties: Optional[dict[str, str]] + :keyword status: The status of the endpoint. + :paramtype status: Optional[str] + :keyword ssh_public_keys: The SSH Public Key to access the job container. + :paramtype ssh_public_keys: Optional[str] + :keyword kwargs: A dictionary of additional configuration parameters. + :paramtype kwargs: dict + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START ssh_job_service_configuration] + :end-before: [END ssh_job_service_configuration] + :language: python + :dedent: 8 + :caption: Configuring a SshJobService configuration on a command job. + """ + + def __init__( + self, + *, + endpoint: Optional[str] = None, + nodes: Optional[Literal["all"]] = None, + status: Optional[str] = None, + port: Optional[int] = None, + ssh_public_keys: Optional[str] = None, + properties: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> None: + super().__init__( + endpoint=endpoint, + nodes=nodes, + status=status, + port=port, + properties=properties, + **kwargs, + ) + self.type = JobServiceTypeNames.EntityNames.SSH + self.ssh_public_keys = ssh_public_keys + + @classmethod + def _from_rest_object(cls, obj: RestJobService) -> "SshJobService": + ssh_job_service = cast(SshJobService, cls._from_rest_job_service_object(obj)) + ssh_job_service.ssh_public_keys = _get_property(obj.properties, "sshPublicKeys") + return ssh_job_service + + def _to_rest_object(self) -> RestJobService: + updated_properties = _append_or_update_properties(self.properties, "sshPublicKeys", self.ssh_public_keys) + return self._to_rest_job_service(updated_properties) + + +class TensorBoardJobService(JobServiceBase): + """TensorBoard job service configuration. + + :ivar type: Specifies the type of job service. Set automatically to "tensor_board" for this class. + :vartype type: str + :keyword endpoint: The endpoint URL. + :paramtype endpoint: Optional[str] + :keyword port: The port for the endpoint. + :paramtype port: Optional[int] + :keyword nodes: Indicates whether the service has to run in all nodes. + :paramtype nodes: Optional[Literal["all"]] + :keyword properties: Additional properties to set on the endpoint. + :paramtype properties: Optional[dict[str, str]] + :keyword status: The status of the endpoint. + :paramtype status: Optional[str] + :keyword log_dir: The directory path for the log file. + :paramtype log_dir: Optional[str] + :keyword kwargs: A dictionary of additional configuration parameters. + :paramtype kwargs: dict + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START ssh_job_service_configuration] + :end-before: [END ssh_job_service_configuration] + :language: python + :dedent: 8 + :caption: Configuring TensorBoardJobService configuration on a command job. + """ + + def __init__( + self, + *, + endpoint: Optional[str] = None, + nodes: Optional[Literal["all"]] = None, + status: Optional[str] = None, + port: Optional[int] = None, + log_dir: Optional[str] = None, + properties: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> None: + super().__init__( + endpoint=endpoint, + nodes=nodes, + status=status, + port=port, + properties=properties, + **kwargs, + ) + self.type = JobServiceTypeNames.EntityNames.TENSOR_BOARD + self.log_dir = log_dir + + @classmethod + def _from_rest_object(cls, obj: RestJobService) -> "TensorBoardJobService": + tensorboard_job_Service = cast(TensorBoardJobService, cls._from_rest_job_service_object(obj)) + tensorboard_job_Service.log_dir = _get_property(obj.properties, "logDir") + return tensorboard_job_Service + + def _to_rest_object(self) -> RestJobService: + updated_properties = _append_or_update_properties(self.properties, "logDir", self.log_dir) + return self._to_rest_job_service(updated_properties) + + +class JupyterLabJobService(JobServiceBase): + """JupyterLab job service configuration. + + :ivar type: Specifies the type of job service. Set automatically to "jupyter_lab" for this class. + :vartype type: str + :keyword endpoint: The endpoint URL. + :paramtype endpoint: Optional[str] + :keyword port: The port for the endpoint. + :paramtype port: Optional[int] + :keyword nodes: Indicates whether the service has to run in all nodes. + :paramtype nodes: Optional[Literal["all"]] + :keyword properties: Additional properties to set on the endpoint. + :paramtype properties: Optional[dict[str, str]] + :keyword status: The status of the endpoint. + :paramtype status: Optional[str] + :keyword kwargs: A dictionary of additional configuration parameters. + :paramtype kwargs: dict + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START ssh_job_service_configuration] + :end-before: [END ssh_job_service_configuration] + :language: python + :dedent: 8 + :caption: Configuring JupyterLabJobService configuration on a command job. + """ + + def __init__( + self, + *, + endpoint: Optional[str] = None, + nodes: Optional[Literal["all"]] = None, + status: Optional[str] = None, + port: Optional[int] = None, + properties: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> None: + super().__init__( + endpoint=endpoint, + nodes=nodes, + status=status, + port=port, + properties=properties, + **kwargs, + ) + self.type = JobServiceTypeNames.EntityNames.JUPYTER_LAB + + @classmethod + def _from_rest_object(cls, obj: RestJobService) -> "JupyterLabJobService": + return cast(JupyterLabJobService, cls._from_rest_job_service_object(obj)) + + def _to_rest_object(self) -> RestJobService: + return self._to_rest_job_service() + + +class VsCodeJobService(JobServiceBase): + """VS Code job service configuration. + + :ivar type: Specifies the type of job service. Set automatically to "vs_code" for this class. + :vartype type: str + :keyword endpoint: The endpoint URL. + :paramtype endpoint: Optional[str] + :keyword port: The port for the endpoint. + :paramtype port: Optional[int] + :keyword nodes: Indicates whether the service has to run in all nodes. + :paramtype nodes: Optional[Literal["all"]] + :keyword properties: Additional properties to set on the endpoint. + :paramtype properties: Optional[dict[str, str]] + :keyword status: The status of the endpoint. + :paramtype status: Optional[str] + :keyword kwargs: A dictionary of additional configuration parameters. + :paramtype kwargs: dict + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START ssh_job_service_configuration] + :end-before: [END ssh_job_service_configuration] + :language: python + :dedent: 8 + :caption: Configuring a VsCodeJobService configuration on a command job. + """ + + def __init__( + self, + *, + endpoint: Optional[str] = None, + nodes: Optional[Literal["all"]] = None, + status: Optional[str] = None, + port: Optional[int] = None, + properties: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> None: + super().__init__( + endpoint=endpoint, + nodes=nodes, + status=status, + port=port, + properties=properties, + **kwargs, + ) + self.type = JobServiceTypeNames.EntityNames.VS_CODE + + @classmethod + def _from_rest_object(cls, obj: RestJobService) -> "VsCodeJobService": + return cast(VsCodeJobService, cls._from_rest_job_service_object(obj)) + + def _to_rest_object(self) -> RestJobService: + return self._to_rest_job_service() + + +def _append_or_update_properties( + properties: Optional[Dict[str, str]], key: str, value: Optional[str] +) -> Dict[str, str]: + if value and not properties: + properties = {key: value} + + if value and properties: + properties.update({key: value}) + return properties if properties is not None else {} + + +def _get_property(properties: Dict[str, str], key: str) -> Optional[str]: + return properties.get(key, None) if properties else None diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/parallel_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/parallel_job.py new file mode 100644 index 00000000..49b2c992 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/parallel_job.py @@ -0,0 +1,244 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from azure.ai.ml._restclient.v2022_02_01_preview.models import JobBaseData +from azure.ai.ml._schema.job.parallel_job import ParallelJobSchema +from azure.ai.ml._utils.utils import is_data_binding_expression +from azure.ai.ml.constants import JobType +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE +from azure.ai.ml.entities._credentials import ( + AmlTokenConfiguration, + ManagedIdentityConfiguration, + UserIdentityConfiguration, +) +from azure.ai.ml.entities._inputs_outputs import Input, Output +from azure.ai.ml.entities._util import load_from_dict +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + +from ..job import Job +from ..job_io_mixin import JobIOMixin +from .parameterized_parallel import ParameterizedParallel + +# avoid circular import error +if TYPE_CHECKING: + from azure.ai.ml.entities._builders import Parallel + from azure.ai.ml.entities._component.parallel_component import ParallelComponent + +module_logger = logging.getLogger(__name__) + + +class ParallelJob(Job, ParameterizedParallel, JobIOMixin): + """Parallel job. + + :param name: Name of the job. + :type name: str + :param version: Version of the job. + :type version: str + :param id: Global id of the resource, Azure Resource Manager ID. + :type id: str + :param type: Type of the job, supported is 'parallel'. + :type type: str + :param description: Description of the job. + :type description: str + :param tags: Internal use only. + :type tags: dict + :param properties: Internal use only. + :type properties: dict + :param display_name: Display name of the job. + :type display_name: str + :param retry_settings: parallel job run failed retry + :type retry_settings: BatchRetrySettings + :param logging_level: A string of the logging level name + :type logging_level: str + :param max_concurrency_per_instance: The max parallellism that each compute instance has. + :type max_concurrency_per_instance: int + :param error_threshold: The number of item processing failures should be ignored. + :type error_threshold: int + :param mini_batch_error_threshold: The number of mini batch processing failures should be ignored. + :type mini_batch_error_threshold: int + :keyword identity: The identity that the job will use while running on compute. + :paramtype identity: Optional[Union[~azure.ai.ml.ManagedIdentityConfiguration, ~azure.ai.ml.AmlTokenConfiguration, + ~azure.ai.ml.UserIdentityConfiguration]] + :param task: The parallel task. + :type task: ParallelTask + :param mini_batch_size: The mini batch size. + :type mini_batch_size: str + :param partition_keys: The partition keys. + :type partition_keys: list + :param input_data: The input data. + :type input_data: str + :param inputs: Inputs of the job. + :type inputs: dict + :param outputs: Outputs of the job. + :type outputs: dict + """ + + def __init__( + self, + *, + inputs: Optional[Dict[str, Union[Input, str, bool, int, float]]] = None, + outputs: Optional[Dict[str, Output]] = None, + identity: Optional[ + Union[ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration, Dict] + ] = None, + **kwargs: Any, + ): + kwargs[TYPE] = JobType.PARALLEL + + super().__init__(**kwargs) + + self.inputs = inputs # type: ignore[assignment] + self.outputs = outputs # type: ignore[assignment] + self.identity = identity + + def _to_dict(self) -> Dict: + res: dict = ParallelJobSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + def _to_rest_object(self) -> None: + pass + + @classmethod + def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "ParallelJob": + loaded_data = load_from_dict(ParallelJobSchema, data, context, additional_message, **kwargs) + return ParallelJob(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_data) + + @classmethod + def _load_from_rest(cls, obj: JobBaseData) -> None: + pass + + def _to_component(self, context: Optional[Dict] = None, **kwargs: Any) -> "ParallelComponent": + """Translate a parallel job to component job. + + :param context: Context of parallel job YAML file. + :type context: dict + :return: Translated parallel component. + :rtype: ParallelComponent + """ + from azure.ai.ml.entities._component.parallel_component import ParallelComponent + + pipeline_job_dict = kwargs.get("pipeline_job_dict", {}) + context = context or {BASE_PATH_CONTEXT_KEY: Path("./")} + + # Create anonymous parallel component with default version as 1 + init_kwargs = {} + for key in [ + "mini_batch_size", + "partition_keys", + "logging_level", + "max_concurrency_per_instance", + "error_threshold", + "mini_batch_error_threshold", + "retry_settings", + "resources", + ]: + value = getattr(self, key) + from azure.ai.ml.entities import BatchRetrySettings, JobResourceConfiguration + + values_to_check: List = [] + if key == "retry_settings" and isinstance(value, BatchRetrySettings): + values_to_check = [value.max_retries, value.timeout] + elif key == "resources" and isinstance(value, JobResourceConfiguration): + values_to_check = [ + value.locations, + value.instance_count, + value.instance_type, + value.shm_size, + value.max_instance_count, + value.docker_args, + ] + else: + values_to_check = [value] + + # note that component level attributes can not be data binding expressions + # so filter out data binding expression properties here; + # they will still take effect at node level according to _to_node + if any( + map( + lambda x: is_data_binding_expression(x, binding_prefix=["parent", "inputs"], is_singular=False) + or is_data_binding_expression(x, binding_prefix=["inputs"], is_singular=False), + values_to_check, + ) + ): + continue + + init_kwargs[key] = getattr(self, key) + + return ParallelComponent( + base_path=context[BASE_PATH_CONTEXT_KEY], + # for parallel_job.task, all attributes for this are string for now so data binding expression is allowed + # in SDK level naturally, but not sure if such component is valid. leave the validation to service side. + task=self.task, + inputs=self._to_inputs(inputs=self.inputs, pipeline_job_dict=pipeline_job_dict), + outputs=self._to_outputs(outputs=self.outputs, pipeline_job_dict=pipeline_job_dict), + input_data=self.input_data, + # keep them if no data binding expression detected to keep the behavior of to_component + **init_kwargs, + ) + + def _to_node(self, context: Optional[Dict] = None, **kwargs: Any) -> "Parallel": + """Translate a parallel job to a pipeline node. + + :param context: Context of parallel job YAML file. + :type context: dict + :return: Translated parallel component. + :rtype: Parallel + """ + from azure.ai.ml.entities._builders import Parallel + + component = self._to_component(context, **kwargs) + + return Parallel( + component=component, + compute=self.compute, + # Need to supply the inputs with double curly. + inputs=self.inputs, # type: ignore[arg-type] + outputs=self.outputs, # type: ignore[arg-type] + mini_batch_size=self.mini_batch_size, + partition_keys=self.partition_keys, + input_data=self.input_data, + # task will be inherited from component & base_path will be set correctly. + retry_settings=self.retry_settings, + logging_level=self.logging_level, + max_concurrency_per_instance=self.max_concurrency_per_instance, + error_threshold=self.error_threshold, + mini_batch_error_threshold=self.mini_batch_error_threshold, + environment_variables=self.environment_variables, + properties=self.properties, + identity=self.identity, + resources=self.resources if self.resources and not isinstance(self.resources, dict) else None, + ) + + def _validate(self) -> None: + if self.name is None: + msg = "Job name is required" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.JOB, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.MISSING_FIELD, + ) + if self.compute is None: + msg = "compute is required" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.JOB, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.MISSING_FIELD, + ) + if self.task is None: + msg = "task is required" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.JOB, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.MISSING_FIELD, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/parallel_task.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/parallel_task.py new file mode 100644 index 00000000..7325aed3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/parallel_task.py @@ -0,0 +1,119 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from os import PathLike +from pathlib import Path +from typing import Any, Dict, Optional, Union + +# from azure.ai.ml.entities._deployment.code_configuration import CodeConfiguration +from azure.ai.ml._schema.component.parallel_task import ComponentParallelTaskSchema +from azure.ai.ml._utils.utils import load_yaml +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY +from azure.ai.ml.entities._assets.environment import Environment +from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin +from azure.ai.ml.entities._util import load_from_dict +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException + + +class ParallelTask(RestTranslatableMixin, DictMixin): + """Parallel task. + + :param type: The type of the parallel task. + Possible values are 'run_function'and 'model'. + :type type: str + :param code: A local or remote path pointing at source code. + :type code: str + :param entry_script: User script which will be run in parallel on multiple nodes. This is + specified as a local file path. + The entry_script should contain two functions: + ``init()``: this function should be used for any costly or common preparation for subsequent inferences, + e.g., deserializing and loading the model into a global object. + ``run(mini_batch)``: The method to be parallelized. Each invocation will have one mini-batch. + 'mini_batch': Batch inference will invoke run method and pass either a list or a Pandas DataFrame as an + argument to the method. Each entry in min_batch will be a filepath if input is a FileDataset, + a Pandas DataFrame if input is a TabularDataset. + run() method should return a Pandas DataFrame or an array. + For append_row output_action, these returned elements are appended into the common output file. + For summary_only, the contents of the elements are ignored. For all output actions, + each returned output element indicates one successful inference of input element in the input mini-batch. + Each parallel worker process will call `init` once and then loop over `run` function until all mini-batches + are processed. + :type entry_script: str + :param program_arguments: The arguments of the parallel task. + :type program_arguments: str + :param model: The model of the parallel task. + :type model: str + :param append_row_to: All values output by run() method invocations will be aggregated into + one unique file which is created in the output location. + if it is not set, 'summary_only' would invoked, which means user script is expected to store the output itself. + :type append_row_to: str + :param environment: Environment that training job will run in. + :type environment: Union[Environment, str] + """ + + def __init__( + self, # pylint: disable=unused-argument + *, + type: Optional[str] = None, # pylint: disable=redefined-builtin + code: Optional[str] = None, + entry_script: Optional[str] = None, + program_arguments: Optional[str] = None, + model: Optional[str] = None, + append_row_to: Optional[str] = None, + environment: Optional[Union[Environment, str]] = None, + **kwargs: Any, + ): + self.type = type + self.code = code + self.entry_script = entry_script + self.program_arguments = program_arguments + self.model = model + self.append_row_to = append_row_to + self.environment: Any = environment + + def _to_dict(self) -> Dict: + # pylint: disable=no-member + res: dict = ComponentParallelTaskSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + @classmethod + def _load( + cls, # pylint: disable=unused-argument + path: Optional[Union[PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "ParallelTask": + params_override = params_override or [] + data = load_yaml(path) + return ParallelTask._load_from_dict(data=data, path=path, params_override=params_override) + + @classmethod + def _load_from_dict( + cls, + data: dict, + path: Optional[Union[PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "ParallelTask": + params_override = params_override or [] + context = { + BASE_PATH_CONTEXT_KEY: Path(path).parent if path else Path.cwd(), + PARAMS_OVERRIDE_KEY: params_override, + } + res: ParallelTask = load_from_dict(ComponentParallelTaskSchema, data, context, **kwargs) + return res + + @classmethod + def _from_dict(cls, dct: dict) -> "ParallelTask": + obj = cls(**dict(dct.items())) + return obj + + def _validate(self) -> None: + if self.type is None: + msg = "'type' is required for ParallelTask {}." + raise ValidationException( + message=msg.format(self.type), + target=ErrorTarget.COMPONENT, + no_personal_data_message=msg.format(""), + error_category=ErrorCategory.USER_ERROR, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/parameterized_parallel.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/parameterized_parallel.py new file mode 100644 index 00000000..6b5dbced --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/parameterized_parallel.py @@ -0,0 +1,96 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import logging +from typing import Any, Dict, List, Optional, Union + +from ..job_resource_configuration import JobResourceConfiguration +from .parallel_task import ParallelTask +from .retry_settings import RetrySettings + +module_logger = logging.getLogger(__name__) + + +class ParameterizedParallel: + """Parallel component that contains the traning parallel and supporting parameters for the parallel. + + :param retry_settings: parallel component run failed retry + :type retry_settings: BatchRetrySettings + :param logging_level: A string of the logging level name + :type logging_level: str + :param max_concurrency_per_instance: The max parallellism that each compute instance has. + :type max_concurrency_per_instance: int + :param error_threshold: The number of item processing failures should be ignored. + :type error_threshold: int + :param mini_batch_error_threshold: The number of mini batch processing failures should be ignored. + :type mini_batch_error_threshold: int + :param task: The parallel task. + :type task: ParallelTask + :param mini_batch_size: The mini batch size. + :type mini_batch_size: str + :param input_data: The input data. + :type input_data: str + :param resources: Compute Resource configuration for the job. + :type resources: Union[Dict, ~azure.ai.ml.entities.JobResourceConfiguration] + """ + + # pylint: disable=too-many-instance-attributes + def __init__( + self, + retry_settings: Optional[RetrySettings] = None, + logging_level: Optional[str] = None, + max_concurrency_per_instance: Optional[int] = None, + error_threshold: Optional[int] = None, + mini_batch_error_threshold: Optional[int] = None, + input_data: Optional[str] = None, + task: Optional[ParallelTask] = None, + mini_batch_size: Optional[int] = None, + partition_keys: Optional[List] = None, + resources: Optional[Union[dict, JobResourceConfiguration]] = None, + environment_variables: Optional[Dict] = None, + ): + self.mini_batch_size = mini_batch_size + self.partition_keys = partition_keys + self.task = task + self.retry_settings = retry_settings + self.input_data = input_data + self.logging_level = logging_level + self.max_concurrency_per_instance = max_concurrency_per_instance + self.error_threshold = error_threshold + self.mini_batch_error_threshold = mini_batch_error_threshold + self.resources = resources + self.environment_variables = dict(environment_variables) if environment_variables else {} + + @property + def task(self) -> Optional[ParallelTask]: + res: Optional[ParallelTask] = self._task + return res + + @task.setter + def task(self, value: Any) -> None: + if isinstance(value, dict): + value = ParallelTask(**value) + self._task = value + + @property + def resources(self) -> Optional[Union[dict, JobResourceConfiguration]]: + res: Optional[Union[dict, JobResourceConfiguration]] = self._resources + return res + + @resources.setter + def resources(self, value: Any) -> None: + if isinstance(value, dict): + value = JobResourceConfiguration(**value) + self._resources = value + + @property + def retry_settings(self) -> Optional[RetrySettings]: + res: Optional[RetrySettings] = self._retry_settings + return res + + @retry_settings.setter + def retry_settings(self, value: Any) -> None: + if isinstance(value, dict): + value = RetrySettings(**value) + self._retry_settings = value diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/retry_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/retry_settings.py new file mode 100644 index 00000000..2fb19ba1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/retry_settings.py @@ -0,0 +1,78 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from os import PathLike +from pathlib import Path +from typing import Any, Dict, Optional, Union + +from azure.ai.ml._schema.component.retry_settings import RetrySettingsSchema +from azure.ai.ml._utils.utils import load_yaml +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY +from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin +from azure.ai.ml.entities._util import load_from_dict + + +class RetrySettings(RestTranslatableMixin, DictMixin): + """Parallel RetrySettings. + + :param timeout: Timeout in seconds for each invocation of the run() method. + (optional) This value could be set through PipelineParameter. + :type timeout: int + :param max_retries: The number of maximum tries for a failed or timeout mini batch. + The range is [1, int.max]. This value could be set through PipelineParameter. + A mini batch with dequeue count greater than this won't be processed again and will be deleted directly. + :type max_retries: int + """ + + def __init__( + self, # pylint: disable=unused-argument + *, + timeout: Optional[Union[int, str]] = None, + max_retries: Optional[Union[int, str]] = None, + **kwargs: Any, + ): + self.timeout = timeout + self.max_retries = max_retries + + def _to_dict(self) -> Dict: + res: dict = RetrySettingsSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) # pylint: disable=no-member + return res + + @classmethod + def _load( + cls, # pylint: disable=unused-argument + path: Optional[Union[PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "RetrySettings": + params_override = params_override or [] + data = load_yaml(path) + return RetrySettings._load_from_dict(data=data, path=path, params_override=params_override) + + @classmethod + def _load_from_dict( + cls, + data: dict, + path: Optional[Union[PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "RetrySettings": + params_override = params_override or [] + context = { + BASE_PATH_CONTEXT_KEY: Path(path).parent if path else Path.cwd(), + PARAMS_OVERRIDE_KEY: params_override, + } + res: RetrySettings = load_from_dict(RetrySettingsSchema, data, context, **kwargs) + return res + + @classmethod + def _from_dict(cls, dct: dict) -> "RetrySettings": + obj = cls(**dict(dct.items())) + return obj + + def _to_rest_object(self) -> Dict: + return self._to_dict() + + @classmethod + def _from_rest_object(cls, obj: dict) -> "RetrySettings": + return cls._from_dict(obj) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/run_function.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/run_function.py new file mode 100644 index 00000000..180cee76 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/run_function.py @@ -0,0 +1,66 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +from typing import Any, Optional, Union + +from azure.ai.ml.constants import ParallelTaskType +from azure.ai.ml.entities._assets.environment import Environment + +from .parallel_task import ParallelTask + + +class RunFunction(ParallelTask): + """Run Function. + + :param code: A local or remote path pointing at source code. + :type code: str + :param entry_script: User script which will be run in parallel on multiple nodes. This is + specified as a local file path. + The entry_script should contain two functions: + ``init()``: this function should be used for any costly or common preparation for subsequent inferences, + e.g., deserializing and loading the model into a global object. + ``run(mini_batch)``: The method to be parallelized. Each invocation will have one mini-batch. + 'mini_batch': Batch inference will invoke run method and pass either a list or a Pandas DataFrame as an + argument to the method. Each entry in min_batch will be a filepath if input is a FileDataset, + a Pandas DataFrame if input is a TabularDataset. + run() method should return a Pandas DataFrame or an array. + For append_row output_action, these returned elements are appended into the common output file. + For summary_only, the contents of the elements are ignored. For all output actions, + each returned output element indicates one successful inference of input element in the input mini-batch. + Each parallel worker process will call `init` once and then loop over `run` function until all mini-batches + are processed. + :type entry_script: str + :param program_arguments: The arguments of the parallel task. + :type args: str + :param model: The model of the parallel task. + :type model: str + :param append_row_to: All values output by run() method invocations will be aggregated into + one unique file which is created in the output location. + if it is not set, 'summary_only' would invoked, which means user script is expected to store the output itself. + :type append_row_to: str + :param environment: Environment that training job will run in. + :type environment: Union[Environment, str] + """ + + def __init__( + self, + *, + code: Optional[str] = None, + entry_script: Optional[str] = None, + program_arguments: Optional[str] = None, + model: Optional[str] = None, + append_row_to: Optional[str] = None, + environment: Optional[Union[Environment, str]] = None, + **kwargs: Any, + ): + super().__init__( + code=code, + entry_script=entry_script, + program_arguments=program_arguments, + model=model, + append_row_to=append_row_to, + environment=environment, + type=ParallelTaskType.RUN_FUNCTION, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parameterized_command.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parameterized_command.py new file mode 100644 index 00000000..57604b38 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parameterized_command.py @@ -0,0 +1,170 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import logging +import os +from typing import Dict, Optional, Union + +from marshmallow import INCLUDE + +from azure.ai.ml._restclient.v2023_04_01_preview.models import SweepJob +from azure.ai.ml._schema.core.fields import ExperimentalField +from azure.ai.ml.entities._assets import Environment + +from ..._schema import NestedField, UnionField +from ..._schema.job.distribution import ( + MPIDistributionSchema, + PyTorchDistributionSchema, + RayDistributionSchema, + TensorFlowDistributionSchema, +) +from .distribution import ( + DistributionConfiguration, + MpiDistribution, + PyTorchDistribution, + RayDistribution, + TensorFlowDistribution, +) +from .job_resource_configuration import JobResourceConfiguration +from .queue_settings import QueueSettings + +module_logger = logging.getLogger(__name__) + +# no reference found. leave it for future use. +INPUT_BINDING_PREFIX = "AZURE_ML_INPUT_" +OLD_INPUT_BINDING_PREFIX = "AZURE_ML_INPUT" + + +class ParameterizedCommand: + """Command component version that contains the command and supporting parameters for a Command component + or job. + + This class should not be instantiated directly. Instead, use the child class + ~azure.ai.ml.entities.CommandComponent. + + :param command: The command to be executed. Defaults to "". + :type command: str + :param resources: The compute resource configuration for the command. + :type resources: Optional[Union[dict, ~azure.ai.ml.entities.JobResourceConfiguration]] + :param code: The source code to run the job. Can be a local path or "http:", "https:", or "azureml:" url pointing + to a remote location. + :type code: Optional[str] + :param environment_variables: A dictionary of environment variable names and values. + These environment variables are set on the process where user script is being executed. + :type environment_variables: Optional[dict[str, str]] + :param distribution: The distribution configuration for distributed jobs. + :type distribution: Optional[Union[dict, ~azure.ai.ml.PyTorchDistribution, ~azure.ai.ml.MpiDistribution, + ~azure.ai.ml.TensorFlowDistribution, ~azure.ai.ml.RayDistribution]] + :param environment: The environment that the job will run in. + :type environment: Optional[Union[str, ~azure.ai.ml.entities.Environment]] + :param queue_settings: The queue settings for the job. + :type queue_settings: Optional[~azure.ai.ml.entities.QueueSettings] + :keyword kwargs: A dictionary of additional configuration parameters. + :paramtype kwargs: dict + """ + + def __init__( + self, + command: Optional[str] = "", + resources: Optional[Union[dict, JobResourceConfiguration]] = None, + code: Optional[Union[str, os.PathLike]] = None, + environment_variables: Optional[Dict] = None, + distribution: Optional[ + Union[ + Dict, + MpiDistribution, + TensorFlowDistribution, + PyTorchDistribution, + RayDistribution, + DistributionConfiguration, + ] + ] = None, + environment: Optional[Union[Environment, str]] = None, + queue_settings: Optional[QueueSettings] = None, + **kwargs: Dict, + ) -> None: + super().__init__(**kwargs) + self.command = command + self.code = code + self.environment_variables = dict(environment_variables) if environment_variables else {} + self.environment = environment + self.distribution = distribution + self.resources = resources # type: ignore[assignment] + self.queue_settings = queue_settings + + @property + def distribution( + self, + ) -> Optional[ + Union[ + dict, + MpiDistribution, + TensorFlowDistribution, + PyTorchDistribution, + RayDistribution, + DistributionConfiguration, + ] + ]: + """The configuration for the distributed command component or job. + + :return: The distribution configuration. + :rtype: Union[~azure.ai.ml.PyTorchDistribution, ~azure.ai.ml.MpiDistribution, + ~azure.ai.ml.TensorFlowDistribution, ~azure.ai.ml.RayDistribution] + """ + return self._distribution + + @distribution.setter + def distribution(self, value: Union[dict, PyTorchDistribution, MpiDistribution]) -> None: + """Sets the configuration for the distributed command component or job. + + :param value: The distribution configuration for distributed jobs. + :type value: Union[dict, ~azure.ai.ml.PyTorchDistribution, ~azure.ai.ml.MpiDistribution, + ~azure.ai.ml.TensorFlowDistribution, ~azure.ai.ml.RayDistribution] + """ + if isinstance(value, dict): + dist_schema = UnionField( + [ + NestedField(PyTorchDistributionSchema, unknown=INCLUDE), + NestedField(TensorFlowDistributionSchema, unknown=INCLUDE), + NestedField(MPIDistributionSchema, unknown=INCLUDE), + ExperimentalField(NestedField(RayDistributionSchema, unknown=INCLUDE)), + ] + ) + value = dist_schema._deserialize(value=value, attr=None, data=None) + self._distribution = value + + @property + def resources(self) -> JobResourceConfiguration: + """The compute resource configuration for the command component or job. + + :return: The compute resource configuration for the command component or job. + :rtype: ~azure.ai.ml.entities.JobResourceConfiguration + """ + return self._resources + + @resources.setter + def resources(self, value: Union[dict, JobResourceConfiguration]) -> None: + """Sets the compute resource configuration for the command component or job. + + :param value: The compute resource configuration for the command component or job. + :type value: Union[dict, ~azure.ai.ml.entities.JobResourceConfiguration] + """ + if isinstance(value, dict): + value = JobResourceConfiguration(**value) + self._resources = value + + @classmethod + def _load_from_sweep_job(cls, sweep_job: SweepJob) -> "ParameterizedCommand": + parameterized_command = cls( + command=sweep_job.trial.command, + code=sweep_job.trial.code_id, + environment_variables=sweep_job.trial.environment_variables, + environment=sweep_job.trial.environment_id, + distribution=DistributionConfiguration._from_rest_object(sweep_job.trial.distribution), + resources=JobResourceConfiguration._from_rest_object(sweep_job.trial.resources), + queue_settings=QueueSettings._from_rest_object(sweep_job.queue_settings), + ) + return parameterized_command diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parameterized_spark.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parameterized_spark.py new file mode 100644 index 00000000..c8a9a0c0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parameterized_spark.py @@ -0,0 +1,88 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import os +from typing import Any, Dict, List, Optional, Union + +from azure.ai.ml.entities._assets import Environment +from azure.ai.ml.entities._job.spark_job_entry import SparkJobEntry + +from .._job.spark_job_entry_mixin import SparkJobEntryMixin + +DUMMY_IMAGE = "conda/miniconda3" + + +class ParameterizedSpark(SparkJobEntryMixin): + """ + This class should not be instantiated directly. Instead, use the child class ~azure.ai.ml.entities.SparkComponent. + + Spark component that contains supporting parameters. + + :param code: The source code to run the job. Can be a local path or "http:", "https:", or "azureml:" url pointing + to a remote location. + :type code: Optional[Union[str, os.PathLike]] + :param entry: The file or class entry point. + :type entry: dict[str, str] + :param py_files: The list of .zip, .egg or .py files to place on the PYTHONPATH for Python apps. + :type py_files: Optional[list[str]] + :param jars: The list of .JAR files to include on the driver and executor classpaths. + :type jars: Optional[list[str]] + :param files: The list of files to be placed in the working directory of each executor. + :type files: Optional[list[str]] + :param archives: The list of archives to be extracted into the working directory of each executor. + :type archives: Optional[list[str]] + :param conf: A dictionary with pre-defined Spark configurations key and values. + :type conf: Optional[dict[str, str]] + :param environment: The Azure ML environment to run the job in. + :type environment: Optional[Union[str, ~azure.ai.ml.entities.Environment]] + :param args: The arguments for the job. + :type args: Optional[str] + :keyword kwargs: A dictionary of additional configuration parameters. + :paramtype kwargs: dict + """ + + def __init__( + self, + code: Optional[Union[str, os.PathLike]] = ".", + entry: Optional[Union[Dict[str, str], SparkJobEntry]] = None, + py_files: Optional[List[str]] = None, + jars: Optional[List[str]] = None, + files: Optional[List[str]] = None, + archives: Optional[List[str]] = None, + conf: Optional[Dict[str, str]] = None, + environment: Optional[Union[str, Environment]] = None, + args: Optional[str] = None, + **kwargs: Any, + ) -> None: + self.args = None + + super().__init__(**kwargs) + self.code = code + self.entry = entry + self.py_files = py_files + self.jars = jars + self.files = files + self.archives = archives + self.conf = conf + self.environment = environment + self.args = args + + @property + def environment(self) -> Optional[Union[str, Environment]]: + """The Azure ML environment to run the Spark component or job in. + + :return: The Azure ML environment to run the Spark component or job in. + :rtype: Optional[Union[str, ~azure.ai.ml.entities.Environment]] + """ + if isinstance(self._environment, Environment) and self._environment.image is None: + return Environment(conda_file=self._environment.conda_file, image=DUMMY_IMAGE) + return self._environment + + @environment.setter + def environment(self, value: Optional[Union[str, Environment]]) -> None: + """Sets the Azure ML environment to run the Spark component or job in. + + :param value: The Azure ML environment to run the Spark component or job in. + :type value: Optional[Union[str, ~azure.ai.ml.entities.Environment]] + """ + self._environment = value diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_attr_dict.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_attr_dict.py new file mode 100644 index 00000000..cf8d92be --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_attr_dict.py @@ -0,0 +1,161 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import logging +from abc import ABC +from typing import Any, Dict, Generic, List, Optional, TypeVar + +K = TypeVar("K") +V = TypeVar("V") + + +class _AttrDict(Generic[K, V], Dict, ABC): + """This class is used for accessing values with instance.some_key. It supports the following scenarios: + + 1. Setting arbitrary attribute, eg: obj.resource_layout.node_count = 2 + 1.1 Setting same nested filed twice will return same object, eg: + obj.resource_layout.node_count = 2 + obj.resource_layout.process_count_per_node = 2 + obj.resource_layout will be {"node_count": 2, "process_count_per_node": 2} + 1.2 Only public attribute is supported, eg: obj._resource_layout._node_count = 2 will raise AttributeError + 1.3 All set attribute can be recorded, eg: + obj.target = "aml" + obj.resource_layout.process_count_per_node = 2 + obj.get_attr() will return {"target": "aml", "resource_layout": {"process_count_per_node": 2}} + 2. Getting arbitrary attribute, getting non-exist attribute will return an empty dict. + 3. Calling arbitrary methods is not allowed, eg: obj.resource_layout() should raise AttributeError + """ + + def __init__(self, allowed_keys: Optional[Dict] = None, **kwargs: Any): + """Initialize a attribute dictionary. + + :param allowed_keys: A dictionary of keys that allowed to set as arbitrary attributes. None means all keys can + be set as arbitrary attributes. + + :type dict + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + """ + super(_AttrDict, self).__init__(**kwargs) + if allowed_keys is None: + # None allowed_keys means no restriction on keys can be set for _AttrDict + self._allowed_keys = {} + self._key_restriction = False + else: + # Otherwise use allowed_keys to restrict keys can be set for _AttrDict + self._allowed_keys = dict(allowed_keys) + self._key_restriction = True + self._logger = logging.getLogger("attr_dict") + + def _initializing(self) -> bool: + # use this to indicate ongoing init process, sub class need to make sure this return True during init process. + return False + + def _get_attrs(self) -> dict: + """Get all arbitrary attributes which has been set, empty values are excluded. + + :return: A dict which contains all arbitrary attributes set by user. + :rtype: dict + """ + + # TODO: check this + def remove_empty_values(data: Dict) -> Dict: + if not isinstance(data, dict): + return data + # skip empty dicts as default value of _AttrDict is empty dict + return {k: remove_empty_values(v) for k, v in data.items() if v or not isinstance(v, dict)} + + return remove_empty_values(self) + + def _is_arbitrary_attr(self, attr_name: str) -> bool: + """Checks if a given attribute name should be treat as arbitrary attribute. + + Attributes inside _AttrDict can be non-arbitrary attribute or arbitrary attribute. + Non-arbitrary attributes are normal attributes like other object which stores in self.__dict__. + Arbitrary attributes are attributes stored in the dictionary it self, what makes it special it it's value + can be an instance of _AttrDict + Take `obj = _AttrDict(allowed_keys={"resource_layout": {"node_count": None}})` as an example. + `obj.some_key` is accessing non-arbitrary attribute. + `obj.resource_layout` is accessing arbitrary attribute, user can use `obj.resource_layout.node_count = 1` to + assign value to it. + + :param attr_name: Attribute name + :type attr_name: str + :return: If the given attribute name should be treated as arbitrary attribute. + :rtype: bool + """ + # Internal attribute won't be set as arbitrary attribute. + if attr_name.startswith("_"): + return False + # All attributes set in __init__ won't be set as arbitrary attribute + if self._initializing(): + return False + # If there's key restriction, only keys in it can be set as arbitrary attribute. + if self._key_restriction and attr_name not in self._allowed_keys: + return False + # Attributes already in attribute dict will not be set as arbitrary attribute. + try: + self.__getattribute__(attr_name) + except AttributeError: + return True + return False + + def __getattr__(self, key: Any) -> Any: + if not self._is_arbitrary_attr(key): + return super().__getattribute__(key) + self._logger.debug("getting %s", key) + try: + return super().__getitem__(key) + except KeyError: + allowed_keys = self._allowed_keys.get(key, None) if self._key_restriction else None + result: Any = _AttrDict(allowed_keys=allowed_keys) + self.__setattr__(key, result) + return result + + def __setattr__(self, key: Any, value: V) -> None: + if not self._is_arbitrary_attr(key): + super().__setattr__(key, value) + else: + self._logger.debug("setting %s to %s", key, value) + super().__setitem__(key, value) + + def __setitem__(self, key: Any, value: V) -> None: + self.__setattr__(key, value) + + def __getitem__(self, item: V) -> Any: + # support attr_dict[item] since dumping it in marshmallow requires this. + return self.__getattr__(item) + + def __dir__(self) -> List: + # For Jupyter Notebook auto-completion + return list(super().__dir__()) + list(self.keys()) + + +def has_attr_safe(obj: Any, attr: Any) -> bool: + if isinstance(obj, _AttrDict): + has_attr = not obj._is_arbitrary_attr(attr) + elif isinstance(obj, dict): + return attr in obj + else: + has_attr = hasattr(obj, attr) + return has_attr + + +def try_get_non_arbitrary_attr(obj: Any, attr: str) -> Optional[Any]: + """Try to get non-arbitrary attribute for potential attribute dict. + + Will not create target attribute if it is an arbitrary attribute in _AttrDict. + + :param obj: The obj + :type obj: Any + :param attr: The attribute name + :type attr: str + :return: obj.attr + :rtype: Any + """ + if has_attr_safe(obj, attr): + return obj[attr] if isinstance(obj, dict) else getattr(obj, attr) + return None diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_component_translatable.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_component_translatable.py new file mode 100644 index 00000000..22be939d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_component_translatable.py @@ -0,0 +1,412 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint: disable=protected-access, redefined-builtin +# disable redefined-builtin to use input as argument name +import re +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union + +from pydash import get + +from azure.ai.ml._utils.utils import is_data_binding_expression +from azure.ai.ml.constants._common import AssetTypes +from azure.ai.ml.constants._component import ComponentJobConstants +from azure.ai.ml.entities._inputs_outputs import Input, Output +from azure.ai.ml.entities._job.pipeline._io import PipelineInput, PipelineOutput +from azure.ai.ml.entities._job.sweep.search_space import Choice, Randint, SweepDistribution +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, JobException + +# avoid circular import error +if TYPE_CHECKING: + from azure.ai.ml.entities._builders import BaseNode + from azure.ai.ml.entities._component.component import Component + + +class ComponentTranslatableMixin: + _PYTHON_SDK_TYPE_MAPPING = { + float: "number", + int: "integer", + bool: "boolean", + str: "string", + } + + @classmethod + def _find_source_from_parent_inputs(cls, input: str, pipeline_job_inputs: dict) -> Tuple[str, Optional[str]]: + """Find source type and mode of input/output from parent input. + + :param input: The input name + :type input: str + :param pipeline_job_inputs: The pipeline job inputs + :type pipeline_job_inputs: dict + :return: A 2-tuple of the type and the mode + :rtype: Tuple[str, Optional[str]] + """ + _input_name = input.split(".")[2][:-2] + if _input_name not in pipeline_job_inputs.keys(): + msg = "Failed to find top level definition for input binding {}." + raise JobException( + message=msg.format(input), + no_personal_data_message=msg.format("[input]"), + target=ErrorTarget.PIPELINE, + error_category=ErrorCategory.USER_ERROR, + ) + input_data = pipeline_job_inputs[_input_name] + input_type = type(input_data) + if input_type in cls._PYTHON_SDK_TYPE_MAPPING: + return cls._PYTHON_SDK_TYPE_MAPPING[input_type], None + return getattr(input_data, "type", AssetTypes.URI_FOLDER), getattr(input_data, "mode", None) + + @classmethod + def _find_source_from_parent_outputs(cls, input: str, pipeline_job_outputs: dict) -> Tuple[str, Optional[str]]: + """Find source type and mode of input/output from parent output. + + :param input: The input name + :type input: str + :param pipeline_job_outputs: The pipeline job outputs + :type pipeline_job_outputs: dict + :return: A 2-tuple of the type and the mode + :rtype: Tuple[str, Optional[str]] + """ + _output_name = input.split(".")[2][:-2] + if _output_name not in pipeline_job_outputs.keys(): + msg = "Failed to find top level definition for output binding {}." + raise JobException( + message=msg.format(input), + no_personal_data_message=msg.format("[input]"), + target=ErrorTarget.PIPELINE, + error_category=ErrorCategory.USER_ERROR, + ) + output_data = pipeline_job_outputs[_output_name] + output_type = type(output_data) + if output_type in cls._PYTHON_SDK_TYPE_MAPPING: + return cls._PYTHON_SDK_TYPE_MAPPING[output_type], None + if isinstance(output_data, dict): + if "type" in output_data: + output_data_type = output_data["type"] + else: + output_data_type = AssetTypes.URI_FOLDER + if "mode" in output_data: + output_data_mode = output_data["mode"] + else: + output_data_mode = None + return output_data_type, output_data_mode + return getattr(output_data, "type", AssetTypes.URI_FOLDER), getattr(output_data, "mode", None) + + @classmethod + def _find_source_from_other_jobs( + cls, input: str, jobs_dict: dict, pipeline_job_dict: dict + ) -> Tuple[str, Optional[str]]: + """Find source type and mode of input/output from other job. + + :param input: The input name + :type input: str + :param jobs_dict: The job dict + :type jobs_dict: + :param pipeline_job_dict: The pipeline job dict + :type pipeline_job_dict: dict + :return: A 2-tuple of the type and the mode + :rtype: Tuple[str, Optional[str]] + """ + from azure.ai.ml.entities import CommandJob + from azure.ai.ml.entities._builders import BaseNode + from azure.ai.ml.entities._job.automl.automl_job import AutoMLJob + from azure.ai.ml.parallel import ParallelJob + + _input_regex = r"\${{parent.jobs.([^.]+).([^.]+).([^.]+)}}" + m = re.match(_input_regex, input) + if m is None: + msg = "Failed to find top level definition for job binding {}." + raise JobException( + message=msg.format(input), + no_personal_data_message=msg.format("[input]"), + target=ErrorTarget.PIPELINE, + error_category=ErrorCategory.USER_ERROR, + ) + _input_job_name, _io_type, _name = m.groups() + _input_job = jobs_dict[_input_job_name] + + # we only support input of one job is from output of another output, but input mode should be decoupled with + # output mode, so we always return None source_mode + source_mode = None + if isinstance(_input_job, BaseNode): + # If source is base node, get type from io builder + _source = _input_job[_io_type][_name] + try: + source_type = _source.type + # Todo: get component type for registered component, and no need following codes + # source_type is None means _input_job's component is registered component which results in its + # input/output type is None. + if source_type is None: + if _source._data is None: + # return default type if _input_job's output data is None + source_type = AssetTypes.URI_FOLDER + elif isinstance(_source._data, Output): + # if _input_job data is a Output object and we return its type. + source_type = _source._data.type + else: + # otherwise _input_job's input/output is bound to pipeline input/output, we continue + # infer the type according to _source._data. Will return corresponding pipeline + # input/output type because we didn't get the component. + source_type, _ = cls._find_source_input_output_type(_source._data, pipeline_job_dict) + return source_type, source_mode + except AttributeError as e: + msg = "Failed to get referenced component type {}." + raise JobException( + message=msg.format(_input_regex), + no_personal_data_message=msg.format("[_input_regex]"), + target=ErrorTarget.PIPELINE, + error_category=ErrorCategory.USER_ERROR, + ) from e + if isinstance(_input_job, (CommandJob, ParallelJob)): + # If source has not parsed to Command yet, infer type + _source = get(_input_job, f"{_io_type}.{_name}") + if isinstance(_source, str): + source_type, _ = cls._find_source_input_output_type(_source, pipeline_job_dict) + return source_type, source_mode + return getattr(_source, "type", AssetTypes.URI_FOLDER), source_mode + if isinstance(_input_job, AutoMLJob): + # If source is AutoMLJob, only outputs is supported + if _io_type != "outputs": + msg = f"Only binding to AutoMLJob output is supported, currently got {_io_type}" + raise JobException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.PIPELINE, + error_category=ErrorCategory.USER_ERROR, + ) + # AutoMLJob's output type can only be MLTABLE + return AssetTypes.MLTABLE, source_mode + msg = f"Unknown referenced source job type: {type(_input_job)}." + raise JobException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.PIPELINE, + error_category=ErrorCategory.USER_ERROR, + ) + + @classmethod + def _find_source_input_output_type(cls, input: str, pipeline_job_dict: dict) -> Tuple[str, Optional[str]]: + """Find source type and mode of input/output. + + :param input: The input binding + :type input: str + :param pipeline_job_dict: The pipeline job dict + :type pipeline_job_dict: dict + :return: A 2-tuple of the type and the mode + :rtype: Tuple[str, Optional[str]] + """ + pipeline_job_inputs = pipeline_job_dict.get("inputs", {}) + pipeline_job_outputs = pipeline_job_dict.get("outputs", {}) + jobs_dict = pipeline_job_dict.get("jobs", {}) + if is_data_binding_expression(input, ["parent", "inputs"]): + return cls._find_source_from_parent_inputs(input, pipeline_job_inputs) + if is_data_binding_expression(input, ["parent", "outputs"]): + return cls._find_source_from_parent_outputs(input, pipeline_job_outputs) + if is_data_binding_expression(input, ["parent", "jobs"]): + try: + return cls._find_source_from_other_jobs(input, jobs_dict, pipeline_job_dict) + except JobException as e: + raise e + except Exception as e: + msg = "Failed to find referenced source for input binding {}" + raise JobException( + message=msg.format(input), + no_personal_data_message=msg.format("[input]"), + target=ErrorTarget.PIPELINE, + error_category=ErrorCategory.SYSTEM_ERROR, + ) from e + else: + msg = "Job input in a pipeline can bind only to a job output or a pipeline input" + raise JobException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.PIPELINE, + error_category=ErrorCategory.USER_ERROR, + ) + + @classmethod + def _to_input( + cls, # pylint: disable=unused-argument + input: Union[Input, str, bool, int, float], + pipeline_job_dict: Optional[dict] = None, + **kwargs: Any, + ) -> Input: + """Convert a single job input value to component input. + + :param input: The input + :type input: Union[Input, str, bool, int, float] + :param pipeline_job_dict: The pipeline job dict + :type pipeline_job_dict: Optional[dict] + :return: The Component Input + :rtype: Input + """ + pipeline_job_dict = pipeline_job_dict or {} + input_variable: Dict = {} + + if isinstance(input, str) and bool(re.search(ComponentJobConstants.INPUT_PATTERN, input)): + # handle input bindings + input_variable["type"], input_variable["mode"] = cls._find_source_input_output_type( + input, pipeline_job_dict + ) + + elif isinstance(input, Input): + input_variable = input._to_dict() + elif isinstance(input, SweepDistribution): + if isinstance(input, Choice): + if input.values is not None: + input_variable["type"] = cls._PYTHON_SDK_TYPE_MAPPING[type(input.values[0])] + elif isinstance(input, Randint): + input_variable["type"] = cls._PYTHON_SDK_TYPE_MAPPING[int] + else: + input_variable["type"] = cls._PYTHON_SDK_TYPE_MAPPING[float] + + input_variable["optional"] = False + elif type(input) in cls._PYTHON_SDK_TYPE_MAPPING: + input_variable["type"] = cls._PYTHON_SDK_TYPE_MAPPING[type(input)] + input_variable["default"] = input + elif isinstance(input, PipelineInput): + # Infer input type from input data + input_variable = input._to_input()._to_dict() + else: + msg = "'{}' is not supported as component input, supported types are '{}'.".format( + type(input), cls._PYTHON_SDK_TYPE_MAPPING.keys() + ) + raise JobException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.PIPELINE, + error_category=ErrorCategory.USER_ERROR, + ) + return Input(**input_variable) + + @classmethod + def _to_input_builder_function(cls, input: Union[Dict, SweepDistribution, Input, str, bool, int, float]) -> Input: + input_variable = {} + + if isinstance(input, Input): + input_variable = input._to_dict() + elif isinstance(input, SweepDistribution): + if isinstance(input, Choice): + if input.values is not None: + input_variable["type"] = cls._PYTHON_SDK_TYPE_MAPPING[type(input.values[0])] + elif isinstance(input, Randint): + input_variable["type"] = cls._PYTHON_SDK_TYPE_MAPPING[int] + else: + input_variable["type"] = cls._PYTHON_SDK_TYPE_MAPPING[float] + + input_variable["optional"] = False + else: + input_variable["type"] = cls._PYTHON_SDK_TYPE_MAPPING[type(input)] + input_variable["default"] = input + return Input(**input_variable) + + @classmethod + def _to_output( + cls, # pylint: disable=unused-argument + output: Optional[Union[Output, Dict, str, bool, int, float]], + pipeline_job_dict: Optional[dict] = None, + **kwargs: Any, + ) -> Output: + """Translate output value to Output and infer component output type + from linked pipeline output, its original type or default type. + + :param output: The output + :type output: Union[Output, str, bool, int, float] + :param pipeline_job_dict: The pipeline job dict + :type pipeline_job_dict: Optional[dict] + :return: The output object + :rtype: Output + """ + pipeline_job_dict = pipeline_job_dict or {} + output_type = None + if not pipeline_job_dict or output is None: + try: + output_type = output.type # type: ignore + except AttributeError: + # default to url_folder if failed to get type + output_type = AssetTypes.URI_FOLDER + output_variable = {"type": output_type} + return Output(**output_variable) + output_variable = {} + + if isinstance(output, str) and bool(re.search(ComponentJobConstants.OUTPUT_PATTERN, output)): + # handle output bindings + output_variable["type"], output_variable["mode"] = cls._find_source_input_output_type( + output, pipeline_job_dict + ) + + elif isinstance(output, Output): + output_variable = output._to_dict() + + elif isinstance(output, PipelineOutput): + output_variable = output._to_output()._to_dict() + + elif type(output) in cls._PYTHON_SDK_TYPE_MAPPING: + output_variable["type"] = cls._PYTHON_SDK_TYPE_MAPPING[type(output)] + output_variable["default"] = output + else: + msg = "'{}' is not supported as component output, supported types are '{}'.".format( + type(output), cls._PYTHON_SDK_TYPE_MAPPING.keys() + ) + raise JobException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.PIPELINE, + error_category=ErrorCategory.USER_ERROR, + ) + return Output(**output_variable) + + def _to_inputs(self, inputs: Optional[Dict], **kwargs: Any) -> Dict: + """Translate inputs to Inputs. + + :param inputs: mapping from input name to input object. + :type inputs: Dict[str, Union[Input, str, bool, int, float]] + :return: mapping from input name to translated component input. + :rtype: Dict[str, Input] + """ + pipeline_job_dict = kwargs.get("pipeline_job_dict", {}) + translated_component_inputs = {} + if inputs is not None: + for io_name, io_value in inputs.items(): + translated_component_inputs[io_name] = self._to_input(io_value, pipeline_job_dict) + return translated_component_inputs + + def _to_outputs(self, outputs: Optional[Dict], **kwargs: Any) -> Dict: + """Translate outputs to Outputs. + + :param outputs: mapping from output name to output object. + :type outputs: Dict[str, Output] + :return: mapping from output name to translated component output. + :rtype: Dict[str, Output] + """ + # Translate outputs to Outputs. + pipeline_job_dict = kwargs.get("pipeline_job_dict", {}) + translated_component_outputs = {} + if outputs is not None: + for output_name, output_value in outputs.items(): + translated_component_outputs[output_name] = self._to_output(output_value, pipeline_job_dict) + return translated_component_outputs + + def _to_component(self, context: Optional[Dict] = None, **kwargs: Any) -> Union["Component", str]: + """Translate to Component. + + :param context: The context + :type context: Optional[context] + :return: Translated Component. + :rtype: Component + """ + # Note: Source of translated component should be same with Job + # And should be set after called _to_component/_to_node as job has no _source now. + raise NotImplementedError() + + def _to_node(self, context: Optional[Dict] = None, **kwargs: Any) -> "BaseNode": + """Translate to pipeline node. + + :param context: The context + :type context: Optional[context] + :return: Translated node. + :rtype: BaseNode + """ + # Note: Source of translated component should be same with Job + # And should be set after called _to_component/_to_node as job has no _source now. + raise NotImplementedError() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_io/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_io/__init__.py new file mode 100644 index 00000000..3ccde947 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_io/__init__.py @@ -0,0 +1,21 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Classes in this package converts input & output set by user to pipeline job input & output.""" + +from .attr_dict import OutputsAttrDict, _GroupAttrDict +from .base import InputOutputBase, NodeInput, NodeOutput, PipelineInput, PipelineOutput +from .mixin import AutoMLNodeIOMixin, NodeWithGroupInputMixin, PipelineJobIOMixin + +__all__ = [ + "PipelineOutput", + "PipelineInput", + "NodeOutput", + "NodeInput", + "InputOutputBase", + "OutputsAttrDict", + "_GroupAttrDict", + "NodeWithGroupInputMixin", + "AutoMLNodeIOMixin", + "PipelineJobIOMixin", +] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_io/attr_dict.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_io/attr_dict.py new file mode 100644 index 00000000..0ae08bcd --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_io/attr_dict.py @@ -0,0 +1,170 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from enum import Enum +from typing import Any, Dict, List, Optional, Union + +from azure.ai.ml.entities._assets import Data +from azure.ai.ml.entities._inputs_outputs import GroupInput, Input, Output +from azure.ai.ml.entities._job.pipeline._attr_dict import K +from azure.ai.ml.entities._job.pipeline._io.base import NodeInput, NodeOutput, PipelineInput +from azure.ai.ml.exceptions import ( + ErrorCategory, + ErrorTarget, + UnexpectedAttributeError, + UnexpectedKeywordError, + ValidationException, +) + + +class InputsAttrDict(dict): + def __init__(self, inputs: dict, **kwargs: Any): + self._validate_inputs(inputs) + super(InputsAttrDict, self).__init__(**inputs, **kwargs) + + @classmethod + def _validate_inputs(cls, inputs: Any) -> None: + msg = "Pipeline/component input should be a \ + azure.ai.ml.entities._job.pipeline._io.NodeInput with owner, got {}." + for val in inputs.values(): + if isinstance(val, NodeInput) and val._owner is not None: # pylint: disable=protected-access + continue + if isinstance(val, _GroupAttrDict): + continue + raise ValidationException( + message=msg.format(val), + no_personal_data_message=msg.format("[val]"), + target=ErrorTarget.PIPELINE, + error_category=ErrorCategory.USER_ERROR, + ) + + def __setattr__( + self, + key: str, + value: Union[int, bool, float, str, NodeOutput, PipelineInput, Input], + ) -> None: + # Extract enum value. + value = value.value if isinstance(value, Enum) else value + original_input = self.__getattr__(key) # Note that an exception will be raised if the keyword is invalid. + if isinstance(original_input, _GroupAttrDict) or isinstance(value, _GroupAttrDict): + # Set the value directly if is parameter group. + self._set_group_with_type_check(key, GroupInput.custom_class_value_to_attr_dict(value)) + return + original_input._data = original_input._build_data(value) + + def _set_group_with_type_check(self, key: Any, value: Any) -> None: + msg = "{!r} is expected to be a parameter group, but got {}." + if not isinstance(value, _GroupAttrDict): + raise ValidationException( + message=msg.format(key, type(value)), + no_personal_data_message=msg.format("[key]", "[value_type]"), + target=ErrorTarget.PIPELINE, + error_category=ErrorCategory.USER_ERROR, + ) + self.__setitem__(key, GroupInput.custom_class_value_to_attr_dict(value)) + + def __getattr__(self, item: Any) -> NodeInput: + res: NodeInput = self.__getitem__(item) + return res + + +class _GroupAttrDict(InputsAttrDict): + """This class is used for accessing values with instance.some_key.""" + + @classmethod + def _validate_inputs(cls, inputs: Any) -> None: + msg = "Pipeline/component input should be a azure.ai.ml.entities._job.pipeline._io.NodeInput, got {}." + for val in inputs.values(): + if isinstance(val, NodeInput) and val._owner is not None: # pylint: disable=protected-access + continue + if isinstance(val, _GroupAttrDict): + continue + # Allow PipelineInput as Group may appear at top level pipeline input. + if isinstance(val, PipelineInput): + continue + raise ValidationException( + message=msg.format(val), + no_personal_data_message=msg.format("[val]"), + target=ErrorTarget.PIPELINE, + error_category=ErrorCategory.USER_ERROR, + ) + + def __getattr__(self, name: K) -> Any: + if name not in self: + raise UnexpectedAttributeError(keyword=name, keywords=list(self)) + return super().__getitem__(name) + + def __getitem__(self, item: K) -> Any: + # We raise this exception instead of KeyError + if item not in self: + raise UnexpectedKeywordError(func_name="ParameterGroup", keyword=item, keywords=list(self)) + return super().__getitem__(item) + + # For Jupyter Notebook auto-completion + def __dir__(self) -> List: + return list(super().__dir__()) + list(self.keys()) + + def flatten(self, group_parameter_name: Optional[str]) -> Dict: + # Return the flattened result of self + + group_parameter_name = group_parameter_name if group_parameter_name else "" + flattened_parameters = {} + msg = "'%s' in parameter group should be a azure.ai.ml.entities._job._io.NodeInput, got '%s'." + for k, v in self.items(): + flattened_name = ".".join([group_parameter_name, k]) + if isinstance(v, _GroupAttrDict): + flattened_parameters.update(v.flatten(flattened_name)) + elif isinstance(v, NodeInput): + flattened_parameters[flattened_name] = v._to_job_input() # pylint: disable=protected-access + else: + raise ValidationException( + message=msg % (flattened_name, type(v)), + no_personal_data_message=msg % ("name", "type"), + target=ErrorTarget.PIPELINE, + ) + return flattened_parameters + + def insert_group_name_for_items(self, group_name: Any) -> None: + # Insert one group name for all items. + for v in self.values(): + if isinstance(v, _GroupAttrDict): + v.insert_group_name_for_items(group_name) + elif isinstance(v, PipelineInput): + # Insert group names for pipeline input + v._group_names = [group_name] + v._group_names # pylint: disable=protected-access + + +class OutputsAttrDict(dict): + def __init__(self, outputs: dict, **kwargs: Any): + for val in outputs.values(): + if not isinstance(val, NodeOutput) or val._owner is None: + msg = "Pipeline/component output should be a azure.ai.ml.dsl.Output with owner, got {}." + raise ValidationException( + message=msg.format(val), + no_personal_data_message=msg.format("[val]"), + target=ErrorTarget.PIPELINE, + error_category=ErrorCategory.USER_ERROR, + ) + super(OutputsAttrDict, self).__init__(**outputs, **kwargs) + + def __getattr__(self, item: Any) -> NodeOutput: + return self.__getitem__(item) + + def __getitem__(self, item: Any) -> NodeOutput: + if item not in self: + # We raise this exception instead of KeyError as OutputsAttrDict doesn't support add new item after + # __init__. + raise UnexpectedAttributeError(keyword=item, keywords=list(self)) + res: NodeOutput = super().__getitem__(item) + return res + + def __setattr__(self, key: str, value: Union[Data, Output]) -> None: + if isinstance(value, Output): + mode = value.mode + value = Output(type=value.type, path=value.path, mode=mode, name=value.name, version=value.version) + original_output = self.__getattr__(key) # Note that an exception will be raised if the keyword is invalid. + original_output._data = original_output._build_data(value) + + def __setitem__(self, key: str, value: Output) -> None: + return self.__setattr__(key, value) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_io/base.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_io/base.py new file mode 100644 index 00000000..b17972ae --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_io/base.py @@ -0,0 +1,848 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import copy +import re +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, TypeVar, Union, cast, overload + +from azure.ai.ml._utils.utils import is_data_binding_expression +from azure.ai.ml.constants import AssetTypes +from azure.ai.ml.constants._component import IOConstants +from azure.ai.ml.entities._assets._artifacts.data import Data +from azure.ai.ml.entities._assets._artifacts.model import Model +from azure.ai.ml.entities._inputs_outputs import Input, Output +from azure.ai.ml.entities._job.pipeline._pipeline_expression import PipelineExpressionMixin +from azure.ai.ml.entities._util import resolve_pipeline_parameter +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, UserErrorException, ValidationException + +# avoid circular import error +if TYPE_CHECKING: + from azure.ai.ml.entities import PipelineJob + from azure.ai.ml.entities._builders import BaseNode + +T = TypeVar("T") + + +def _build_data_binding(data: Union[str, "PipelineInput", "Output"]) -> Union[str, Output]: + """Build input builders to data bindings. + + :param data: The data to build a data binding from + :type data: Union[str, PipelineInput, Output] + :return: A data binding string if data isn't a str, otherwise data + :rtype: str + """ + result: Union[str, Output] = "" + + if isinstance(data, (InputOutputBase)): + # Build data binding when data is PipelineInput, Output + result = data._data_binding() + else: + # Otherwise just return the data + result = data + return result + + +def _resolve_builders_2_data_bindings( + data: Union[list, dict, str, "PipelineInput", "Output"] +) -> Union[dict, list, str, Output]: + """Traverse data and build input builders inside it to data bindings. + + :param data: The bindings to resolve + :type data: Union[list, dict, str, "PipelineInput", "Output"] + :return: + * A dict if data was a dict + * A list if data was a list + * A str otherwise + :rtype: Union[list, dict, str] + """ + if isinstance(data, dict): + for key, val in data.items(): + if isinstance(val, (dict, list)): + data[key] = _resolve_builders_2_data_bindings(val) + else: + data[key] = _build_data_binding(val) + return data + if isinstance(data, list): + resolved_data = [] + for val in data: + resolved_data.append(_resolve_builders_2_data_bindings(val)) + return resolved_data + return _build_data_binding(data) + + +def _data_to_input(data: Union[Data, Model]) -> Input: + """Convert a Data object to an Input object. + + :param data: The data to convert + :type data: Data + :return: The Input object + :rtype: Input + """ + if data.id: + return Input(type=data.type, path=data.id) + return Input(type=data.type, path=f"{data.name}:{data.version}") + + +class InputOutputBase(ABC): + # TODO: refine this code, always use _data to store builder level settings and use _meta to store definition + # TODO: when _data missing, return value from _meta + + def __init__( + self, + meta: Optional[Union[Input, Output]], + data: Optional[Union[int, bool, float, str, Input, Output, "PipelineInput"]], + default_data: Optional[Union[int, bool, float, str, Input, Output]] = None, + **kwargs: Any, + ): + """Base class of input & output. + + :param meta: Metadata of this input/output, eg: type, min, max, etc. + :type meta: Union[Input, Output] + :param data: Actual value of input/output, None means un-configured data. + :type data: Union[None, int, bool, float, str, + azure.ai.ml.Input, + azure.ai.ml.Output] + :param default_data: default value of input/output, None means un-configured data. + :type default_data: Union[None, int, bool, float, str, + azure.ai.ml.Input, + azure.ai.ml.Output] + """ + self._meta = meta + self._original_data = data + self._data: Any = self._build_data(data) + self._default_data = default_data + self._type: str = meta.type if meta is not None else kwargs.pop("type", None) + self._mode = self._get_mode(original_data=data, data=self._data, kwargs=kwargs) + self._description = ( + self._data.description + if self._data is not None and hasattr(self._data, "description") and self._data.description + else kwargs.pop("description", None) + ) + # TODO: remove this + self._attribute_map: Dict = {} + self._name: Optional[str] = "" + self._version: Optional[str] = "" + super(InputOutputBase, self).__init__(**kwargs) + + @abstractmethod + def _build_data(self, data: T) -> Union[T, str, Input, "InputOutputBase"]: + """Validate if data matches type and translate it to Input/Output acceptable type. + + :param data: The data + :type data: T + :return: The built data + :rtype: Union[T, str, Input, InputOutputBase] + """ + + @abstractmethod + def _build_default_data(self) -> None: + """Build default data when data not configured.""" + + @property + def type(self) -> str: + """Type of input/output. + + :return: The type + :rtype: str + """ + return self._type + + @type.setter + def type(self, type: Any) -> None: # pylint: disable=redefined-builtin + # For un-configured input/output, we build a default data entry for them. + self._build_default_data() + self._type = type + if isinstance(self._data, (Input, Output)): + self._data.type = type + elif self._data is not None and not isinstance( + self._data, (int, float, str) + ): # when type of self._data is InputOutputBase or its child class + self._data._type = type + + @property + def mode(self) -> Optional[str]: + return self._mode + + @mode.setter + def mode(self, mode: Optional[str]) -> None: + # For un-configured input/output, we build a default data entry for them. + self._build_default_data() + self._mode = mode + if isinstance(self._data, (Input, Output)): + self._data.mode = mode + elif self._data is not None and not isinstance(self._data, (int, float, str)): + self._data._mode = mode + + @property + def description(self) -> Any: + return self._description + + @description.setter + def description(self, description: str) -> None: + # For un-configured input/output, we build a default data entry for them. + self._build_default_data() + self._description = description + if isinstance(self._data, (Input, Output)): + self._data.description = description + elif self._data is not None and not isinstance(self._data, (int, float, str)): + self._data._description = description + + @property + def path(self) -> Optional[str]: + # This property is introduced for static intellisense. + if hasattr(self._data, "path"): + if self._data is not None and not isinstance(self._data, (int, float, str)): + res: Optional[str] = self._data.path + return res + msg = f"{type(self._data)} does not have path." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.PIPELINE, + error_category=ErrorCategory.USER_ERROR, + ) + + @path.setter + def path(self, path: str) -> None: + # For un-configured input/output, we build a default data entry for them. + self._build_default_data() + if hasattr(self._data, "path"): + if self._data is not None and not isinstance(self._data, (int, float, str)): + self._data.path = path + else: + msg = f"{type(self._data)} does not support setting path." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.PIPELINE, + error_category=ErrorCategory.USER_ERROR, + ) + + def _data_binding(self) -> str: + """Return data binding string representation for this input/output. + + :return: The data binding string + :rtype: str + """ + raise NotImplementedError() + + # Why did we have this function? It prevents the DictMixin from being applied. + # Unclear if we explicitly do NOT want the mapping protocol to be applied to this, or it this was just + # confirmation that it didn't at the time. + def keys(self) -> None: + # This property is introduced to raise catchable Exception in marshmallow mapping validation trial. + raise TypeError(f"'{type(self).__name__}' object is not a mapping") + + def __str__(self) -> str: + try: + return self._data_binding() + except AttributeError: + return super(InputOutputBase, self).__str__() + + def __hash__(self) -> int: + return id(self) + + @classmethod + def _get_mode( + cls, + original_data: Optional[Union[int, bool, float, str, Input, Output, "PipelineInput"]], + data: Optional[Union[int, bool, float, str, Input, Output]], + kwargs: dict, + ) -> Optional[str]: + """Get mode of this input/output builder. + + :param original_data: Original value of input/output. + :type original_data: Union[None, int, bool, float, str + azure.ai.ml.Input, + azure.ai.ml.Output, + azure.ai.ml.entities._job.pipeline._io.PipelineInput] + :param data: Built input/output data. + :type data: Union[None, int, bool, float, str + azure.ai.ml.Input, + azure.ai.ml.Output] + :param kwargs: The kwargs + :type kwargs: Dict + :return: The mode + :rtype: Optional[str] + """ + # pipeline level inputs won't pass mode to bound node level inputs + if isinstance(original_data, PipelineInput): + return None + return data.mode if data is not None and hasattr(data, "mode") else kwargs.pop("mode", None) + + @property + def _is_primitive_type(self) -> bool: + return self.type in IOConstants.PRIMITIVE_STR_2_TYPE + + +class NodeInput(InputOutputBase): + """Define one input of a Component.""" + + def __init__( + self, + port_name: str, + meta: Optional[Input], + *, + data: Optional[Union[int, bool, float, str, Output, "PipelineInput", Input]] = None, + # TODO: Bug Item number: 2883405 + owner: Optional[Union["BaseComponent", "PipelineJob"]] = None, # type: ignore + **kwargs: Any, + ): + """Initialize an input of a component. + + :param name: The name of the input. + :type name: str + :param meta: Metadata of this input, eg: type, min, max, etc. + :type meta: Input + :param data: The input data. Valid types include int, bool, float, str, + Output of another component or pipeline input and Input. + Note that the output of another component or pipeline input associated should be reachable in the scope + of current pipeline. Input is introduced to support case like + TODO: new examples + component.inputs.xxx = Input(path="arm_id") + :type data: Union[int, bool, float, str + azure.ai.ml.Output, + azure.ai.ml.Input] + :param owner: The owner component of the input, used to calculate binding. + :type owner: Union[azure.ai.ml.entities.BaseNode, azure.ai.ml.entities.PipelineJob] + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + """ + # TODO: validate data matches type in meta + # TODO: validate supported data + self._port_name = port_name + self._owner = owner + super().__init__(meta=meta, data=data, **kwargs) + + def _build_default_data(self) -> None: + """Build default data when input not configured.""" + if self._data is None: + self._data = Input() + + def _build_data(self, data: T) -> Union[T, str, Input, InputOutputBase]: + """Build input data according to assigned input + + eg: node.inputs.key = data + + :param data: The data + :type data: T + :return: The built data + :rtype: Union[T, str, Input, "PipelineInput", "NodeOutput"] + """ + _data: Union[T, str, NodeOutput] = resolve_pipeline_parameter(data) + if _data is None: + return _data + # Unidiomatic typecheck: Checks that data is _exactly_ this type, and not potentially a subtype + if type(_data) is NodeInput: # pylint: disable=unidiomatic-typecheck + msg = "Can not bind input to another component's input." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.PIPELINE, + error_category=ErrorCategory.USER_ERROR, + ) + if isinstance(_data, (PipelineInput, NodeOutput)): + # If value is input or output, it's a data binding, we require it have a owner so we can convert it to + # a data binding, eg: ${{inputs.xxx}} + if isinstance(_data, NodeOutput) and _data._owner is None: + msg = "Setting input binding {} to output without owner is not allowed." + raise ValidationException( + message=msg.format(_data), + no_personal_data_message=msg.format("[_data]"), + target=ErrorTarget.PIPELINE, + error_category=ErrorCategory.USER_ERROR, + ) + return _data + # for data binding case, set is_singular=False for case like "${{parent.inputs.job_in_folder}}/sample1.csv" + if isinstance(_data, Input) or is_data_binding_expression(_data, is_singular=False): + return _data + if isinstance(_data, (Data, Model)): + return _data_to_input(_data) + # self._meta.type could be None when sub pipeline has no annotation + if isinstance(self._meta, Input) and self._meta.type and not self._meta._is_primitive_type: + if isinstance(_data, str): + return Input(type=self._meta.type, path=_data) + msg = "only path input is supported now but get {}: {}." + raise UserErrorException( + message=msg.format(type(_data), _data), + no_personal_data_message=msg.format(type(_data), "[_data]"), + ) + return _data + + def _to_job_input(self) -> Optional[Union[Input, str, Output]]: + """convert the input to Input, this logic will change if backend contract changes.""" + result: Optional[Union[Input, str, Output]] = None + + if self._data is None: + # None data means this input is not configured. + result = None + elif isinstance(self._data, (PipelineInput, NodeOutput)): + # Build data binding when data is PipelineInput, Output + result = Input(path=self._data._data_binding(), mode=self.mode) + elif is_data_binding_expression(self._data): + result = Input(path=self._data, mode=self.mode) + else: + data_binding = _build_data_binding(self._data) + if is_data_binding_expression(self._data): + result = Input(path=data_binding, mode=self.mode) + else: + result = data_binding + # TODO: validate is self._data is supported + + return result + + def _data_binding(self) -> str: + msg = "Input binding {} can only come from a pipeline, currently got {}" + # call type(self._owner) to avoid circular import + raise ValidationException( + message=msg.format(self._port_name, type(self._owner)), + target=ErrorTarget.PIPELINE, + no_personal_data_message=msg.format("[port_name]", "[owner]"), + error_category=ErrorCategory.USER_ERROR, + ) + + def _copy(self, owner: Any) -> "NodeInput": + return NodeInput( + port_name=self._port_name, + data=self._data, + owner=owner, + meta=cast(Input, self._meta), + ) + + def _deepcopy(self) -> "NodeInput": + return NodeInput( + port_name=self._port_name, + data=copy.copy(self._data), + owner=self._owner, + meta=cast(Input, self._meta), + ) + + def _get_data_owner(self) -> Optional["BaseNode"]: + """Gets the data owner of the node + + Note: This only works for @pipeline, not for YAML pipeline. + + Note: Inner step will be returned as the owner when node's input is from sub pipeline's output. + @pipeline + def sub_pipeline(): + inner_node = component_func() + return inner_node.outputs + + @pipeline + def root_pipeline(): + pipeline_node = sub_pipeline() + node = copy_files_component_func(input_dir=pipeline_node.outputs.output_dir) + owner = node.inputs.input_dir._get_data_owner() + assert owner == pipeline_node.nodes[0] + + :return: The node if Input is from another node's output. Returns None for literal value. + :rtype: Optional[BaseNode] + """ + from azure.ai.ml.entities import Pipeline + from azure.ai.ml.entities._builders import BaseNode + + def _resolve_data_owner(data: Any) -> Optional["BaseNode"]: + if isinstance(data, BaseNode) and not isinstance(data, Pipeline): + return data + while isinstance(data, PipelineInput): + # for pipeline input, it's original value(can be literal value or another node's output) + # is stored in _original_data + return _resolve_data_owner(data._original_data) + if isinstance(data, NodeOutput): + if isinstance(data._owner, Pipeline): + # for input from subgraph's output, trace back to inner node + return _resolve_data_owner(data._binding_output) + # for input from another node's output, return the node + return _resolve_data_owner(data._owner) + return None + + return _resolve_data_owner(self._data) + + +class NodeOutput(InputOutputBase, PipelineExpressionMixin): + """Define one output of a Component.""" + + def __init__( + self, + port_name: str, + meta: Optional[Union[Input, Output]], + *, + data: Optional[Union[Output, str]] = None, + # TODO: Bug Item number: 2883405 + owner: Optional[Union["BaseComponent", "PipelineJob"]] = None, # type: ignore + binding_output: Optional["NodeOutput"] = None, + **kwargs: Any, + ): + """Initialize an Output of a component. + + :param port_name: The port_name of the output. + :type port_name: str + :param name: The name used to register NodeOutput/PipelineOutput data. + :type name: str + :param version: The version used to register NodeOutput/PipelineOutput data. + :ype version: str + :param data: The output data. Valid types include str, Output + :type data: Union[str + azure.ai.ml.entities.Output] + :param mode: The mode of the output. + :type mode: str + :param owner: The owner component of the output, used to calculate binding. + :type owner: Union[azure.ai.ml.entities.BaseNode, azure.ai.ml.entities.PipelineJob] + :param binding_output: The node output bound to pipeline output, only available for pipeline. + :type binding_output: azure.ai.ml.entities.NodeOutput + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if object cannot be successfully validated. + Details will be provided in the error message. + """ + # Allow inline output binding with string, eg: "component_out_path_1": "${{parents.outputs.job_out_data_1}}" + if data is not None and not isinstance(data, (Output, str)): + msg = "Got unexpected type for output: {}." + raise ValidationException( + message=msg.format(data), + target=ErrorTarget.PIPELINE, + no_personal_data_message=msg.format("[data]"), + ) + super().__init__(meta=meta, data=data, **kwargs) + self._port_name = port_name + self._owner = owner + self._name: Optional[str] = self._data.name if isinstance(self._data, Output) else None + self._version: Optional[str] = self._data.version if isinstance(self._data, Output) else None + + self._assert_name_and_version() + + # store original node output to be able to trace back to inner node from a pipeline output builder. + self._binding_output = binding_output + + @property + def port_name(self) -> str: + """The output port name, eg: node.outputs.port_name. + + :return: The port name + :rtype: str + """ + return self._port_name + + @property + def name(self) -> Optional[str]: + """Used in registering output data. + + :return: The output name + :rtype: str + """ + return self._name + + @name.setter + def name(self, name: str) -> None: + """Assigns the name to NodeOutput/PipelineOutput and builds data according to the name. + + :param name: The new name + :type name: str + """ + self._build_default_data() + self._name = name + if isinstance(self._data, Output): + self._data.name = name + elif isinstance(self._data, InputOutputBase): + self._data._name = name + else: + raise UserErrorException( + f"We support self._data of Input, Output, InputOutputBase, NodeOutput and NodeInput," + f"but got type: {type(self._data)}." + ) + + @property + def version(self) -> Optional[str]: + """Used in registering output data. + + :return: The output data + :rtype: str + """ + return self._version + + @version.setter + def version(self, version: str) -> None: + """Assigns the version to NodeOutput/PipelineOutput and builds data according to the version. + + :param version: The new version + :type version: str + """ + self._build_default_data() + self._version = version + if isinstance(self._data, Output): + self._data.version = version + elif isinstance(self._data, InputOutputBase): + self._data._version = version + else: + raise UserErrorException( + f"We support self._data of Input, Output, InputOutputBase, NodeOutput and NodeInput," + f"but got type: {type(self._data)}." + ) + + @property + def path(self) -> Any: + # For node output path, + if self._data is not None and hasattr(self._data, "path"): + return self._data.path + return None + + @path.setter + def path(self, path: Optional[str]) -> None: + # For un-configured output, we build a default data entry for them. + self._build_default_data() + if self._data is not None and hasattr(self._data, "path"): + self._data.path = path + else: + # YAML job will have string output binding and do not support setting path for it. + msg = f"{type(self._data)} does not support setting path." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.PIPELINE, + error_category=ErrorCategory.USER_ERROR, + ) + + def _assert_name_and_version(self) -> None: + if self.name and not (re.match("^[A-Za-z0-9_-]*$", self.name) and len(self.name) <= 255): + raise UserErrorException( + f"The output name {self.name} can only contain alphanumeric characters, dashes and underscores, " + f"with a limit of 255 characters." + ) + if self.version and not self.name: + raise UserErrorException("Output name is required when output version is specified.") + + def _build_default_data(self) -> None: + """Build default data when output not configured.""" + if self._data is None: + # _meta will be None when node._component is not a Component object + # so we just leave the type inference work to backend + self._data = Output(type=None) # type: ignore[call-overload] + + def _build_data(self, data: T) -> Any: + """Build output data according to assigned input, eg: node.outputs.key = data + + :param data: The data + :type data: T + :return: `data` + :rtype: T + """ + if data is None: + return data + if not isinstance(data, (Output, str)): + msg = f"{self.__class__.__name__} only allow set {Output.__name__} object, {type(data)} is not supported." + raise ValidationException( + message=msg, + target=ErrorTarget.PIPELINE, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + ) + res: T = cast(T, data) + return res + + def _to_job_output(self) -> Optional[Output]: + """Convert the output to Output, this logic will change if backend contract changes.""" + if self._data is None: + # None data means this output is not configured. + result = None + elif isinstance(self._data, str): + result = Output( + type=AssetTypes.URI_FOLDER, path=self._data, mode=self.mode, name=self.name, version=self.version + ) + elif isinstance(self._data, Output): + result = self._data + elif isinstance(self._data, PipelineOutput): + result = Output( + type=AssetTypes.URI_FOLDER, + path=self._data._data_binding(), + mode=self.mode, + name=self._data.name, + version=self._data.version, + description=self.description, + ) + else: + msg = "Got unexpected type for output: {}." + raise ValidationException( + message=msg.format(self._data), + target=ErrorTarget.PIPELINE, + no_personal_data_message=msg.format("[data]"), + ) + return result + + def _data_binding(self) -> str: + if self._owner is not None: + return f"${{{{parent.jobs.{self._owner.name}.outputs.{self._port_name}}}}}" + + return "" + + def _copy(self, owner: Any) -> "NodeOutput": + return NodeOutput( + port_name=self._port_name, + data=cast(Output, self._data), + owner=owner, + meta=self._meta, + ) + + def _deepcopy(self) -> "NodeOutput": + return NodeOutput( + port_name=self._port_name, + data=cast(Output, copy.copy(self._data)), + owner=self._owner, + meta=self._meta, + binding_output=self._binding_output, + ) + + +class PipelineInput(NodeInput, PipelineExpressionMixin): + """Define one input of a Pipeline.""" + + def __init__(self, name: str, meta: Optional[Input], group_names: Optional[List[str]] = None, **kwargs: Any): + """Initialize a PipelineInput. + + :param name: The name of the input. + :type name: str + :param meta: Metadata of this input, eg: type, min, max, etc. + :type meta: Input + :param group_names: The input parameter's group names. + :type group_names: List[str] + """ + super(PipelineInput, self).__init__(port_name=name, meta=meta, **kwargs) + self._group_names = group_names if group_names else [] + + def result(self) -> Any: + """Return original value of pipeline input. + + :return: The original value of pipeline input + :rtype: Any + + Example: + + .. code-block:: python + + @pipeline + def pipeline_func(param1): + # node1's param1 will get actual value of param1 instead of a input binding. + node1 = component_func(param1=param1.result()) + """ + + # use this to break self loop + original_data_cache: Set = set() + original_data = self._original_data + while isinstance(original_data, PipelineInput) and original_data not in original_data_cache: + original_data_cache.add(original_data) + original_data = original_data._original_data + return original_data + + def __str__(self) -> str: + return self._data_binding() + + @overload + def _build_data(self, data: Union[Model, Data]) -> Input: ... + + @overload + def _build_data(self, data: T) -> Any: ... + + def _build_data(self, data: Union[Model, Data, T]) -> Any: + """Build data according to input type. + + :param data: The data + :type data: Union[Model, Data, T] + :return: + * Input if data is a Model or Data + * data otherwise + :rtype: Union[Input, T] + """ + if data is None: + return data + # Unidiomatic typecheck: Checks that data is _exactly_ this type, and not potentially a subtype + if type(data) is NodeInput: # pylint: disable=unidiomatic-typecheck + msg = "Can not bind input to another component's input." + raise ValidationException(message=msg, no_personal_data_message=msg, target=ErrorTarget.PIPELINE) + if isinstance(data, (PipelineInput, NodeOutput)): + # If value is input or output, it's a data binding, owner is required to convert it to + # a data binding, eg: ${{parent.inputs.xxx}} + if isinstance(data, NodeOutput) and data._owner is None: + msg = "Setting input binding {} to output without owner is not allowed." + raise ValidationException( + message=msg.format(data), + no_personal_data_message=msg.format("[data]"), + target=ErrorTarget.PIPELINE, + error_category=ErrorCategory.USER_ERROR, + ) + return data + if isinstance(data, (Data, Model)): + # If value is Data, we convert it to an corresponding Input + return _data_to_input(data) + return data + + def _data_binding(self) -> str: + full_name = "%s.%s" % (".".join(self._group_names), self._port_name) if self._group_names else self._port_name + return f"${{{{parent.inputs.{full_name}}}}}" + + def _to_input(self) -> Optional[Union[Input, Output]]: + """Convert pipeline input to component input for pipeline component. + + :return: The component input + :rtype: Input + """ + if self._data is None: + # None data means this input is not configured. + return self._meta + data_type = self._data.type if isinstance(self._data, Input) else None + # If type is asset type, return data type without default. + # Else infer type from data and set it as default. + if data_type and data_type.lower() in AssetTypes.__dict__.values(): + if not isinstance(self._data, (int, float, str)): + result = Input(type=data_type, mode=self._data.mode) + elif type(self._data) in IOConstants.PRIMITIVE_TYPE_2_STR: + result = Input( + type=IOConstants.PRIMITIVE_TYPE_2_STR[type(self._data)], + default=self._data, + ) + else: + msg = f"Unsupported Input type {type(self._data)} detected when translate job to component." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.PIPELINE, + error_category=ErrorCategory.USER_ERROR, + ) + return result # pylint: disable=possibly-used-before-assignment + + +class PipelineOutput(NodeOutput): + """Define one output of a Pipeline.""" + + def _to_job_output(self) -> Optional[Output]: + result: Optional[Output] = None + if isinstance(self._data, Output): + # For pipeline output with type Output, always pass to backend. + result = self._data + elif self._data is None and self._meta and self._meta.type: + # For un-configured pipeline output with meta, we need to return Output with accurate type, + # so it won't default to uri_folder. + result = Output(type=self._meta.type, mode=self._meta.mode, description=self._meta.description) + else: + result = super(PipelineOutput, self)._to_job_output() + # Copy meta type to avoid built output's None type default to uri_folder. + if self.type and result is not None and not result.type: + result.type = self.type + return result + + def _data_binding(self) -> str: + return f"${{{{parent.outputs.{self._port_name}}}}}" + + def _to_output(self) -> Optional[Output]: + """Convert pipeline output to component output for pipeline component.""" + if self._data is None: + # None data means this input is not configured. + return None + if isinstance(self._meta, Output): + return self._meta + # Assign type directly as we didn't have primitive output type for now. + if not isinstance(self._data, (int, float, str)): + return Output(type=self._data.type, mode=self._data.mode) + return Output() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_io/mixin.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_io/mixin.py new file mode 100644 index 00000000..6c3d9357 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_io/mixin.py @@ -0,0 +1,623 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import copy +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import JobInput as RestJobInput +from azure.ai.ml._restclient.v2023_04_01_preview.models import JobOutput as RestJobOutput +from azure.ai.ml.constants._component import ComponentJobConstants +from azure.ai.ml.entities._inputs_outputs import GroupInput, Input, Output +from azure.ai.ml.entities._util import copy_output_setting +from azure.ai.ml.exceptions import ErrorTarget, ValidationErrorType, ValidationException + +from ..._input_output_helpers import ( + from_rest_data_outputs, + from_rest_inputs_to_dataset_literal, + to_rest_data_outputs, + to_rest_dataset_literal_inputs, +) +from .._pipeline_job_helpers import from_dict_to_rest_io, process_sdk_component_job_io +from .attr_dict import InputsAttrDict, OutputsAttrDict, _GroupAttrDict +from .base import NodeInput, NodeOutput, PipelineInput, PipelineOutput + + +class NodeIOMixin: + """Provides ability to wrap node inputs/outputs and build data bindings + dynamically.""" + + @classmethod + def _get_supported_inputs_types(cls) -> Optional[Any]: + return None + + @classmethod + def _get_supported_outputs_types(cls) -> Optional[Any]: + return None + + @classmethod + def _validate_io(cls, value: Any, allowed_types: Optional[tuple], *, key: Optional[str] = None) -> None: + if allowed_types is None: + return + + if value is None or isinstance(value, allowed_types): + pass + else: + msg = "Expecting {} for input/output {}, got {} instead." + raise ValidationException( + message=msg.format(allowed_types, key, type(value)), + no_personal_data_message=msg.format(allowed_types, "[key]", type(value)), + target=ErrorTarget.PIPELINE, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + def _build_input( + self, + name: str, + meta: Optional[Input], + data: Optional[Union[dict, int, bool, float, str, Output, "PipelineInput", Input]], + ) -> NodeInput: + # output mode of last node should not affect input mode of next node + if isinstance(data, NodeOutput): + # Decoupled input and output + # value = copy.deepcopy(value) + data = data._deepcopy() # pylint: disable=protected-access + data.mode = None + elif isinstance(data, dict): + # Use type comparison instead of is_instance to skip _GroupAttrDict + # when loading from yaml io will be a dict, + # like {'job_data_path': '${{parent.inputs.pipeline_job_data_path}}'} + # parse dict to allowed type + data = Input(**data) + + # parameter group can be of custom type, so we don't check it here + if meta is not None and not isinstance(meta, GroupInput): + self._validate_io(data, self._get_supported_inputs_types(), key=name) + return NodeInput(port_name=name, meta=meta, data=data, owner=self) + + def _build_output(self, name: str, meta: Optional[Output], data: Optional[Union[Output, str]]) -> NodeOutput: + if isinstance(data, dict): + data = Output(**data) + + self._validate_io(data, self._get_supported_outputs_types(), key=name) + # For un-configured outputs, settings it to None, so we won't pass extra fields(eg: default mode) + return NodeOutput(port_name=name, meta=meta, data=data, owner=self) + + # pylint: disable=unused-argument + def _get_default_input_val(self, val: Any): # type: ignore + # use None value as data placeholder for unfilled inputs. + # server side will fill the default value + return None + + def _build_inputs_dict( + self, + inputs: Dict[str, Union[Input, str, bool, int, float]], + *, + input_definition_dict: Optional[dict] = None, + ) -> InputsAttrDict: + """Build an input attribute dict so user can get/set inputs by + accessing attribute, eg: node1.inputs.xxx. + + :param inputs: Provided kwargs when parameterizing component func. + :type inputs: Dict[str, Union[Input, str, bool, int, float]] + :keyword input_definition_dict: Static input definition dict. If not provided, will build inputs without meta. + :paramtype input_definition_dict: dict + :return: Built dynamic input attribute dict. + :rtype: InputsAttrDict + """ + if input_definition_dict is not None: + # TODO: validate inputs.keys() in input_definitions.keys() + input_dict = {} + for key, val in input_definition_dict.items(): + if key in inputs.keys(): + # If input is set through component functions' kwargs, create an input object with real value. + data = inputs[key] + else: + data = self._get_default_input_val(val) # pylint: disable=assignment-from-none + + val = self._build_input(name=key, meta=val, data=data) + input_dict[key] = val + else: + input_dict = {key: self._build_input(name=key, meta=None, data=val) for key, val in inputs.items()} + return InputsAttrDict(input_dict) + + def _build_outputs_dict( + self, outputs: Dict, *, output_definition_dict: Optional[dict] = None, none_data: bool = False + ) -> OutputsAttrDict: + """Build an output attribute dict so user can get/set outputs by + accessing attribute, eg: node1.outputs.xxx. + + :param outputs: Provided kwargs when parameterizing component func. + :type outputs: Dict[str, Output] + :keyword output_definition_dict: Static output definition dict. + :paramtype output_definition_dict: Dict + :keyword none_data: If True, will set output data to None. + :paramtype none_data: bool + :return: Built dynamic output attribute dict. + :rtype: OutputsAttrDict + """ + if output_definition_dict is not None: + # TODO: check if we need another way to mark a un-configured output instead of just set None. + # Create None as data placeholder for all outputs. + output_dict = {} + for key, val in output_definition_dict.items(): + if key in outputs.keys(): + # If output has given value, create an output object with real value. + val = self._build_output(name=key, meta=val, data=outputs[key]) + else: + val = self._build_output(name=key, meta=val, data=None) + output_dict[key] = val + else: + output_dict = {} + for key, val in outputs.items(): + output_val = self._build_output(name=key, meta=None, data=val if not none_data else None) + output_dict[key] = output_val + return OutputsAttrDict(output_dict) + + def _build_inputs(self) -> Dict: + """Build inputs of this component to a dict dict which maps output to + actual value. + + The built input dict will have same input format as other jobs, eg: + { + "input_data": Input(path="path/to/input/data", mode="Mount"), + "input_value": 10, + "learning_rate": "${{jobs.step1.inputs.learning_rate}}" + } + + :return: The input dict + :rtype: Dict[str, Union[Input, str, bool, int, float]] + """ + inputs = {} + # pylint: disable=redefined-builtin + for name, input in self.inputs.items(): # type: ignore + if isinstance(input, _GroupAttrDict): + # Flatten group inputs into inputs dict + inputs.update(input.flatten(group_parameter_name=name)) + continue + inputs[name] = input._to_job_input() # pylint: disable=protected-access + return inputs + + def _build_outputs(self) -> Dict[str, Output]: + """Build outputs of this component to a dict which maps output to + actual value. + + The built output dict will have same output format as other jobs, eg: + { + "eval_output": "${{jobs.eval.outputs.eval_output}}" + } + + :return: The output dict + :rtype: Dict[str, Output] + """ + outputs = {} + for name, output in self.outputs.items(): # type: ignore + if isinstance(output, NodeOutput): + output = output._to_job_output() # pylint: disable=protected-access + outputs[name] = output + # Remove non-configured output + return {k: v for k, v in outputs.items() if v is not None} + + def _to_rest_inputs(self) -> Dict[str, Dict]: + """Translate input builders to rest input dicts. + + The built dictionary's format aligns with component job's input yaml, eg: + { + "input_data": {"data": {"path": "path/to/input/data"}, "mode"="Mount"}, + "input_value": 10, + "learning_rate": "${{jobs.step1.inputs.learning_rate}}" + } + + :return: The REST inputs + :rtype: Dict[str, Dict] + """ + built_inputs = self._build_inputs() + return self._input_entity_to_rest_inputs(input_entity=built_inputs) + + @classmethod + def _input_entity_to_rest_inputs(cls, input_entity: Dict[str, Input]) -> Dict[str, Dict]: + # Convert io entity to rest io objects + input_bindings, dataset_literal_inputs = process_sdk_component_job_io( + input_entity, [ComponentJobConstants.INPUT_PATTERN] + ) + + # parse input_bindings to InputLiteral(value=str(binding)) + rest_inputs = {**input_bindings, **dataset_literal_inputs} + # Note: The function will only be called from BaseNode, + # and job_type is used to enable dot in pipeline job input keys, + # so pass job_type as None directly here. + rest_inputs = to_rest_dataset_literal_inputs(rest_inputs, job_type=None) + + # convert rest io to dict + rest_dataset_literal_inputs = {} + for name, val in rest_inputs.items(): + rest_dataset_literal_inputs[name] = val.as_dict() + if hasattr(val, "mode") and val.mode: + rest_dataset_literal_inputs[name].update({"mode": val.mode.value}) + return rest_dataset_literal_inputs + + def _to_rest_outputs(self) -> Dict[str, Dict]: + """Translate output builders to rest output dicts. + + The built dictionary's format aligns with component job's output yaml, eg: + {"eval_output": "${{jobs.eval.outputs.eval_output}}"} + + :return: The REST outputs + :rtype: Dict[str, Dict] + """ + built_outputs = self._build_outputs() + + # Convert io entity to rest io objects + output_bindings, data_outputs = process_sdk_component_job_io( + built_outputs, [ComponentJobConstants.OUTPUT_PATTERN] + ) + rest_data_outputs = to_rest_data_outputs(data_outputs) + + # convert rest io to dict + # parse output_bindings to {"value": binding, "type": "literal"} since there's no mode for it + rest_output_bindings = {} + for key, binding in output_bindings.items(): + rest_output_bindings[key] = {"value": binding["value"], "type": "literal"} + if "mode" in binding: + rest_output_bindings[key].update({"mode": binding["mode"].value}) + if "name" in binding: + rest_output_bindings[key].update({"name": binding["name"]}) + if "version" in binding: + rest_output_bindings[key].update({"version": binding["version"]}) + + def _rename_name_and_version(output_dict: Dict) -> Dict: + # NodeOutput can only be registered with name and version, therefore we rename here + if "asset_name" in output_dict.keys(): + output_dict["name"] = output_dict.pop("asset_name") + if "asset_version" in output_dict.keys(): + output_dict["version"] = output_dict.pop("asset_version") + return output_dict + + rest_data_outputs = {name: _rename_name_and_version(val.as_dict()) for name, val in rest_data_outputs.items()} + self._update_output_types(rest_data_outputs) + rest_data_outputs.update(rest_output_bindings) + return rest_data_outputs + + @classmethod + def _from_rest_inputs(cls, inputs: Dict) -> Dict[str, Union[Input, str, bool, int, float]]: + """Load inputs from rest inputs. + + :param inputs: The REST inputs + :type inputs: Dict[str, Union[str, dict]] + :return: Input dict + :rtype: Dict[str, Union[Input, str, bool, int, float]] + """ + + # JObject -> RestJobInput/RestJobOutput + input_bindings, rest_inputs = from_dict_to_rest_io(inputs, RestJobInput, [ComponentJobConstants.INPUT_PATTERN]) + + # RestJobInput/RestJobOutput -> Input/Output + dataset_literal_inputs = from_rest_inputs_to_dataset_literal(rest_inputs) + + return {**dataset_literal_inputs, **input_bindings} + + @classmethod + def _from_rest_outputs(cls, outputs: Dict[str, Union[str, dict]]) -> Dict: + """Load outputs from rest outputs. + + :param outputs: The REST outputs + :type outputs: Dict[str, Union[str, dict]] + :return: Output dict + :rtype: Dict[str, Output] + """ + + # JObject -> RestJobInput/RestJobOutput + output_bindings, rest_outputs = from_dict_to_rest_io( + outputs, RestJobOutput, [ComponentJobConstants.OUTPUT_PATTERN] + ) + + # RestJobInput/RestJobOutput -> Input/Output + data_outputs = from_rest_data_outputs(rest_outputs) + + return {**data_outputs, **output_bindings} + + def _update_output_types(self, rest_data_outputs: dict) -> None: + """Update output types in rest_data_outputs according to meta level output. + + :param rest_data_outputs: The REST data outputs + :type rest_data_outputs: Dict + """ + + for name, rest_output in rest_data_outputs.items(): + original_output = self.outputs[name] # type: ignore + # for configured output with meta, "correct" the output type to file to avoid the uri_folder default value + if original_output and original_output.type: + if original_output.type in ["AnyFile", "uri_file"]: + rest_output["job_output_type"] = "uri_file" + + +def flatten_dict( + dct: Optional[Dict], + _type: Union[Type["_GroupAttrDict"], Type[GroupInput]], + *, + allow_dict_fields: Optional[List[str]] = None, +) -> Dict: + """Flatten inputs/input_definitions dict for inputs dict build. + + :param dct: The dictionary to flatten + :type dct: Dict + :param _type: Either _GroupAttrDict or GroupInput (both have the method `flatten`) + :type _type: Union[Type["_GroupAttrDict"], Type[GroupInput]] + :keyword allow_dict_fields: A list of keys for dictionary values that will be included in flattened output + :paramtype allow_dict_fields: Optional[List[str]] + :return: The flattened dict + :rtype: Dict + """ + _result = {} + if dct is not None: + for key, val in dct.items(): + # to support passing dict value as parameter group + if allow_dict_fields and key in allow_dict_fields and isinstance(val, dict): + # for child dict, all values are allowed to be dict + for flattened_key, flattened_val in flatten_dict( + val, _type, allow_dict_fields=list(val.keys()) + ).items(): + _result[key + "." + flattened_key] = flattened_val + continue + val = GroupInput.custom_class_value_to_attr_dict(val) + if isinstance(val, _type): + _result.update(val.flatten(group_parameter_name=key)) + continue + _result[key] = val + return _result + + +class NodeWithGroupInputMixin(NodeIOMixin): + """This class provide build_inputs_dict for a node to use ParameterGroup as an input.""" + + @classmethod + def _validate_group_input_type( + cls, + input_definition_dict: dict, + inputs: Dict[str, Union[Input, str, bool, int, float]], + ) -> None: + """Raise error when group input receive a value not group type. + + :param input_definition_dict: The input definition dict + :type input_definition_dict: dict + :param inputs: The inputs + :type inputs: Dict[str, Union[Input, str, bool, int, float]] + """ + # Note: We put and extra validation here instead of doing it in pipeline._validate() + # due to group input will be discarded silently if assign it to a non-group parameter. + group_msg = "'%s' is defined as a parameter group but got input '%s' with type '%s'." + non_group_msg = "'%s' is defined as a parameter but got a parameter group as input." + for key, val in inputs.items(): + definition = input_definition_dict.get(key) + val = GroupInput.custom_class_value_to_attr_dict(val) + if val is None: + continue + # 1. inputs.group = 'a string' + if isinstance(definition, GroupInput) and not isinstance(val, (_GroupAttrDict, dict)): + raise ValidationException( + message=group_msg % (key, val, type(val)), + no_personal_data_message=group_msg % ("[key]", "[val]", "[type(val)]"), + target=ErrorTarget.PIPELINE, + type=ValidationErrorType.INVALID_VALUE, + ) + # 2. inputs.str_param = group + if not isinstance(definition, GroupInput) and isinstance(val, _GroupAttrDict): + raise ValidationException( + message=non_group_msg % key, + no_personal_data_message=non_group_msg % "[key]", + target=ErrorTarget.PIPELINE, + type=ValidationErrorType.INVALID_VALUE, + ) + + @classmethod + def _flatten_inputs_and_definition( + cls, + inputs: Dict[str, Union[Input, str, bool, int, float]], + input_definition_dict: dict, + ) -> Tuple[Dict, Dict]: + """ + Flatten all GroupInput(definition) and GroupAttrDict recursively and build input dict. + For example: + input_definition_dict = { + "group1": GroupInput( + values={ + "param1": GroupInput( + values={ + "param1_1": Input(type="str"), + } + ), + "param2": Input(type="int"), + } + ), + "group2": GroupInput( + values={ + "param3": Input(type="str"), + } + ), + } => { + "group1.param1.param1_1": Input(type="str"), + "group1.param2": Input(type="int"), + "group2.param3": Input(type="str"), + } + inputs = { + "group1": { + "param1": { + "param1_1": "value1", + }, + "param2": 2, + }, + "group2": { + "param3": "value3", + }, + } => { + "group1.param1.param1_1": "value1", + "group1.param2": 2, + "group2.param3": "value3", + } + :param inputs: The inputs + :type inputs: Dict[str, Union[Input, str, bool, int, float]] + :param input_definition_dict: The input definition dict + :type input_definition_dict: dict + :return: The flattened inputs and definition + :rtype: Tuple[Dict, Dict] + """ + group_input_names = [key for key, val in input_definition_dict.items() if isinstance(val, GroupInput)] + flattened_inputs = flatten_dict(inputs, _GroupAttrDict, allow_dict_fields=group_input_names) + flattened_definition_dict = flatten_dict(input_definition_dict, GroupInput) + return flattened_inputs, flattened_definition_dict + + def _build_inputs_dict( + self, + inputs: Dict[str, Union[Input, str, bool, int, float]], + *, + input_definition_dict: Optional[dict] = None, + ) -> InputsAttrDict: + """Build an input attribute dict so user can get/set inputs by + accessing attribute, eg: node1.inputs.xxx. + + :param inputs: Provided kwargs when parameterizing component func. + :type inputs: Dict[str, Union[Input, str, bool, int, float]] + :keyword input_definition_dict: Input definition dict from component entity. + :paramtype input_definition_dict: dict + :return: Built input attribute dict. + :rtype: InputsAttrDict + """ + + # TODO: should we support group input when there is no local input definition? + if input_definition_dict is not None: + # Validate group mismatch + self._validate_group_input_type(input_definition_dict, inputs) + + # Flatten inputs and definition + flattened_inputs, flattened_definition_dict = self._flatten_inputs_and_definition( + inputs, input_definition_dict + ) + # Build: zip all flattened parameter with definition + inputs = super()._build_inputs_dict(flattened_inputs, input_definition_dict=flattened_definition_dict) + return InputsAttrDict(GroupInput.restore_flattened_inputs(inputs)) + return super()._build_inputs_dict(inputs) + + +class PipelineJobIOMixin(NodeWithGroupInputMixin): + """Provides ability to wrap pipeline job inputs/outputs and build data bindings + dynamically.""" + + def _build_input(self, name: str, meta: Optional[Input], data: Any) -> "PipelineInput": + return PipelineInput(name=name, meta=meta, data=data, owner=self) + + def _build_output( + self, name: str, meta: Optional[Union[Input, Output]], data: Optional[Union[Output, str]] + ) -> "PipelineOutput": + # TODO: settings data to None for un-configured outputs so we won't passing extra fields(eg: default mode) + result = PipelineOutput(port_name=name, meta=meta, data=data, owner=self) + return result + + def _build_inputs_dict( + self, + inputs: Dict[str, Union[Input, str, bool, int, float]], + *, + input_definition_dict: Optional[dict] = None, + ) -> InputsAttrDict: + """Build an input attribute dict so user can get/set inputs by + accessing attribute, eg: node1.inputs.xxx. + + :param inputs: Provided kwargs when parameterizing component func. + :type inputs: Dict[str, Union[Input, str, bool, int, float]] + :keyword input_definition_dict: Input definition dict from component entity. + :return: Built input attribute dict. + :rtype: InputsAttrDict + """ + input_dict = super()._build_inputs_dict(inputs, input_definition_dict=input_definition_dict) + # TODO: should we do this when input_definition_dict is not None? + # TODO: should we put this in super()._build_inputs_dict? + if input_definition_dict is None: + return InputsAttrDict(GroupInput.restore_flattened_inputs(input_dict)) + return input_dict + + def _build_output_for_pipeline(self, name: str, data: Optional[Union[Output, NodeOutput]]) -> "PipelineOutput": + """Build an output object for pipeline and copy settings from source output. + + :param name: Output name. + :type name: str + :param data: Output data. + :type data: Optional[Union[Output, NodeOutput]] + :return: Built output object. + :rtype: PipelineOutput + """ + # pylint: disable=protected-access + if data is None: + # For None output, build an empty output builder + output_val = self._build_output(name=name, meta=None, data=None) + elif isinstance(data, Output): + # For output entity, build an output builder with data points to it + output_val = self._build_output(name=name, meta=data, data=data) + elif isinstance(data, NodeOutput): + # For output builder, build a new output builder and copy settings from it + output_val = self._build_output(name=name, meta=data._meta, data=None) + copy_output_setting(source=data, target=output_val) + else: + message = "Unsupported output type: {} for pipeline output: {}: {}" + raise ValidationException( + message=message.format(type(data), name, data), + no_personal_data_message=message, + target=ErrorTarget.PIPELINE, + ) + return output_val + + def _build_pipeline_outputs_dict(self, outputs: Dict) -> OutputsAttrDict: + """Build an output attribute dict without output definition metadata. + For pipeline outputs, its setting should be copied from node level outputs. + + :param outputs: Node output dict or pipeline component's outputs. + :type outputs: Dict[str, Union[Output, NodeOutput]] + :return: Built dynamic output attribute dict. + :rtype: OutputsAttrDict + """ + output_dict = {} + for key, val in outputs.items(): + output_dict[key] = self._build_output_for_pipeline(name=key, data=val) + return OutputsAttrDict(output_dict) + + def _build_outputs(self) -> Dict[str, Output]: + """Build outputs of this pipeline to a dict which maps output to actual + value. + + The built dictionary's format aligns with component job's output yaml, + un-configured outputs will be None, eg: + {"eval_output": "${{jobs.eval.outputs.eval_output}}", "un_configured": None} + + :return: The output dict + :rtype: Dict[str, Output] + """ + outputs = {} + for name, output in self.outputs.items(): # type: ignore + if isinstance(output, NodeOutput): + output = output._to_job_output() # pylint: disable=protected-access + outputs[name] = output + return outputs + + def _get_default_input_val(self, val: Any): # type: ignore + # use Default value as data placeholder for unfilled inputs. + # client side need to fill the default value for dsl.pipeline + if isinstance(val, GroupInput): + # Copy default value dict for group + return copy.deepcopy(val.default) + return val.default + + def _update_output_types(self, rest_data_outputs: Dict) -> None: + """Won't clear output type for pipeline level outputs since it's required in rest object. + + :param rest_data_outputs: The REST data outputs + :type rest_data_outputs: Dict + """ + + +class AutoMLNodeIOMixin(NodeIOMixin): + """Wrap outputs of automl node and build data bindings dynamically.""" + + def __init__(self, **kwargs): # type: ignore + # add a inputs field to align with other nodes + self.inputs = {} + super(AutoMLNodeIOMixin, self).__init__(**kwargs) + if getattr(self, "outputs", None): + self._outputs = self._build_outputs_dict(self.outputs or {}) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_load_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_load_component.py new file mode 100644 index 00000000..60c4cbe7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_load_component.py @@ -0,0 +1,313 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access +from typing import Any, Callable, Dict, List, Mapping, Optional, Union, cast + +from marshmallow import INCLUDE + +from azure.ai.ml import Output +from azure.ai.ml._schema import NestedField +from azure.ai.ml._schema.pipeline.component_job import SweepSchema +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, SOURCE_PATH_CONTEXT_KEY, CommonYamlFields +from azure.ai.ml.constants._component import ControlFlowType, DataTransferTaskType, NodeType +from azure.ai.ml.constants._compute import ComputeType +from azure.ai.ml.dsl._component_func import to_component_func +from azure.ai.ml.dsl._overrides_definition import OverrideDefinition +from azure.ai.ml.entities._builders import ( + BaseNode, + Command, + DataTransferCopy, + DataTransferExport, + DataTransferImport, + Import, + Parallel, + Spark, + Sweep, +) +from azure.ai.ml.entities._builders.condition_node import ConditionNode +from azure.ai.ml.entities._builders.control_flow_node import ControlFlowNode +from azure.ai.ml.entities._builders.do_while import DoWhile +from azure.ai.ml.entities._builders.parallel_for import ParallelFor +from azure.ai.ml.entities._builders.pipeline import Pipeline +from azure.ai.ml.entities._component.component import Component +from azure.ai.ml.entities._job.automl.automl_job import AutoMLJob +from azure.ai.ml.entities._util import get_type_from_spec +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException + + +class _PipelineNodeFactory: + """A class to create pipeline node instances from yaml dict or rest objects without hard-coded type check.""" + + def __init__(self) -> None: + self._create_instance_funcs: dict = {} + self._load_from_rest_object_funcs: dict = {} + + self.register_type( + _type=NodeType.COMMAND, + create_instance_func=lambda: Command.__new__(Command), + load_from_rest_object_func=Command._from_rest_object, + nested_schema=None, + ) + self.register_type( + _type=NodeType.IMPORT, + create_instance_func=lambda: Import.__new__(Import), + load_from_rest_object_func=Import._from_rest_object, + nested_schema=None, + ) + self.register_type( + _type=NodeType.PARALLEL, + create_instance_func=lambda: Parallel.__new__(Parallel), + load_from_rest_object_func=Parallel._from_rest_object, + nested_schema=None, + ) + self.register_type( + _type=NodeType.PIPELINE, + create_instance_func=lambda: Pipeline.__new__(Pipeline), + load_from_rest_object_func=Pipeline._from_rest_object, + nested_schema=None, + ) + self.register_type( + _type=NodeType.SWEEP, + create_instance_func=lambda: Sweep.__new__(Sweep), + load_from_rest_object_func=Sweep._from_rest_object, + nested_schema=NestedField(SweepSchema, unknown=INCLUDE), + ) + self.register_type( + _type=NodeType.AUTOML, + create_instance_func=None, + load_from_rest_object_func=self._automl_from_rest_object, + nested_schema=None, + ) + self.register_type( + _type=NodeType.SPARK, + create_instance_func=lambda: Spark.__new__(Spark), + load_from_rest_object_func=Spark._from_rest_object, + nested_schema=None, + ) + self.register_type( + _type=ControlFlowType.DO_WHILE, + create_instance_func=None, + load_from_rest_object_func=DoWhile._from_rest_object, + nested_schema=None, + ) + self.register_type( + _type=ControlFlowType.IF_ELSE, + create_instance_func=None, + load_from_rest_object_func=ConditionNode._from_rest_object, + nested_schema=None, + ) + self.register_type( + _type=ControlFlowType.PARALLEL_FOR, + create_instance_func=None, + load_from_rest_object_func=ParallelFor._from_rest_object, + nested_schema=None, + ) + self.register_type( + _type="_".join([NodeType.DATA_TRANSFER, DataTransferTaskType.COPY_DATA]), + create_instance_func=lambda: DataTransferCopy.__new__(DataTransferCopy), + load_from_rest_object_func=DataTransferCopy._from_rest_object, + nested_schema=None, + ) + self.register_type( + _type="_".join([NodeType.DATA_TRANSFER, DataTransferTaskType.IMPORT_DATA]), + create_instance_func=lambda: DataTransferImport.__new__(DataTransferImport), + load_from_rest_object_func=DataTransferImport._from_rest_object, + nested_schema=None, + ) + self.register_type( + _type="_".join([NodeType.DATA_TRANSFER, DataTransferTaskType.EXPORT_DATA]), + create_instance_func=lambda: DataTransferExport.__new__(DataTransferExport), + load_from_rest_object_func=DataTransferExport._from_rest_object, + nested_schema=None, + ) + self.register_type( + _type=NodeType.FLOW_PARALLEL, + create_instance_func=lambda: Parallel.__new__(Parallel), + load_from_rest_object_func=None, + nested_schema=None, + ) + + @classmethod + def _get_func(cls, _type: str, funcs: Dict[str, Callable]) -> Callable: + if _type == NodeType._CONTAINER: + msg = ( + "Component returned by 'list' is abbreviated and can not be used directly, " + "please use result from 'get'." + ) + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.COMPONENT, + error_category=ErrorCategory.USER_ERROR, + ) + _type = get_type_from_spec({CommonYamlFields.TYPE: _type}, valid_keys=funcs) + return funcs[_type] + + def get_create_instance_func(self, _type: str) -> Callable[..., BaseNode]: + """Get the function to create a new instance of the node. + + :param _type: The type of the node. + :type _type: str + :return: The create instance function + :rtype: Callable[..., BaseNode] + """ + return self._get_func(_type, self._create_instance_funcs) + + def get_load_from_rest_object_func(self, _type: str) -> Callable: + """Get the function to load a node from a rest object. + + :param _type: The type of the node. + :type _type: str + :return: The `_load_from_rest_object` function + :rtype: Callable[[Any], Union[BaseNode, AutoMLJob, ControlFlowNode]] + """ + return self._get_func(_type, self._load_from_rest_object_funcs) + + def register_type( + self, + _type: str, + *, + create_instance_func: Optional[Callable[..., Union[BaseNode, AutoMLJob]]] = None, + load_from_rest_object_func: Optional[Callable] = None, + nested_schema: Optional[Union[NestedField, List[NestedField]]] = None, + ) -> None: + """Register a type of node. + + :param _type: The type of the node. + :type _type: str + :keyword create_instance_func: A function to create a new instance of the node + :paramtype create_instance_func: typing.Optional[typing.Callable[..., typing.Union[BaseNode, AutoMLJob]]] + :keyword load_from_rest_object_func: A function to load a node from a rest object + :paramtype load_from_rest_object_func: typing.Optional[typing.Callable[[Any], typing.Union[BaseNode, AutoMLJob\ + , ControlFlowNode]]] + :keyword nested_schema: schema/schemas of corresponding nested field, will be used in \ + PipelineJobSchema.jobs.value + :paramtype nested_schema: typing.Optional[typing.Union[NestedField, List[NestedField]]] + """ + if create_instance_func is not None: + self._create_instance_funcs[_type] = create_instance_func + if load_from_rest_object_func is not None: + self._load_from_rest_object_funcs[_type] = load_from_rest_object_func + if nested_schema is not None: + from azure.ai.ml._schema.core.fields import TypeSensitiveUnionField + from azure.ai.ml._schema.pipeline.pipeline_component import PipelineComponentSchema + from azure.ai.ml._schema.pipeline.pipeline_job import PipelineJobSchema + + for declared_fields in [ + PipelineJobSchema._declared_fields, + PipelineComponentSchema._declared_fields, + ]: + jobs_value_field: TypeSensitiveUnionField = declared_fields["jobs"].value_field + if not isinstance(nested_schema, list): + nested_schema = [nested_schema] + for nested_field in nested_schema: + jobs_value_field.insert_type_sensitive_field(type_name=_type, field=nested_field) + + def load_from_dict(self, *, data: dict, _type: Optional[str] = None) -> Union[BaseNode, AutoMLJob]: + """Load a node from a dict. + + :keyword data: A dict containing the node's data. + :paramtype data: dict + :keyword _type: The type of the node. If not specified, it will be inferred from the data. + :paramtype _type: str + :return: The node + :rtype: Union[BaseNode, AutoMLJob] + """ + if _type is None: + _type = data[CommonYamlFields.TYPE] if CommonYamlFields.TYPE in data else NodeType.COMMAND + # todo: refine Hard code for now to support different task type for DataTransfer node + if _type == NodeType.DATA_TRANSFER: + _type = "_".join([NodeType.DATA_TRANSFER, data.get("task", " ")]) + else: + data[CommonYamlFields.TYPE] = _type + + new_instance: Union[BaseNode, AutoMLJob] = self.get_create_instance_func(_type)() + + if isinstance(new_instance, BaseNode): + # parse component + component_key = new_instance._get_component_attr_name() + if component_key in data and isinstance(data[component_key], dict): + data[component_key] = Component._load( + data=data[component_key], + yaml_path=data[component_key].pop(SOURCE_PATH_CONTEXT_KEY, None), + ) + # TODO: Bug Item number: 2883415 + new_instance.__init__(**data) # type: ignore + return new_instance + + def load_from_rest_object( + self, *, obj: dict, _type: Optional[str] = None, **kwargs: Any + ) -> Union[BaseNode, AutoMLJob, ControlFlowNode]: + """Load a node from a rest object. + + :keyword obj: A rest object containing the node's data. + :paramtype obj: dict + :keyword _type: The type of the node. If not specified, it will be inferred from the data. + :paramtype _type: str + :return: The node + :rtype: Union[BaseNode, AutoMLJob, ControlFlowNode] + """ + + # TODO: Remove in PuP with native import job/component type support in MFE/Designer + if "computeId" in obj and obj["computeId"] and obj["computeId"].endswith("/" + ComputeType.ADF): + _type = NodeType.IMPORT + + if _type is None: + _type = obj[CommonYamlFields.TYPE] if CommonYamlFields.TYPE in obj else NodeType.COMMAND + # todo: refine Hard code for now to support different task type for DataTransfer node + if _type == NodeType.DATA_TRANSFER: + _type = "_".join([NodeType.DATA_TRANSFER, obj.get("task", " ")]) + else: + obj[CommonYamlFields.TYPE] = _type + + res: Union[BaseNode, AutoMLJob, ControlFlowNode] = self.get_load_from_rest_object_func(_type)(obj, **kwargs) + return res + + @classmethod + def _automl_from_rest_object(cls, node: Dict) -> AutoMLJob: + _outputs = cast(Dict[str, Union[str, dict]], node.get("outputs")) + # rest dict outputs -> Output objects + outputs = AutoMLJob._from_rest_outputs(_outputs) + # Output objects -> yaml dict outputs + parsed_outputs = {} + for key, val in outputs.items(): + if isinstance(val, Output): + val = val._to_dict() + parsed_outputs[key] = val + node["outputs"] = parsed_outputs + return AutoMLJob._load_from_dict( + node, + context={BASE_PATH_CONTEXT_KEY: "./"}, + additional_message="Failed to load automl task from backend.", + inside_pipeline=True, + ) + + +def _generate_component_function( + component_entity: Component, + override_definitions: Optional[Mapping[str, OverrideDefinition]] = None, # pylint: disable=unused-argument +) -> Callable[..., Union[Command, Parallel]]: + # Generate a function which returns a component node. + def create_component_func(**kwargs: Any) -> Union[BaseNode, AutoMLJob]: + # todo: refine Hard code for now to support different task type for DataTransfer node + _type = component_entity.type + if _type == NodeType.DATA_TRANSFER: + # TODO: Bug Item number: 2883431 + _type = "_".join([NodeType.DATA_TRANSFER, component_entity.task]) # type: ignore + if component_entity.task == DataTransferTaskType.IMPORT_DATA: # type: ignore + return pipeline_node_factory.load_from_dict( + data={"component": component_entity, "_from_component_func": True, **kwargs}, + _type=_type, + ) + return pipeline_node_factory.load_from_dict( + data={"component": component_entity, "inputs": kwargs, "_from_component_func": True}, + _type=_type, + ) + + res: Callable = to_component_func(component_entity, create_component_func) + return res + + +pipeline_node_factory = _PipelineNodeFactory() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_pipeline_expression.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_pipeline_expression.py new file mode 100644 index 00000000..49bb8a61 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_pipeline_expression.py @@ -0,0 +1,662 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import re +import tempfile +from collections import namedtuple +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast + +from azure.ai.ml._utils.utils import dump_yaml_to_file, get_all_data_binding_expressions, load_yaml +from azure.ai.ml.constants._common import AZUREML_PRIVATE_FEATURES_ENV_VAR, DefaultOpenEncoding +from azure.ai.ml.constants._component import ComponentParameterTypes, IOConstants +from azure.ai.ml.exceptions import UserErrorException + +if TYPE_CHECKING: + from azure.ai.ml.entities._builders import BaseNode + +ExpressionInput = namedtuple("ExpressionInput", ["name", "type", "value"]) +NONE_PARAMETER_TYPE = "None" + + +class PipelineExpressionOperator: + """Support operator in native Python experience.""" + + ADD = "+" + SUB = "-" + MUL = "*" + DIV = "/" + MOD = "%" + POW = "**" + FLOORDIV = "//" + LT = "<" + GT = ">" + LTE = "<=" + GTE = ">=" + EQ = "==" + NE = "!=" + AND = "&" + OR = "|" + XOR = "^" + + +_SUPPORTED_OPERATORS = { + getattr(PipelineExpressionOperator, attr) + for attr in PipelineExpressionOperator.__dict__ + if not attr.startswith("__") +} + + +def _enumerate_operation_combination() -> Dict[str, Union[str, Exception]]: + """Enumerate the result type of binary operations on types + + Leverages `eval` to validate operation and get its result type. + + :return: A dictionary that maps an operation to either: + * A result type + * An Exception + :rtype: Dict[str, Union[str, Exception]] + """ + res: Dict = {} + primitive_types_values = { + NONE_PARAMETER_TYPE: repr(None), + ComponentParameterTypes.BOOLEAN: repr(True), + ComponentParameterTypes.INTEGER: repr(1), + ComponentParameterTypes.NUMBER: repr(1.0), + ComponentParameterTypes.STRING: repr("1"), + } + for type1, operand1 in primitive_types_values.items(): + for type2, operand2 in primitive_types_values.items(): + for operator in _SUPPORTED_OPERATORS: + k = f"{type1} {operator} {type2}" + try: + eval_result = eval(f"{operand1} {operator} {operand2}") # pylint: disable=eval-used # nosec + res[k] = IOConstants.PRIMITIVE_TYPE_2_STR[type(eval_result)] + except TypeError: + error_message = ( + f"Operator '{operator}' is not supported between instances of '{type1}' and '{type2}'." + ) + res[k] = UserErrorException(message=error_message, no_personal_data_message=error_message) + return res + + +# enumerate and store as a lookup table: +# key format is "<operand1_type> <operator> <operand2_type>" +# value can be either result type as str and UserErrorException for invalid operation +_OPERATION_RESULT_TYPE_LOOKUP = _enumerate_operation_combination() + + +class PipelineExpressionMixin: + _SUPPORTED_PRIMITIVE_TYPES = (bool, int, float, str) + _SUPPORTED_PIPELINE_INPUT_TYPES = ( + ComponentParameterTypes.BOOLEAN, + ComponentParameterTypes.INTEGER, + ComponentParameterTypes.NUMBER, + ComponentParameterTypes.STRING, + ) + + def _validate_binary_operation(self, other: Any, operator: str) -> None: + from azure.ai.ml.entities._job.pipeline._io import NodeOutput, PipelineInput + + if ( + other is not None + and not isinstance(other, self._SUPPORTED_PRIMITIVE_TYPES) + and not isinstance(other, (PipelineInput, NodeOutput, PipelineExpression)) + ): + error_message = ( + f"Operator '{operator}' is not supported with {type(other)}; " + "currently only support primitive types (None, bool, int, float and str), " + "pipeline input, component output and expression." + ) + raise UserErrorException(message=error_message, no_personal_data_message=error_message) + + def __add__(self, other: Any) -> "PipelineExpression": + self._validate_binary_operation(other, PipelineExpressionOperator.ADD) + return PipelineExpression._from_operation(self, other, PipelineExpressionOperator.ADD) + + def __radd__(self, other: Any) -> "PipelineExpression": + self._validate_binary_operation(other, PipelineExpressionOperator.ADD) + return PipelineExpression._from_operation(other, self, PipelineExpressionOperator.ADD) + + def __sub__(self, other: Any) -> "PipelineExpression": + self._validate_binary_operation(other, PipelineExpressionOperator.SUB) + return PipelineExpression._from_operation(self, other, PipelineExpressionOperator.SUB) + + def __rsub__(self, other: Any) -> "PipelineExpression": + self._validate_binary_operation(other, PipelineExpressionOperator.SUB) + return PipelineExpression._from_operation(other, self, PipelineExpressionOperator.SUB) + + def __mul__(self, other: Any) -> "PipelineExpression": + self._validate_binary_operation(other, PipelineExpressionOperator.MUL) + return PipelineExpression._from_operation(self, other, PipelineExpressionOperator.MUL) + + def __rmul__(self, other: Any) -> "PipelineExpression": + self._validate_binary_operation(other, PipelineExpressionOperator.MUL) + return PipelineExpression._from_operation(other, self, PipelineExpressionOperator.MUL) + + def __truediv__(self, other: Any) -> "PipelineExpression": + self._validate_binary_operation(other, PipelineExpressionOperator.DIV) + return PipelineExpression._from_operation(self, other, PipelineExpressionOperator.DIV) + + def __rtruediv__(self, other: Any) -> "PipelineExpression": + self._validate_binary_operation(other, PipelineExpressionOperator.DIV) + return PipelineExpression._from_operation(other, self, PipelineExpressionOperator.DIV) + + def __mod__(self, other: Any) -> "PipelineExpression": + self._validate_binary_operation(other, PipelineExpressionOperator.MOD) + return PipelineExpression._from_operation(self, other, PipelineExpressionOperator.MOD) + + def __rmod__(self, other: Any) -> "PipelineExpression": + self._validate_binary_operation(other, PipelineExpressionOperator.MOD) + return PipelineExpression._from_operation(other, self, PipelineExpressionOperator.MOD) + + def __pow__(self, other: Any) -> "PipelineExpression": + self._validate_binary_operation(other, PipelineExpressionOperator.POW) + return PipelineExpression._from_operation(self, other, PipelineExpressionOperator.POW) + + def __rpow__(self, other: Any) -> "PipelineExpression": + self._validate_binary_operation(other, PipelineExpressionOperator.POW) + return PipelineExpression._from_operation(other, self, PipelineExpressionOperator.POW) + + def __floordiv__(self, other: Any) -> "PipelineExpression": + self._validate_binary_operation(other, PipelineExpressionOperator.FLOORDIV) + return PipelineExpression._from_operation(self, other, PipelineExpressionOperator.FLOORDIV) + + def __rfloordiv__(self, other: Any) -> "PipelineExpression": + self._validate_binary_operation(other, PipelineExpressionOperator.FLOORDIV) + return PipelineExpression._from_operation(other, self, PipelineExpressionOperator.FLOORDIV) + + def __lt__(self, other: Any) -> "PipelineExpression": + self._validate_binary_operation(other, PipelineExpressionOperator.LT) + return PipelineExpression._from_operation(self, other, PipelineExpressionOperator.LT) + + def __gt__(self, other: Any) -> "PipelineExpression": + self._validate_binary_operation(other, PipelineExpressionOperator.GT) + return PipelineExpression._from_operation(self, other, PipelineExpressionOperator.GT) + + def __le__(self, other: Any) -> "PipelineExpression": + self._validate_binary_operation(other, PipelineExpressionOperator.LTE) + return PipelineExpression._from_operation(self, other, PipelineExpressionOperator.LTE) + + def __ge__(self, other: Any) -> "PipelineExpression": + self._validate_binary_operation(other, PipelineExpressionOperator.GTE) + return PipelineExpression._from_operation(self, other, PipelineExpressionOperator.GTE) + + # TODO: Bug Item number: 2883354 + def __eq__(self, other: Any) -> "PipelineExpression": # type: ignore + self._validate_binary_operation(other, PipelineExpressionOperator.EQ) + return PipelineExpression._from_operation(self, other, PipelineExpressionOperator.EQ) + + # TODO: Bug Item number: 2883354 + def __ne__(self, other: Any) -> "PipelineExpression": # type: ignore + self._validate_binary_operation(other, PipelineExpressionOperator.NE) + return PipelineExpression._from_operation(self, other, PipelineExpressionOperator.NE) + + def __and__(self, other: Any) -> "PipelineExpression": + self._validate_binary_operation(other, PipelineExpressionOperator.AND) + return PipelineExpression._from_operation(self, other, PipelineExpressionOperator.AND) + + def __or__(self, other: Any) -> "PipelineExpression": + self._validate_binary_operation(other, PipelineExpressionOperator.OR) + return PipelineExpression._from_operation(self, other, PipelineExpressionOperator.OR) + + def __xor__(self, other: Any) -> "PipelineExpression": + self._validate_binary_operation(other, PipelineExpressionOperator.XOR) + return PipelineExpression._from_operation(self, None, PipelineExpressionOperator.XOR) + + def __bool__(self) -> bool: + """Python method that is used to implement truth value testing and the built-in operation bool(). + + This method is not supported as PipelineExpressionMixin is designed to record operation history, + while this method can only return False or True, leading to history breaks here. + As overloadable boolean operators PEP (refer to: https://www.python.org/dev/peps/pep-0335/) + was rejected, logical operations are also not supported. + + :return: True if not inside dsl pipeline func, raises otherwise + :rtype: bool + """ + from azure.ai.ml.dsl._pipeline_component_builder import _is_inside_dsl_pipeline_func + + # note: unexpected bool test always be checking if the object is None; + # so for non-pipeline scenarios, directly return True to avoid unexpected breaking, + # and for pipeline scenarios, will use is not None to replace bool test. + if not _is_inside_dsl_pipeline_func(): + return True + + error_message = f"Type {type(self)} is not supported for operation bool()." + raise UserErrorException(message=error_message, no_personal_data_message=error_message) + + +class PipelineExpression(PipelineExpressionMixin): + """Pipeline expression entity. + + Use PipelineExpression to support simple and trivial parameter transformation tasks with constants + or other parameters. Operations are recorded in this class during executions, and expected result + will be generated for corresponding scenario. + """ + + _PIPELINE_INPUT_PREFIX = ["parent", "inputs"] + _PIPELINE_INPUT_PATTERN = re.compile(pattern=r"parent.inputs.(?P<pipeline_input_name>[^.]+)") + _PIPELINE_INPUT_NAME_GROUP = "pipeline_input_name" + # AML type to Python type, for generated Python code + _TO_PYTHON_TYPE = { + ComponentParameterTypes.BOOLEAN: bool.__name__, + ComponentParameterTypes.INTEGER: int.__name__, + ComponentParameterTypes.NUMBER: float.__name__, + ComponentParameterTypes.STRING: str.__name__, + } + + _INDENTATION = " " + _IMPORT_MLDESIGNER_LINE = "from mldesigner import command_component, Output" + _DECORATOR_LINE = "@command_component(@@decorator_parameters@@)" + _COMPONENT_FUNC_NAME = "expression_func" + _COMPONENT_FUNC_DECLARATION_LINE = ( + f"def {_COMPONENT_FUNC_NAME}(@@component_parameters@@)" " -> Output(type=@@return_type@@):" + ) + _PYTHON_CACHE_FOLDER_NAME = "__pycache__" + + def __init__(self, postfix: List[str], inputs: Dict[str, ExpressionInput]): + self._postfix = postfix + self._inputs = inputs.copy() # including PiplineInput and Output, extra stored name and type + self._result_type: Optional[str] = None + self._created_component = None + + @property + def expression(self) -> str: + """Infix expression string, wrapped with parentheses. + + :return: The infix expression + :rtype: str + """ + return self._to_infix() + + def __str__(self) -> str: + return self._to_data_binding() + + def _data_binding(self) -> str: + return self._to_data_binding() + + def _to_infix(self) -> str: + stack = [] + for token in self._postfix: + if token not in _SUPPORTED_OPERATORS: + stack.append(token) + continue + operand2, operand1 = stack.pop(), stack.pop() + stack.append(f"({operand1} {token} {operand2})") + return stack.pop() + + # pylint: disable=too-many-statements + @staticmethod + def _handle_operand( + operand: "PipelineExpression", + postfix: List[str], + expression_inputs: Dict[str, ExpressionInput], + pipeline_inputs: dict, + ) -> Tuple[List[str], Dict[str, ExpressionInput]]: + """Handle operand in expression, update postfix expression and expression inputs. + + :param operand: The operand + :type operand: "PipelineExpression" + :param postfix: + :type postfix: List[str] + :param expression_inputs: The expression inputs + :type expression_inputs: Dict[str, ExpressionInput] + :param pipeline_inputs: The pipeline inputs + :type pipeline_inputs: dict + :return: A 2-tuple of the updated postfix expression and expression inputs + :rtype: Tuple[List[str], Dict[str, ExpressionInput]] + """ + from azure.ai.ml.entities._job.pipeline._io import NodeOutput, PipelineInput + + def _update_postfix(_postfix: List[str], _old_name: str, _new_name: str) -> List[str]: + return list(map(lambda _x: _new_name if _x == _old_name else _x, _postfix)) + + def _get_or_create_input_name( + _original_name: str, + _operand: Union[PipelineInput, NodeOutput], + _expression_inputs: Dict[str, ExpressionInput], + ) -> str: + """Get or create expression input name as current operand may have appeared in expression. + + :param _original_name: The original name + :type _original_name: str + :param _operand: The expression operand + :type _operand: Union[PipelineInput, NodeOutput] + :param _expression_inputs: The expression inputs + :type _expression_inputs: Dict[str, ExpressionInput] + :return: The input name + :rtype: str + """ + _existing_id_to_name = {id(_v.value): _k for _k, _v in _expression_inputs.items()} + if id(_operand) in _existing_id_to_name: + return _existing_id_to_name[id(_operand)] + # use a counter to generate a unique name for current operand + _name, _counter = _original_name, 0 + while _name in _expression_inputs: + _name = f"{_original_name}_{_counter}" + _counter += 1 + return _name + + def _handle_pipeline_input( + _pipeline_input: PipelineInput, + _postfix: List[str], + _expression_inputs: Dict[str, ExpressionInput], + ) -> Tuple[List[str], dict]: + _name = _pipeline_input._port_name + # 1. use name with counter for pipeline input; 2. add component's name to component output + if _name in _expression_inputs: + _seen_input = _expression_inputs[_name] + if isinstance(_seen_input.value, PipelineInput): + _name = _get_or_create_input_name(_name, _pipeline_input, _expression_inputs) + else: + _expression_inputs.pop(_name) + _new_name = f"{_seen_input.value._owner.component.name}__{_seen_input.value._port_name}" + _postfix = _update_postfix(_postfix, _name, _new_name) + _expression_inputs[_new_name] = ExpressionInput(_new_name, _seen_input.type, _seen_input) + _postfix.append(_name) + + param_input = pipeline_inputs + for group_name in _pipeline_input._group_names: + param_input = param_input[group_name].values + _expression_inputs[_name] = ExpressionInput( + _name, param_input[_pipeline_input._port_name].type, _pipeline_input + ) + return _postfix, _expression_inputs + + def _handle_component_output( + _component_output: NodeOutput, + _postfix: List[str], + _expression_inputs: Dict[str, ExpressionInput], + ) -> Tuple[List[str], dict]: + if _component_output._meta is not None and not _component_output._meta._is_primitive_type: + error_message = ( + f"Component output {_component_output._port_name} in expression must " + f"be a primitive type with value {True!r}, " + f"got {_component_output._meta._is_primitive_type!r}" + ) + raise UserErrorException(message=error_message, no_personal_data_message=error_message) + _name = _component_output._port_name + _has_prefix = False + # "output" is the default output name for command component, add component's name as prefix + if _name == "output": + if _component_output._owner is not None and not isinstance(_component_output._owner.component, str): + _name = f"{_component_output._owner.component.name}__output" + _has_prefix = True + # following loop is expected to execute at most twice: + # 1. add component's name to output(s) + # 2. use name with counter + while _name in _expression_inputs: + _seen_input = _expression_inputs[_name] + if isinstance(_seen_input.value, PipelineInput): + if not _has_prefix: + if _component_output._owner is not None and not isinstance( + _component_output._owner.component, str + ): + _name = f"{_component_output._owner.component.name}__{_component_output._port_name}" + _has_prefix = True + continue + _name = _get_or_create_input_name(_name, _component_output, _expression_inputs) + else: + if not _has_prefix: + _expression_inputs.pop(_name) + _new_name = f"{_seen_input.value._owner.component.name}__{_seen_input.value._port_name}" + _postfix = _update_postfix(_postfix, _name, _new_name) + _expression_inputs[_new_name] = ExpressionInput(_new_name, _seen_input.type, _seen_input) + if _component_output._owner is not None and not isinstance( + _component_output._owner.component, str + ): + _name = f"{_component_output._owner.component.name}__{_component_output._port_name}" + _has_prefix = True + _name = _get_or_create_input_name(_name, _component_output, _expression_inputs) + _postfix.append(_name) + _expression_inputs[_name] = ExpressionInput(_name, _component_output.type, _component_output) + return _postfix, _expression_inputs + + if operand is None or isinstance(operand, PipelineExpression._SUPPORTED_PRIMITIVE_TYPES): + postfix.append(repr(operand)) + elif isinstance(operand, PipelineInput): + postfix, expression_inputs = _handle_pipeline_input(operand, postfix, expression_inputs) + elif isinstance(operand, NodeOutput): + postfix, expression_inputs = _handle_component_output(operand, postfix, expression_inputs) + elif isinstance(operand, PipelineExpression): + postfix.extend(operand._postfix.copy()) + expression_inputs.update(operand._inputs.copy()) + return postfix, expression_inputs + + @staticmethod + def _from_operation(operand1: Any, operand2: Any, operator: str) -> "PipelineExpression": + if operator not in _SUPPORTED_OPERATORS: + error_message = ( + f"Operator '{operator}' is not supported operator, " + f"currently supported operators are {','.join(_SUPPORTED_OPERATORS)}." + ) + raise UserErrorException(message=error_message, no_personal_data_message=error_message) + + # get all pipeline input types from builder stack + # TODO: check if there is pipeline input we cannot know its type (missing in `PipelineComponentBuilder.inputs`)? + from azure.ai.ml.dsl._pipeline_component_builder import _definition_builder_stack + + res = _definition_builder_stack.top() + pipeline_inputs = res.inputs if res is not None else {} + postfix: List[str] = [] + inputs: Dict[str, ExpressionInput] = {} + postfix, inputs = PipelineExpression._handle_operand(operand1, postfix, inputs, pipeline_inputs) + postfix, inputs = PipelineExpression._handle_operand(operand2, postfix, inputs, pipeline_inputs) + postfix.append(operator) + return PipelineExpression(postfix, inputs) + + @property + def _string_concatenation(self) -> bool: + """If all operands are string and operations are addition, it is a string concatenation expression. + + :return: Whether this represents string concatenation + :rtype: bool + """ + for token in self._postfix: + # operator can only be "+" for string concatenation + if token in _SUPPORTED_OPERATORS: + if token != PipelineExpressionOperator.ADD: + return False + continue + # constant and PiplineInput should be type string + if token in self._inputs: + if self._inputs[token].type != ComponentParameterTypes.STRING: + return False + else: + if not isinstance(eval(token), str): # pylint: disable=eval-used # nosec + return False + return True + + def _to_data_binding(self) -> str: + """Convert operands to data binding and concatenate them in the order of postfix expression. + + :return: The data binding + :rtype: str + """ + if not self._string_concatenation: + error_message = ( + "Only string concatenation expression is supported to convert to data binding, " + f"current expression is '{self.expression}'." + ) + raise UserErrorException(message=error_message, no_personal_data_message=error_message) + + stack = [] + for token in self._postfix: + if token != PipelineExpressionOperator.ADD: + if token in self._inputs: + stack.append(self._inputs[token].value._data_binding()) + else: + stack.append(eval(token)) # pylint: disable=eval-used # nosec + continue + operand2, operand1 = stack.pop(), stack.pop() + stack.append(operand1 + operand2) + res: str = stack.pop() + return res + + def resolve(self) -> Union[str, "BaseNode"]: + """Resolve expression to data binding or component, depend on the operations. + + :return: The data binding string or the component + :rtype: Union[str, BaseNode] + """ + if self._string_concatenation: + return self._to_data_binding() + return cast(Union[str, "BaseNode"], self._create_component()) + + @staticmethod + def parse_pipeline_inputs_from_data_binding(data_binding: str) -> List[str]: + """Parse all PipelineInputs name from data binding expression. + + :param data_binding: Data binding expression + :type data_binding: str + :return: List of PipelineInput's name from given data binding expression + :rtype: List[str] + """ + pipeline_input_names = [] + for single_data_binding in get_all_data_binding_expressions( + value=data_binding, + binding_prefix=PipelineExpression._PIPELINE_INPUT_PREFIX, + is_singular=False, + ): + m = PipelineExpression._PIPELINE_INPUT_PATTERN.match(single_data_binding) + # `get_all_data_binding_expressions` should work as pre-filter, so no need to concern `m` is None + if m is not None: + pipeline_input_names.append(m.group(PipelineExpression._PIPELINE_INPUT_NAME_GROUP)) + return pipeline_input_names + + @staticmethod + def _get_operation_result_type(type1: str, operator: str, type2: str) -> str: + def _validate_operand_type(_type: str) -> None: + if _type != NONE_PARAMETER_TYPE and _type not in PipelineExpression._SUPPORTED_PIPELINE_INPUT_TYPES: + error_message = ( + f"Pipeline input type {_type!r} is not supported in expression; " + f"currently only support None, " + + ", ".join(PipelineExpression._SUPPORTED_PIPELINE_INPUT_TYPES) + + "." + ) + raise UserErrorException(message=error_message, no_personal_data_message=error_message) + + _validate_operand_type(type1) + _validate_operand_type(type2) + operation = f"{type1} {operator} {type2}" + lookup_value = _OPERATION_RESULT_TYPE_LOOKUP.get(operation) + if isinstance(lookup_value, str): + return lookup_value # valid operation, return result type + _user_exception: UserErrorException = lookup_value + raise _user_exception # invalid operation, raise UserErrorException + + def _get_operand_type(self, operand: str) -> str: + if operand in self._inputs: + res: str = self._inputs[operand].type + return res + primitive_type = type(eval(operand)) # pylint: disable=eval-used # nosec + res_type: str = IOConstants.PRIMITIVE_TYPE_2_STR.get(primitive_type, NONE_PARAMETER_TYPE) + return res_type + + @property + def _component_code(self) -> str: + def _generate_function_code_lines() -> Tuple[List[str], str]: + """Return lines of code and return type. + + :return: A 2-tuple of (function body, return type name) + :rtype: Tuple[List[str], str] + """ + _inter_id, _code, _stack = 0, [], [] + _line_recorder: Dict = {} + for _token in self._postfix: + if _token not in _SUPPORTED_OPERATORS: + _type = self._get_operand_type(_token) + _stack.append((_token, _type)) + continue + _operand2, _type2 = _stack.pop() + _operand1, _type1 = _stack.pop() + _current_line = f"{_operand1} {_token} {_operand2}" + if _current_line in _line_recorder: + _inter_var, _inter_var_type = _line_recorder[_current_line] + else: + _inter_var = f"inter_var_{_inter_id}" + _inter_id += 1 + _inter_var_type = self._get_operation_result_type(_type1, _token, _type2) + _code.append(f"{self._INDENTATION}{_inter_var} = {_current_line}") + _line_recorder[_current_line] = (_inter_var, _inter_var_type) + _stack.append((_inter_var, _inter_var_type)) + _return_var, _result_type = _stack.pop() + _code.append(f"{self._INDENTATION}return {_return_var}") + return _code, _result_type + + def _generate_function_decorator_and_declaration_lines(_return_type: str) -> List[str]: + # decorator parameters + _display_name = f'{self._INDENTATION}display_name="Expression: {self.expression}",' + _decorator_parameters = "\n" + "\n".join([_display_name]) + "\n" + # component parameters + _component_parameters = [] + for _name in sorted(self._inputs): + _type = self._TO_PYTHON_TYPE[self._inputs[_name].type] + _component_parameters.append(f"{_name}: {_type}") + _component_parameters_str = ( + "\n" + + "\n".join( + [f"{self._INDENTATION}{_component_parameter}," for _component_parameter in _component_parameters] + ) + + "\n" + ) + return [ + self._IMPORT_MLDESIGNER_LINE + "\n\n", + self._DECORATOR_LINE.replace("@@decorator_parameters@@", _decorator_parameters), + self._COMPONENT_FUNC_DECLARATION_LINE.replace( + "@@component_parameters@@", _component_parameters_str + ).replace("@@return_type@@", f'"{_return_type}"'), + ] + + lines, result_type = _generate_function_code_lines() + self._result_type = result_type + code = _generate_function_decorator_and_declaration_lines(result_type) + lines + return "\n".join(code) + "\n" + + def _create_component(self) -> Any: + def _generate_python_file(_folder: Path) -> None: + _folder.mkdir() + with open(_folder / "expression_component.py", "w", encoding=DefaultOpenEncoding.WRITE) as _f: + _f.write(self._component_code) + + def _generate_yaml_file(_path: Path) -> None: + _data_folder = Path(__file__).parent / "data" + # update YAML content from template and dump + with open(_data_folder / "expression_component_template.yml", "r", encoding=DefaultOpenEncoding.READ) as _f: + _data = load_yaml(_f) + _data["display_name"] = f"Expression: {self.expression}" + _data["inputs"] = {} + _data["outputs"]["output"]["type"] = self._result_type + _command_inputs_items = [] + for _name in sorted(self._inputs): + _type = self._inputs[_name].type + _data["inputs"][_name] = {"type": _type} + _command_inputs_items.append(_name + '="${{inputs.' + _name + '}}"') + _command_inputs_string = " ".join(_command_inputs_items) + _command_output_string = 'output="${{outputs.output}}"' + _command = ( + "mldesigner execute --source expression_component.py --name expression_func" + " --inputs " + _command_inputs_string + " --outputs " + _command_output_string + ) + _data["command"] = _data["command"].format(command_placeholder=_command) + dump_yaml_to_file(_path, _data) + + if self._created_component is None: + tmp_folder = Path(tempfile.mkdtemp()) + code_folder = tmp_folder / "src" + yaml_path = tmp_folder / "component_spec.yml" + _generate_python_file(code_folder) + _generate_yaml_file(yaml_path) + + from azure.ai.ml import load_component + + component_func = load_component(yaml_path) + component_kwargs = {k: v.value for k, v in self._inputs.items()} + self._created_component = component_func(**component_kwargs) + if self._created_component is not None: + self._created_component.environment_variables = {AZUREML_PRIVATE_FEATURES_ENV_VAR: "true"} + return self._created_component diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_pipeline_job_helpers.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_pipeline_job_helpers.py new file mode 100644 index 00000000..3a7d89e7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_pipeline_job_helpers.py @@ -0,0 +1,182 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import re +from typing import Dict, List, Tuple, Type, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import InputDeliveryMode +from azure.ai.ml._restclient.v2023_04_01_preview.models import JobInput as RestJobInput +from azure.ai.ml._restclient.v2023_04_01_preview.models import JobOutput as RestJobOutput +from azure.ai.ml._restclient.v2023_04_01_preview.models import Mpi, PyTorch, Ray, TensorFlow +from azure.ai.ml.constants._component import ComponentJobConstants +from azure.ai.ml.entities._inputs_outputs import Input, Output +from azure.ai.ml.entities._job._input_output_helpers import ( + INPUT_MOUNT_MAPPING_FROM_REST, + INPUT_MOUNT_MAPPING_TO_REST, + OUTPUT_MOUNT_MAPPING_FROM_REST, + OUTPUT_MOUNT_MAPPING_TO_REST, +) +from azure.ai.ml.entities._util import normalize_job_input_output_type +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException + + +def process_sdk_component_job_io( + io: Dict, + io_binding_regex_list: List[str], +) -> Tuple: + """Separates SDK ComponentJob inputs that are data bindings (i.e. string inputs prefixed with 'inputs.' or + 'outputs.') and dataset and literal inputs/outputs. + + :param io: Input or output dictionary of an SDK ComponentJob + :type io: Dict[str, Union[str, float, bool, Input]] + :param io_binding_regex_list: A list of regexes for io bindings + :type io_binding_regex_list: List[str] + :return: A tuple of dictionaries: + * One mapping inputs to REST formatted ComponentJobInput/ComponentJobOutput for data binding io. + * The other dictionary contains any IO that is not a databinding that is yet to be turned into REST form + :rtype: Tuple[Dict[str, str], Dict[str, Union[str, float, bool, Input]]] + """ + io_bindings: Dict = {} + dataset_literal_io: Dict = {} + legacy_io_binding_regex_list = [ + ComponentJobConstants.LEGACY_INPUT_PATTERN, + ComponentJobConstants.LEGACY_OUTPUT_PATTERN, + ] + for io_name, io_value in io.items(): + if isinstance(io_value, (Input, Output)) and isinstance(io_value.path, str): + mode = io_value.mode + path = io_value.path + name = io_value.name if hasattr(io_value, "name") else None + version = io_value.version if hasattr(io_value, "version") else None + if any(re.match(item, path) for item in io_binding_regex_list): + # Yaml syntax requires using ${{}} to enclose inputs and outputs bindings + # io_bindings[io_name] = io_value + io_bindings.update({io_name: {"value": path}}) + # add mode to literal value for binding input + if mode: + if isinstance(io_value, Input): + io_bindings[io_name].update({"mode": INPUT_MOUNT_MAPPING_TO_REST[mode]}) + else: + io_bindings[io_name].update({"mode": OUTPUT_MOUNT_MAPPING_TO_REST[mode]}) + if name or version: + assert isinstance(io_value, Output) + if name: + io_bindings[io_name].update({"name": name}) + if version: + io_bindings[io_name].update({"version": version}) + if isinstance(io_value, Output) and io_value.name is not None: + # when the output should be registered, + # we add io_value to dataset_literal_io for further to_rest_data_outputs + dataset_literal_io[io_name] = io_value + elif any(re.match(item, path) for item in legacy_io_binding_regex_list): + new_format = path.replace("{{", "{{parent.") + msg = "{} has changed to {}, please change to use new format." + raise ValidationException( + message=msg.format(path, new_format), + no_personal_data_message=msg.format("[io_value]", "[io_value_new_format]"), + target=ErrorTarget.PIPELINE, + error_category=ErrorCategory.USER_ERROR, + ) + else: + dataset_literal_io[io_name] = io_value + else: + # Collect non-input data inputs + dataset_literal_io[io_name] = io_value + return io_bindings, dataset_literal_io + + +def from_dict_to_rest_io( + io: Dict[str, Union[str, dict]], + rest_object_class: Union[Type[RestJobInput], Type[RestJobOutput]], + io_binding_regex_list: List[str], +) -> Tuple[Dict[str, str], Dict[str, Union[RestJobInput, RestJobOutput]]]: + """Translate rest JObject dictionary to rest inputs/outputs and bindings. + + :param io: Input or output dictionary. + :type io: Dict[str, Union[str, dict]] + :param rest_object_class: RestJobInput or RestJobOutput + :type rest_object_class: Union[Type[RestJobInput], Type[RestJobOutput]] + :param io_binding_regex_list: A list of regexes for io bindings + :type io_binding_regex_list: List[str] + :return: Map from IO name to IO bindings and Map from IO name to IO objects. + :rtype: Tuple[Dict[str, str], Dict[str, Union[RestJobInput, RestJobOutput]]] + """ + io_bindings: dict = {} + rest_io_objects = {} + DIRTY_MODE_MAPPING = { + "Mount": InputDeliveryMode.READ_ONLY_MOUNT, + "RoMount": InputDeliveryMode.READ_ONLY_MOUNT, + "RwMount": InputDeliveryMode.READ_WRITE_MOUNT, + } + for key, val in io.items(): + if isinstance(val, dict): + # convert the input of camel to snake to be compatible with the Jun api + # todo: backend help convert node level input/output type + normalize_job_input_output_type(val) + + # Add casting as sometimes we got value like 1(int) + io_value = str(val.get("value", "")) + io_mode = val.get("mode", None) + io_name = val.get("name", None) + io_version = val.get("version", None) + if any(re.match(item, io_value) for item in io_binding_regex_list): + io_bindings.update({key: {"path": io_value}}) + # add mode to literal value for binding input + if io_mode: + # deal with dirty mode data submitted before + if io_mode in DIRTY_MODE_MAPPING: + io_mode = DIRTY_MODE_MAPPING[io_mode] + val["mode"] = io_mode + if io_mode in OUTPUT_MOUNT_MAPPING_FROM_REST: + io_bindings[key].update({"mode": OUTPUT_MOUNT_MAPPING_FROM_REST[io_mode]}) + else: + io_bindings[key].update({"mode": INPUT_MOUNT_MAPPING_FROM_REST[io_mode]}) + # add name and version for binding input + if io_name or io_version: + assert rest_object_class.__name__ == "JobOutput" + # current code only support dump name and version for JobOutput + # this assert can be deleted if we need to dump name/version for JobInput + if io_name: + io_bindings[key].update({"name": io_name}) + if io_version: + io_bindings[key].update({"version": io_version}) + if not io_mode and not io_name and not io_version: + io_bindings[key] = io_value + else: + if rest_object_class.__name__ == "JobOutput": + # current code only support dump name and version for JobOutput + # this condition can be deleted if we need to dump name/version for JobInput + if "name" in val.keys(): + val["asset_name"] = val.pop("name") + if "version" in val.keys(): + val["asset_version"] = val.pop("version") + rest_obj = rest_object_class.from_dict(val) + rest_io_objects[key] = rest_obj + else: + msg = "Got unsupported type of input/output: {}:" + f"{type(val)}" + raise ValidationException( + message=msg.format(val), + no_personal_data_message=msg.format("[val]"), + target=ErrorTarget.PIPELINE, + error_category=ErrorCategory.USER_ERROR, + ) + return io_bindings, rest_io_objects + + +def from_dict_to_rest_distribution(distribution_dict: Dict) -> Union[PyTorch, Mpi, TensorFlow, Ray]: + target_type = distribution_dict["distribution_type"].lower() + if target_type == "pytorch": + return PyTorch(**distribution_dict) + if target_type == "mpi": + return Mpi(**distribution_dict) + if target_type == "tensorflow": + return TensorFlow(**distribution_dict) + if target_type == "ray": + return Ray(**distribution_dict) + msg = "Distribution type must be pytorch, mpi, tensorflow or ray: {}".format(target_type) + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.PIPELINE, + error_category=ErrorCategory.USER_ERROR, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/data/expression_component_template.yml b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/data/expression_component_template.yml new file mode 100644 index 00000000..10d391aa --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/data/expression_component_template.yml @@ -0,0 +1,16 @@ +$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json +type: command + +name: expression_component +version: 1 + +outputs: + output: + is_control: true + +code: ./src + +environment: azureml://registries/azureml/environments/mldesigner/labels/latest + +command: >- + {command_placeholder} diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/pipeline_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/pipeline_job.py new file mode 100644 index 00000000..7ddbbc46 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/pipeline_job.py @@ -0,0 +1,711 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access +import itertools +import logging +import typing +from functools import partial +from pathlib import Path +from typing import Any, Dict, Generator, List, Optional, Union, cast + +from typing_extensions import Literal + +from azure.ai.ml._restclient.v2024_01_01_preview.models import JobBase +from azure.ai.ml._restclient.v2024_01_01_preview.models import PipelineJob as RestPipelineJob +from azure.ai.ml._schema import PathAwareSchema +from azure.ai.ml._schema.pipeline.pipeline_job import PipelineJobSchema +from azure.ai.ml._utils._arm_id_utils import get_resource_name_from_arm_id_safe +from azure.ai.ml._utils.utils import ( + camel_to_snake, + is_data_binding_expression, + is_private_preview_enabled, + transform_dict_keys, +) +from azure.ai.ml.constants import JobType +from azure.ai.ml.constants._common import AZUREML_PRIVATE_FEATURES_ENV_VAR, BASE_PATH_CONTEXT_KEY +from azure.ai.ml.constants._component import ComponentSource +from azure.ai.ml.constants._job.pipeline import ValidationErrorCode +from azure.ai.ml.entities._builders import BaseNode +from azure.ai.ml.entities._builders.condition_node import ConditionNode +from azure.ai.ml.entities._builders.control_flow_node import LoopNode +from azure.ai.ml.entities._builders.import_node import Import +from azure.ai.ml.entities._builders.parallel import Parallel +from azure.ai.ml.entities._builders.pipeline import Pipeline +from azure.ai.ml.entities._component.component import Component +from azure.ai.ml.entities._component.pipeline_component import PipelineComponent + +# from azure.ai.ml.entities._job.identity import AmlToken, Identity, ManagedIdentity, UserIdentity +from azure.ai.ml.entities._credentials import ( + AmlTokenConfiguration, + ManagedIdentityConfiguration, + UserIdentityConfiguration, + _BaseJobIdentityConfiguration, +) +from azure.ai.ml.entities._inputs_outputs import Input, Output +from azure.ai.ml.entities._inputs_outputs.group_input import GroupInput +from azure.ai.ml.entities._job._input_output_helpers import ( + from_rest_data_outputs, + from_rest_inputs_to_dataset_literal, + to_rest_data_outputs, + to_rest_dataset_literal_inputs, +) +from azure.ai.ml.entities._job.import_job import ImportJob +from azure.ai.ml.entities._job.job import Job +from azure.ai.ml.entities._job.job_service import JobServiceBase +from azure.ai.ml.entities._job.pipeline._io import PipelineInput, PipelineJobIOMixin +from azure.ai.ml.entities._job.pipeline.pipeline_job_settings import PipelineJobSettings +from azure.ai.ml.entities._mixins import YamlTranslatableMixin +from azure.ai.ml.entities._system_data import SystemData +from azure.ai.ml.entities._validation import MutableValidationResult, PathAwareSchemaValidatableMixin +from azure.ai.ml.exceptions import ErrorTarget, UserErrorException, ValidationException + +module_logger = logging.getLogger(__name__) + + +class PipelineJob(Job, YamlTranslatableMixin, PipelineJobIOMixin, PathAwareSchemaValidatableMixin): + """Pipeline job. + + You should not instantiate this class directly. Instead, you should + use the `@pipeline` decorator to create a `PipelineJob`. + + :param component: Pipeline component version. The field is mutually exclusive with 'jobs'. + :type component: Union[str, ~azure.ai.ml.entities._component.pipeline_component.PipelineComponent] + :param inputs: Inputs to the pipeline job. + :type inputs: dict[str, Union[~azure.ai.ml.entities.Input, str, bool, int, float]] + :param outputs: Outputs of the pipeline job. + :type outputs: dict[str, ~azure.ai.ml.entities.Output] + :param name: Name of the PipelineJob. Defaults to None. + :type name: str + :param description: Description of the pipeline job. Defaults to None + :type description: str + :param display_name: Display name of the pipeline job. Defaults to None + :type display_name: str + :param experiment_name: Name of the experiment the job will be created under. + If None is provided, the experiment will be set to the current directory. Defaults to None + :type experiment_name: str + :param jobs: Pipeline component node name to component object. Defaults to None + :type jobs: dict[str, ~azure.ai.ml.entities._builders.BaseNode] + :param settings: Setting of the pipeline job. Defaults to None + :type settings: ~azure.ai.ml.entities.PipelineJobSettings + :param identity: Identity that the training job will use while running on compute. Defaults to None + :type identity: Union[ + ~azure.ai.ml.entities._credentials.ManagedIdentityConfiguration, + ~azure.ai.ml.entities._credentials.AmlTokenConfiguration, + ~azure.ai.ml.entities._credentials.UserIdentityConfiguration + + ] + :param compute: Compute target name of the built pipeline. Defaults to None + :type compute: str + :param tags: Tag dictionary. Tags can be added, removed, and updated. Defaults to None + :type tags: dict[str, str] + :param kwargs: A dictionary of additional configuration parameters. Defaults to None + :type kwargs: dict + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_pipeline_job_configurations.py + :start-after: [START configure_pipeline_job_and_settings] + :end-before: [END configure_pipeline_job_and_settings] + :language: python + :dedent: 8 + :caption: Shows how to create a pipeline using this class. + """ + + def __init__( + self, + *, + component: Optional[Union[str, PipelineComponent, Component]] = None, + inputs: Optional[Dict[str, Union[Input, str, bool, int, float]]] = None, + outputs: Optional[Dict[str, Output]] = None, + name: Optional[str] = None, + description: Optional[str] = None, + display_name: Optional[str] = None, + experiment_name: Optional[str] = None, + jobs: Optional[Dict[str, BaseNode]] = None, + settings: Optional[PipelineJobSettings] = None, + identity: Optional[ + Union[ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration] + ] = None, + compute: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> None: + # initialize io + inputs, outputs = inputs or {}, outputs or {} + if isinstance(component, PipelineComponent) and component._source in [ + ComponentSource.DSL, + ComponentSource.YAML_COMPONENT, + ]: + self._inputs = self._build_inputs_dict(inputs, input_definition_dict=component.inputs) + # for pipeline component created pipeline jobs, + # it's output should have same value with the component outputs, + # then override it with given outputs (filter out None value) + pipeline_outputs = {k: v for k, v in (outputs or {}).items() if v} + self._outputs = self._build_pipeline_outputs_dict({**component.outputs, **pipeline_outputs}) + else: + # Build inputs/outputs dict without meta when definition not available + self._inputs = self._build_inputs_dict(inputs) + # for node created pipeline jobs, + # it's output should have same value with the given outputs + self._outputs = self._build_pipeline_outputs_dict(outputs=outputs) + source = kwargs.pop("_source", ComponentSource.CLASS) + if component is None: + component = PipelineComponent( + jobs=jobs, + description=description, + display_name=display_name, + base_path=kwargs.get(BASE_PATH_CONTEXT_KEY), + _source=source, + ) + + # If component is Pipeline component, jobs will be component.jobs + self._jobs = (jobs or {}) if isinstance(component, str) else {} + + self.component: Union[PipelineComponent, str] = cast(Union[PipelineComponent, str], component) + if "type" not in kwargs: + kwargs["type"] = JobType.PIPELINE + if isinstance(component, PipelineComponent): + description = component.description if description is None else description + display_name = component.display_name if display_name is None else display_name + super(PipelineJob, self).__init__( + name=name, + description=description, + tags=tags, + display_name=display_name, + experiment_name=experiment_name, + compute=compute, + **kwargs, + ) + + self._remove_pipeline_input() + self.compute = compute + self._settings: Any = None + self.settings = settings + self.identity = identity + # TODO: remove default code & environment? + self._default_code = None + self._default_environment = None + + @property + def inputs(self) -> Dict: + """Inputs of the pipeline job. + + :return: Inputs of the pipeline job. + :rtype: dict[str, Union[~azure.ai.ml.entities.Input, str, bool, int, float]] + """ + return self._inputs + + @property + def outputs(self) -> Dict[str, Union[str, Output]]: + """Outputs of the pipeline job. + + :return: Outputs of the pipeline job. + :rtype: dict[str, Union[str, ~azure.ai.ml.entities.Output]] + """ + return self._outputs + + @property + def jobs(self) -> Dict: + """Return jobs of pipeline job. + + :return: Jobs of pipeline job. + :rtype: dict + """ + res: dict = self.component.jobs if isinstance(self.component, PipelineComponent) else self._jobs + return res + + @property + def settings(self) -> Optional[PipelineJobSettings]: + """Settings of the pipeline job. + + :return: Settings of the pipeline job. + :rtype: ~azure.ai.ml.entities.PipelineJobSettings + """ + if self._settings is None: + self._settings = PipelineJobSettings() + res: Optional[PipelineJobSettings] = self._settings + return res + + @settings.setter + def settings(self, value: Union[Dict, PipelineJobSettings]) -> None: + """Set the pipeline job settings. + + :param value: The pipeline job settings. + :type value: Union[dict, ~azure.ai.ml.entities.PipelineJobSettings] + """ + if value is not None: + if isinstance(value, PipelineJobSettings): + # since PipelineJobSettings inherit _AttrDict, we need add this branch to distinguish with dict + pass + elif isinstance(value, dict): + value = PipelineJobSettings(**value) + else: + raise TypeError("settings must be PipelineJobSettings or dict but got {}".format(type(value))) + self._settings = value + + @classmethod + def _create_validation_error(cls, message: str, no_personal_data_message: str) -> ValidationException: + return ValidationException( + message=message, + no_personal_data_message=no_personal_data_message, + target=ErrorTarget.PIPELINE, + ) + + @classmethod + def _create_schema_for_validation(cls, context: Any) -> PathAwareSchema: + # import this to ensure that nodes are registered before schema is created. + + return PipelineJobSchema(context=context) + + @classmethod + def _get_skip_fields_in_schema_validation(cls) -> typing.List[str]: + # jobs validations are done in _customized_validate() + return ["component", "jobs"] + + @property + def _skip_required_compute_missing_validation(self) -> Literal[True]: + return True + + def _validate_compute_is_set(self) -> MutableValidationResult: + validation_result = self._create_empty_validation_result() + if self.compute is not None: + return validation_result + if self.settings is not None and self.settings.default_compute is not None: + return validation_result + + if not isinstance(self.component, str): + validation_result.merge_with(self.component._validate_compute_is_set()) + return validation_result + + def _customized_validate(self) -> MutableValidationResult: + """Validate that all provided inputs and parameters are valid for current pipeline and components in it. + + :return: The validation result + :rtype: MutableValidationResult + """ + validation_result = super(PipelineJob, self)._customized_validate() + + if isinstance(self.component, PipelineComponent): + # Merge with pipeline component validate result for structure validation. + # Skip top level parameter missing type error + validation_result.merge_with( + self.component._customized_validate(), + condition_skip=lambda x: x.error_code == ValidationErrorCode.PARAMETER_TYPE_UNKNOWN + and x.yaml_path.startswith("inputs"), + ) + # Validate compute + validation_result.merge_with(self._validate_compute_is_set()) + # Validate Input + validation_result.merge_with(self._validate_input()) + # Validate initialization & finalization jobs + validation_result.merge_with(self._validate_init_finalize_job()) + + return validation_result + + def _validate_input(self) -> MutableValidationResult: + validation_result = self._create_empty_validation_result() + if not isinstance(self.component, str): + # TODO(1979547): refine this logic: not all nodes have `_get_input_binding_dict` method + used_pipeline_inputs = set( + itertools.chain( + *[ + self.component._get_input_binding_dict(node if not isinstance(node, LoopNode) else node.body)[0] + for node in self.jobs.values() + if not isinstance(node, ConditionNode) + # condition node has no inputs + ] + ) + ) + # validate inputs + if not isinstance(self.component, Component): + return validation_result + for key, meta in self.component.inputs.items(): + if key not in used_pipeline_inputs: # pylint: disable=possibly-used-before-assignment + # Only validate inputs certainly used. + continue + # raise error when required input with no default value not set + if ( + self.inputs.get(key, None) is None # input not provided + and meta.optional is not True # and it's required + and meta.default is None # and it does not have default + ): + name = self.name or self.display_name + name = f"{name!r} " if name else "" + validation_result.append_error( + yaml_path=f"inputs.{key}", + message=f"Required input {key!r} for pipeline {name}not provided.", + ) + return validation_result + + def _validate_init_finalize_job(self) -> MutableValidationResult: # pylint: disable=too-many-statements + from azure.ai.ml.entities._job.pipeline._io import InputOutputBase, _GroupAttrDict + + validation_result = self._create_empty_validation_result() + # subgraph (PipelineComponent) should not have on_init/on_finalize set + for job_name, job in self.jobs.items(): + if job.type != "pipeline": + continue + if job.settings.on_init: + validation_result.append_error( + yaml_path=f"jobs.{job_name}.settings.on_init", + message="On_init is not supported for pipeline component.", + ) + if job.settings.on_finalize: + validation_result.append_error( + yaml_path=f"jobs.{job_name}.settings.on_finalize", + message="On_finalize is not supported for pipeline component.", + ) + + on_init = None + on_finalize = None + + if self.settings is not None: + # quick return if neither on_init nor on_finalize is set + if self.settings.on_init is None and self.settings.on_finalize is None: + return validation_result + + on_init, on_finalize = self.settings.on_init, self.settings.on_finalize + + append_on_init_error = partial(validation_result.append_error, "settings.on_init") + append_on_finalize_error = partial(validation_result.append_error, "settings.on_finalize") + # on_init and on_finalize cannot be same + if on_init == on_finalize: + append_on_init_error(f"Invalid on_init job {on_init}, it should be different from on_finalize.") + append_on_finalize_error(f"Invalid on_finalize job {on_finalize}, it should be different from on_init.") + # pipeline should have at least one normal node + if len(set(self.jobs.keys()) - {on_init, on_finalize}) == 0: + validation_result.append_error(yaml_path="jobs", message="No other job except for on_init/on_finalize job.") + + def _is_control_flow_node(_validate_job_name: str) -> bool: + from azure.ai.ml.entities._builders.control_flow_node import ControlFlowNode + + _validate_job = self.jobs[_validate_job_name] + return issubclass(type(_validate_job), ControlFlowNode) + + def _is_isolated_job(_validate_job_name: str) -> bool: + def _try_get_data_bindings( + _name: str, _input_output_data: Union["_GroupAttrDict", "InputOutputBase"] + ) -> Optional[List]: + """Try to get data bindings from input/output data, return None if not found. + :param _name: The name to use when flattening GroupAttrDict + :type _name: str + :param _input_output_data: The input/output data + :type _input_output_data: Union[_GroupAttrDict, str, InputOutputBase] + :return: A list of data bindings, or None if not found + :rtype: Optional[List[str]] + """ + # handle group input + if GroupInput._is_group_attr_dict(_input_output_data): + _new_input_output_data: _GroupAttrDict = cast(_GroupAttrDict, _input_output_data) + # flatten to avoid nested cases + flattened_values: List[Input] = list(_new_input_output_data.flatten(_name).values()) + # handle invalid empty group + if len(flattened_values) == 0: + return None + return [_value.path for _value in flattened_values] + _input_output_data = _input_output_data._data + if isinstance(_input_output_data, str): + return [_input_output_data] + if not hasattr(_input_output_data, "_data_binding"): + return None + return [_input_output_data._data_binding()] + + _validate_job = self.jobs[_validate_job_name] + # no input to validate job + for _input_name in _validate_job.inputs: + _data_bindings = _try_get_data_bindings(_input_name, _validate_job.inputs[_input_name]) + if _data_bindings is None: + continue + for _data_binding in _data_bindings: + if is_data_binding_expression(_data_binding, ["parent", "jobs"]): + return False + # no output from validate job - iterate other jobs input(s) to validate + for _job_name, _job in self.jobs.items(): + # exclude control flow node as it does not have inputs + if _is_control_flow_node(_job_name): + continue + for _input_name in _job.inputs: + _data_bindings = _try_get_data_bindings(_input_name, _job.inputs[_input_name]) + if _data_bindings is None: + continue + for _data_binding in _data_bindings: + if is_data_binding_expression(_data_binding, ["parent", "jobs", _validate_job_name]): + return False + return True + + # validate on_init + if on_init is not None: + if on_init not in self.jobs: + append_on_init_error(f"On_init job name {on_init} not exists in jobs.") + else: + if _is_control_flow_node(on_init): + append_on_init_error("On_init job should not be a control flow node.") + elif not _is_isolated_job(on_init): + append_on_init_error("On_init job should not have connection to other execution node.") + # validate on_finalize + if on_finalize is not None: + if on_finalize not in self.jobs: + append_on_finalize_error(f"On_finalize job name {on_finalize} not exists in jobs.") + else: + if _is_control_flow_node(on_finalize): + append_on_finalize_error("On_finalize job should not be a control flow node.") + elif not _is_isolated_job(on_finalize): + append_on_finalize_error("On_finalize job should not have connection to other execution node.") + return validation_result + + def _remove_pipeline_input(self) -> None: + """Remove None pipeline input.If not remove, it will pass "None" to backend.""" + redundant_pipeline_inputs = [] + for pipeline_input_name, pipeline_input in self._inputs.items(): + if isinstance(pipeline_input, PipelineInput) and pipeline_input._data is None: + redundant_pipeline_inputs.append(pipeline_input_name) + for redundant_pipeline_input in redundant_pipeline_inputs: + self._inputs.pop(redundant_pipeline_input) + + def _check_private_preview_features(self) -> None: + """Checks is private preview features included in pipeline. + + If private preview environment not set, raise exception. + """ + if not is_private_preview_enabled(): + error_msg = ( + "{} is a private preview feature, " + f"please set environment variable {AZUREML_PRIVATE_FEATURES_ENV_VAR} to true to use it." + ) + # check has not supported nodes + for _, node in self.jobs.items(): + # TODO: Remove in PuP + if isinstance(node, (ImportJob, Import)): + msg = error_msg.format("Import job in pipeline") + raise UserErrorException(message=msg, no_personal_data_message=msg) + + def _to_node(self, context: Optional[Dict] = None, **kwargs: Any) -> "Pipeline": + """Translate a command job to a pipeline node when load schema. + + (Write a pipeline job as node in yaml is not supported presently.) + + :param context: Context of command job YAML file. + :type context: dict + :return: Translated command component. + :rtype: Pipeline + """ + component = self._to_component(context, **kwargs) + + return Pipeline( + component=component, + compute=self.compute, + # Need to supply the inputs with double curly. + inputs=self.inputs, + outputs=self.outputs, + description=self.description, + tags=self.tags, + display_name=self.display_name, + properties=self.properties, + ) + + def _to_rest_object(self) -> JobBase: + """Build current parameterized pipeline instance to a pipeline job object before submission. + + :return: Rest pipeline job. + :rtype: JobBase + """ + # Check if there are private preview features in it + self._check_private_preview_features() + + # Build the inputs to dict. Handle both value & binding assignment. + # Example: { + # "input_data": {"data": {"path": "path/to/input/data"}, "mode"="Mount"}, + # "input_value": 10, + # "learning_rate": "${{jobs.step1.inputs.learning_rate}}" + # } + built_inputs = self._build_inputs() + + # Build the outputs to dict + # example: {"eval_output": "${{jobs.eval.outputs.eval_output}}"} + built_outputs = self._build_outputs() + + if self.settings is not None: + settings_dict = self.settings._to_dict() + + if isinstance(self.component, PipelineComponent): + source = self.component._source + # Build the jobs to dict + rest_component_jobs = self.component._build_rest_component_jobs() + else: + source = ComponentSource.REMOTE_WORKSPACE_JOB + rest_component_jobs = {} + # add _source on pipeline job.settings + if "_source" not in settings_dict: # pylint: disable=possibly-used-before-assignment + settings_dict.update({"_source": source}) + + # TODO: Revisit this logic when multiple types of component jobs are supported + rest_compute = self.compute + # This will be resolved in job_operations _resolve_arm_id_or_upload_dependencies. + component_id = self.component if isinstance(self.component, str) else self.component.id + + # TODO remove it in the future. + # MFE not support pass None or empty input value. Remove the empty inputs in pipeline job. + built_inputs = {k: v for k, v in built_inputs.items() if v is not None and v != ""} + + pipeline_job = RestPipelineJob( + compute_id=rest_compute, + component_id=component_id, + display_name=self.display_name, + tags=self.tags, + description=self.description, + properties=self.properties, + experiment_name=self.experiment_name, + jobs=rest_component_jobs, + inputs=to_rest_dataset_literal_inputs(built_inputs, job_type=self.type), + outputs=to_rest_data_outputs(built_outputs), + settings=settings_dict, + services={k: v._to_rest_object() for k, v in self.services.items()} if self.services else None, + identity=self.identity._to_job_rest_object() if self.identity else None, + ) + + rest_job = JobBase(properties=pipeline_job) + rest_job.name = self.name + return rest_job + + @classmethod + def _load_from_rest(cls, obj: JobBase) -> "PipelineJob": + """Build a pipeline instance from rest pipeline object. + + :param obj: The REST Pipeline Object + :type obj: JobBase + :return: pipeline job. + :rtype: PipelineJob + """ + properties: RestPipelineJob = obj.properties + # Workaround for BatchEndpoint as these fields are not filled in + # Unpack the inputs + from_rest_inputs = from_rest_inputs_to_dataset_literal(properties.inputs) or {} + from_rest_outputs = from_rest_data_outputs(properties.outputs) or {} + # Unpack the component jobs + sub_nodes = PipelineComponent._resolve_sub_nodes(properties.jobs) if properties.jobs else {} + # backend may still store Camel settings, eg: DefaultDatastore, translate them to snake when load back + settings_dict = transform_dict_keys(properties.settings, camel_to_snake) if properties.settings else None + settings_sdk = PipelineJobSettings(**settings_dict) if settings_dict else PipelineJobSettings() + # Create component or use component id + if getattr(properties, "component_id", None): + component = properties.component_id + else: + component = PipelineComponent._load_from_rest_pipeline_job( + { + "inputs": from_rest_inputs, + "outputs": from_rest_outputs, + "display_name": properties.display_name, + "description": properties.description, + "jobs": sub_nodes, + } + ) + + job = PipelineJob( + component=component, + inputs=from_rest_inputs, + outputs=from_rest_outputs, + name=obj.name, + id=obj.id, + jobs=sub_nodes, + display_name=properties.display_name, + tags=properties.tags, + properties=properties.properties, + experiment_name=properties.experiment_name, + status=properties.status, + creation_context=SystemData._from_rest_object(obj.system_data) if obj.system_data else None, + services=JobServiceBase._from_rest_job_services(properties.services) if properties.services else None, + compute=get_resource_name_from_arm_id_safe(properties.compute_id), + settings=settings_sdk, + identity=( + _BaseJobIdentityConfiguration._from_rest_object(properties.identity) if properties.identity else None + ), + ) + + return job + + def _to_dict(self) -> Dict: + res: dict = self._dump_for_validation() + return res + + @classmethod + def _component_items_from_path(cls, data: Dict) -> Generator: + if "jobs" in data: + for node_name, job_instance in data["jobs"].items(): + potential_component_path = job_instance["component"] if "component" in job_instance else None + if isinstance(potential_component_path, str) and potential_component_path.startswith("file:"): + yield node_name, potential_component_path + + @classmethod + def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "PipelineJob": + path_first_occurrence: dict = {} + component_first_occurrence = {} + for node_name, component_path in cls._component_items_from_path(data): + if component_path in path_first_occurrence: + component_first_occurrence[node_name] = path_first_occurrence[component_path] + # set components to be replaced here may break the validation logic + else: + path_first_occurrence[component_path] = node_name + + # use this instead of azure.ai.ml.entities._util.load_from_dict to avoid parsing + loaded_schema = cls._create_schema_for_validation(context=context).load(data, **kwargs) + + # replace repeat component with first occurrence to reduce arm id resolution + # current load yaml file logic is in azure.ai.ml._schema.core.schema.YamlFileSchema.load_from_file + # is it possible to load the same yaml file only once in 1 pipeline loading? + for node_name, first_occurrence in component_first_occurrence.items(): + job = loaded_schema["jobs"][node_name] + job._component = loaded_schema["jobs"][first_occurrence].component + # For Parallel job, should also align task attribute which is usually from component.task + if isinstance(job, Parallel): + job.task = job._component.task + # parallel.task.code is based on parallel._component.base_path, so need to update it + job._base_path = job._component.base_path + return PipelineJob( + base_path=context[BASE_PATH_CONTEXT_KEY], + _source=ComponentSource.YAML_JOB, + **loaded_schema, + ) + + def __str__(self) -> str: + try: + res_to_yaml: str = self._to_yaml() + return res_to_yaml + except BaseException: # pylint: disable=W0718 + res: str = super(PipelineJob, self).__str__() + return res + + def _get_telemetry_values(self) -> Dict: + telemetry_values: dict = super()._get_telemetry_values() + if isinstance(self.component, PipelineComponent): + telemetry_values.update(self.component._get_telemetry_values()) + else: + telemetry_values.update({"source": ComponentSource.REMOTE_WORKSPACE_JOB}) + telemetry_values.pop("is_anonymous") + return telemetry_values + + def _to_component(self, context: Optional[Dict] = None, **kwargs: Any) -> "PipelineComponent": + """Translate a pipeline job to pipeline component. + + :param context: Context of pipeline job YAML file. + :type context: dict + :return: Translated pipeline component. + :rtype: PipelineComponent + """ + ignored_keys = PipelineComponent._check_ignored_keys(self) + if ignored_keys: + name = self.name or self.display_name + name = f"{name!r} " if name else "" + module_logger.warning("%s ignored when translating PipelineJob %sto PipelineComponent.", ignored_keys, name) + pipeline_job_dict = kwargs.get("pipeline_job_dict", {}) + context = context or {BASE_PATH_CONTEXT_KEY: Path("./")} + + # Create anonymous pipeline component with default version as 1 + return PipelineComponent( + base_path=context[BASE_PATH_CONTEXT_KEY], + display_name=self.display_name, + inputs=self._to_inputs(inputs=self.inputs, pipeline_job_dict=pipeline_job_dict), + outputs=self._to_outputs(outputs=self.outputs, pipeline_job_dict=pipeline_job_dict), + jobs=self.jobs, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/pipeline_job_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/pipeline_job_settings.py new file mode 100644 index 00000000..0fe41e2e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/pipeline_job_settings.py @@ -0,0 +1,75 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Any, Dict, Generator, Optional + +from azure.ai.ml.entities._job.pipeline._attr_dict import _AttrDict + + +class PipelineJobSettings(_AttrDict): + """Settings of PipelineJob. + + :param default_datastore: The default datastore of the pipeline. + :type default_datastore: str + :param default_compute: The default compute target of the pipeline. + :type default_compute: str + :param continue_on_step_failure: Flag indicating whether to continue pipeline execution if a step fails. + :type continue_on_step_failure: bool + :param force_rerun: Flag indicating whether to force rerun pipeline execution. + :type force_rerun: bool + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_pipeline_job_configurations.py + :start-after: [START configure_pipeline_job_and_settings] + :end-before: [END configure_pipeline_job_and_settings] + :language: python + :dedent: 8 + :caption: Shows how to set pipeline properties using this class. + """ + + def __init__( + self, + default_datastore: Optional[str] = None, + default_compute: Optional[str] = None, + continue_on_step_failure: Optional[bool] = None, + force_rerun: Optional[bool] = None, + **kwargs: Any + ) -> None: + self._init = True + super().__init__() + self.default_compute: Any = default_compute + self.default_datastore: Any = default_datastore + self.continue_on_step_failure = continue_on_step_failure + self.force_rerun = force_rerun + self.on_init = kwargs.get("on_init", None) + self.on_finalize = kwargs.get("on_finalize", None) + for k, v in kwargs.items(): + setattr(self, k, v) + self._init = False + + def _get_valid_keys(self) -> Generator[str, Any, None]: + for k, v in self.__dict__.items(): + if v is None: + continue + # skip private attributes inherited from _AttrDict + if k in ["_logger", "_allowed_keys", "_init", "_key_restriction"]: + continue + yield k + + def _to_dict(self) -> Dict: + result = {} + for k in self._get_valid_keys(): + result[k] = self.__dict__[k] + result.update(self._get_attrs()) + return result + + def _initializing(self) -> bool: + return self._init + + def __bool__(self) -> bool: + for _ in self._get_valid_keys(): + return True + # _attr_dict will return False if no extra attributes are set + return self.__len__() > 0 diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/queue_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/queue_settings.py new file mode 100644 index 00000000..5b51fb6e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/queue_settings.py @@ -0,0 +1,87 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import logging +from typing import Any, Dict, Optional, Union + +from ..._restclient.v2023_04_01_preview.models import QueueSettings as RestQueueSettings +from ..._utils._experimental import experimental +from ..._utils.utils import is_data_binding_expression +from ...constants._job.job import JobPriorityValues, JobTierNames +from ...entities._mixins import DictMixin, RestTranslatableMixin +from ...exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + +module_logger = logging.getLogger(__name__) + + +@experimental +class QueueSettings(RestTranslatableMixin, DictMixin): + """Queue settings for a pipeline job. + + :ivar job_tier: Enum to determine the job tier. Possible values include: "Spot", "Basic", + "Standard", "Premium", "Null". + :vartype job_tier: str or ~azure.mgmt.machinelearningservices.models.JobTier + :ivar priority: Controls the priority of the job on a compute. + :vartype priority: str + :keyword job_tier: The job tier. Accepted values are "Spot", "Basic", "Standard", and "Premium". + :paramtype job_tier: Optional[Literal]] + :keyword priority: The priority of the job on a compute. Accepted values are "low", "medium", and "high". + Defaults to "medium". + :paramtype priority: Optional[Literal] + :keyword kwargs: Additional properties for QueueSettings. + :paramtype kwargs: Optional[dict] + """ + + def __init__( + self, # pylint: disable=unused-argument + *, + job_tier: Optional[str] = None, + priority: Optional[str] = None, + **kwargs: Any, + ) -> None: + self.job_tier = job_tier + self.priority = priority + + def _to_rest_object(self) -> RestQueueSettings: + self._validate() + job_tier = JobTierNames.ENTITY_TO_REST.get(self.job_tier.lower(), None) if self.job_tier else None + priority = JobPriorityValues.ENTITY_TO_REST.get(self.priority.lower(), None) if self.priority else None + return RestQueueSettings(job_tier=job_tier, priority=priority) + + @classmethod + def _from_rest_object(cls, obj: Union[Dict[str, Any], RestQueueSettings, None]) -> Optional["QueueSettings"]: + if obj is None: + return None + if isinstance(obj, dict): + queue_settings = RestQueueSettings.from_dict(obj) + return cls._from_rest_object(queue_settings) + job_tier = JobTierNames.REST_TO_ENTITY.get(obj.job_tier, None) if obj.job_tier else None + priority = JobPriorityValues.REST_TO_ENTITY.get(obj.priority, None) if hasattr(obj, "priority") else None + return cls(job_tier=job_tier, priority=priority) + + def _validate(self) -> None: + for key, enum_class in [("job_tier", JobTierNames), ("priority", JobPriorityValues)]: + value = getattr(self, key) + if is_data_binding_expression(value): + msg = ( + f"do not support data binding expression on {key} as it involves value mapping " + f"when transformed to rest object, but received '{value}'." + ) + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.JOB, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + valid_keys = list(enum_class.ENTITY_TO_REST.keys()) # type: ignore[attr-defined] + if value and value.lower() not in valid_keys: + msg = f"{key} should be one of {valid_keys}, but received '{value}'." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.JOB, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/resource_configuration.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/resource_configuration.py new file mode 100644 index 00000000..a10d4a66 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/resource_configuration.py @@ -0,0 +1,98 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import json +import logging +from typing import Any, Dict, Optional + +from azure.ai.ml._restclient.v2023_04_01_preview.models import ResourceConfiguration as RestResourceConfiguration +from azure.ai.ml.constants._job.job import JobComputePropertyFields +from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin + +module_logger = logging.getLogger(__name__) + + +class ResourceConfiguration(RestTranslatableMixin, DictMixin): + """Resource configuration for a job. + + This class should not be instantiated directly. Instead, use its subclasses. + + :keyword instance_count: The number of instances to use for the job. + :paramtype instance_count: Optional[int] + :keyword instance_type: The type of instance to use for the job. + :paramtype instance_type: Optional[str] + :keyword properties: The resource's property dictionary. + :paramtype properties: Optional[dict[str, Any]] + """ + + def __init__( + self, # pylint: disable=unused-argument + *, + instance_count: Optional[int] = None, + instance_type: Optional[str] = None, + properties: Optional[Dict[str, Any]] = None, + **kwargs: Any + ) -> None: + self.instance_count = instance_count + self.instance_type = instance_type + self.properties = {} + if properties is not None: + for key, value in properties.items(): + if key == JobComputePropertyFields.AISUPERCOMPUTER: + self.properties[JobComputePropertyFields.SINGULARITY.lower()] = value + else: + self.properties[key] = value + + def _to_rest_object(self) -> RestResourceConfiguration: + serialized_properties = {} + if self.properties: + for key, value in self.properties.items(): + try: + if ( + key.lower() == JobComputePropertyFields.SINGULARITY.lower() + or key.lower() == JobComputePropertyFields.AISUPERCOMPUTER.lower() + ): + # Map Singularity -> AISupercomputer in SDK until MFE does mapping + key = JobComputePropertyFields.AISUPERCOMPUTER + # recursively convert Ordered Dict to dictionary + serialized_properties[key] = json.loads(json.dumps(value)) + except Exception: # pylint: disable=W0718 + pass + return RestResourceConfiguration( + instance_count=self.instance_count, + instance_type=self.instance_type, + properties=serialized_properties, + ) + + @classmethod + def _from_rest_object( # pylint: disable=arguments-renamed + cls, rest_obj: Optional[RestResourceConfiguration] + ) -> Optional["ResourceConfiguration"]: + if rest_obj is None: + return None + return ResourceConfiguration( + instance_count=rest_obj.instance_count, + instance_type=rest_obj.instance_type, + properties=rest_obj.properties, + deserialize_properties=True, + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ResourceConfiguration): + return NotImplemented + return self.instance_count == other.instance_count and self.instance_type == other.instance_type + + def __ne__(self, other: object) -> bool: + if not isinstance(other, ResourceConfiguration): + return NotImplemented + return not self.__eq__(other) + + def _merge_with(self, other: "ResourceConfiguration") -> None: + if other: + if other.instance_count: + self.instance_count = other.instance_count + if other.instance_type: + self.instance_type = other.instance_type + if other.properties: + self.properties = other.properties diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/service_instance.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/service_instance.py new file mode 100644 index 00000000..0e5ba6c6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/service_instance.py @@ -0,0 +1,59 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import logging +from typing import Any, Dict, Optional + +from azure.ai.ml._restclient.runhistory.models import ServiceInstanceResult +from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin + +module_logger = logging.getLogger(__name__) + + +class ServiceInstance(RestTranslatableMixin, DictMixin): + """Service Instance Result. + + :keyword type: The type of service. + :paramtype type: Optional[str] + :keyword port: The port used by the service. + :paramtype port: Optional[int] + :keyword status: The status of the service. + :paramtype status: Optional[str] + :keyword error: The error message. + :paramtype error: Optional[str] + :keyword endpoint: The service endpoint. + :paramtype endpoint: Optional[str] + :keyword properties: The service instance's properties. + :paramtype properties: Optional[dict[str, str]] + """ + + def __init__( + self, # pylint: disable=unused-argument + *, + type: Optional[str] = None, # pylint: disable=redefined-builtin + port: Optional[int] = None, + status: Optional[str] = None, + error: Optional[str] = None, + endpoint: Optional[str] = None, + properties: Optional[Dict[str, str]] = None, + **kwargs: Any + ) -> None: + self.type = type + self.port = port + self.status = status + self.error = error + self.endpoint = endpoint + self.properties = properties + + @classmethod + # pylint: disable=arguments-differ + def _from_rest_object(cls, obj: ServiceInstanceResult, node_index: int) -> "ServiceInstance": # type: ignore + return cls( + type=obj.type, + port=obj.port, + status=obj.status, + error=obj.error.error.message if obj.error and obj.error.error else None, + endpoint=obj.endpoint.replace("<nodeIndex>", str(node_index)) if obj.endpoint else obj.endpoint, + properties=obj.properties, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/spark_helpers.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/spark_helpers.py new file mode 100644 index 00000000..d3fdf9dc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/spark_helpers.py @@ -0,0 +1,210 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint: disable=protected-access +import re +from typing import Any + +from azure.ai.ml.constants import InputOutputModes +from azure.ai.ml.constants._component import ComponentJobConstants +from azure.ai.ml.entities._inputs_outputs import Input, Output +from azure.ai.ml.entities._job.pipeline._io import NodeInput, NodeOutput +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException + + +def _validate_spark_configurations(obj: Any) -> None: + # skip validation when component of node is from remote + if hasattr(obj, "component") and isinstance(obj.component, str): + return + if obj.dynamic_allocation_enabled in ["True", "true", True]: + if ( + obj.driver_cores is None + or obj.driver_memory is None + or obj.executor_cores is None + or obj.executor_memory is None + ): + msg = ( + "spark.driver.cores, spark.driver.memory, spark.executor.cores and spark.executor.memory are " + "mandatory fields." + ) + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.SPARK_JOB, + error_category=ErrorCategory.USER_ERROR, + ) + if obj.dynamic_allocation_min_executors is None or obj.dynamic_allocation_max_executors is None: + msg = ( + "spark.dynamicAllocation.minExecutors and spark.dynamicAllocation.maxExecutors are required " + "when dynamic allocation is enabled." + ) + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.SPARK_JOB, + error_category=ErrorCategory.USER_ERROR, + ) + if not ( + obj.dynamic_allocation_min_executors > 0 + and obj.dynamic_allocation_min_executors <= obj.dynamic_allocation_max_executors + ): + msg = ( + "Dynamic min executors should be bigger than 0 and min executors should be equal or less than " + "max executors." + ) + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.SPARK_JOB, + error_category=ErrorCategory.USER_ERROR, + ) + if obj.executor_instances and ( + obj.executor_instances > obj.dynamic_allocation_max_executors + or obj.executor_instances < obj.dynamic_allocation_min_executors + ): + msg = ( + "Executor instances must be a valid non-negative integer and must be between " + "spark.dynamicAllocation.minExecutors and spark.dynamicAllocation.maxExecutors" + ) + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.SPARK_JOB, + error_category=ErrorCategory.USER_ERROR, + ) + else: + if ( + obj.driver_cores is None + or obj.driver_memory is None + or obj.executor_cores is None + or obj.executor_memory is None + or obj.executor_instances is None + ): + msg = ( + "spark.driver.cores, spark.driver.memory, spark.executor.cores, spark.executor.memory and " + "spark.executor.instances are mandatory fields." + ) + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.SPARK_JOB, + error_category=ErrorCategory.USER_ERROR, + ) + if obj.dynamic_allocation_min_executors is not None or obj.dynamic_allocation_max_executors is not None: + msg = "Should not specify min or max executors when dynamic allocation is disabled." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.SPARK_JOB, + error_category=ErrorCategory.USER_ERROR, + ) + + +def _validate_compute_or_resources(compute: Any, resources: Any) -> None: + # if resources is set, then ensure it is valid before + # checking mutual exclusiveness against compute existence + if compute is None and resources is None: + msg = "One of either compute or resources must be specified for Spark job" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.SPARK_JOB, + error_category=ErrorCategory.USER_ERROR, + ) + if compute and resources: + msg = "Only one of either compute or resources may be specified for Spark job" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.SPARK_JOB, + error_category=ErrorCategory.USER_ERROR, + ) + + +# Only "direct" mode is supported for spark job inputs and outputs +# pylint: disable=no-else-raise, too-many-boolean-expressions +def _validate_input_output_mode(inputs: Any, outputs: Any) -> None: + for input_name, input_value in inputs.items(): + if isinstance(input_value, Input) and input_value.mode != InputOutputModes.DIRECT: + # For standalone job input + msg = "Input '{}' is using '{}' mode, only '{}' is supported for Spark job" + raise ValidationException( + message=msg.format(input_name, input_value.mode, InputOutputModes.DIRECT), + no_personal_data_message=msg.format("[input_name]", "[input_value.mode]", "direct"), + target=ErrorTarget.SPARK_JOB, + error_category=ErrorCategory.USER_ERROR, + ) + elif ( + isinstance(input_value, NodeInput) + and ( + isinstance(input_value._data, Input) + and not ( + isinstance(input_value._data.path, str) + and bool(re.search(ComponentJobConstants.INPUT_PATTERN, input_value._data.path)) + ) + and input_value._data.mode != InputOutputModes.DIRECT + ) + and (isinstance(input_value._meta, Input) and input_value._meta.mode != InputOutputModes.DIRECT) + ): + # For node input in pipeline job, client side can only validate node input which isn't bound to pipeline + # input or node output. + # 1. If node input is bound to pipeline input, we can't get pipeline level input mode in node level + # validate. Even if we can judge through component input mode (_meta), we should note that pipeline level + # input mode has higher priority than component level. so component input can be set "Mount", but it can + # run successfully when pipeline input is "Direct". + # 2. If node input is bound to last node output, input mode should be decoupled with output mode, so we + # always get None mode in node level. In this case, if we define correct "Direct" mode in component yaml, + # component level mode will take effect and run successfully. Otherwise, it need to set mode in node level + # like input1: path: ${{parent.jobs.sample_word.outputs.output1}} mode: direct. + msg = "Input '{}' is using '{}' mode, only '{}' is supported for Spark job" + raise ValidationException( + message=msg.format( + input_name, input_value._data.mode or input_value._meta.mode, InputOutputModes.DIRECT + ), + no_personal_data_message=msg.format("[input_name]", "[input_value.mode]", "direct"), + target=ErrorTarget.SPARK_JOB, + error_category=ErrorCategory.USER_ERROR, + ) + + for output_name, output_value in outputs.items(): + if ( + isinstance(output_value, Output) + and output_name != "default" + and output_value.mode != InputOutputModes.DIRECT + ): + # For standalone job output + msg = "Output '{}' is using '{}' mode, only '{}' is supported for Spark job" + raise ValidationException( + message=msg.format(output_name, output_value.mode, InputOutputModes.DIRECT), + no_personal_data_message=msg.format("[output_name]", "[output_value.mode]", "direct"), + target=ErrorTarget.SPARK_JOB, + error_category=ErrorCategory.USER_ERROR, + ) + elif ( + isinstance(output_value, NodeOutput) + and output_name != "default" + and ( + isinstance(output_value._data, Output) + and not ( + isinstance(output_value._data.path, str) + and bool(re.search(ComponentJobConstants.OUTPUT_PATTERN, output_value._data.path)) + ) + and output_value._data.mode != InputOutputModes.DIRECT + ) + and (isinstance(output_value._meta, Output) and output_value._meta.mode != InputOutputModes.DIRECT) + ): + # For node output in pipeline job, client side can only validate node output which isn't bound to pipeline + # output. + # 1. If node output is bound to pipeline output, we can't get pipeline level output mode in node level + # validate. Even if we can judge through component output mode (_meta), we should note that pipeline level + # output mode has higher priority than component level. so component output can be set "upload", but it + # can run successfully when pipeline output is "Direct". + msg = "Output '{}' is using '{}' mode, only '{}' is supported for Spark job" + raise ValidationException( + message=msg.format( + output_name, output_value._data.mode or output_value._meta.mode, InputOutputModes.DIRECT + ), + no_personal_data_message=msg.format("[output_name]", "[output_value.mode]", "direct"), + target=ErrorTarget.SPARK_JOB, + error_category=ErrorCategory.USER_ERROR, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/spark_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/spark_job.py new file mode 100644 index 00000000..10930fb4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/spark_job.py @@ -0,0 +1,393 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint: disable=protected-access, too-many-instance-attributes + +import copy +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, Optional, Union + +from marshmallow import INCLUDE + +from azure.ai.ml._restclient.v2023_04_01_preview.models import JobBase +from azure.ai.ml._restclient.v2023_04_01_preview.models import SparkJob as RestSparkJob +from azure.ai.ml._schema.job.identity import AMLTokenIdentitySchema, ManagedIdentitySchema, UserIdentitySchema +from azure.ai.ml._schema.job.parameterized_spark import CONF_KEY_MAP +from azure.ai.ml._schema.job.spark_job import SparkJobSchema +from azure.ai.ml.constants import JobType +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE +from azure.ai.ml.constants._job.job import SparkConfKey +from azure.ai.ml.entities._credentials import ( + AmlTokenConfiguration, + ManagedIdentityConfiguration, + UserIdentityConfiguration, + _BaseJobIdentityConfiguration, +) +from azure.ai.ml.entities._inputs_outputs import Input, Output +from azure.ai.ml.entities._job._input_output_helpers import ( + from_rest_data_outputs, + from_rest_inputs_to_dataset_literal, + to_rest_data_outputs, + to_rest_dataset_literal_inputs, + validate_inputs_for_args, +) +from azure.ai.ml.entities._job.parameterized_spark import ParameterizedSpark +from azure.ai.ml.entities._util import load_from_dict + +from ..._schema import NestedField, UnionField +from .job import Job +from .job_io_mixin import JobIOMixin +from .spark_helpers import _validate_compute_or_resources, _validate_input_output_mode, _validate_spark_configurations +from .spark_job_entry import SparkJobEntry +from .spark_job_entry_mixin import SparkJobEntryMixin +from .spark_resource_configuration import SparkResourceConfiguration + +# avoid circular import error +if TYPE_CHECKING: + from azure.ai.ml.entities import SparkComponent + from azure.ai.ml.entities._builders import Spark + +module_logger = logging.getLogger(__name__) + + +class SparkJob(Job, ParameterizedSpark, JobIOMixin, SparkJobEntryMixin): + """A standalone Spark job. + + :keyword driver_cores: The number of cores to use for the driver process, only in cluster mode. + :paramtype driver_cores: Optional[int] + :keyword driver_memory: The amount of memory to use for the driver process, formatted as strings with a size unit + suffix ("k", "m", "g" or "t") (e.g. "512m", "2g"). + :paramtype driver_memory: Optional[str] + :keyword executor_cores: The number of cores to use on each executor. + :paramtype executor_cores: Optional[int] + :keyword executor_memory: The amount of memory to use per executor process, formatted as strings with a size unit + suffix ("k", "m", "g" or "t") (e.g. "512m", "2g"). + :paramtype executor_memory: Optional[str] + :keyword executor_instances: The initial number of executors. + :paramtype executor_instances: Optional[int] + :keyword dynamic_allocation_enabled: Whether to use dynamic resource allocation, which scales the number of + executors registered with this application up and down based on the workload. + :paramtype dynamic_allocation_enabled: Optional[bool] + :keyword dynamic_allocation_min_executors: The lower bound for the number of executors if dynamic allocation is + enabled. + :paramtype dynamic_allocation_min_executors: Optional[int] + :keyword dynamic_allocation_max_executors: The upper bound for the number of executors if dynamic allocation is + enabled. + :paramtype dynamic_allocation_max_executors: Optional[int] + :keyword inputs: The mapping of input data bindings used in the job. + :paramtype inputs: Optional[dict[str, ~azure.ai.ml.Input]] + :keyword outputs: The mapping of output data bindings used in the job. + :paramtype outputs: Optional[dict[str, ~azure.ai.ml.Output]] + :keyword compute: The compute resource the job runs on. + :paramtype compute: Optional[str] + :keyword identity: The identity that the Spark job will use while running on compute. + :paramtype identity: Optional[Union[dict[str, str], ~azure.ai.ml.ManagedIdentityConfiguration, + ~azure.ai.ml.AmlTokenConfiguration, ~azure.ai.ml.UserIdentityConfiguration]] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_spark_configurations.py + :start-after: [START spark_job_configuration] + :end-before: [END spark_job_configuration] + :language: python + :dedent: 8 + :caption: Configuring a SparkJob. + """ + + def __init__( + self, + *, + driver_cores: Optional[Union[int, str]] = None, + driver_memory: Optional[str] = None, + executor_cores: Optional[Union[int, str]] = None, + executor_memory: Optional[str] = None, + executor_instances: Optional[Union[int, str]] = None, + dynamic_allocation_enabled: Optional[Union[bool, str]] = None, + dynamic_allocation_min_executors: Optional[Union[int, str]] = None, + dynamic_allocation_max_executors: Optional[Union[int, str]] = None, + inputs: Optional[Dict[str, Union[Input, str, bool, int, float]]] = None, + outputs: Optional[Dict[str, Output]] = None, + compute: Optional[str] = None, + identity: Optional[ + Union[Dict[str, str], ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration] + ] = None, + resources: Optional[Union[Dict, SparkResourceConfiguration]] = None, + **kwargs: Any, + ) -> None: + kwargs[TYPE] = JobType.SPARK + + super().__init__(**kwargs) + self.conf: Dict = self.conf or {} + self.properties_sparkJob = self.properties or {} + self.driver_cores = driver_cores + self.driver_memory = driver_memory + self.executor_cores = executor_cores + self.executor_memory = executor_memory + self.executor_instances = executor_instances + self.dynamic_allocation_enabled = dynamic_allocation_enabled + self.dynamic_allocation_min_executors = dynamic_allocation_min_executors + self.dynamic_allocation_max_executors = dynamic_allocation_max_executors + self.inputs = inputs # type: ignore[assignment] + self.outputs = outputs # type: ignore[assignment] + self.compute = compute + self.resources = resources + self.identity = identity + if self.executor_instances is None and str(self.dynamic_allocation_enabled).lower() == "true": + self.executor_instances = self.dynamic_allocation_min_executors + + @property + def resources(self) -> Optional[Union[Dict, SparkResourceConfiguration]]: + """The compute resource configuration for the job. + + :return: The compute resource configuration for the job. + :rtype: Optional[~azure.ai.ml.entities.SparkResourceConfiguration] + """ + return self._resources + + @resources.setter + def resources(self, value: Optional[Union[Dict[str, str], SparkResourceConfiguration]]) -> None: + """Sets the compute resource configuration for the job. + + :param value: The compute resource configuration for the job. + :type value: Optional[Union[dict[str, str], ~azure.ai.ml.entities.SparkResourceConfiguration]] + """ + if isinstance(value, dict): + value = SparkResourceConfiguration(**value) + self._resources = value + + @property + def identity( + self, + ) -> Optional[Union[Dict, ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration]]: + """The identity that the Spark job will use while running on compute. + + :return: The identity that the Spark job will use while running on compute. + :rtype: Optional[Union[~azure.ai.ml.ManagedIdentityConfiguration, ~azure.ai.ml.AmlTokenConfiguration, + ~azure.ai.ml.UserIdentityConfiguration]] + """ + return self._identity + + @identity.setter + def identity( + self, + value: Optional[ + Union[Dict[str, str], ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration] + ], + ) -> None: + """Sets the identity that the Spark job will use while running on compute. + + :param value: The identity that the Spark job will use while running on compute. + :type value: Optional[Union[dict[str, str], ~azure.ai.ml.ManagedIdentityConfiguration, + ~azure.ai.ml.AmlTokenConfiguration, ~azure.ai.ml.UserIdentityConfiguration]] + """ + if isinstance(value, dict): + identify_schema = UnionField( + [ + NestedField(ManagedIdentitySchema, unknown=INCLUDE), + NestedField(AMLTokenIdentitySchema, unknown=INCLUDE), + NestedField(UserIdentitySchema, unknown=INCLUDE), + ] + ) + value = identify_schema._deserialize(value=value, attr=None, data=None) + self._identity = value + + def _to_dict(self) -> Dict: + res: dict = SparkJobSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + def filter_conf_fields(self) -> Dict[str, str]: + """Filters out the fields of the conf attribute that are not among the Spark configuration fields + listed in ~azure.ai.ml._schema.job.parameterized_spark.CONF_KEY_MAP and returns them in their own dictionary. + + :return: A dictionary of the conf fields that are not Spark configuration fields. + :rtype: dict[str, str] + """ + if self.conf is None: + return {} + data_conf = {} + for conf_key, conf_val in self.conf.items(): + if not conf_key in CONF_KEY_MAP: + data_conf[conf_key] = conf_val + return data_conf + + def _to_rest_object(self) -> JobBase: + self._validate() + conf = { + **(self.filter_conf_fields()), + "spark.driver.cores": self.driver_cores, + "spark.driver.memory": self.driver_memory, + "spark.executor.cores": self.executor_cores, + "spark.executor.memory": self.executor_memory, + } + if self.dynamic_allocation_enabled in ["True", "true", True]: + conf["spark.dynamicAllocation.enabled"] = True + conf["spark.dynamicAllocation.minExecutors"] = self.dynamic_allocation_min_executors + conf["spark.dynamicAllocation.maxExecutors"] = self.dynamic_allocation_max_executors + if self.executor_instances is not None: + conf["spark.executor.instances"] = self.executor_instances + + properties = RestSparkJob( + experiment_name=self.experiment_name, + display_name=self.display_name, + description=self.description, + tags=self.tags, + code_id=self.code, + entry=self.entry._to_rest_object() if self.entry is not None and not isinstance(self.entry, dict) else None, + py_files=self.py_files, + jars=self.jars, + files=self.files, + archives=self.archives, + identity=( + self.identity._to_job_rest_object() if self.identity and not isinstance(self.identity, dict) else None + ), + conf=conf, + properties=self.properties_sparkJob, + environment_id=self.environment, + inputs=to_rest_dataset_literal_inputs(self.inputs, job_type=self.type), + outputs=to_rest_data_outputs(self.outputs), + args=self.args, + compute_id=self.compute, + resources=( + self.resources._to_rest_object() if self.resources and not isinstance(self.resources, Dict) else None + ), + ) + result = JobBase(properties=properties) + result.name = self.name + return result + + @classmethod + def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "SparkJob": + loaded_data = load_from_dict(SparkJobSchema, data, context, additional_message, **kwargs) + return SparkJob(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_data) + + @classmethod + def _load_from_rest(cls, obj: JobBase) -> "SparkJob": + rest_spark_job: RestSparkJob = obj.properties + rest_spark_conf = copy.copy(rest_spark_job.conf) or {} + spark_job = SparkJob( + name=obj.name, + entry=SparkJobEntry._from_rest_object(rest_spark_job.entry), + experiment_name=rest_spark_job.experiment_name, + id=obj.id, + display_name=rest_spark_job.display_name, + description=rest_spark_job.description, + tags=rest_spark_job.tags, + properties=rest_spark_job.properties, + services=rest_spark_job.services, + status=rest_spark_job.status, + creation_context=obj.system_data, + code=rest_spark_job.code_id, + compute=rest_spark_job.compute_id, + environment=rest_spark_job.environment_id, + identity=( + _BaseJobIdentityConfiguration._from_rest_object(rest_spark_job.identity) + if rest_spark_job.identity + else None + ), + args=rest_spark_job.args, + conf=rest_spark_conf, + driver_cores=rest_spark_conf.get( + SparkConfKey.DRIVER_CORES, None + ), # copy fields from conf into the promote attribute in spark + driver_memory=rest_spark_conf.get(SparkConfKey.DRIVER_MEMORY, None), + executor_cores=rest_spark_conf.get(SparkConfKey.EXECUTOR_CORES, None), + executor_memory=rest_spark_conf.get(SparkConfKey.EXECUTOR_MEMORY, None), + executor_instances=rest_spark_conf.get(SparkConfKey.EXECUTOR_INSTANCES, None), + dynamic_allocation_enabled=rest_spark_conf.get(SparkConfKey.DYNAMIC_ALLOCATION_ENABLED, None), + dynamic_allocation_min_executors=rest_spark_conf.get(SparkConfKey.DYNAMIC_ALLOCATION_MIN_EXECUTORS, None), + dynamic_allocation_max_executors=rest_spark_conf.get(SparkConfKey.DYNAMIC_ALLOCATION_MAX_EXECUTORS, None), + resources=SparkResourceConfiguration._from_rest_object(rest_spark_job.resources), + inputs=from_rest_inputs_to_dataset_literal(rest_spark_job.inputs), + outputs=from_rest_data_outputs(rest_spark_job.outputs), + ) + return spark_job + + def _to_component(self, context: Optional[Dict] = None, **kwargs: Any) -> "SparkComponent": + """Translate a spark job to component. + + :param context: Context of spark job YAML file. + :type context: dict + :return: Translated spark component. + :rtype: SparkComponent + """ + from azure.ai.ml.entities import SparkComponent + + pipeline_job_dict = kwargs.get("pipeline_job_dict", {}) + context = context or {BASE_PATH_CONTEXT_KEY: Path("./")} + + # Create anonymous spark component with default version as 1 + return SparkComponent( + tags=self.tags, + is_anonymous=True, + base_path=context[BASE_PATH_CONTEXT_KEY], + description=self.description, + code=self.code, + entry=self.entry, + py_files=self.py_files, + jars=self.jars, + files=self.files, + archives=self.archives, + driver_cores=self.driver_cores, + driver_memory=self.driver_memory, + executor_cores=self.executor_cores, + executor_memory=self.executor_memory, + executor_instances=self.executor_instances, + dynamic_allocation_enabled=self.dynamic_allocation_enabled, + dynamic_allocation_min_executors=self.dynamic_allocation_min_executors, + dynamic_allocation_max_executors=self.dynamic_allocation_max_executors, + conf=self.conf, + properties=self.properties_sparkJob, + environment=self.environment, + inputs=self._to_inputs(inputs=self.inputs, pipeline_job_dict=pipeline_job_dict), + outputs=self._to_outputs(outputs=self.outputs, pipeline_job_dict=pipeline_job_dict), + args=self.args, + ) + + def _to_node(self, context: Optional[Dict] = None, **kwargs: Any) -> "Spark": + """Translate a spark job to a pipeline node. + + :param context: Context of spark job YAML file. + :type context: dict + :return: Translated spark component. + :rtype: Spark + """ + from azure.ai.ml.entities._builders import Spark + + component = self._to_component(context, **kwargs) + + return Spark( + display_name=self.display_name, + description=self.description, + tags=self.tags, + # code, entry, py_files, jars, files, archives, environment and args are static and not allowed to be + # overwritten. And we will always get them from component. + component=component, + identity=self.identity, + driver_cores=self.driver_cores, + driver_memory=self.driver_memory, + executor_cores=self.executor_cores, + executor_memory=self.executor_memory, + executor_instances=self.executor_instances, + dynamic_allocation_enabled=self.dynamic_allocation_enabled, + dynamic_allocation_min_executors=self.dynamic_allocation_min_executors, + dynamic_allocation_max_executors=self.dynamic_allocation_max_executors, + conf=self.conf, + inputs=self.inputs, # type: ignore[arg-type] + outputs=self.outputs, # type: ignore[arg-type] + compute=self.compute, + resources=self.resources, + properties=self.properties_sparkJob, + ) + + def _validate(self) -> None: + # TODO: make spark job schema validatable? + if self.resources and not isinstance(self.resources, Dict): + self.resources._validate() + _validate_compute_or_resources(self.compute, self.resources) + _validate_input_output_mode(self.inputs, self.outputs) + _validate_spark_configurations(self) + self._validate_entry() + + if self.args: + validate_inputs_for_args(self.args, self.inputs) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/spark_job_entry.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/spark_job_entry.py new file mode 100644 index 00000000..ed8d3ca7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/spark_job_entry.py @@ -0,0 +1,59 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint: disable=redefined-builtin + +from typing import Optional, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import SparkJobEntry as RestSparkJobEntry +from azure.ai.ml._restclient.v2023_04_01_preview.models import SparkJobPythonEntry, SparkJobScalaEntry +from azure.ai.ml.entities._mixins import RestTranslatableMixin + + +class SparkJobEntryType: + """Type of Spark job entry. Possibilities are Python file entry or Scala class entry.""" + + SPARK_JOB_FILE_ENTRY = "SparkJobPythonEntry" + SPARK_JOB_CLASS_ENTRY = "SparkJobScalaEntry" + + +class SparkJobEntry(RestTranslatableMixin): + """Entry for Spark job. + + :keyword entry: The file or class entry point. + :paramtype entry: str + :keyword type: The entry type. Accepted values are SparkJobEntryType.SPARK_JOB_FILE_ENTRY or + SparkJobEntryType.SPARK_JOB_CLASS_ENTRY. Defaults to SparkJobEntryType.SPARK_JOB_FILE_ENTRY. + :paramtype type: ~azure.ai.ml.entities.SparkJobEntryType + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_spark_configurations.py + :start-after: [START spark_component_definition] + :end-before: [END spark_component_definition] + :language: python + :dedent: 8 + :caption: Creating SparkComponent. + """ + + def __init__(self, *, entry: str, type: str = SparkJobEntryType.SPARK_JOB_FILE_ENTRY) -> None: + self.entry_type = type + self.entry = entry + + @classmethod + def _from_rest_object(cls, obj: Union[SparkJobPythonEntry, SparkJobScalaEntry]) -> Optional["SparkJobEntry"]: + if obj is None: + return None + if isinstance(obj, dict): + obj = RestSparkJobEntry.from_dict(obj) + if obj.spark_job_entry_type == SparkJobEntryType.SPARK_JOB_FILE_ENTRY: + return SparkJobEntry( + entry=obj.__dict__.get("file", None), + type=SparkJobEntryType.SPARK_JOB_FILE_ENTRY, + ) + return SparkJobEntry(entry=obj.class_name, type=SparkJobEntryType.SPARK_JOB_CLASS_ENTRY) + + def _to_rest_object(self) -> Union[SparkJobPythonEntry, SparkJobScalaEntry]: + if self.entry_type == SparkJobEntryType.SPARK_JOB_FILE_ENTRY: + return SparkJobPythonEntry(file=self.entry) + return SparkJobScalaEntry(class_name=self.entry) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/spark_job_entry_mixin.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/spark_job_entry_mixin.py new file mode 100644 index 00000000..2a1ff549 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/spark_job_entry_mixin.py @@ -0,0 +1,64 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import re +from typing import Any, Dict, Optional, Union, cast + +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException + +from .spark_job_entry import SparkJobEntry, SparkJobEntryType + + +class SparkJobEntryMixin: + CODE_ID_RE_PATTERN = re.compile( + ( + r"\/subscriptions\/(?P<subscription>[\w,-]+)\/resourceGroups\/(?P<resource_group>[\w,-]+)" + r"\/providers\/Microsoft\.MachineLearningServices\/workspaces\/(?P<workspace>[\w,-]+)" + r"\/codes\/(?P<code_id>[\w,-]+)" # fmt: skip + ) + ) + + def __init__(self, **kwargs: Any): + self._entry = None + self.entry = kwargs.get("entry", None) + + @property + def entry(self) -> Optional[Union[Dict[str, str], SparkJobEntry]]: + return self._entry + + @entry.setter + def entry(self, value: Optional[Union[Dict[str, str], SparkJobEntry]]) -> None: + if isinstance(value, dict): + if value.get("file", None): + _entry = cast(str, value.get("file")) + self._entry = SparkJobEntry(entry=_entry, type=SparkJobEntryType.SPARK_JOB_FILE_ENTRY) + return + if value.get("class_name", None): + _entry = cast(str, value.get("class_name")) + self._entry = SparkJobEntry(entry=_entry, type=SparkJobEntryType.SPARK_JOB_CLASS_ENTRY) + return + self._entry = value + + def _validate_entry(self) -> None: + if self.entry is None: + # Entry is a required field for local component and when we load a remote job, component now is an arm_id, + # entry is from node level returned from service. Entry is only None when we reference an existing + # component with a function and the referenced component is in remote with name and version. + return + if not isinstance(self.entry, SparkJobEntry): + msg = f"Unsupported type {type(self.entry)} detected when validate entry, entry should be SparkJobEntry." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.SPARK_JOB, + error_category=ErrorCategory.USER_ERROR, + ) + if self.entry.entry_type == SparkJobEntryType.SPARK_JOB_CLASS_ENTRY: + msg = "Classpath is not supported, please use 'file' to define the entry file." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.SPARK_JOB, + error_category=ErrorCategory.USER_ERROR, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/spark_resource_configuration.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/spark_resource_configuration.py new file mode 100644 index 00000000..138fc7ed --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/spark_resource_configuration.py @@ -0,0 +1,91 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Optional, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + SparkResourceConfiguration as RestSparkResourceConfiguration, +) +from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException + + +class SparkResourceConfiguration(RestTranslatableMixin, DictMixin): + """Compute resource configuration for Spark component or job. + + :keyword instance_type: The type of VM to be used by the compute target. + :paramtype instance_type: Optional[str] + :keyword runtime_version: The Spark runtime version. + :paramtype runtime_version: Optional[str] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_spark_configurations.py + :start-after: [START spark_resource_configuration] + :end-before: [END spark_resource_configuration] + :language: python + :dedent: 8 + :caption: Configuring a SparkJob with SparkResourceConfiguration. + """ + + instance_type_list = [ + "standard_e4s_v3", + "standard_e8s_v3", + "standard_e16s_v3", + "standard_e32s_v3", + "standard_e64s_v3", + ] + + def __init__(self, *, instance_type: Optional[str] = None, runtime_version: Optional[str] = None) -> None: + self.instance_type = instance_type + self.runtime_version = runtime_version + + def _to_rest_object(self) -> RestSparkResourceConfiguration: + return RestSparkResourceConfiguration(instance_type=self.instance_type, runtime_version=self.runtime_version) + + @classmethod + def _from_rest_object( + cls, obj: Union[dict, None, RestSparkResourceConfiguration] + ) -> Optional["SparkResourceConfiguration"]: + if obj is None: + return None + if isinstance(obj, dict): + return SparkResourceConfiguration(**obj) + return SparkResourceConfiguration(instance_type=obj.instance_type, runtime_version=obj.runtime_version) + + def _validate(self) -> None: + # TODO: below logic is duplicated to SparkResourceConfigurationSchema, maybe make SparkJob schema validatable + if self.instance_type is None or self.instance_type == "": + msg = "Instance type must be specified for SparkResourceConfiguration" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.SPARK_JOB, + error_category=ErrorCategory.USER_ERROR, + ) + if self.instance_type.lower() not in self.instance_type_list: + msg = "Instance type must be specified for the list of {}".format(",".join(self.instance_type_list)) + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.SPARK_JOB, + error_category=ErrorCategory.USER_ERROR, + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SparkResourceConfiguration): + return NotImplemented + return self.instance_type == other.instance_type and self.runtime_version == other.runtime_version + + def __ne__(self, other: object) -> bool: + if not isinstance(other, SparkResourceConfiguration): + return NotImplemented + return not self.__eq__(other) + + def _merge_with(self, other: "SparkResourceConfiguration") -> None: + if other: + if other.instance_type: + self.instance_type = other.instance_type + if other.runtime_version: + self.runtime_version = other.runtime_version diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/early_termination_policy.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/early_termination_policy.py new file mode 100644 index 00000000..b1b928fc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/early_termination_policy.py @@ -0,0 +1,191 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from abc import ABC +from typing import Any, Optional, cast + +from azure.ai.ml._restclient.v2023_04_01_preview.models import BanditPolicy as RestBanditPolicy +from azure.ai.ml._restclient.v2023_04_01_preview.models import EarlyTerminationPolicy as RestEarlyTerminationPolicy +from azure.ai.ml._restclient.v2023_04_01_preview.models import EarlyTerminationPolicyType +from azure.ai.ml._restclient.v2023_04_01_preview.models import MedianStoppingPolicy as RestMedianStoppingPolicy +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + TruncationSelectionPolicy as RestTruncationSelectionPolicy, +) +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.entities._mixins import RestTranslatableMixin + + +class EarlyTerminationPolicy(ABC, RestTranslatableMixin): + def __init__( + self, + *, + delay_evaluation: int, + evaluation_interval: int, + ): + self.type = None + self.delay_evaluation = delay_evaluation + self.evaluation_interval = evaluation_interval + + @classmethod + def _from_rest_object(cls, obj: RestEarlyTerminationPolicy) -> Optional["EarlyTerminationPolicy"]: + if not obj: + return None + + policy: Any = None + if obj.policy_type == EarlyTerminationPolicyType.BANDIT: + policy = BanditPolicy._from_rest_object(obj) # pylint: disable=protected-access + + if obj.policy_type == EarlyTerminationPolicyType.MEDIAN_STOPPING: + policy = MedianStoppingPolicy._from_rest_object(obj) # pylint: disable=protected-access + + if obj.policy_type == EarlyTerminationPolicyType.TRUNCATION_SELECTION: + policy = TruncationSelectionPolicy._from_rest_object(obj) # pylint: disable=protected-access + + return cast(Optional["EarlyTerminationPolicy"], policy) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, EarlyTerminationPolicy): + raise NotImplementedError + res: bool = self._to_rest_object() == other._to_rest_object() + return res + + +class BanditPolicy(EarlyTerminationPolicy): + """Defines an early termination policy based on slack criteria and a frequency and delay interval for evaluation. + + :keyword delay_evaluation: Number of intervals by which to delay the first evaluation. Defaults to 0. + :paramtype delay_evaluation: int + :keyword evaluation_interval: Interval (number of runs) between policy evaluations. Defaults to 0. + :paramtype evaluation_interval: int + :keyword slack_amount: Absolute distance allowed from the best performing run. Defaults to 0. + :paramtype slack_amount: float + :keyword slack_factor: Ratio of the allowed distance from the best performing run. Defaults to 0. + :paramtype slack_factor: float + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_sweep_configurations.py + :start-after: [START configure_sweep_job_bandit_policy] + :end-before: [END configure_sweep_job_bandit_policy] + :language: python + :dedent: 8 + :caption: Configuring BanditPolicy early termination of a hyperparameter sweep on a Command job. + """ + + def __init__( + self, + *, + delay_evaluation: int = 0, + evaluation_interval: int = 0, + slack_amount: float = 0, + slack_factor: float = 0, + ) -> None: + super().__init__(delay_evaluation=delay_evaluation, evaluation_interval=evaluation_interval) + self.type = EarlyTerminationPolicyType.BANDIT.lower() + self.slack_factor = slack_factor + self.slack_amount = slack_amount + + def _to_rest_object(self) -> RestBanditPolicy: + return RestBanditPolicy( + delay_evaluation=self.delay_evaluation, + evaluation_interval=self.evaluation_interval, + slack_factor=self.slack_factor, + slack_amount=self.slack_amount, + ) + + @classmethod + def _from_rest_object(cls, obj: RestBanditPolicy) -> "BanditPolicy": + return cls( + delay_evaluation=obj.delay_evaluation, + evaluation_interval=obj.evaluation_interval, + slack_factor=obj.slack_factor, + slack_amount=obj.slack_amount, + ) + + +class MedianStoppingPolicy(EarlyTerminationPolicy): + """Defines an early termination policy based on a running average of the primary metric of all runs. + + :keyword delay_evaluation: Number of intervals by which to delay the first evaluation. Defaults to 0. + :paramtype delay_evaluation: int + :keyword evaluation_interval: Interval (number of runs) between policy evaluations. Defaults to 1. + :paramtype evaluation_interval: int + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_sweep_configurations.py + :start-after: [START configure_sweep_job_median_stopping_policy] + :end-before: [END configure_sweep_job_median_stopping_policy] + :language: python + :dedent: 8 + :caption: Configuring an early termination policy for a hyperparameter sweep job using MedianStoppingPolicy + """ + + def __init__( + self, + *, + delay_evaluation: int = 0, + evaluation_interval: int = 1, + ) -> None: + super().__init__(delay_evaluation=delay_evaluation, evaluation_interval=evaluation_interval) + self.type = camel_to_snake(EarlyTerminationPolicyType.MEDIAN_STOPPING) + + def _to_rest_object(self) -> RestMedianStoppingPolicy: + return RestMedianStoppingPolicy( + delay_evaluation=self.delay_evaluation, evaluation_interval=self.evaluation_interval + ) + + @classmethod + def _from_rest_object(cls, obj: RestMedianStoppingPolicy) -> "MedianStoppingPolicy": + return cls( + delay_evaluation=obj.delay_evaluation, + evaluation_interval=obj.evaluation_interval, + ) + + +class TruncationSelectionPolicy(EarlyTerminationPolicy): + """Defines an early termination policy that cancels a given percentage of runs at each evaluation interval. + + :keyword delay_evaluation: Number of intervals by which to delay the first evaluation. Defaults to 0. + :paramtype delay_evaluation: int + :keyword evaluation_interval: Interval (number of runs) between policy evaluations. Defaults to 0. + :paramtype evaluation_interval: int + :keyword truncation_percentage: The percentage of runs to cancel at each evaluation interval. Defaults to 0. + :paramtype truncation_percentage: int + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_sweep_configurations.py + :start-after: [START configure_sweep_job_truncation_selection_policy] + :end-before: [END configure_sweep_job_truncation_selection_policy] + :language: python + :dedent: 8 + :caption: Configuring an early termination policy for a hyperparameter sweep job + using TruncationStoppingPolicy + """ + + def __init__( + self, + *, + delay_evaluation: int = 0, + evaluation_interval: int = 0, + truncation_percentage: int = 0, + ) -> None: + super().__init__(delay_evaluation=delay_evaluation, evaluation_interval=evaluation_interval) + self.type = camel_to_snake(EarlyTerminationPolicyType.TRUNCATION_SELECTION) + self.truncation_percentage = truncation_percentage + + def _to_rest_object(self) -> RestTruncationSelectionPolicy: + return RestTruncationSelectionPolicy( + delay_evaluation=self.delay_evaluation, + evaluation_interval=self.evaluation_interval, + truncation_percentage=self.truncation_percentage, + ) + + @classmethod + def _from_rest_object(cls, obj: RestTruncationSelectionPolicy) -> "TruncationSelectionPolicy": + return cls( + delay_evaluation=obj.delay_evaluation, + evaluation_interval=obj.evaluation_interval, + truncation_percentage=obj.truncation_percentage, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/objective.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/objective.py new file mode 100644 index 00000000..45e13332 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/objective.py @@ -0,0 +1,53 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from typing import Optional + +from azure.ai.ml._restclient.v2023_08_01_preview.models import Objective as RestObjective +from azure.ai.ml.entities._mixins import RestTranslatableMixin + + +class Objective(RestTranslatableMixin): + """Optimization objective. + + :param goal: Defines supported metric goals for hyperparameter tuning. Accepted values + are: "minimize", "maximize". + :type goal: str + :param primary_metric: The name of the metric to optimize. + :type primary_metric: str + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_sweep_configurations.py + :start-after: [START configure_sweep_job_bayesian_sampling_algorithm] + :end-before: [END configure_sweep_job_bayesian_sampling_algorithm] + :language: python + :dedent: 8 + :caption: Assigning an objective to a SweepJob. + """ + + def __init__(self, goal: Optional[str], primary_metric: Optional[str] = None) -> None: + """Optimization objective. + + :param goal: Defines supported metric goals for hyperparameter tuning. Acceptable values + are: "minimize" or "maximize". + :type goal: str + :param primary_metric: The name of the metric to optimize. + :type primary_metric: str + """ + if goal is not None: + self.goal = goal.lower() + self.primary_metric = primary_metric + + def _to_rest_object(self) -> RestObjective: + return RestObjective( + goal=self.goal, + primary_metric=self.primary_metric, + ) + + @classmethod + def _from_rest_object(cls, obj: RestObjective) -> Optional["Objective"]: + if not obj: + return None + + return cls(goal=obj.goal, primary_metric=obj.primary_metric) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/parameterized_sweep.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/parameterized_sweep.py new file mode 100644 index 00000000..5d69201f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/parameterized_sweep.py @@ -0,0 +1,341 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from typing import Any, Dict, List, Optional, Type, Union + +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + +from ..job_limits import SweepJobLimits +from ..job_resource_configuration import JobResourceConfiguration +from ..queue_settings import QueueSettings +from .early_termination_policy import ( + BanditPolicy, + EarlyTerminationPolicy, + EarlyTerminationPolicyType, + MedianStoppingPolicy, + TruncationSelectionPolicy, +) +from .objective import Objective +from .sampling_algorithm import ( + BayesianSamplingAlgorithm, + GridSamplingAlgorithm, + RandomSamplingAlgorithm, + RestBayesianSamplingAlgorithm, + RestGridSamplingAlgorithm, + RestRandomSamplingAlgorithm, + RestSamplingAlgorithm, + SamplingAlgorithm, + SamplingAlgorithmType, +) + +SAMPLING_ALGORITHM_TO_REST_CONSTRUCTOR: Dict[SamplingAlgorithmType, Type[RestSamplingAlgorithm]] = { + SamplingAlgorithmType.RANDOM: RestRandomSamplingAlgorithm, + SamplingAlgorithmType.GRID: RestGridSamplingAlgorithm, + SamplingAlgorithmType.BAYESIAN: RestBayesianSamplingAlgorithm, +} + +SAMPLING_ALGORITHM_CONSTRUCTOR: Dict[SamplingAlgorithmType, Type[SamplingAlgorithm]] = { + SamplingAlgorithmType.RANDOM: RandomSamplingAlgorithm, + SamplingAlgorithmType.GRID: GridSamplingAlgorithm, + SamplingAlgorithmType.BAYESIAN: BayesianSamplingAlgorithm, +} + + +class ParameterizedSweep: # pylint:disable=too-many-instance-attributes + """Shared logic for standalone and pipeline sweep job.""" + + def __init__( + self, + limits: Optional[SweepJobLimits] = None, + sampling_algorithm: Optional[Union[str, SamplingAlgorithm]] = None, + objective: Optional[Union[Dict, Objective]] = None, + early_termination: Optional[Any] = None, + search_space: Optional[Dict] = None, + queue_settings: Optional[QueueSettings] = None, + resources: Optional[Union[dict, JobResourceConfiguration]] = None, + ) -> None: + """ + :param limits: Limits for sweep job. + :type limits: ~azure.ai.ml.sweep.SweepJobLimits + :param sampling_algorithm: Sampling algorithm for sweep job. + :type sampling_algorithm: ~azure.ai.ml.sweep.SamplingAlgorithm + :param objective: Objective for sweep job. + :type objective: ~azure.ai.ml.sweep.Objective + :param early_termination: Early termination policy for sweep job. + :type early_termination: ~azure.ai.ml.entities._job.sweep.early_termination_policy.EarlyTerminationPolicy + :param search_space: Search space for sweep job. + :type search_space: Dict[str, Union[ + ~azure.ai.ml.sweep.Choice, + ~azure.ai.ml.sweep.LogNormal, + ~azure.ai.ml.sweep.LogUniform, + ~azure.ai.ml.sweep.Normal, + ~azure.ai.ml.sweep.QLogNormal, + ~azure.ai.ml.sweep.QLogUniform, + ~azure.ai.ml.sweep.QNormal, + ~azure.ai.ml.sweep.QUniform, + ~azure.ai.ml.sweep.Randint, + ~azure.ai.ml.sweep.Uniform + + ]] + :param queue_settings: Queue settings for sweep job. + :type queue_settings: ~azure.ai.ml.entities.QueueSettings + :param resources: Compute Resource configuration for the job. + :type resources: ~azure.ai.ml.entities.ResourceConfiguration + """ + self.sampling_algorithm = sampling_algorithm + self.early_termination = early_termination # type: ignore[assignment] + self._limits = limits + self.search_space = search_space + self.queue_settings = queue_settings + self.objective: Optional[Objective] = None + self.resources = resources + + if isinstance(objective, Dict): + self.objective = Objective(**objective) + else: + self.objective = objective + + @property + def resources(self) -> Optional[Union[dict, JobResourceConfiguration]]: + """Resources for sweep job. + + :returns: Resources for sweep job. + :rtype: ~azure.ai.ml.entities.ResourceConfiguration + """ + return self._resources + + @resources.setter + def resources(self, value: Optional[Union[dict, JobResourceConfiguration]]) -> None: + """Set Resources for sweep job. + + :param value: Compute Resource configuration for the job. + :type value: ~azure.ai.ml.entities.ResourceConfiguration + """ + if isinstance(value, dict): + value = JobResourceConfiguration(**value) + self._resources = value + + @property + def limits(self) -> Optional[SweepJobLimits]: + """Limits for sweep job. + + :returns: Limits for sweep job. + :rtype: ~azure.ai.ml.sweep.SweepJobLimits + """ + return self._limits + + @limits.setter + def limits(self, value: SweepJobLimits) -> None: + """Set limits for sweep job. + + :param value: Limits for sweep job. + :type value: ~azure.ai.ml.sweep.SweepJobLimits + """ + if not isinstance(value, SweepJobLimits): + msg = f"limits must be SweepJobLimits but get {type(value)} instead" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.SWEEP_JOB, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + self._limits = value + + def set_resources( + self, + *, + instance_type: Optional[Union[str, List[str]]] = None, + instance_count: Optional[int] = None, + locations: Optional[List[str]] = None, + properties: Optional[Dict] = None, + docker_args: Optional[str] = None, + shm_size: Optional[str] = None, + ) -> None: + """Set resources for Sweep. + + :keyword instance_type: The instance type to use for the job. + :paramtype instance_type: Optional[Union[str, List[str]]] + :keyword instance_count: The number of instances to use for the job. + :paramtype instance_count: Optional[int] + :keyword locations: The locations to use for the job. + :paramtype locations: Optional[List[str]] + :keyword properties: The properties for the job. + :paramtype properties: Optional[Dict] + :keyword docker_args: The docker arguments for the job. + :paramtype docker_args: Optional[str] + :keyword shm_size: The shared memory size for the job. + :paramtype shm_size: Optional[str] + """ + if self.resources is None: + self.resources = JobResourceConfiguration() + + if not isinstance(self.resources, dict): + if locations is not None: + self.resources.locations = locations + if instance_type is not None: + self.resources.instance_type = instance_type + if instance_count is not None: + self.resources.instance_count = instance_count + if properties is not None: + self.resources.properties = properties + if docker_args is not None: + self.resources.docker_args = docker_args + if shm_size is not None: + self.resources.shm_size = shm_size + + def set_limits( + self, + *, + max_concurrent_trials: Optional[int] = None, + max_total_trials: Optional[int] = None, + timeout: Optional[int] = None, + trial_timeout: Optional[int] = None, + ) -> None: + """Set limits for Sweep node. Leave parameters as None if you don't want to update corresponding values. + + :keyword max_concurrent_trials: maximum concurrent trial number. + :paramtype max_concurrent_trials: int + :keyword max_total_trials: maximum total trial number. + :paramtype max_total_trials: int + :keyword timeout: total timeout in seconds for sweep node + :paramtype timeout: int + :keyword trial_timeout: timeout in seconds for each trial + :paramtype trial_timeout: int + """ + # Looks related to https://github.com/pylint-dev/pylint/issues/3502, still an open issue + # pylint:disable=attribute-defined-outside-init + if self._limits is None: + self._limits = SweepJobLimits( + max_concurrent_trials=max_concurrent_trials, + max_total_trials=max_total_trials, + timeout=timeout, + trial_timeout=trial_timeout, + ) + else: + if self.limits is not None: + if max_concurrent_trials is not None: + self.limits.max_concurrent_trials = max_concurrent_trials + if max_total_trials is not None: + self.limits.max_total_trials = max_total_trials + if timeout is not None: + self.limits.timeout = timeout + if trial_timeout is not None: + self.limits.trial_timeout = trial_timeout + + def set_objective(self, *, goal: Optional[str] = None, primary_metric: Optional[str] = None) -> None: + """Set the sweep object.. Leave parameters as None if you don't want to update corresponding values. + + :keyword goal: Defines supported metric goals for hyperparameter tuning. Acceptable values are: + "minimize" and "maximize". + :paramtype goal: str + :keyword primary_metric: Name of the metric to optimize. + :paramtype primary_metric: str + """ + + if self.objective is not None: + if goal: + self.objective.goal = goal + if primary_metric: + self.objective.primary_metric = primary_metric + else: + self.objective = Objective(goal=goal, primary_metric=primary_metric) + + @property + def sampling_algorithm(self) -> Optional[Union[str, SamplingAlgorithm]]: + """Sampling algorithm for sweep job. + + :returns: Sampling algorithm for sweep job. + :rtype: ~azure.ai.ml.sweep.SamplingAlgorithm + """ + return self._sampling_algorithm + + @sampling_algorithm.setter + def sampling_algorithm(self, value: Optional[Union[SamplingAlgorithm, str]] = None) -> None: + """Set sampling algorithm for sweep job. + + :param value: Sampling algorithm for sweep job. + :type value: ~azure.ai.ml.sweep.SamplingAlgorithm + """ + if value is None: + self._sampling_algorithm = None + elif isinstance(value, SamplingAlgorithm) or ( + isinstance(value, str) and value.lower().capitalize() in SAMPLING_ALGORITHM_CONSTRUCTOR + ): + self._sampling_algorithm = value + else: + msg = f"unsupported sampling algorithm: {value}" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.SWEEP_JOB, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + def _get_rest_sampling_algorithm(self) -> RestSamplingAlgorithm: + # TODO: self.sampling_algorithm will always return SamplingAlgorithm + if isinstance(self.sampling_algorithm, SamplingAlgorithm): + return self.sampling_algorithm._to_rest_object() # pylint: disable=protected-access + + if isinstance(self.sampling_algorithm, str): + return SAMPLING_ALGORITHM_CONSTRUCTOR[ # pylint: disable=protected-access + SamplingAlgorithmType(self.sampling_algorithm.lower().capitalize()) + ]()._to_rest_object() + + msg = f"Received unsupported value {self._sampling_algorithm} as the sampling algorithm" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.SWEEP_JOB, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + @property + def early_termination(self) -> Optional[Union[str, EarlyTerminationPolicy]]: + """Early termination policy for sweep job. + + :returns: Early termination policy for sweep job. + :rtype: ~azure.ai.ml.entities._job.sweep.early_termination_policy.EarlyTerminationPolicy + """ + return self._early_termination + + @early_termination.setter + def early_termination(self, value: Any) -> None: + """Set early termination policy for sweep job. + + :param value: Early termination policy for sweep job. + :type value: ~azure.ai.ml.entities._job.sweep.early_termination_policy.EarlyTerminationPolicy + """ + self._early_termination: Optional[Union[str, EarlyTerminationPolicy]] + if value is None: + self._early_termination = None + elif isinstance(value, EarlyTerminationPolicy): + self._early_termination = value + elif isinstance(value, str): + value = value.lower().capitalize() + if value == EarlyTerminationPolicyType.BANDIT: + self._early_termination = BanditPolicy() + elif value == EarlyTerminationPolicyType.MEDIAN_STOPPING: + self._early_termination = MedianStoppingPolicy() + elif value == EarlyTerminationPolicyType.TRUNCATION_SELECTION: + self._early_termination = TruncationSelectionPolicy() + else: + msg = f"Received unsupported value {value} as the early termination policy" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.SWEEP_JOB, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + else: + msg = f"Received unsupported value of type {type(value)} as the early termination policy" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.SWEEP_JOB, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/sampling_algorithm.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/sampling_algorithm.py new file mode 100644 index 00000000..d0bf795d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/sampling_algorithm.py @@ -0,0 +1,141 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from abc import ABC +from typing import Any, Optional, Union, cast + +from azure.ai.ml._restclient.v2023_08_01_preview.models import ( + BayesianSamplingAlgorithm as RestBayesianSamplingAlgorithm, +) +from azure.ai.ml._restclient.v2023_08_01_preview.models import GridSamplingAlgorithm as RestGridSamplingAlgorithm +from azure.ai.ml._restclient.v2023_08_01_preview.models import RandomSamplingAlgorithm as RestRandomSamplingAlgorithm +from azure.ai.ml._restclient.v2023_08_01_preview.models import SamplingAlgorithm as RestSamplingAlgorithm +from azure.ai.ml._restclient.v2023_08_01_preview.models import SamplingAlgorithmType +from azure.ai.ml.entities._mixins import RestTranslatableMixin + + +class SamplingAlgorithm(ABC, RestTranslatableMixin): + """Base class for sampling algorithms. + + This class should not be instantiated directly. Instead, use one of its subclasses. + """ + + def __init__(self) -> None: + self.type = None + + @classmethod + def _from_rest_object(cls, obj: RestSamplingAlgorithm) -> Optional["SamplingAlgorithm"]: + if not obj: + return None + + sampling_algorithm: Any = None + if obj.sampling_algorithm_type == SamplingAlgorithmType.RANDOM: + sampling_algorithm = RandomSamplingAlgorithm._from_rest_object(obj) # pylint: disable=protected-access + + if obj.sampling_algorithm_type == SamplingAlgorithmType.GRID: + sampling_algorithm = GridSamplingAlgorithm._from_rest_object(obj) # pylint: disable=protected-access + + if obj.sampling_algorithm_type == SamplingAlgorithmType.BAYESIAN: + sampling_algorithm = BayesianSamplingAlgorithm._from_rest_object(obj) # pylint: disable=protected-access + + return cast(Optional["SamplingAlgorithm"], sampling_algorithm) + + +class RandomSamplingAlgorithm(SamplingAlgorithm): + """Random Sampling Algorithm. + + :keyword rule: The specific type of random algorithm. Accepted values are: "random" and "sobol". + :type rule: str + :keyword seed: The seed for random number generation. + :paramtype seed: int + :keyword logbase: A positive number or the number "e" in string format to be used as the base for log + based random sampling. + :paramtype logbase: Union[float, str] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_sweep_configurations.py + :start-after: [START configure_sweep_job_random_sampling_algorithm] + :end-before: [END configure_sweep_job_random_sampling_algorithm] + :language: python + :dedent: 8 + :caption: Assigning a random sampling algorithm for a SweepJob + """ + + def __init__( + self, + *, + rule: Optional[str] = None, + seed: Optional[int] = None, + logbase: Optional[Union[float, str]] = None, + ) -> None: + super().__init__() + self.type = SamplingAlgorithmType.RANDOM.lower() + self.rule = rule + self.seed = seed + self.logbase = logbase + + def _to_rest_object(self) -> RestRandomSamplingAlgorithm: + return RestRandomSamplingAlgorithm( + rule=self.rule, + seed=self.seed, + logbase=self.logbase, + ) + + @classmethod + def _from_rest_object(cls, obj: RestRandomSamplingAlgorithm) -> "RandomSamplingAlgorithm": + return cls( + rule=obj.rule, + seed=obj.seed, + logbase=obj.logbase, + ) + + +class GridSamplingAlgorithm(SamplingAlgorithm): + """Grid Sampling Algorithm. + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_sweep_configurations.py + :start-after: [START configure_sweep_job_grid_sampling_algorithm] + :end-before: [END configure_sweep_job_grid_sampling_algorithm] + :language: python + :dedent: 8 + :caption: Assigning a grid sampling algorithm for a SweepJob + """ + + def __init__(self) -> None: + super().__init__() + self.type = SamplingAlgorithmType.GRID.lower() + + def _to_rest_object(self) -> RestGridSamplingAlgorithm: + return RestGridSamplingAlgorithm() + + @classmethod + def _from_rest_object(cls, obj: RestGridSamplingAlgorithm) -> "GridSamplingAlgorithm": + return cls() + + +class BayesianSamplingAlgorithm(SamplingAlgorithm): + """Bayesian Sampling Algorithm. + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_sweep_configurations.py + :start-after: [START configure_sweep_job_bayesian_sampling_algorithm] + :end-before: [END configure_sweep_job_bayesian_sampling_algorithm] + :language: python + :dedent: 8 + :caption: Assigning a Bayesian sampling algorithm for a SweepJob + """ + + def __init__(self) -> None: + super().__init__() + self.type = SamplingAlgorithmType.BAYESIAN.lower() + + def _to_rest_object(self) -> RestBayesianSamplingAlgorithm: + return RestBayesianSamplingAlgorithm() + + @classmethod + def _from_rest_object(cls, obj: RestBayesianSamplingAlgorithm) -> "BayesianSamplingAlgorithm": + return cls() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/search_space.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/search_space.py new file mode 100644 index 00000000..bbc08d98 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/search_space.py @@ -0,0 +1,393 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from abc import ABC +from typing import Any, List, Optional, Union + +from azure.ai.ml.constants._common import TYPE +from azure.ai.ml.constants._job.sweep import SearchSpace +from azure.ai.ml.entities._mixins import RestTranslatableMixin +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, JobException + + +class SweepDistribution(ABC, RestTranslatableMixin): + """Base class for sweep distribution configuration. + + This class should not be instantiated directly. Instead, use one of its subclasses. + + :keyword type: Type of distribution. + :paramtype type: str + """ + + def __init__(self, *, type: Optional[str] = None) -> None: # pylint: disable=redefined-builtin + self.type = type + + @classmethod + def _from_rest_object(cls, obj: List) -> "SweepDistribution": + mapping = { + SearchSpace.CHOICE: Choice, + SearchSpace.NORMAL: Normal, + SearchSpace.LOGNORMAL: LogNormal, + SearchSpace.QNORMAL: QNormal, + SearchSpace.QLOGNORMAL: QLogNormal, + SearchSpace.RANDINT: Randint, + SearchSpace.UNIFORM: Uniform, + SearchSpace.QUNIFORM: QUniform, + SearchSpace.LOGUNIFORM: LogUniform, + SearchSpace.QLOGUNIFORM: QLogUniform, + } + + ss_class: Any = mapping.get(obj[0], None) + if ss_class: + res: SweepDistribution = ss_class._from_rest_object(obj) + return res + + msg = f"Unknown search space type: {obj[0]}" + raise JobException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.SWEEP_JOB, + error_category=ErrorCategory.SYSTEM_ERROR, + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SweepDistribution): + return NotImplemented + res: bool = self._to_rest_object() == other._to_rest_object() + return res + + +class Choice(SweepDistribution): + """Choice distribution configuration. + + :param values: List of values to choose from. + :type values: list[Union[float, str, dict]] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_sweep_configurations.py + :start-after: [START configure_sweep_job_choice_loguniform] + :end-before: [END configure_sweep_job_choice_loguniform] + :language: python + :dedent: 8 + :caption: Using Choice distribution to set values for a hyperparameter sweep + """ + + def __init__(self, values: Optional[List[Union[float, str, dict]]] = None, **kwargs: Any) -> None: + kwargs.setdefault(TYPE, SearchSpace.CHOICE) + super().__init__(**kwargs) + self.values = values + + def _to_rest_object(self) -> List: + items: List = [] + if self.values is not None: + for value in self.values: + if isinstance(value, dict): + rest_dict = {} + for k, v in value.items(): + if isinstance(v, SweepDistribution): + rest_dict[k] = v._to_rest_object() + else: + rest_dict[k] = v + items.append(rest_dict) + else: + items.append(value) + return [self.type, [items]] + + @classmethod + def _from_rest_object(cls, obj: List) -> "Choice": + rest_values = obj[1][0] + from_rest_values = [] + for rest_value in rest_values: + if isinstance(rest_value, dict): + from_rest_dict = {} + for k, v in rest_value.items(): + try: + # first assume that any dictionary value is a valid distribution (i.e. normal, uniform, etc) + # and try to deserialize it into a the correct SDK distribution object + from_rest_dict[k] = SweepDistribution._from_rest_object(v) + except Exception: # pylint: disable=W0718 + # if an exception is raised, assume that the value was not a valid distribution and use the + # value as it is for deserialization + from_rest_dict[k] = v + from_rest_values.append(from_rest_dict) + else: + from_rest_values.append(rest_value) + return Choice(values=from_rest_values) # type: ignore[arg-type] + + +class Normal(SweepDistribution): + """Normal distribution configuration. + + :param mu: Mean of the distribution. + :type mu: float + :param sigma: Standard deviation of the distribution. + :type sigma: float + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_sweep_configurations.py + :start-after: [START configure_sweep_job_randint_normal] + :end-before: [END configure_sweep_job_randint_normal] + :language: python + :dedent: 8 + :caption: Configuring Normal distributions for a hyperparameter sweep on a Command job. + """ + + def __init__(self, mu: Optional[float] = None, sigma: Optional[float] = None, **kwargs: Any) -> None: + kwargs.setdefault(TYPE, SearchSpace.NORMAL) + super().__init__(**kwargs) + self.mu = mu + self.sigma = sigma + + def _to_rest_object(self) -> List: + return [self.type, [self.mu, self.sigma]] + + @classmethod + def _from_rest_object(cls, obj: List) -> "Normal": + return cls(mu=obj[1][0], sigma=obj[1][1]) + + +class LogNormal(Normal): + """LogNormal distribution configuration. + + :param mu: Mean of the log of the distribution. + :type mu: float + :param sigma: Standard deviation of the log of the distribution. + :type sigma: float + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_sweep_configurations.py + :start-after: [START configure_sweep_job_lognormal_qlognormal] + :end-before: [END configure_sweep_job_lognormal_qlognormal] + :language: python + :dedent: 8 + :caption: Configuring LogNormal distributions for a hyperparameter sweep on a Command job. + """ + + def __init__(self, mu: Optional[float] = None, sigma: Optional[float] = None, **kwargs: Any) -> None: + kwargs.setdefault(TYPE, SearchSpace.LOGNORMAL) + super().__init__(mu=mu, sigma=sigma, **kwargs) + + +class QNormal(Normal): + """QNormal distribution configuration. + + :param mu: Mean of the distribution. + :type mu: float + :param sigma: Standard deviation of the distribution. + :type sigma: float + :param q: Quantization factor. + :type q: int + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_sweep_configurations.py + :start-after: [START configure_sweep_job_qloguniform_qnormal] + :end-before: [END configure_sweep_job_qloguniform_qnormal] + :language: python + :dedent: 8 + :caption: Configuring QNormal distributions for a hyperparameter sweep on a Command job. + """ + + def __init__( + self, mu: Optional[float] = None, sigma: Optional[float] = None, q: Optional[int] = None, **kwargs: Any + ) -> None: + kwargs.setdefault(TYPE, SearchSpace.QNORMAL) + super().__init__(mu=mu, sigma=sigma, **kwargs) + self.q = q + + def _to_rest_object(self) -> List: + return [self.type, [self.mu, self.sigma, self.q]] + + @classmethod + def _from_rest_object(cls, obj: List) -> "QNormal": + return cls(mu=obj[1][0], sigma=obj[1][1], q=obj[1][2]) + + +class QLogNormal(QNormal): + """QLogNormal distribution configuration. + + :param mu: Mean of the log of the distribution. + :type mu: Optional[float] + :param sigma: Standard deviation of the log of the distribution. + :type sigma: Optional[float] + :param q: Quantization factor. + :type q: Optional[int] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_sweep_configurations.py + :start-after: [START configure_sweep_job_lognormal_qlognormal] + :end-before: [END configure_sweep_job_lognormal_qlognormal] + :language: python + :dedent: 8 + :caption: Configuring QLogNormal distributions for a hyperparameter sweep on a Command job. + """ + + def __init__( + self, mu: Optional[float] = None, sigma: Optional[float] = None, q: Optional[int] = None, **kwargs: Any + ) -> None: + kwargs.setdefault(TYPE, SearchSpace.QLOGNORMAL) + super().__init__(mu=mu, sigma=sigma, q=q, **kwargs) + + +class Randint(SweepDistribution): + """Randint distribution configuration. + + :param upper: Upper bound of the distribution. + :type upper: int + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_sweep_configurations.py + :start-after: [START configure_sweep_job_randint_normal] + :end-before: [END configure_sweep_job_randint_normal] + :language: python + :dedent: 8 + :caption: Configuring Randint distributions for a hyperparameter sweep on a Command job. + """ + + def __init__(self, upper: Optional[int] = None, **kwargs: Any) -> None: + kwargs.setdefault(TYPE, SearchSpace.RANDINT) + super().__init__(**kwargs) + self.upper = upper + + def _to_rest_object(self) -> List: + return [self.type, [self.upper]] + + @classmethod + def _from_rest_object(cls, obj: List) -> "Randint": + return cls(upper=obj[1][0]) + + +class Uniform(SweepDistribution): + """ + + Uniform distribution configuration. + + :param min_value: Minimum value of the distribution. + :type min_value: float + :param max_value: Maximum value of the distribution. + :type max_value: float + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_sweep_configurations.py + :start-after: [START configure_sweep_job_uniform] + :end-before: [END configure_sweep_job_uniform] + :language: python + :dedent: 8 + :caption: Configuring Uniform distributions for learning rates and momentum + during a hyperparameter sweep on a Command job. + """ + + def __init__(self, min_value: Optional[float] = None, max_value: Optional[float] = None, **kwargs: Any) -> None: + kwargs.setdefault(TYPE, SearchSpace.UNIFORM) + super().__init__(**kwargs) + self.min_value = min_value + self.max_value = max_value + + def _to_rest_object(self) -> List: + return [self.type, [self.min_value, self.max_value]] + + @classmethod + def _from_rest_object(cls, obj: List) -> "Uniform": + return cls(min_value=obj[1][0], max_value=obj[1][1]) + + +class LogUniform(Uniform): + """LogUniform distribution configuration. + + :param min_value: Minimum value of the log of the distribution. + :type min_value: float + :param max_value: Maximum value of the log of the distribution. + :type max_value: float + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_sweep_configurations.py + :start-after: [START configure_sweep_job_choice_loguniform] + :end-before: [END configure_sweep_job_choice_loguniform] + :language: python + :dedent: 8 + :caption: Configuring a LogUniform distribution for a hyperparameter sweep job learning rate + """ + + def __init__(self, min_value: Optional[float] = None, max_value: Optional[float] = None, **kwargs: Any) -> None: + kwargs.setdefault(TYPE, SearchSpace.LOGUNIFORM) + super().__init__(min_value=min_value, max_value=max_value, **kwargs) + + +class QUniform(Uniform): + """QUniform distribution configuration. + + :param min_value: Minimum value of the distribution. + :type min_value: Optional[Union[int, float]] + :param max_value: Maximum value of the distribution. + :type max_value: Optional[Union[int, float]] + :param q: Quantization factor. + :type q: Optional[int] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_sweep_configurations.py + :start-after: [START configure_sweep_job_truncation_selection_policy] + :end-before: [END configure_sweep_job_truncation_selection_policy] + :language: python + :dedent: 8 + :caption: Configuring QUniform distributions for a hyperparameter sweep on a Command job. + """ + + def __init__( + self, + min_value: Optional[Union[int, float]] = None, + max_value: Optional[Union[int, float]] = None, + q: Optional[int] = None, + **kwargs: Any, + ) -> None: + kwargs.setdefault(TYPE, SearchSpace.QUNIFORM) + super().__init__(min_value=min_value, max_value=max_value, **kwargs) + self.q = q + + def _to_rest_object(self) -> List: + return [self.type, [self.min_value, self.max_value, self.q]] + + @classmethod + def _from_rest_object(cls, obj: List) -> "QUniform": + return cls(min_value=obj[1][0], max_value=obj[1][1], q=obj[1][2]) + + +class QLogUniform(QUniform): + """QLogUniform distribution configuration. + + :param min_value: Minimum value of the log of the distribution. + :type min_value: Optional[float] + :param max_value: Maximum value of the log of the distribution. + :type max_value: Optional[float] + :param q: Quantization factor. + :type q: Optional[int] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_sweep_configurations.py + :start-after: [START configure_sweep_job_qloguniform_qnormal] + :end-before: [END configure_sweep_job_qloguniform_qnormal] + :language: python + :dedent: 8 + :caption: Configuring QLogUniform distributions for a hyperparameter sweep on a Command job. + """ + + def __init__( + self, + min_value: Optional[float] = None, + max_value: Optional[float] = None, + q: Optional[int] = None, + **kwargs: Any, + ) -> None: + kwargs.setdefault(TYPE, SearchSpace.QLOGUNIFORM) + super().__init__(min_value=min_value, max_value=max_value, q=q, **kwargs) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/sweep_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/sweep_job.py new file mode 100644 index 00000000..0a99bb39 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/sweep_job.py @@ -0,0 +1,361 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import logging +from typing import Any, Dict, NoReturn, Optional, Union + +from azure.ai.ml._restclient.v2023_08_01_preview.models import JobBase +from azure.ai.ml._restclient.v2023_08_01_preview.models import SweepJob as RestSweepJob +from azure.ai.ml._restclient.v2023_08_01_preview.models import TrialComponent +from azure.ai.ml._schema._sweep.sweep_job import SweepJobSchema +from azure.ai.ml._utils.utils import map_single_brackets_and_warn +from azure.ai.ml.constants import JobType +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE +from azure.ai.ml.entities._component.command_component import CommandComponent +from azure.ai.ml.entities._credentials import ( + AmlTokenConfiguration, + ManagedIdentityConfiguration, + UserIdentityConfiguration, + _BaseJobIdentityConfiguration, +) +from azure.ai.ml.entities._inputs_outputs import Input +from azure.ai.ml.entities._job._input_output_helpers import ( + from_rest_data_outputs, + from_rest_inputs_to_dataset_literal, + to_rest_data_outputs, + to_rest_dataset_literal_inputs, + validate_inputs_for_command, + validate_key_contains_allowed_characters, +) +from azure.ai.ml.entities._job.command_job import CommandJob +from azure.ai.ml.entities._job.job import Job +from azure.ai.ml.entities._job.job_io_mixin import JobIOMixin +from azure.ai.ml.entities._job.job_resource_configuration import JobResourceConfiguration +from azure.ai.ml.entities._job.sweep.sampling_algorithm import SamplingAlgorithm +from azure.ai.ml.entities._system_data import SystemData +from azure.ai.ml.entities._util import load_from_dict +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, JobException + +# from ..identity import AmlToken, Identity, ManagedIdentity, UserIdentity +from ..job_limits import SweepJobLimits +from ..parameterized_command import ParameterizedCommand +from ..queue_settings import QueueSettings +from .early_termination_policy import ( + BanditPolicy, + EarlyTerminationPolicy, + MedianStoppingPolicy, + TruncationSelectionPolicy, +) +from .objective import Objective +from .parameterized_sweep import ParameterizedSweep +from .search_space import ( + Choice, + LogNormal, + LogUniform, + Normal, + QLogNormal, + QLogUniform, + QNormal, + QUniform, + Randint, + SweepDistribution, + Uniform, +) + +module_logger = logging.getLogger(__name__) + + +class SweepJob(Job, ParameterizedSweep, JobIOMixin): + """Sweep job for hyperparameter tuning. + + .. note:: + For sweep jobs, inputs, outputs, and parameters are accessible as environment variables using the prefix + ``AZUREML_SWEEP_``. For example, if you have a parameter named "learning_rate", you can access it as + ``AZUREML_SWEEP_learning_rate``. + + :keyword name: Name of the job. + :paramtype name: str + :keyword display_name: Display name of the job. + :paramtype display_name: str + :keyword description: Description of the job. + :paramtype description: str + :keyword tags: Tag dictionary. Tags can be added, removed, and updated. + :paramtype tags: dict[str, str] + :keyword properties: The asset property dictionary. + :paramtype properties: dict[str, str] + :keyword experiment_name: Name of the experiment the job will be created under. If None is provided, + job will be created under experiment 'Default'. + :paramtype experiment_name: str + :keyword identity: Identity that the training job will use while running on compute. + :paramtype identity: Union[ + ~azure.ai.ml.ManagedIdentityConfiguration, + ~azure.ai.ml.AmlTokenConfiguration, + ~azure.ai.ml.UserIdentityConfiguration + + ] + + :keyword inputs: Inputs to the command. + :paramtype inputs: dict + :keyword outputs: Mapping of output data bindings used in the job. + :paramtype outputs: dict[str, ~azure.ai.ml.Output] + :keyword sampling_algorithm: The hyperparameter sampling algorithm to use over the `search_space`. Defaults to + "random". + + :paramtype sampling_algorithm: str + :keyword search_space: Dictionary of the hyperparameter search space. The key is the name of the hyperparameter + and the value is the parameter expression. + + :paramtype search_space: Dict + :keyword objective: Metric to optimize for. + :paramtype objective: Objective + :keyword compute: The compute target the job runs on. + :paramtype compute: str + :keyword trial: The job configuration for each trial. Each trial will be provided with a different combination + of hyperparameter values that the system samples from the search_space. + + :paramtype trial: Union[ + ~azure.ai.ml.entities.CommandJob, + ~azure.ai.ml.entities.CommandComponent + + ] + + :keyword early_termination: The early termination policy to use. A trial job is canceled + when the criteria of the specified policy are met. If omitted, no early termination policy will be applied. + + :paramtype early_termination: Union[ + ~azure.mgmt.machinelearningservices.models.BanditPolicy, + ~azure.mgmt.machinelearningservices.models.MedianStoppingPolicy, + ~azure.mgmt.machinelearningservices.models.TruncationSelectionPolicy + + ] + + :keyword limits: Limits for the sweep job. + :paramtype limits: ~azure.ai.ml.entities.SweepJobLimits + :keyword queue_settings: Queue settings for the job. + :paramtype queue_settings: ~azure.ai.ml.entities.QueueSettings + :keyword resources: Compute Resource configuration for the job. + :paramtype resources: Optional[Union[~azure.ai.ml.entities.ResourceConfiguration] + :keyword kwargs: A dictionary of additional configuration parameters. + :paramtype kwargs: dict + + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_sweep_configurations.py + :start-after: [START configure_sweep_job_bayesian_sampling_algorithm] + :end-before: [END configure_sweep_job_bayesian_sampling_algorithm] + :language: python + :dedent: 8 + :caption: Creating a SweepJob + """ + + def __init__( + self, + *, + name: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[Dict] = None, + display_name: Optional[str] = None, + experiment_name: Optional[str] = None, + identity: Optional[ + Union[ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration] + ] = None, + inputs: Optional[Dict[str, Union[Input, str, bool, int, float]]] = None, + outputs: Optional[Dict] = None, + compute: Optional[str] = None, + limits: Optional[SweepJobLimits] = None, + sampling_algorithm: Optional[Union[str, SamplingAlgorithm]] = None, + search_space: Optional[ + Dict[ + str, + Union[ + Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform + ], + ] + ] = None, + objective: Optional[Objective] = None, + trial: Optional[Union[CommandJob, CommandComponent]] = None, + early_termination: Optional[ + Union[EarlyTerminationPolicy, BanditPolicy, MedianStoppingPolicy, TruncationSelectionPolicy] + ] = None, + queue_settings: Optional[QueueSettings] = None, + resources: Optional[Union[dict, JobResourceConfiguration]] = None, + **kwargs: Any, + ) -> None: + kwargs[TYPE] = JobType.SWEEP + + Job.__init__( + self, + name=name, + description=description, + tags=tags, + display_name=display_name, + experiment_name=experiment_name, + compute=compute, + **kwargs, + ) + self.inputs = inputs # type: ignore[assignment] + self.outputs = outputs # type: ignore[assignment] + self.trial = trial + self.identity = identity + + ParameterizedSweep.__init__( + self, + limits=limits, + sampling_algorithm=sampling_algorithm, + objective=objective, + early_termination=early_termination, + search_space=search_space, + queue_settings=queue_settings, + resources=resources, + ) + + def _to_dict(self) -> Dict: + res: dict = SweepJobSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + def _to_rest_object(self) -> JobBase: + self._override_missing_properties_from_trial() + if self.trial is not None: + self.trial.command = map_single_brackets_and_warn(self.trial.command) + + if self.search_space is not None: + search_space = {param: space._to_rest_object() for (param, space) in self.search_space.items()} + + if self.trial is not None: + validate_inputs_for_command(self.trial.command, self.inputs) + for key in search_space.keys(): # pylint: disable=possibly-used-before-assignment + validate_key_contains_allowed_characters(key) + + if self.trial is not None: + trial_component = TrialComponent( + code_id=self.trial.code, + distribution=( + self.trial.distribution._to_rest_object() + if self.trial.distribution and not isinstance(self.trial.distribution, Dict) + else None + ), + environment_id=self.trial.environment, + command=self.trial.command, + environment_variables=self.trial.environment_variables, + resources=( + self.trial.resources._to_rest_object() + if self.trial.resources and not isinstance(self.trial.resources, Dict) + else None + ), + ) + + sweep_job = RestSweepJob( + display_name=self.display_name, + description=self.description, + experiment_name=self.experiment_name, + search_space=search_space, + sampling_algorithm=self._get_rest_sampling_algorithm() if self.sampling_algorithm else None, + limits=self.limits._to_rest_object() if self.limits else None, + early_termination=( + self.early_termination._to_rest_object() + if self.early_termination and not isinstance(self.early_termination, str) + else None + ), + properties=self.properties, + compute_id=self.compute, + objective=self.objective._to_rest_object() if self.objective else None, + trial=trial_component, # pylint: disable=possibly-used-before-assignment + tags=self.tags, + inputs=to_rest_dataset_literal_inputs(self.inputs, job_type=self.type), + outputs=to_rest_data_outputs(self.outputs), + identity=self.identity._to_job_rest_object() if self.identity else None, + queue_settings=self.queue_settings._to_rest_object() if self.queue_settings else None, + resources=( + self.resources._to_rest_object() if self.resources and not isinstance(self.resources, dict) else None + ), + ) + + if not sweep_job.resources and sweep_job.trial.resources: + sweep_job.resources = sweep_job.trial.resources + + sweep_job_resource = JobBase(properties=sweep_job) + sweep_job_resource.name = self.name + return sweep_job_resource + + def _to_component(self, context: Optional[Dict] = None, **kwargs: Any) -> NoReturn: + msg = "no sweep component entity" + raise JobException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.SWEEP_JOB, + error_category=ErrorCategory.USER_ERROR, + ) + + @classmethod + def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "SweepJob": + loaded_schema = load_from_dict(SweepJobSchema, data, context, additional_message, **kwargs) + loaded_schema["trial"] = ParameterizedCommand(**(loaded_schema["trial"])) + sweep_job = SweepJob(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_schema) + return sweep_job + + @classmethod + def _load_from_rest(cls, obj: JobBase) -> "SweepJob": + properties: RestSweepJob = obj.properties + + # Unpack termination schema + early_termination = EarlyTerminationPolicy._from_rest_object(properties.early_termination) + + # Unpack sampling algorithm + sampling_algorithm = SamplingAlgorithm._from_rest_object(properties.sampling_algorithm) + + trial = ParameterizedCommand._load_from_sweep_job(obj.properties) + # Compute also appears in both layers of the yaml, but only one of the REST. + # This should be a required field in one place, but cannot be if its optional in two + + _search_space = {} + for param, dist in properties.search_space.items(): + _search_space[param] = SweepDistribution._from_rest_object(dist) + + return SweepJob( + name=obj.name, + id=obj.id, + display_name=properties.display_name, + description=properties.description, + properties=properties.properties, + tags=properties.tags, + experiment_name=properties.experiment_name, + services=properties.services, + status=properties.status, + creation_context=SystemData._from_rest_object(obj.system_data) if obj.system_data else None, + trial=trial, # type: ignore[arg-type] + compute=properties.compute_id, + sampling_algorithm=sampling_algorithm, + search_space=_search_space, # type: ignore[arg-type] + limits=SweepJobLimits._from_rest_object(properties.limits), + early_termination=early_termination, + objective=Objective._from_rest_object(properties.objective) if properties.objective else None, + inputs=from_rest_inputs_to_dataset_literal(properties.inputs), + outputs=from_rest_data_outputs(properties.outputs), + identity=( + _BaseJobIdentityConfiguration._from_rest_object(properties.identity) if properties.identity else None + ), + queue_settings=properties.queue_settings, + resources=properties.resources if hasattr(properties, "resources") else None, + ) + + def _override_missing_properties_from_trial(self) -> None: + if not isinstance(self.trial, CommandJob): + return + + if not self.compute: + self.compute = self.trial.compute + if not self.inputs: + self.inputs = self.trial.inputs + if not self.outputs: + self.outputs = self.trial.outputs + + has_trial_limits_timeout = self.trial.limits and self.trial.limits.timeout + if has_trial_limits_timeout and not self.limits: + time_out = self.trial.limits.timeout if self.trial.limits is not None else None + self.limits = SweepJobLimits(trial_timeout=time_out) + elif has_trial_limits_timeout and self.limits is not None and not self.limits.trial_timeout: + self.limits.trial_timeout = self.trial.limits.timeout if self.trial.limits is not None else None diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/to_rest_functions.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/to_rest_functions.py new file mode 100644 index 00000000..472cbc91 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/to_rest_functions.py @@ -0,0 +1,82 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from functools import singledispatch +from pathlib import Path +from typing import Any + +from azure.ai.ml._restclient.v2023_08_01_preview.models import JobBase as JobBaseData +from azure.ai.ml._restclient.v2025_01_01_preview.models import JobBase as JobBaseData202501 +from azure.ai.ml.constants._common import DEFAULT_EXPERIMENT_NAME +from azure.ai.ml.entities._builders.command import Command +from azure.ai.ml.entities._builders.pipeline import Pipeline +from azure.ai.ml.entities._builders.spark import Spark +from azure.ai.ml.entities._builders.sweep import Sweep +from azure.ai.ml.entities._job.job_name_generator import generate_job_name + +from .import_job import ImportJob +from .job import Job + + +def generate_defaults(job: Job, rest_job: JobBaseData) -> None: + # Default name to a generated user friendly name. + if not job.name: + rest_job.name = generate_job_name() + + if not job.display_name: + rest_job.properties.display_name = rest_job.name + + # Default experiment to current folder name or "Default" + if not job.experiment_name: + rest_job.properties.experiment_name = Path("./").resolve().stem.replace(" ", "") or DEFAULT_EXPERIMENT_NAME + + +@singledispatch +def to_rest_job_object(something: Any) -> JobBaseData: + raise NotImplementedError() + + +@to_rest_job_object.register(Job) +def _(job: Job) -> JobBaseData: + # TODO: Bug Item number: 2883432 + rest_job = job._to_rest_object() # type: ignore + generate_defaults(job, rest_job) + return rest_job + + +@to_rest_job_object.register(Command) +def _(command: Command) -> JobBaseData202501: + rest_job = command._to_job()._to_rest_object() + generate_defaults(command, rest_job) + return rest_job + + +@to_rest_job_object.register(Sweep) +def _(sweep: Sweep) -> JobBaseData: + rest_job = sweep._to_job()._to_rest_object() + generate_defaults(sweep, rest_job) + return rest_job + + +@to_rest_job_object.register(Pipeline) +def _(pipeline: Pipeline) -> JobBaseData: + rest_job = pipeline._to_job()._to_rest_object() + generate_defaults(pipeline, rest_job) + return rest_job + + +@to_rest_job_object.register(Spark) +def _(spark: Spark) -> JobBaseData: + rest_job = spark._to_job()._to_rest_object() + generate_defaults(spark, rest_job) + return rest_job + + +@to_rest_job_object.register(ImportJob) +def _(importJob: ImportJob) -> JobBaseData: + rest_job = importJob._to_rest_object() + generate_defaults(importJob, rest_job) + return rest_job diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_load_functions.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_load_functions.py new file mode 100644 index 00000000..81417792 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_load_functions.py @@ -0,0 +1,1103 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=too-many-lines + +import logging +import warnings +from os import PathLike +from pathlib import Path +from typing import IO, Any, AnyStr, Dict, List, Optional, Union, cast + +from marshmallow import ValidationError +from pydash import objects + +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml._utils.utils import load_yaml +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY +from azure.ai.ml.entities._assets._artifacts._package.model_package import ModelPackage +from azure.ai.ml.entities._assets._artifacts.code import Code +from azure.ai.ml.entities._assets._artifacts.data import Data +from azure.ai.ml.entities._assets._artifacts.feature_set import FeatureSet +from azure.ai.ml.entities._assets._artifacts.index import Index +from azure.ai.ml.entities._assets._artifacts.model import Model +from azure.ai.ml.entities._assets.environment import Environment +from azure.ai.ml.entities._autogen_entities.models import MarketplaceSubscription, ServerlessEndpoint +from azure.ai.ml.entities._component.command_component import CommandComponent +from azure.ai.ml.entities._component.component import Component +from azure.ai.ml.entities._component.parallel_component import ParallelComponent +from azure.ai.ml.entities._component.pipeline_component import PipelineComponent +from azure.ai.ml.entities._compute.compute import Compute +from azure.ai.ml.entities._datastore.datastore import Datastore +from azure.ai.ml.entities._deployment.batch_deployment import BatchDeployment +from azure.ai.ml.entities._deployment.model_batch_deployment import ModelBatchDeployment +from azure.ai.ml.entities._deployment.online_deployment import OnlineDeployment +from azure.ai.ml.entities._deployment.pipeline_component_batch_deployment import PipelineComponentBatchDeployment +from azure.ai.ml.entities._endpoint.batch_endpoint import BatchEndpoint +from azure.ai.ml.entities._endpoint.online_endpoint import OnlineEndpoint +from azure.ai.ml.entities._feature_set.feature_set_backfill_request import FeatureSetBackfillRequest +from azure.ai.ml.entities._feature_store.feature_store import FeatureStore +from azure.ai.ml.entities._feature_store_entity.feature_store_entity import FeatureStoreEntity +from azure.ai.ml.entities._job.job import Job +from azure.ai.ml.entities._registry.registry import Registry +from azure.ai.ml.entities._resource import Resource +from azure.ai.ml.entities._schedule.schedule import Schedule +from azure.ai.ml.entities._validation import PathAwareSchemaValidatableMixin, ValidationResultBuilder +from azure.ai.ml.entities._workspace.connections.workspace_connection import WorkspaceConnection +from azure.ai.ml.entities._workspace.workspace import Workspace +from azure.ai.ml.entities._workspace._ai_workspaces.capability_host import CapabilityHost +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + +module_logger = logging.getLogger(__name__) + +_DEFAULT_RELATIVE_ORIGIN = "./" + + +def load_common( + cls: Any, + source: Union[str, PathLike, IO[AnyStr]], + relative_origin: Optional[str] = None, + params_override: Optional[list] = None, + **kwargs: Any, +) -> Resource: + """Private function to load a yaml file to an entity object. + + :param cls: The entity class type. + :type cls: type[Resource] + :param source: A source of yaml. + :type source: Union[str, PathLike, IO[AnyStr]] + :param relative_origin: The origin of to be used when deducing + the relative locations of files referenced in the parsed yaml. + Must be provided, and is assumed to be assigned by other internal + functions that call this. + :type relative_origin: str + :param params_override: Fields to overwrite on top of the yaml file. + Format is [{"field1": "value1"}, {"field2": "value2"}] + :type params_override: List[Dict] + :return: The loaded resource + :rtype: Resource + """ + + path = kwargs.pop("path", None) + # Check for deprecated path input, either named or as first unnamed input + if source is None and path is not None: + source = path + warnings.warn( + "the 'path' input for load functions is deprecated. Please use 'source' instead.", DeprecationWarning + ) + + if relative_origin is None: + if isinstance(source, (str, PathLike)): + relative_origin = str(source) + else: + try: + relative_origin = source.name + except AttributeError: # input is a stream or something + relative_origin = _DEFAULT_RELATIVE_ORIGIN + + params_override = params_override or [] + yaml_dict = _try_load_yaml_dict(source) + + # pylint: disable=protected-access + cls, type_str = cls._resolve_cls_and_type(data=yaml_dict, params_override=params_override) + + try: + return _load_common_raising_marshmallow_error(cls, yaml_dict, relative_origin, params_override, **kwargs) + except ValidationError as e: + if issubclass(cls, PathAwareSchemaValidatableMixin): + validation_result = ValidationResultBuilder.from_validation_error(e, source_path=relative_origin) + schema = cls._create_schema_for_validation(context={BASE_PATH_CONTEXT_KEY: Path.cwd()}) + if type_str is None: + additional_message = "" + else: + additional_message = ( + f"If you are trying to configure an entity that is not " + f"of type {type_str}, please specify the correct " + f"type in the 'type' property." + ) + + def build_error(message: str, _: Any) -> ValidationError: + from azure.ai.ml.entities._util import decorate_validation_error + + return ValidationError( + message=decorate_validation_error( + schema=schema.__class__, + pretty_error=message, + additional_message=additional_message, + ), + ) + + validation_result.try_raise(error_func=build_error) + raise e + + +def _try_load_yaml_dict(source: Union[str, PathLike, IO[AnyStr]]) -> dict: + yaml_dict = load_yaml(source) + if yaml_dict is None: # This happens when a YAML is empty. + msg = "Target yaml file is empty" + raise ValidationException( + message=msg, + target=ErrorTarget.COMPONENT, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.CANNOT_PARSE, + ) + if not isinstance(yaml_dict, dict): # This happens when a YAML file is mal formatted. + msg = "Expect dict but get {} after parsing yaml file" + raise ValidationException( + message=msg.format(type(yaml_dict)), + target=ErrorTarget.COMPONENT, + no_personal_data_message=msg.format(type(yaml_dict)), + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.CANNOT_PARSE, + ) + return yaml_dict + + +def _load_common_raising_marshmallow_error( + cls: Any, + yaml_dict: Dict, + relative_origin: Optional[Union[PathLike, str, IO[AnyStr]]], + params_override: Optional[list] = None, + **kwargs: Any, +) -> Resource: + # pylint: disable=protected-access + res: Resource = cls._load(data=yaml_dict, yaml_path=relative_origin, params_override=params_override, **kwargs) + return res + + +def add_param_overrides(data, param_overrides) -> None: + if param_overrides is not None: + for override in param_overrides: + for param, val in override.items(): + # Check that none of the intermediary levels are string references (azureml/file) + param_tokens = param.split(".") + test_layer = data + for layer in param_tokens: + if test_layer is None: + continue + if isinstance(test_layer, str): + # pylint: disable=broad-exception-raised + raise Exception(f"Cannot use '--set' on properties defined by reference strings: --set {param}") + test_layer = test_layer.get(layer, None) + objects.set_(data, param, val) + + +def load_from_autogen_entity(cls, source: Union[str, PathLike, IO[AnyStr]], **kwargs): + loaded_dict = _try_load_yaml_dict(source) + add_param_overrides(loaded_dict, param_overrides=kwargs.get("params_override", None)) + entity = cls(loaded_dict) + try: + entity._validate() # pylint: disable=protected-access + except ValueError as e: + validation_result = ValidationResultBuilder.from_single_message(singular_error_message=str(e)) + validation_result.try_raise() + return entity + + +def load_job( + source: Union[str, PathLike, IO[AnyStr]], + *, + relative_origin: Optional[str] = None, + params_override: Optional[List[Dict]] = None, + **kwargs: Any, +) -> Job: + """Constructs a Job object from a YAML file. + + :param source: A path to a local YAML file or an already-open file object containing a job configuration. + If the source is a path, it will be opened and read. If the source is an open file, the file will be read + directly. + :type source: Union[PathLike, str, io.TextIOWrapper] + :keyword relative_origin: The root directory for the YAML. This directory will be used as the origin for deducing + the relative locations of files referenced in the parsed YAML. Defaults to the same directory as source if + source is a file or file path input. Defaults to "./" if the source is a stream input with no name value. + :paramtype relative_origin: Optional[str] + :keyword params_override: Fields to overwrite on top of the yaml file. + Format is [{"field1": "value1"}, {"field2": "value2"}] + :paramtype params_override: List[Dict] + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Job cannot be successfully validated. + Details will be provided in the error message. + :return: A loaded Job object. + :rtype: ~azure.ai.ml.entities.Job + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START load_job] + :end-before: [END load_job] + :language: python + :dedent: 8 + :caption: Loading a Job from a YAML config file. + """ + return cast(Job, load_common(Job, source, relative_origin, params_override, **kwargs)) + + +@experimental +def load_index( + source: Union[str, PathLike, IO[AnyStr]], + *, + relative_origin: Optional[str] = None, + params_override: Optional[List[Dict]] = None, + **kwargs: Any, +) -> Index: + """Constructs a Index object from a YAML file. + + :param source: A path to a local YAML file or an already-open file object containing an index configuration. + If the source is a path, it will be opened and read. If the source is an open file, the file will be read + directly. + :type source: Union[PathLike, str, io.TextIOWrapper] + :keyword relative_origin: The root directory for the YAML. This directory will be used as the origin for deducing + the relative locations of files referenced in the parsed YAML. Defaults to the same directory as source if + source is a file or file path input. Defaults to "./" if the source is a stream input with no name value. + :paramtype relative_origin: Optional[str] + :keyword params_override: Fields to overwrite on top of the yaml file. + Format is [{"field1": "value1"}, {"field2": "value2"}] + :paramtype params_override: List[Dict] + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Index cannot be successfully validated. + Details will be provided in the error message. + :return: A loaded Index object. + :rtype: ~azure.ai.ml.entities.Index + """ + return cast(Index, load_common(Index, source, relative_origin, params_override, **kwargs)) + + +@experimental +def load_serverless_endpoint( + source: Union[str, PathLike, IO[AnyStr]], + *, + relative_origin: Optional[str] = None, # pylint: disable=unused-argument + **kwargs: Any, +) -> ServerlessEndpoint: + return load_from_autogen_entity(ServerlessEndpoint, source, **kwargs) + + +@experimental +def load_marketplace_subscription( + source: Union[str, PathLike, IO[AnyStr]], + *, + relative_origin: Optional[str] = None, # pylint: disable=unused-argument + **kwargs: Any, +) -> MarketplaceSubscription: + return load_from_autogen_entity(MarketplaceSubscription, source, **kwargs) + + +def load_workspace( + source: Union[str, PathLike, IO[AnyStr]], + *, + relative_origin: Optional[str] = None, + params_override: Optional[List[Dict]] = None, + **kwargs: Any, +) -> Workspace: + """Load a workspace object from a yaml file. This includes workspace sub-classes + like hubs and projects. + + :param source: The local yaml source of a workspace. Must be either a + path to a local file, or an already-open file. + If the source is a path, it will be open and read. + An exception is raised if the file does not exist. + If the source is an open file, the file will be read directly, + and an exception is raised if the file is not readable. + :type source: Union[PathLike, str, io.TextIOWrapper] + :keyword relative_origin: The origin to be used when deducing + the relative locations of files referenced in the parsed yaml. + Defaults to the inputted source's directory if it is a file or file path input. + Defaults to "./" if the source is a stream input with no name value. + :paramtype relative_origin: str + :keyword params_override: Fields to overwrite on top of the yaml file. + Format is [{"field1": "value1"}, {"field2": "value2"}] + :paramtype params_override: List[Dict] + :return: Loaded workspace object. + :rtype: Workspace + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_workspace.py + :start-after: [START load_workspace] + :end-before: [END load_workspace] + :language: python + :dedent: 8 + :caption: Loading a Workspace from a YAML config file. + """ + return cast(Workspace, load_common(Workspace, source, relative_origin, params_override, **kwargs)) + + +def load_registry( + source: Union[str, PathLike, IO[AnyStr]], + *, + relative_origin: Optional[str] = None, + params_override: Optional[List[Dict]] = None, + **kwargs: Any, +) -> Registry: + """Load a registry object from a yaml file. + + :param source: The local yaml source of a registry. Must be either a + path to a local file, or an already-open file. + If the source is a path, it will be open and read. + An exception is raised if the file does not exist. + If the source is an open file, the file will be read directly, + and an exception is raised if the file is not readable. + :type source: Union[PathLike, str, io.TextIOWrapper] + :keyword relative_origin: The origin to be used when deducing + the relative locations of files referenced in the parsed yaml. + Defaults to the inputted source's directory if it is a file or file path input. + Defaults to "./" if the source is a stream input with no name value. + :paramtype relative_origin: str + :keyword params_override: Fields to overwrite on top of the yaml file. + Format is [{"field1": "value1"}, {"field2": "value2"}] + :paramtype params_override: List[Dict] + :return: Loaded registry object. + :rtype: Registry + """ + return cast(Registry, load_common(Registry, source, relative_origin, params_override, **kwargs)) + + +def load_datastore( + source: Union[str, PathLike, IO[AnyStr]], + *, + relative_origin: Optional[str] = None, + params_override: Optional[List[Dict]] = None, + **kwargs: Any, +) -> Datastore: + """Construct a datastore object from a yaml file. + + :param source: The local yaml source of a datastore. Must be either a + path to a local file, or an already-open file. + If the source is a path, it will be open and read. + An exception is raised if the file does not exist. + If the source is an open file, the file will be read directly, + and an exception is raised if the file is not readable. + :type source: Union[PathLike, str, io.TextIOWrapper] + :keyword relative_origin: The origin to be used when deducing + the relative locations of files referenced in the parsed yaml. + Defaults to the inputted source's directory if it is a file or file path input. + Defaults to "./" if the source is a stream input with no name value. + :paramtype relative_origin: str + :keyword params_override: Fields to overwrite on top of the yaml file. + Format is [{"field1": "value1"}, {"field2": "value2"}] + :paramtype params_override: List[Dict] + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Datastore cannot be successfully validated. + Details will be provided in the error message. + :return: Loaded datastore object. + :rtype: Datastore + """ + return cast(Datastore, load_common(Datastore, source, relative_origin, params_override, **kwargs)) + + +def load_code( + source: Union[str, PathLike, IO[AnyStr]], + *, + relative_origin: Optional[str] = None, + params_override: Optional[List[Dict]] = None, + **kwargs: Any, +) -> Code: + """Construct a code object from a yaml file. + + :param source: The local yaml source of a code object. Must be either a + path to a local file, or an already-open file. + If the source is a path, it will be open and read. + An exception is raised if the file does not exist. + If the source is an open file, the file will be read directly, + and an exception is raised if the file is not readable. + :type source: Union[PathLike, str, io.TextIOWrapper] + :keyword relative_origin: The origin to be used when deducing + the relative locations of files referenced in the parsed yaml. + Defaults to the inputted source's directory if it is a file or file path input. + Defaults to "./" if the source is a stream input with no name value. + :paramtype relative_origin: Optional[str] + :keyword params_override: Fields to overwrite on top of the yaml file. + Format is [{"field1": "value1"}, {"field2": "value2"}] + :paramtype params_override: List[Dict] + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Code cannot be successfully validated. + Details will be provided in the error message. + :return: Loaded code object. + :rtype: ~azure.ai.ml.entities._assets._artifacts.code.Code + """ + return cast(Code, load_common(Code, source, relative_origin, params_override, **kwargs)) + + +def load_compute( + source: Union[str, PathLike, IO[AnyStr]], + *, + relative_origin: Optional[str] = None, + params_override: Optional[List[Dict[str, str]]] = None, + **kwargs: Any, +) -> Compute: + """Construct a compute object from a yaml file. + + :param source: The local yaml source of a compute. Must be either a + path to a local file, or an already-open file. + If the source is a path, it will be open and read. + An exception is raised if the file does not exist. + If the source is an open file, the file will be read directly, + and an exception is raised if the file is not readable. + :type source: Union[PathLike, str, io.TextIOWrapper] + :keyword relative_origin: The origin to be used when deducing + the relative locations of files referenced in the parsed yaml. + Defaults to the inputted source's directory if it is a file or file path input. + Defaults to "./" if the source is a stream input with no name value. + :paramtype relative_origin: Optional[str] + :keyword params_override: Optional parameters to override in the loaded yaml. + :paramtype params_override: Optional[List[Dict[str, str]] + :return: Loaded compute object. + :rtype: ~azure.ai.ml.entities.Compute + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_compute.py + :start-after: [START load_compute] + :end-before: [END load_compute] + :language: python + :dedent: 8 + :caption: Loading a Compute object from a YAML file and overriding its description. + """ + return cast(Compute, load_common(Compute, source, relative_origin, params_override, **kwargs)) + + +def load_component( + source: Optional[Union[str, PathLike, IO[AnyStr]]] = None, + *, + relative_origin: Optional[str] = None, + params_override: Optional[List[Dict]] = None, + **kwargs: Any, +) -> Union[CommandComponent, ParallelComponent, PipelineComponent]: + """Load component from local or remote to a component function. + + :param source: The local yaml source of a component. Must be either a + path to a local file, or an already-open file. + If the source is a path, it will be open and read. + An exception is raised if the file does not exist. + If the source is an open file, the file will be read directly, + and an exception is raised if the file is not readable. + :type source: Union[PathLike, str, io.TextIOWrapper] + :keyword relative_origin: The origin to be used when deducing + the relative locations of files referenced in the parsed yaml. + Defaults to the inputted source's directory if it is a file or file path input. + Defaults to "./" if the source is a stream input with no name value. + :paramtype relative_origin: str + :keyword params_override: Fields to overwrite on top of the yaml file. + Format is [{"field1": "value1"}, {"field2": "value2"}] + :paramtype params_override: List[Dict] + :return: A Component object + :rtype: Union[CommandComponent, ParallelComponent, PipelineComponent] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_component_configurations.py + :start-after: [START configure_load_component] + :end-before: [END configure_load_component] + :language: python + :dedent: 8 + :caption: Loading a Component object from a YAML file, overriding its version to "1.0.2", and + registering it remotely. + """ + + client = kwargs.pop("client", None) + name = kwargs.pop("name", None) + version = kwargs.pop("version", None) + + if source: + component_entity = load_common(Component, source, relative_origin, params_override, **kwargs) + elif client and name and version: + component_entity = client.components.get(name, version) + else: + msg = "One of (client, name, version), (source) should be provided." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.COMPONENT, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.MISSING_FIELD, + ) + return cast(Union[CommandComponent, ParallelComponent, PipelineComponent], component_entity) + + +def load_model( + source: Union[str, PathLike, IO[AnyStr]], + *, + relative_origin: Optional[str] = None, + params_override: Optional[List[Dict]] = None, + **kwargs: Any, +) -> Model: + """Constructs a Model object from a YAML file. + + :param source: A path to a local YAML file or an already-open file object containing a job configuration. + If the source is a path, it will be opened and read. If the source is an open file, the file will be read + directly. + :type source: Union[PathLike, str, io.TextIOWrapper] + :keyword relative_origin: The root directory for the YAML. This directory will be used as the origin for deducing + the relative locations of files referenced in the parsed YAML. Defaults to the same directory as source if + source is a file or file path input. Defaults to "./" if the source is a stream input with no name value. + :paramtype relative_origin: Optional[str] + :keyword params_override: Fields to overwrite on top of the yaml file. + Format is [{"field1": "value1"}, {"field2": "value2"}] + :paramtype params_override: List[Dict] + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Job cannot be successfully validated. + Details will be provided in the error message. + :return: A loaded Model object. + :rtype: ~azure.ai.ml.entities.Model + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START load_model] + :end-before: [END load_model] + :language: python + :dedent: 8 + :caption: Loading a Model from a YAML config file, overriding the name and version parameters. + """ + return cast(Model, load_common(Model, source, relative_origin, params_override, **kwargs)) + + +def load_data( + source: Union[str, PathLike, IO[AnyStr]], + *, + relative_origin: Optional[str] = None, + params_override: Optional[List[Dict]] = None, + **kwargs: Any, +) -> Data: + """Construct a data object from yaml file. + + :param source: The local yaml source of a data object. Must be either a + path to a local file, or an already-open file. + If the source is a path, it will be open and read. + An exception is raised if the file does not exist. + If the source is an open file, the file will be read directly, + and an exception is raised if the file is not readable. + :type source: Union[PathLike, str, io.TextIOWrapper] + :keyword relative_origin: The origin to be used when deducing + the relative locations of files referenced in the parsed yaml. + Defaults to the inputted source's directory if it is a file or file path input. + Defaults to "./" if the source is a stream input with no name value. + :paramtype relative_origin: str + :keyword params_override: Fields to overwrite on top of the yaml file. + Format is [{"field1": "value1"}, {"field2": "value2"}] + :paramtype params_override: List[Dict] + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Data cannot be successfully validated. + Details will be provided in the error message. + :return: Constructed Data or DataImport object. + :rtype: Data + """ + return cast(Data, load_common(Data, source, relative_origin, params_override, **kwargs)) + + +def load_environment( + source: Union[str, PathLike, IO[AnyStr]], + *, + relative_origin: Optional[str] = None, + params_override: Optional[List[Dict]] = None, + **kwargs: Any, +) -> Environment: + """Construct a environment object from yaml file. + + :param source: The local yaml source of an environment. Must be either a + path to a local file, or an already-open file. + If the source is a path, it will be open and read. + An exception is raised if the file does not exist. + If the source is an open file, the file will be read directly, + and an exception is raised if the file is not readable. + :type source: Union[PathLike, str, io.TextIOWrapper] + :keyword relative_origin: The origin to be used when deducing + the relative locations of files referenced in the parsed yaml. + Defaults to the inputted source's directory if it is a file or file path input. + Defaults to "./" if the source is a stream input with no name value. + :paramtype relative_origin: str + :keyword params_override: Fields to overwrite on top of the yaml file. + Format is [{"field1": "value1"}, {"field2": "value2"}] + :paramtype params_override: List[Dict] + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Environment cannot be successfully validated. + Details will be provided in the error message. + :return: Constructed environment object. + :rtype: Environment + """ + return cast(Environment, load_common(Environment, source, relative_origin, params_override, **kwargs)) + + +def load_online_deployment( + source: Union[str, PathLike, IO[AnyStr]], + *, + relative_origin: Optional[str] = None, + params_override: Optional[List[Dict]] = None, + **kwargs: Any, +) -> OnlineDeployment: + """Construct a online deployment object from yaml file. + + :param source: The local yaml source of an online deployment object. Must be either a + path to a local file, or an already-open file. + If the source is a path, it will be open and read. + An exception is raised if the file does not exist. + If the source is an open file, the file will be read directly, + and an exception is raised if the file is not readable. + :type source: Union[PathLike, str, io.TextIOWrapper] + :keyword relative_origin: The origin to be used when deducing + the relative locations of files referenced in the parsed yaml. + Defaults to the inputted source's directory if it is a file or file path input. + Defaults to "./" if the source is a stream input with no name value. + :paramtype relative_origin: str + :keyword params_override: Fields to overwrite on top of the yaml file. + Format is [{"field1": "value1"}, {"field2": "value2"}] + :paramtype params_override: List[Dict] + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Online Deployment cannot be successfully validated. + Details will be provided in the error message. + :return: Constructed online deployment object. + :rtype: OnlineDeployment + """ + return cast(OnlineDeployment, load_common(OnlineDeployment, source, relative_origin, params_override, **kwargs)) + + +def load_batch_deployment( + source: Union[str, PathLike, IO[AnyStr]], + *, + relative_origin: Optional[str] = None, + params_override: Optional[List[Dict]] = None, + **kwargs: Any, +) -> BatchDeployment: + """Construct a batch deployment object from yaml file. + + :param source: The local yaml source of a batch deployment object. Must be either a + path to a local file, or an already-open file. + If the source is a path, it will be open and read. + An exception is raised if the file does not exist. + If the source is an open file, the file will be read directly, + and an exception is raised if the file is not readable. + :type source: Union[PathLike, str, io.TextIOWrapper] + :keyword relative_origin: The origin to be used when deducing + the relative locations of files referenced in the parsed yaml. + Defaults to the inputted source's directory if it is a file or file path input. + Defaults to "./" if the source is a stream input with no name value. + :paramtype relative_origin: str + :keyword params_override: Fields to overwrite on top of the yaml file. + Format is [{"field1": "value1"}, {"field2": "value2"}] + :paramtype params_override: List[Dict] + :return: Constructed batch deployment object. + :rtype: BatchDeployment + """ + return cast(BatchDeployment, load_common(BatchDeployment, source, relative_origin, params_override, **kwargs)) + + +def load_model_batch_deployment( + source: Union[str, PathLike, IO[AnyStr]], + *, + relative_origin: Optional[str] = None, + params_override: Optional[List[Dict]] = None, + **kwargs: Any, +) -> ModelBatchDeployment: + """Construct a model batch deployment object from yaml file. + + :param source: The local yaml source of a batch deployment object. Must be either a + path to a local file, or an already-open file. + If the source is a path, it will be open and read. + An exception is raised if the file does not exist. + If the source is an open file, the file will be read directly, + and an exception is raised if the file is not readable. + :type source: Union[PathLike, str, io.TextIOWrapper] + :keyword relative_origin: The origin to be used when deducing + the relative locations of files referenced in the parsed yaml. + Defaults to the inputted source's directory if it is a file or file path input. + Defaults to "./" if the source is a stream input with no name value. + :paramtype relative_origin: str + :keyword params_override: Fields to overwrite on top of the yaml file. + Format is [{"field1": "value1"}, {"field2": "value2"}] + :paramtype params_override: List[Dict] + :return: Constructed model batch deployment object. + :rtype: ModelBatchDeployment + """ + return cast( + ModelBatchDeployment, load_common(ModelBatchDeployment, source, relative_origin, params_override, **kwargs) + ) + + +def load_pipeline_component_batch_deployment( + source: Union[str, PathLike, IO[AnyStr]], + *, + relative_origin: Optional[str] = None, + params_override: Optional[List[Dict]] = None, + **kwargs: Any, +) -> PipelineComponentBatchDeployment: + """Construct a pipeline component batch deployment object from yaml file. + + :param source: The local yaml source of a batch deployment object. Must be either a + path to a local file, or an already-open file. + If the source is a path, it will be open and read. + An exception is raised if the file does not exist. + If the source is an open file, the file will be read directly, + and an exception is raised if the file is not readable. + :type source: Union[PathLike, str, io.TextIOWrapper] + :keyword relative_origin: The origin to be used when deducing + the relative locations of files referenced in the parsed yaml. + Defaults to the inputted source's directory if it is a file or file path input. + Defaults to "./" if the source is a stream input with no name value. + :paramtype relative_origin: str + :keyword params_override: Fields to overwrite on top of the yaml file. + Format is [{"field1": "value1"}, {"field2": "value2"}] + :paramtype params_override: List[Dict] + :return: Constructed pipeline component batch deployment object. + :rtype: PipelineComponentBatchDeployment + """ + return cast( + PipelineComponentBatchDeployment, + load_common(PipelineComponentBatchDeployment, source, relative_origin, params_override, **kwargs), + ) + + +def load_online_endpoint( + source: Union[str, PathLike, IO[AnyStr]], + *, + relative_origin: Optional[str] = None, + params_override: Optional[List[Dict]] = None, + **kwargs: Any, +) -> OnlineEndpoint: + """Construct a online endpoint object from yaml file. + + :param source: The local yaml source of an online endpoint object. Must be either a + path to a local file, or an already-open file. + If the source is a path, it will be open and read. + An exception is raised if the file does not exist. + If the source is an open file, the file will be read directly, + and an exception is raised if the file is not readable. + :type source: Union[PathLike, str, io.TextIOWrapper] + :keyword relative_origin: The origin to be used when deducing + the relative locations of files referenced in the parsed yaml. + Defaults to the inputted source's directory if it is a file or file path input. + Defaults to "./" if the source is a stream input with no name value. + :paramtype relative_origin: str + :keyword params_override: Fields to overwrite on top of the yaml file. + Format is [{"field1": "value1"}, {"field2": "value2"}] + :paramtype params_override: List[Dict] + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Online Endpoint cannot be successfully validated. + Details will be provided in the error message. + :return: Constructed online endpoint object. + :rtype: OnlineEndpoint + """ + return cast(OnlineEndpoint, load_common(OnlineEndpoint, source, relative_origin, params_override, **kwargs)) + + +def load_batch_endpoint( + source: Union[str, PathLike, IO[AnyStr]], + relative_origin: Optional[str] = None, + *, + params_override: Optional[List[Dict]] = None, + **kwargs: Any, +) -> BatchEndpoint: + """Construct a batch endpoint object from yaml file. + + :param source: The local yaml source of a batch endpoint object. Must be either a + path to a local file, or an already-open file. + If the source is a path, it will be open and read. + An exception is raised if the file does not exist. + If the source is an open file, the file will be read directly, + and an exception is raised if the file is not readable. + :type source: Union[PathLike, str, io.TextIOWrapper] + :param relative_origin: The origin to be used when deducing + the relative locations of files referenced in the parsed yaml. + Defaults to the inputted source's directory if it is a file or file path input. + Defaults to "./" if the source is a stream input with no name value. + :type relative_origin: str + :keyword params_override: Fields to overwrite on top of the yaml file. + Format is [{"field1": "value1"}, {"field2": "value2"}] + :paramtype params_override: List[Dict] + :return: Constructed batch endpoint object. + :rtype: BatchEndpoint + """ + return cast(BatchEndpoint, load_common(BatchEndpoint, source, relative_origin, params_override, **kwargs)) + + +def load_connection( + source: Union[str, PathLike, IO[AnyStr]], + *, + relative_origin: Optional[str] = None, + params_override: Optional[List[Dict]] = None, + **kwargs: Any, +) -> WorkspaceConnection: + """Construct a connection object from yaml file. + + :param source: The local yaml source of a connection object. Must be either a + path to a local file, or an already-open file. + If the source is a path, it will be open and read. + An exception is raised if the file does not exist. + If the source is an open file, the file will be read directly, + and an exception is raised if the file is not readable. + :type source: Union[PathLike, str, io.TextIOWrapper] + :keyword relative_origin: The origin to be used when deducing + the relative locations of files referenced in the parsed yaml. + Defaults to the inputted source's directory if it is a file or file path input. + Defaults to "./" if the source is a stream input with no name value. + :paramtype relative_origin: str + :keyword params_override: Fields to overwrite on top of the yaml file. + Format is [{"field1": "value1"}, {"field2": "value2"}] + :paramtype params_override: List[Dict] + :return: Constructed connection object. + :rtype: Connection + + """ + return cast( + WorkspaceConnection, load_common(WorkspaceConnection, source, relative_origin, params_override, **kwargs) + ) + + +# Unlike other aspects of connections, this wasn't made experimental, and thus couldn't just be replaced +# During the renaming from 'workspace connection' to just 'connection'. +def load_workspace_connection( + source: Union[str, PathLike, IO[AnyStr]], + *, + relative_origin: Optional[str] = None, + **kwargs: Any, +) -> WorkspaceConnection: + """Deprecated - use 'load_connection' instead. Construct a connection object from yaml file. + + :param source: The local yaml source of a connection object. Must be either a + path to a local file, or an already-open file. + If the source is a path, it will be open and read. + An exception is raised if the file does not exist. + If the source is an open file, the file will be read directly, + and an exception is raised if the file is not readable. + :type source: Union[PathLike, str, io.TextIOWrapper] + :keyword relative_origin: The origin to be used when deducing + the relative locations of files referenced in the parsed yaml. + Defaults to the inputted source's directory if it is a file or file path input. + Defaults to "./" if the source is a stream input with no name value. + :paramtype relative_origin: str + + :return: Constructed connection object. + :rtype: Connection + + """ + warnings.warn( + "the 'load_workspace_connection' function is deprecated. Use 'load_connection' instead.", DeprecationWarning + ) + return load_connection(source, relative_origin=relative_origin, **kwargs) + + +def load_schedule( + source: Union[str, PathLike, IO[AnyStr]], + relative_origin: Optional[str] = None, + *, + params_override: Optional[List[Dict]] = None, + **kwargs: Any, +) -> Schedule: + """Construct a schedule object from yaml file. + + :param source: The local yaml source of a schedule object. Must be either a + path to a local file, or an already-open file. + If the source is a path, it will be open and read. + An exception is raised if the file does not exist. + If the source is an open file, the file will be read directly, + and an exception is raised if the file is not readable. + :type source: Union[PathLike, str, io.TextIOWrapper] + :param relative_origin: The origin to be used when deducing + the relative locations of files referenced in the parsed yaml. + Defaults to the inputted source's directory if it is a file or file path input. + Defaults to "./" if the source is a stream input with no name value. + :type relative_origin: str + :keyword params_override: Fields to overwrite on top of the yaml file. + Format is [{"field1": "value1"}, {"field2": "value2"}] + :paramtype params_override: List[Dict] + :return: Constructed schedule object. + :rtype: Schedule + """ + return cast(Schedule, load_common(Schedule, source, relative_origin, params_override, **kwargs)) + + +def load_feature_store( + source: Union[str, PathLike, IO[AnyStr]], + *, + relative_origin: Optional[str] = None, + params_override: Optional[List[Dict]] = None, + **kwargs: Any, +) -> FeatureStore: + """Load a feature store object from a yaml file. + + :param source: The local yaml source of a feature store. Must be either a + path to a local file, or an already-open file. + If the source is a path, it will be open and read. + An exception is raised if the file does not exist. + If the source is an open file, the file will be read directly, + and an exception is raised if the file is not readable. + :type source: Union[PathLike, str, io.TextIOWrapper] + :keyword relative_origin: The origin to be used when deducing + the relative locations of files referenced in the parsed yaml. + Defaults to the inputted source's directory if it is a file or file path input. + Defaults to "./" if the source is a stream input with no name value. + :paramtype relative_origin: str + :keyword params_override: Fields to overwrite on top of the yaml file. + Format is [{"field1": "value1"}, {"field2": "value2"}] + :paramtype params_override: List[Dict] + :return: Loaded feature store object. + :rtype: FeatureStore + """ + return cast(FeatureStore, load_common(FeatureStore, source, relative_origin, params_override, **kwargs)) + + +def load_feature_set( + source: Union[str, PathLike, IO[AnyStr]], + *, + relative_origin: Optional[str] = None, + params_override: Optional[List[Dict]] = None, + **kwargs: Any, +) -> FeatureSet: + """Construct a FeatureSet object from yaml file. + + :param source: The local yaml source of a FeatureSet object. Must be either a + path to a local file, or an already-open file. + If the source is a path, it will be open and read. + An exception is raised if the file does not exist. + If the source is an open file, the file will be read directly, + and an exception is raised if the file is not readable. + :type source: Union[PathLike, str, io.TextIOWrapper] + :keyword relative_origin: The origin to be used when deducing + the relative locations of files referenced in the parsed yaml. + Defaults to the inputted source's directory if it is a file or file path input. + Defaults to "./" if the source is a stream input with no name value. + :paramtype relative_origin: str + :keyword params_override: Fields to overwrite on top of the yaml file. + Format is [{"field1": "value1"}, {"field2": "value2"}] + :paramtype params_override: List[Dict] + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if FeatureSet cannot be successfully validated. + Details will be provided in the error message. + :return: Constructed FeatureSet object. + :rtype: FeatureSet + """ + return cast(FeatureSet, load_common(FeatureSet, source, relative_origin, params_override, **kwargs)) + + +def load_feature_store_entity( + source: Union[str, PathLike, IO[AnyStr]], + *, + relative_origin: Optional[str] = None, + params_override: Optional[List[Dict]] = None, + **kwargs: Any, +) -> FeatureStoreEntity: + """Construct a FeatureStoreEntity object from yaml file. + + :param source: The local yaml source of a FeatureStoreEntity object. Must be either a + path to a local file, or an already-open file. + If the source is a path, it will be open and read. + An exception is raised if the file does not exist. + If the source is an open file, the file will be read directly, + and an exception is raised if the file is not readable. + :type source: Union[PathLike, str, io.TextIOWrapper] + :keyword relative_origin: The origin to be used when deducing + the relative locations of files referenced in the parsed yaml. + Defaults to the inputted source's directory if it is a file or file path input. + Defaults to "./" if the source is a stream input with no name value. + :paramtype relative_origin: str + :keyword params_override: Fields to overwrite on top of the yaml file. + Format is [{"field1": "value1"}, {"field2": "value2"}] + :paramtype params_override: List[Dict] + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if FeatureStoreEntity cannot be successfully validated. + Details will be provided in the error message. + :return: Constructed FeatureStoreEntity object. + :rtype: FeatureStoreEntity + """ + return cast(FeatureStoreEntity, load_common(FeatureStoreEntity, source, relative_origin, params_override, **kwargs)) + + +@experimental +def load_model_package( + source: Union[str, PathLike, IO[AnyStr]], + *, + relative_origin: Optional[str] = None, + params_override: Optional[List[Dict]] = None, + **kwargs: Any, +) -> ModelPackage: + """Constructs a ModelPackage object from a YAML file. + + :param source: A path to a local YAML file or an already-open file object containing a job configuration. + If the source is a path, it will be opened and read. If the source is an open file, the file will be read + directly. + :type source: Union[PathLike, str, io.TextIOWrapper] + :keyword relative_origin: The root directory for the YAML. This directory will be used as the origin for deducing + the relative locations of files referenced in the parsed YAML. Defaults to the same directory as source if + source is a file or file path input. Defaults to "./" if the source is a stream input with no name value. + :paramtype relative_origin: Optional[str] + :keyword params_override: Fields to overwrite on top of the yaml file. + Format is [{"field1": "value1"}, {"field2": "value2"}] + :paramtype params_override: List[Dict] + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Job cannot be successfully validated. + Details will be provided in the error message. + :return: A loaded ModelPackage object. + :rtype: ~azure.ai.ml.entities.ModelPackage + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START load_model_package] + :end-before: [END load_model_package] + :language: python + :dedent: 8 + :caption: Loading a ModelPackage from a YAML config file. + """ + return cast(ModelPackage, load_common(ModelPackage, source, relative_origin, params_override, **kwargs)) + + +def load_feature_set_backfill_request( + source: Union[str, PathLike, IO[AnyStr]], + *, + relative_origin: Optional[str] = None, + params_override: Optional[List[Dict]] = None, + **kwargs: Any, +) -> FeatureSetBackfillRequest: + """Construct a FeatureSetBackfillRequest object from yaml file. + + :param source: The local yaml source of a FeatureSetBackfillRequest object. Must be either a + path to a local file, or an already-open file. + If the source is a path, it will be open and read. + An exception is raised if the file does not exist. + If the source is an open file, the file will be read directly, + and an exception is raised if the file is not readable. + :type source: Union[PathLike, str, io.TextIOWrapper] + :keyword relative_origin: The origin to be used when deducing + the relative locations of files referenced in the parsed yaml. + Defaults to the inputted source's directory if it is a file or file path input. + Defaults to "./" if the source is a stream input with no name value. + :type relative_origin: str + :keyword params_override: Fields to overwrite on top of the yaml file. + Format is [{"field1": "value1"}, {"field2": "value2"}] + :paramtype params_override: List[Dict] + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if FeatureSetBackfillRequest + cannot be successfully validated. Details will be provided in the error message. + :return: Constructed FeatureSetBackfillRequest object. + :rtype: FeatureSetBackfillRequest + """ + return cast( + FeatureSetBackfillRequest, + load_common(FeatureSetBackfillRequest, source, relative_origin, params_override, **kwargs), + ) + + +def load_capability_host( + source: Union[str, PathLike, IO[AnyStr]], + *, + relative_origin: Optional[str] = None, + params_override: Optional[List[Dict]] = None, + **kwargs: Any, +) -> CapabilityHost: + """Constructs a CapabilityHost object from a YAML file. + + :param source: A path to a local YAML file or an already-open file object containing a capabilityhost configuration. + If the source is a path, it will be opened and read. If the source is an open file, the file will be read + directly. + :type source: Union[PathLike, str, io.TextIOWrapper] + :keyword relative_origin: The root directory for the YAML. This directory will be used as the origin for deducing + the relative locations of files referenced in the parsed YAML. Defaults to the same directory as source if + source is a file or file path input. Defaults to "./" if the source is a stream input with no name value. + :paramtype relative_origin: Optional[str] + :keyword params_override: Fields to overwrite on top of the yaml file. + Format is [{"field1": "value1"}, {"field2": "value2"}] + :paramtype params_override: List[Dict] + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if CapabilityHost cannot be successfully validated. + Details will be provided in the error message. + :return: Loaded CapabilityHost object. + :rtype: ~azure.ai.ml.entities._workspace._ai_workspaces.capability_host.CapabilityHost + + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_capability_host.py + :start-after: [START load_capability_host] + :end-before: [END load_capability_host] + :language: python + :dedent: 8 + :caption: Loading a capabilityhost from a YAML config file. + """ + return cast(CapabilityHost, load_common(CapabilityHost, source, relative_origin, params_override, **kwargs)) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_mixins.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_mixins.py new file mode 100644 index 00000000..5b7306f9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_mixins.py @@ -0,0 +1,163 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from abc import abstractmethod +from typing import Any, Dict, Iterator, Optional + +from azure.ai.ml._utils.utils import dump_yaml + + +class RestTranslatableMixin: + def _to_rest_object(self) -> Any: + pass + + @classmethod + def _from_rest_object(cls, obj: Any) -> Any: + pass + + +class DictMixin(object): + def __contains__(self, item: Any) -> bool: + return self.__dict__.__contains__(item) + + def __iter__(self) -> Iterator[str]: + return self.__dict__.__iter__() + + def __setitem__(self, key: Any, item: Any) -> None: + self.__dict__[key] = item + + def __getitem__(self, key: Any) -> Any: + return self.__dict__[key] + + def __repr__(self) -> str: + return str(self) + + def __len__(self) -> int: + return len(self.keys()) + + def __delitem__(self, key: Any) -> None: + self.__dict__[key] = None + + def __eq__(self, other: Any) -> bool: + """Compare objects by comparing all attributes. + + :param other: The other object + :type other: Any + :return: True if both object are the same class and have matching __dict__, False otherwise + :rtype: bool + """ + if isinstance(other, self.__class__): + return self.__dict__ == other.__dict__ + return False + + def __ne__(self, other: Any) -> bool: + """Compare objects by comparing all attributes. + + :param other: The other object + :type other: Any + :return: not self.__eq__(other) + :rtype: bool + """ + return not self.__eq__(other) + + def __str__(self) -> str: + return str({k: v for k, v in self.__dict__.items() if not k.startswith("_") and v is not None}) + + def has_key(self, k: Any) -> bool: + return k in self.__dict__ + + def update(self, *args: Any, **kwargs: Any) -> None: + return self.__dict__.update(*args, **kwargs) + + def keys(self) -> list: + return [k for k in self.__dict__ if not k.startswith("_")] + + def values(self) -> list: + return [v for k, v in self.__dict__.items() if not k.startswith("_")] + + def items(self) -> list: + return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_")] + + def get(self, key: Any, default: Optional[Any] = None) -> Any: + if key in self.__dict__: + return self.__dict__[key] + return default + + +class TelemetryMixin: + # pylint: disable-next=docstring-missing-param + def _get_telemetry_values(self, *args: Any, **kwargs: Any) -> Dict: # pylint: disable=unused-argument + """Return the telemetry values of object. + + :return: The telemetry values + :rtype: Dict + """ + return {} + + +class YamlTranslatableMixin: + @abstractmethod + def _to_dict(self) -> Dict: + """Dump the object into a dictionary.""" + + def _to_ordered_dict_for_yaml_dump(self) -> Dict: + """Dump the object into a dictionary with a specific key order. + + :return: The ordered dict + :rtype: Dict + """ + order_keys = [ + "$schema", + "name", + "version", + "display_name", + "description", + "tags", + "type", + "inputs", + "outputs", + "command", + "environment", + "code", + "resources", + "limits", + "schedule", + "jobs", + ] + nested_keys = ["component", "trial"] + + def _sort_dict_according_to_list(order_keys: Any, dict_value: Any) -> dict: + for nested_key in nested_keys: + if nested_key in dict_value and isinstance(dict_value[nested_key], dict): + dict_value[nested_key] = _sort_dict_according_to_list(order_keys, dict_value[nested_key]) + if "jobs" in dict_value: + for node_name, node in dict_value["jobs"].items(): + dict_value["jobs"][node_name] = _sort_dict_according_to_list(order_keys, node) + difference = list(set(dict_value.keys()).difference(set(order_keys))) + # keys not in order_keys will be put at the end of the list in the order of alphabetic + order_keys.extend(sorted(difference)) + return dict( + sorted( + dict_value.items(), + key=lambda dict_value_: order_keys.index(dict_value_[0]), + ) + ) + + return _sort_dict_according_to_list(order_keys, self._to_dict()) + + def _to_yaml(self) -> str: + """Dump the object content into a sorted yaml string. + + :return: YAML formatted string + :rtype: str + """ + return str(dump_yaml(self._to_ordered_dict_for_yaml_dump(), sort_keys=False)) + + +class LocalizableMixin: + def _localize(self, base_path: str) -> None: + """Called on an asset got from service to clean up remote attributes like id, creation_context, etc. + + :param base_path: The base path + :type base_path: str + """ diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/alert_notification.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/alert_notification.py new file mode 100644 index 00000000..2df0d055 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/alert_notification.py @@ -0,0 +1,54 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import List, Optional + +from azure.ai.ml._restclient.v2023_06_01_preview.models import ( + EmailMonitoringAlertNotificationSettings, + EmailNotificationEnableType, + NotificationSetting, +) +from azure.ai.ml.entities._mixins import RestTranslatableMixin + + +class AlertNotification(RestTranslatableMixin): + """Alert notification configuration for monitoring jobs + + :keyword emails: A list of email addresses that will receive notifications for monitoring alerts. + Defaults to None. + :paramtype emails: Optional[List[str]] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_spark_configurations.py + :start-after: [START spark_monitor_definition] + :end-before: [END spark_monitor_definition] + :language: python + :dedent: 8 + :caption: Configuring alert notifications for a monitored job. + """ + + def __init__( + self, + *, + emails: Optional[List[str]] = None, + ) -> None: + self.emails = emails + + def _to_rest_object( + self, + ) -> EmailMonitoringAlertNotificationSettings: + return EmailMonitoringAlertNotificationSettings( + email_notification_setting=NotificationSetting( + emails=self.emails, + email_on=[ + EmailNotificationEnableType.JOB_FAILED, + EmailNotificationEnableType.JOB_COMPLETED, + ], + ) + ) + + @classmethod + def _from_rest_object(cls, obj: EmailMonitoringAlertNotificationSettings) -> "AlertNotification": + return cls(emails=obj.email_notification_setting.emails) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/compute.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/compute.py new file mode 100644 index 00000000..ff91a814 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/compute.py @@ -0,0 +1,55 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from azure.ai.ml._exception_helper import log_and_raise_error +from azure.ai.ml._restclient.v2023_06_01_preview.models import AmlTokenComputeIdentity, MonitorServerlessSparkCompute +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + + +class ServerlessSparkCompute: + """Serverless Spark compute. + + :param runtime_version: The runtime version of the compute. + :type runtime_version: str + :param instance_type: The instance type of the compute. + :type instance_type: str + """ + + def __init__( + self, + *, + runtime_version: str, + instance_type: str, + ): + self.runtime_version = runtime_version + self.instance_type = instance_type + + def _to_rest_object(self) -> MonitorServerlessSparkCompute: + self._validate() + return MonitorServerlessSparkCompute( + runtime_version=self.runtime_version, + instance_type=self.instance_type, + compute_identity=AmlTokenComputeIdentity( + compute_identity_type="AmlToken", + ), + ) + + @classmethod + def _from_rest_object(cls, obj: MonitorServerlessSparkCompute) -> "ServerlessSparkCompute": + return cls( + runtime_version=obj.runtime_version, + instance_type=obj.instance_type, + ) + + def _validate(self) -> None: + if self.runtime_version != "3.4": + msg = "Compute runtime version must be 3.4" + err = ValidationException( + message=msg, + target=ErrorTarget.MODEL_MONITORING, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.MISSING_FIELD, + ) + log_and_raise_error(err) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/definition.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/definition.py new file mode 100644 index 00000000..3b81be1e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/definition.py @@ -0,0 +1,162 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from typing import Any, Dict, Optional, Union + +from typing_extensions import Literal + +from azure.ai.ml._restclient.v2023_06_01_preview.models import AzMonMonitoringAlertNotificationSettings +from azure.ai.ml._restclient.v2023_06_01_preview.models import MonitorDefinition as RestMonitorDefinition +from azure.ai.ml.constants._monitoring import ( + AZMONITORING, + DEFAULT_DATA_DRIFT_SIGNAL_NAME, + DEFAULT_DATA_QUALITY_SIGNAL_NAME, + DEFAULT_PREDICTION_DRIFT_SIGNAL_NAME, + DEFAULT_TOKEN_USAGE_SIGNAL_NAME, + MonitorTargetTasks, +) +from azure.ai.ml.entities._mixins import RestTranslatableMixin +from azure.ai.ml.entities._monitoring.alert_notification import AlertNotification +from azure.ai.ml.entities._monitoring.compute import ServerlessSparkCompute +from azure.ai.ml.entities._monitoring.signals import ( + CustomMonitoringSignal, + DataDriftSignal, + DataQualitySignal, + FeatureAttributionDriftSignal, + GenerationSafetyQualitySignal, + GenerationTokenStatisticsSignal, + MonitoringSignal, + PredictionDriftSignal, +) +from azure.ai.ml.entities._monitoring.target import MonitoringTarget + + +class MonitorDefinition(RestTranslatableMixin): + """Monitor definition + + :keyword compute: The Spark resource configuration to be associated with the monitor + :paramtype compute: ~azure.ai.ml.entities.SparkResourceConfiguration + :keyword monitoring_target: The ARM ID object associated with the model or deployment that is being monitored. + :paramtype monitoring_target: Optional[~azure.ai.ml.entities.MonitoringTarget] + :keyword monitoring_signals: The dictionary of signals to monitor. The key is the name of the signal and the value + is the DataSignal object. Accepted values for the DataSignal objects are DataDriftSignal, DataQualitySignal, + PredictionDriftSignal, FeatureAttributionDriftSignal, and CustomMonitoringSignal. + :paramtype monitoring_signals: Optional[Dict[str, Union[~azure.ai.ml.entities.DataDriftSignal + , ~azure.ai.ml.entities.DataQualitySignal, ~azure.ai.ml.entities.PredictionDriftSignal + , ~azure.ai.ml.entities.FeatureAttributionDriftSignal + , ~azure.ai.ml.entities.CustomMonitoringSignal + , ~azure.ai.ml.entities.GenerationSafetyQualitySignal + , ~azure.ai.ml.entities.GenerationTokenStatisticsSignal + , ~azure.ai.ml.entities.ModelPerformanceSignal]]] + :keyword alert_notification: The alert configuration for the monitor. + :paramtype alert_notification: Optional[Union[Literal['azmonitoring'], ~azure.ai.ml.entities.AlertNotification]] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_spark_configurations.py + :start-after: [START spark_monitor_definition] + :end-before: [END spark_monitor_definition] + :language: python + :dedent: 8 + :caption: Creating Monitor definition. + """ + + def __init__( + self, + *, + compute: ServerlessSparkCompute, + monitoring_target: Optional[MonitoringTarget] = None, + monitoring_signals: Dict[ + str, + Union[ + DataDriftSignal, + DataQualitySignal, + PredictionDriftSignal, + FeatureAttributionDriftSignal, + CustomMonitoringSignal, + GenerationSafetyQualitySignal, + GenerationTokenStatisticsSignal, + ], + ] = None, # type: ignore[assignment] + alert_notification: Optional[Union[Literal["azmonitoring"], AlertNotification]] = None, + ) -> None: + self.compute = compute + self.monitoring_target = monitoring_target + self.monitoring_signals = monitoring_signals + self.alert_notification = alert_notification + + def _to_rest_object(self, **kwargs: Any) -> RestMonitorDefinition: + default_data_window_size = kwargs.get("default_data_window_size") + ref_data_window_size = kwargs.get("ref_data_window_size") + rest_alert_notification = None + if self.alert_notification: + if isinstance(self.alert_notification, str) and self.alert_notification.lower() == AZMONITORING: + rest_alert_notification = AzMonMonitoringAlertNotificationSettings() + else: + if not isinstance(self.alert_notification, str): + rest_alert_notification = self.alert_notification._to_rest_object() + + if self.monitoring_signals is not None: + _signals = { + signal_name: signal._to_rest_object( + default_data_window_size=default_data_window_size, + ref_data_window_size=ref_data_window_size, + ) + for signal_name, signal in self.monitoring_signals.items() + } + return RestMonitorDefinition( + compute_configuration=self.compute._to_rest_object(), + monitoring_target=self.monitoring_target._to_rest_object() if self.monitoring_target else None, + signals=_signals, # pylint: disable=possibly-used-before-assignment + alert_notification_setting=rest_alert_notification, + ) + + @classmethod + def _from_rest_object( + cls, # pylint: disable=unused-argument + obj: RestMonitorDefinition, + **kwargs: Any, + ) -> "MonitorDefinition": + from_rest_alert_notification: Any = None + if obj.alert_notification_setting: + if isinstance(obj.alert_notification_setting, AzMonMonitoringAlertNotificationSettings): + from_rest_alert_notification = AZMONITORING + else: + from_rest_alert_notification = AlertNotification._from_rest_object(obj.alert_notification_setting) + + _monitoring_signals = {} + for signal_name, signal in obj.signals.items(): + _monitoring_signals[signal_name] = MonitoringSignal._from_rest_object(signal) + + return cls( + compute=ServerlessSparkCompute._from_rest_object(obj.compute_configuration), + monitoring_target=( + MonitoringTarget( + endpoint_deployment_id=obj.monitoring_target.deployment_id, ml_task=obj.monitoring_target.task_type + ) + if obj.monitoring_target + else None + ), + monitoring_signals=_monitoring_signals, # type: ignore[arg-type] + alert_notification=from_rest_alert_notification, + ) + + def _populate_default_signal_information(self) -> None: + if ( + isinstance(self.monitoring_target, MonitoringTarget) + and self.monitoring_target.ml_task is not None + and self.monitoring_target.ml_task.lower() + == MonitorTargetTasks.QUESTION_ANSWERING.lower() # type: ignore[union-attr] + ): + self.monitoring_signals = { + DEFAULT_TOKEN_USAGE_SIGNAL_NAME: GenerationTokenStatisticsSignal._get_default_token_statistics_signal(), + } + else: + self.monitoring_signals = { + DEFAULT_DATA_DRIFT_SIGNAL_NAME: DataDriftSignal._get_default_data_drift_signal(), + DEFAULT_PREDICTION_DRIFT_SIGNAL_NAME: PredictionDriftSignal._get_default_prediction_drift_signal(), + DEFAULT_DATA_QUALITY_SIGNAL_NAME: DataQualitySignal._get_default_data_quality_signal(), + } diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/input_data.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/input_data.py new file mode 100644 index 00000000..10d80531 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/input_data.py @@ -0,0 +1,206 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import datetime +from typing import Dict, Optional + +import isodate + +from azure.ai.ml._restclient.v2023_06_01_preview.models import FixedInputData as RestFixedInputData +from azure.ai.ml._restclient.v2023_06_01_preview.models import MonitoringInputDataBase as RestMonitorInputBase +from azure.ai.ml._restclient.v2023_06_01_preview.models import StaticInputData as RestStaticInputData +from azure.ai.ml._restclient.v2023_06_01_preview.models import TrailingInputData as RestTrailingInputData +from azure.ai.ml._utils.utils import camel_to_snake, snake_to_camel +from azure.ai.ml.constants._monitoring import MonitorDatasetContext, MonitorInputDataType +from azure.ai.ml.entities._mixins import RestTranslatableMixin + + +class MonitorInputData(RestTranslatableMixin): + """Monitor input data. + + :keyword type: Specifies the type of monitoring input data. + :paramtype type: MonitorInputDataType + :keyword input_dataset: Input data used by the monitor + :paramtype input_dataset: Optional[~azure.ai.ml.Input] + :keyword dataset_context: The context of the input dataset. Accepted values are "model_inputs", + "model_outputs", "training", "test", "validation", and "ground_truth". + :paramtype dataset_context: Optional[Union[str, ~azure.ai.ml.constants.MonitorDatasetContext]] + :keyword target_column_name: The target column in the given input dataset. + :paramtype target_column_name: Optional[str] + :keyword pre_processing_component: The ARM (Azure Resource Manager) resource ID of the component resource used to + preprocess the data. + :paramtype pre_processing_component: Optional[str] + """ + + def __init__( + self, + *, + type: Optional[MonitorInputDataType] = None, + data_context: Optional[MonitorDatasetContext] = None, + target_columns: Optional[Dict] = None, + job_type: Optional[str] = None, + uri: Optional[str] = None, + ): + self.type = type + self.data_context = data_context + self.target_columns = target_columns + self.job_type = job_type + self.uri = uri + + @classmethod + def _from_rest_object(cls, obj: RestMonitorInputBase) -> Optional["MonitorInputData"]: + if obj.input_data_type == MonitorInputDataType.FIXED: + return FixedInputData._from_rest_object(obj) + if obj.input_data_type == MonitorInputDataType.TRAILING: + return TrailingInputData._from_rest_object(obj) + if obj.input_data_type == MonitorInputDataType.STATIC: + return StaticInputData._from_rest_object(obj) + + return None + + +class FixedInputData(MonitorInputData): + """ + :ivar type: Specifies the type of monitoring input data. Set automatically to "Fixed" for this class. + :var type: MonitorInputDataType + """ + + def __init__( + self, + *, + data_context: Optional[MonitorDatasetContext] = None, + target_columns: Optional[Dict] = None, + job_type: Optional[str] = None, + uri: Optional[str] = None, + ): + super().__init__( + type=MonitorInputDataType.FIXED, + data_context=data_context, + target_columns=target_columns, + job_type=job_type, + uri=uri, + ) + + def _to_rest_object(self) -> RestFixedInputData: + return RestFixedInputData( + data_context=camel_to_snake(self.data_context), + columns=self.target_columns, + job_input_type=self.job_type, + uri=self.uri, + ) + + @classmethod + def _from_rest_object(cls, obj: RestFixedInputData) -> "FixedInputData": + return cls( + data_context=camel_to_snake(obj.data_context), + target_columns=obj.columns, + job_type=obj.job_input_type, + uri=obj.uri, + ) + + +class TrailingInputData(MonitorInputData): + """ + :ivar type: Specifies the type of monitoring input data. Set automatically to "Trailing" for this class. + :var type: MonitorInputDataType + """ + + def __init__( + self, + *, + data_context: Optional[MonitorDatasetContext] = None, + target_columns: Optional[Dict] = None, + job_type: Optional[str] = None, + uri: Optional[str] = None, + window_size: Optional[str] = None, + window_offset: Optional[str] = None, + pre_processing_component_id: Optional[str] = None, + ): + super().__init__( + type=MonitorInputDataType.TRAILING, + data_context=data_context, + target_columns=target_columns, + job_type=job_type, + uri=uri, + ) + self.window_size = window_size + self.window_offset = window_offset + self.pre_processing_component_id = pre_processing_component_id + + def _to_rest_object(self) -> RestTrailingInputData: + return RestTrailingInputData( + data_context=camel_to_snake(self.data_context), + columns=self.target_columns, + job_input_type=self.job_type, + uri=self.uri, + window_size=self.window_size, + window_offset=self.window_offset, + preprocessing_component_id=self.pre_processing_component_id, + ) + + @classmethod + def _from_rest_object(cls, obj: RestTrailingInputData) -> "TrailingInputData": + return cls( + data_context=snake_to_camel(obj.data_context), + target_columns=obj.columns, + job_type=obj.job_input_type, + uri=obj.uri, + window_size=str(isodate.duration_isoformat(obj.window_size)), + window_offset=str(isodate.duration_isoformat(obj.window_offset)), + pre_processing_component_id=obj.preprocessing_component_id, + ) + + +class StaticInputData(MonitorInputData): + """ + :ivar type: Specifies the type of monitoring input data. Set automatically to "Static" for this class. + :var type: MonitorInputDataType + """ + + def __init__( + self, + *, + data_context: Optional[MonitorDatasetContext] = None, + target_columns: Optional[Dict] = None, + job_type: Optional[str] = None, + uri: Optional[str] = None, + pre_processing_component_id: Optional[str] = None, + window_start: Optional[str] = None, + window_end: Optional[str] = None, + ): + super().__init__( + type=MonitorInputDataType.STATIC, + data_context=data_context, + target_columns=target_columns, + job_type=job_type, + uri=uri, + ) + self.pre_processing_component_id = pre_processing_component_id + self.window_start = window_start + self.window_end = window_end + + def _to_rest_object(self) -> RestStaticInputData: + return RestStaticInputData( + data_context=camel_to_snake(self.data_context), + columns=self.target_columns, + job_input_type=self.job_type, + uri=self.uri, + preprocessing_component_id=self.pre_processing_component_id, + window_start=datetime.datetime.strptime(str(self.window_start), "%Y-%m-%d"), + window_end=datetime.datetime.strptime(str(self.window_end), "%Y-%m-%d"), + ) + + @classmethod + def _from_rest_object(cls, obj: RestStaticInputData) -> "StaticInputData": + return cls( + data_context=snake_to_camel(obj.data_context), + target_columns=obj.columns, + job_type=obj.job_input_type, + uri=obj.uri, + pre_processing_component_id=obj.preprocessing_component_id, + window_start=str(datetime.datetime.strftime(obj.window_start, "%Y-%m-%d")), + window_end=datetime.datetime.strftime(obj.window_end, "%Y-%m-%d"), + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/schedule.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/schedule.py new file mode 100644 index 00000000..f23c4e3e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/schedule.py @@ -0,0 +1,175 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import logging +from os import PathLike +from pathlib import Path +from typing import IO, Any, AnyStr, Dict, Optional, Union, cast + +from azure.ai.ml._restclient.v2023_06_01_preview.models import CreateMonitorAction, RecurrenceFrequency +from azure.ai.ml._restclient.v2023_06_01_preview.models import Schedule as RestSchedule +from azure.ai.ml._restclient.v2023_06_01_preview.models import ScheduleProperties +from azure.ai.ml._schema.monitoring.schedule import MonitorScheduleSchema +from azure.ai.ml._utils.utils import dump_yaml_to_file +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY, ScheduleType +from azure.ai.ml.entities._mixins import RestTranslatableMixin +from azure.ai.ml.entities._monitoring.definition import MonitorDefinition +from azure.ai.ml.entities._schedule.schedule import Schedule +from azure.ai.ml.entities._schedule.trigger import CronTrigger, RecurrenceTrigger, TriggerBase +from azure.ai.ml.entities._system_data import SystemData +from azure.ai.ml.entities._util import load_from_dict + +module_logger = logging.getLogger(__name__) + + +class MonitorSchedule(Schedule, RestTranslatableMixin): + """Monitor schedule. + + :keyword name: The schedule name. + :paramtype name: str + :keyword trigger: The schedule trigger. + :paramtype trigger: Union[~azure.ai.ml.entities.CronTrigger, ~azure.ai.ml.entities.RecurrenceTrigger] + :keyword create_monitor: The schedule action monitor definition. + :paramtype create_monitor: ~azure.ai.ml.entities.MonitorDefinition + :keyword display_name: The display name of the schedule. + :paramtype display_name: Optional[str] + :keyword description: A description of the schedule. + :paramtype description: Optional[str] + :keyword tags: Tag dictionary. Tags can be added, removed, and updated. + :paramtype tags: Optional[dict[str, str]] + :keyword properties: The job property dictionary. + :paramtype properties: Optional[dict[str, str]] + """ + + def __init__( + self, + *, + name: str, + trigger: Optional[Union[CronTrigger, RecurrenceTrigger]], + create_monitor: MonitorDefinition, + display_name: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[Dict] = None, + properties: Optional[Dict] = None, + **kwargs: Any, + ) -> None: + super().__init__( + name=name, + trigger=trigger, + display_name=display_name, + description=description, + tags=tags, + properties=properties, + **kwargs, + ) + self.create_monitor = create_monitor + self._type = ScheduleType.MONITOR + + @classmethod + def _load( + cls, + data: Optional[Dict] = None, + yaml_path: Optional[Union[PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "MonitorSchedule": + data = data or {} + params_override = params_override or [] + context = { + BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"), + PARAMS_OVERRIDE_KEY: params_override, + } + return cls( + base_path=cast(Dict, context[BASE_PATH_CONTEXT_KEY]), + **load_from_dict(MonitorScheduleSchema, data, context, **kwargs), + ) + + def _to_rest_object(self) -> RestSchedule: + if self.tags is not None: + tags = { + **self.tags, + } + # default data window size is calculated based on the trigger frequency + # by default 7 days if user provides incorrect recurrence frequency + # or a cron expression + default_data_window_size = "P7D" + ref_data_window_size = "P14D" + if isinstance(self.trigger, RecurrenceTrigger): + frequency = self.trigger.frequency.lower() + interval = self.trigger.interval + if frequency == RecurrenceFrequency.MINUTE.lower() or frequency == RecurrenceFrequency.HOUR.lower(): + default_data_window_size = "P1D" + ref_data_window_size = "P2D" + elif frequency == RecurrenceFrequency.DAY.lower(): + default_data_window_size = f"P{interval}D" + ref_data_window_size = f"P{interval * 2}D" + elif frequency == RecurrenceFrequency.WEEK.lower(): + default_data_window_size = f"P{interval * 7}D" + ref_data_window_size = f"P{(interval * 7) * 2}D" + elif frequency == RecurrenceFrequency.MONTH.lower(): + default_data_window_size = f"P{interval * 30}D" + ref_data_window_size = f"P{(interval * 30) * 2}D" + + return RestSchedule( + properties=ScheduleProperties( + description=self.description, + properties=self.properties, + tags=tags, # pylint: disable=possibly-used-before-assignment + action=CreateMonitorAction( + monitor_definition=self.create_monitor._to_rest_object( + default_data_window_size=default_data_window_size, ref_data_window_size=ref_data_window_size + ) + ), + display_name=self.display_name, + is_enabled=self._is_enabled, + trigger=self.trigger._to_rest_object() if self.trigger is not None else None, + ) + ) + + def dump(self, dest: Union[str, PathLike, IO[AnyStr]], **kwargs: Any) -> None: + """Dump the asset content into a file in YAML format. + + :param dest: The local path or file stream to write the YAML content to. + If dest is a file path, a new file will be created. + If dest is an open file, the file will be written to directly. + :type dest: Union[PathLike, str, IO[AnyStr]] + :raises FileExistsError: Raised if dest is a file path and the file already exists. + :raises IOError: Raised if dest is an open file and the file is not writable. + """ + path = kwargs.pop("path", None) + yaml_serialized = self._to_dict() + dump_yaml_to_file(dest, yaml_serialized, default_flow_style=False, path=path, **kwargs) + + def _to_dict(self) -> Dict: + res: dict = MonitorScheduleSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + @classmethod + def _from_rest_object(cls, obj: RestSchedule) -> "MonitorSchedule": + properties = obj.properties + return cls( + trigger=TriggerBase._from_rest_object(properties.trigger), + create_monitor=MonitorDefinition._from_rest_object( + properties.action.monitor_definition, tags=obj.properties.tags + ), + name=obj.name, + id=obj.id, + display_name=properties.display_name, + description=properties.description, + tags=properties.tags, + properties=properties.properties, + provisioning_state=properties.provisioning_state, + is_enabled=properties.is_enabled, + creation_context=SystemData._from_rest_object(obj.system_data) if obj.system_data else None, + ) + + def _create_default_monitor_definition(self) -> None: + self.create_monitor._populate_default_signal_information() + + def _set_baseline_data_trailing_tags_for_signal(self, signal_name: str) -> None: + if self.tags is not None: + self.tags[f"{signal_name}.baselinedata.datarange.type"] = "Trailing" + self.tags[f"{signal_name}.baselinedata.datarange.window_size"] = "P7D" diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/signals.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/signals.py new file mode 100644 index 00000000..5a9e1df7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/signals.py @@ -0,0 +1,1338 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access, too-many-lines + +import datetime +from typing import Any, Dict, List, Optional, Union + +import isodate +from typing_extensions import Literal + +from azure.ai.ml._exception_helper import log_and_raise_error +from azure.ai.ml._restclient.v2023_06_01_preview.models import AllFeatures as RestAllFeatures +from azure.ai.ml._restclient.v2023_06_01_preview.models import CustomMonitoringSignal as RestCustomMonitoringSignal +from azure.ai.ml._restclient.v2023_06_01_preview.models import ( + DataDriftMonitoringSignal as RestMonitoringDataDriftSignal, +) +from azure.ai.ml._restclient.v2023_06_01_preview.models import ( + DataQualityMonitoringSignal as RestMonitoringDataQualitySignal, +) +from azure.ai.ml._restclient.v2023_06_01_preview.models import ( + FeatureAttributionDriftMonitoringSignal as RestFeatureAttributionDriftMonitoringSignal, +) +from azure.ai.ml._restclient.v2023_06_01_preview.models import FeatureSubset as RestFeatureSubset +from azure.ai.ml._restclient.v2023_06_01_preview.models import ( + GenerationSafetyQualityMonitoringSignal as RestGenerationSafetyQualityMonitoringSignal, +) +from azure.ai.ml._restclient.v2023_06_01_preview.models import ( + GenerationTokenStatisticsSignal as RestGenerationTokenStatisticsSignal, +) +from azure.ai.ml._restclient.v2023_06_01_preview.models import ModelPerformanceSignal as RestModelPerformanceSignal +from azure.ai.ml._restclient.v2023_06_01_preview.models import MonitoringDataSegment as RestMonitoringDataSegment +from azure.ai.ml._restclient.v2023_06_01_preview.models import ( + MonitoringFeatureFilterBase as RestMonitoringFeatureFilterBase, +) +from azure.ai.ml._restclient.v2023_06_01_preview.models import MonitoringInputDataBase as RestMonitoringInputData +from azure.ai.ml._restclient.v2023_06_01_preview.models import MonitoringNotificationMode +from azure.ai.ml._restclient.v2023_06_01_preview.models import MonitoringSignalBase as RestMonitoringSignalBase +from azure.ai.ml._restclient.v2023_06_01_preview.models import MonitoringSignalType +from azure.ai.ml._restclient.v2023_06_01_preview.models import ( + MonitoringWorkspaceConnection as RestMonitoringWorkspaceConnection, +) +from azure.ai.ml._restclient.v2023_06_01_preview.models import ( + PredictionDriftMonitoringSignal as RestPredictionDriftMonitoringSignal, +) +from azure.ai.ml._restclient.v2023_06_01_preview.models import ( + TopNFeaturesByAttribution as RestTopNFeaturesByAttribution, +) +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.constants._monitoring import ( + ALL_FEATURES, + MonitorDatasetContext, + MonitorFeatureDataType, + MonitorSignalType, +) +from azure.ai.ml.entities._inputs_outputs import Input +from azure.ai.ml.entities._job._input_output_helpers import ( + from_rest_inputs_to_dataset_literal, + to_rest_dataset_literal_inputs, +) +from azure.ai.ml.entities._mixins import RestTranslatableMixin +from azure.ai.ml.entities._monitoring.input_data import FixedInputData, StaticInputData, TrailingInputData +from azure.ai.ml.entities._monitoring.thresholds import ( + CustomMonitoringMetricThreshold, + DataDriftMetricThreshold, + DataQualityMetricThreshold, + FeatureAttributionDriftMetricThreshold, + GenerationSafetyQualityMonitoringMetricThreshold, + GenerationTokenStatisticsMonitorMetricThreshold, + MetricThreshold, + ModelPerformanceMetricThreshold, + PredictionDriftMetricThreshold, +) +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + + +class DataSegment(RestTranslatableMixin): + """Data segment for monitoring. + + :keyword feature_name: The feature to segment the data on. + :paramtype feature_name: str + :keyword feature_values: A list of values for the given segmented feature to filter. + :paramtype feature_values: List[str] + """ + + def __init__( + self, + *, + feature_name: Optional[str] = None, + feature_values: Optional[List[str]] = None, + ) -> None: + self.feature_name = feature_name + self.feature_values = feature_values + + def _to_rest_object(self) -> RestMonitoringDataSegment: + return RestMonitoringDataSegment(feature=self.feature_name, values=self.feature_values) + + @classmethod + def _from_rest_object(cls, obj: RestMonitoringDataSegment) -> "DataSegment": + return cls( + feature_name=obj.feature, + feature_values=obj.values, + ) + + +class MonitorFeatureFilter(RestTranslatableMixin): + """Monitor feature filter + + :keyword top_n_feature_importance: The number of top features to include. Defaults to 10. + :paramtype top_n_feature_importance: int + """ + + def __init__( + self, + *, + top_n_feature_importance: int = 10, + ) -> None: + self.top_n_feature_importance = top_n_feature_importance + + def _to_rest_object(self) -> RestTopNFeaturesByAttribution: + return RestTopNFeaturesByAttribution( + top=self.top_n_feature_importance, + ) + + @classmethod + def _from_rest_object(cls, obj: RestTopNFeaturesByAttribution) -> "MonitorFeatureFilter": + return cls(top_n_feature_importance=obj.top) + + +class BaselineDataRange: + """Baseline data range for monitoring. + + This class is used when initializing a data_window for a ReferenceData object. + For trailing input, set lookback_window_size and lookback_window_offset to a desired value. + For static input, set window_start and window_end to a desired value. + """ + + def __init__( + self, + *, + window_start: Optional[str] = None, + window_end: Optional[str] = None, + lookback_window_size: Optional[str] = None, + lookback_window_offset: Optional[str] = None, + ): + self.window_start = window_start + self.window_end = window_end + self.lookback_window_size = lookback_window_size + self.lookback_window_offset = lookback_window_offset + + +class ProductionData(RestTranslatableMixin): + """Production Data + + :param input_data: The data for which drift will be calculated + :type Input: ~azure.ai.ml.entities._input_outputs + :param data_context: The context of the input dataset. Possible values + include: model_inputs, model_outputs, training, test, validation, ground_truth + :type MonitorDatasetContext: ~azure.ai.ml.constants.MonitorDatasetContext + :param pre_processing_component: ARM resource ID of the component resource used to + preprocess the data. + :type pre_processing_component: string + :param data_window: The number of days or a time frame that a singal monitor looks back over the target. + :type data_window_size: BaselineDataRange + """ + + def __init__( + self, + *, + input_data: Input, + data_context: Optional[MonitorDatasetContext] = None, + pre_processing_component: Optional[str] = None, + data_window: Optional[BaselineDataRange] = None, + data_column_names: Optional[Dict[str, str]] = None, + ): + self.input_data = input_data + self.data_context = data_context + self.pre_processing_component = pre_processing_component + self.data_window = data_window + self.data_column_names = data_column_names + + def _to_rest_object(self, **kwargs: Any) -> RestMonitoringInputData: + self._validate() + default_data_window_size = kwargs.get("default_data_window_size") + if self.data_window is None: + self.data_window = BaselineDataRange( + lookback_window_size=default_data_window_size, lookback_window_offset="P0D" + ) + if self.data_window.lookback_window_size in ["default", None]: + self.data_window.lookback_window_size = default_data_window_size + uri = self.input_data.path + job_type = self.input_data.type + monitoring_input_data = TrailingInputData( + data_context=self.data_context, + target_columns=self.data_column_names, + job_type=job_type, + uri=uri, + pre_processing_component_id=self.pre_processing_component, + window_size=self.data_window.lookback_window_size, + window_offset=( + self.data_window.lookback_window_offset + if self.data_window.lookback_window_offset is not None + else "P0D" + ), + ) + return monitoring_input_data._to_rest_object() + + @classmethod + def _from_rest_object(cls, obj: RestMonitoringInputData) -> "ProductionData": + data_window = BaselineDataRange( + lookback_window_size=isodate.duration_isoformat(obj.window_size), + lookback_window_offset=isodate.duration_isoformat(obj.window_offset), + ) + return cls( + input_data=Input( + path=obj.uri, + type=obj.job_input_type, + ), + data_context=obj.data_context, + pre_processing_component=obj.preprocessing_component_id, + data_window=data_window, + data_column_names=obj.columns, + ) + + def _validate(self) -> None: + if self.data_window: + if self.data_window.window_start or self.data_window.window_end: + msg = "ProductionData only accepts lookback_window_size and lookback_window_offset." + err = ValidationException( + message=msg, + target=ErrorTarget.MODEL_MONITORING, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.MISSING_FIELD, + ) + log_and_raise_error(err) + + +class ReferenceData(RestTranslatableMixin): + """Reference Data + + :param input_data: The data for which drift will be calculated + :type Input: ~azure.ai.ml.entities._input_outputs + :param data_context: The context of the input dataset. Possible values + include: model_inputs, model_outputs, training, test, validation, ground_truth + :type MonitorDatasetContext: ~azure.ai.ml.constants.MonitorDatasetContext + :param pre_processing_component: ARM resource ID of the component resource used to + preprocess the data. + :type pre_processing_component: string + :param target_column_name: The name of the target column in the dataset. + :type target_column_name: string + :param data_window: The number of days or a time frame that a single monitor looks back over the target. + :type data_window_size: BaselineDataRange + """ + + def __init__( + self, + *, + input_data: Input, + data_context: Optional[MonitorDatasetContext] = None, + pre_processing_component: Optional[str] = None, + data_window: Optional[BaselineDataRange] = None, + data_column_names: Optional[Dict[str, str]] = None, + ): + self.input_data = input_data + self.data_context = data_context + self.pre_processing_component = pre_processing_component + self.data_window = data_window + self.data_column_names = data_column_names + + def _to_rest_object(self, **kwargs: Any) -> RestMonitoringInputData: + default_data_window = kwargs.get("default_data_window") + ref_data_window_size = kwargs.get("ref_data_window_size") + if self.data_window is not None: + if self.data_window.lookback_window_size is not None: + if self.data_window.lookback_window_size == "default": + self.data_window.lookback_window_size = ref_data_window_size + if self.data_window.lookback_window_offset == "default": + self.data_window.lookback_window_offset = default_data_window + return TrailingInputData( + data_context=self.data_context, + target_columns=self.data_column_names, + job_type=self.input_data.type, + uri=self.input_data.path, + pre_processing_component_id=self.pre_processing_component, + window_size=self.data_window.lookback_window_size, + window_offset=( + self.data_window.lookback_window_offset + if self.data_window.lookback_window_offset is not None + else "P0D" + ), + )._to_rest_object() + if self.data_window.window_start is not None and self.data_window.window_end is not None: + return StaticInputData( + data_context=self.data_context, + target_columns=self.data_column_names, + job_type=self.input_data.type, + uri=self.input_data.path, + pre_processing_component_id=self.pre_processing_component, + window_start=self.data_window.window_start, + window_end=self.data_window.window_end, + )._to_rest_object() + + return FixedInputData( + data_context=self.data_context, + target_columns=self.data_column_names, + job_type=self.input_data.type, + uri=self.input_data.path, + )._to_rest_object() + + @classmethod + def _from_rest_object(cls, obj: RestMonitoringInputData) -> "ReferenceData": + data_window = None + if obj.input_data_type == "Static": + data_window = BaselineDataRange( + window_start=datetime.datetime.strftime(obj.window_start, "%Y-%m-%d"), + window_end=datetime.datetime.strftime(obj.window_end, "%Y-%m-%d"), + ) + if obj.input_data_type == "Trailing": + data_window = BaselineDataRange( + lookback_window_size=isodate.duration_isoformat(obj.window_size), + lookback_window_offset=isodate.duration_isoformat(obj.window_offset), + ) + + return cls( + input_data=Input( + path=obj.uri, + type=obj.job_input_type, + ), + data_context=obj.data_context, + pre_processing_component=obj.preprocessing_component_id if obj.input_data_type != "Fixed" else None, + data_window=data_window, + data_column_names=obj.columns, + ) + + +class MonitoringSignal(RestTranslatableMixin): + """ + Base class for monitoring signals. + + This class should not be instantiated directly. Instead, use one of its subclasses. + + :keyword baseline_dataset: The baseline dataset definition for monitor input. + :paramtype baseline_dataset: ~azure.ai.ml.entities.MonitorInputData + :keyword metric_thresholds: The metric thresholds for the signal. + :paramtype metric_thresholds: Union[ + ~azure.ai.ml.entities.DataDriftMetricThreshold, + ~azure.ai.ml.entities.DataQualityMetricThreshold, + ~azure.ai.ml.entities.PredictionDriftMetricThreshold, + ~azure.ai.ml.entities.FeatureAttributionDriftMetricThreshold, + ~azure.ai.ml.entities.CustomMonitoringMetricThreshold, + ~azure.ai.ml.entities.GenerationSafetyQualityMonitoringMetricThreshold, + List[Union[ + ~azure.ai.ml.entities.DataDriftMetricThreshold, + ~azure.ai.ml.entities.DataQualityMetricThreshold, + ~azure.ai.ml.entities.PredictionDriftMetricThreshold, + ~azure.ai.ml.entities.FeatureAttributionDriftMetricThreshold, + ~azure.ai.ml.entities.CustomMonitoringMetricThreshold, + ~azure.ai.ml.entities.GenerationSafetyQualityMonitoringMetricThreshold, + + ]]] + :keyword alert_enabled: Whether or not to enable alerts for the signal. Defaults to False. + :paramtype alert_enabled: bool + """ + + def __init__( + self, + *, + production_data: Optional[ProductionData] = None, + reference_data: Optional[ReferenceData] = None, + metric_thresholds: Optional[Union[MetricThreshold, List[MetricThreshold]]], + properties: Optional[Dict[str, str]] = None, + alert_enabled: bool = False, + ): + self.production_data = production_data + self.reference_data = reference_data + self.metric_thresholds = metric_thresholds + self.alert_enabled = alert_enabled + self.properties = properties + + @classmethod + def _from_rest_object(cls, obj: RestMonitoringSignalBase) -> Optional[ # pylint: disable=too-many-return-statements + Union[ + "DataDriftSignal", + "DataQualitySignal", + "PredictionDriftSignal", + "ModelPerformanceSignal", + "FeatureAttributionDriftSignal", + "CustomMonitoringSignal", + "GenerationSafetyQualitySignal", + "GenerationTokenStatisticsSignal", + ] + ]: + if obj.signal_type == MonitoringSignalType.DATA_DRIFT: + return DataDriftSignal._from_rest_object(obj) + if obj.signal_type == MonitoringSignalType.DATA_QUALITY: + return DataQualitySignal._from_rest_object(obj) + if obj.signal_type == MonitoringSignalType.PREDICTION_DRIFT: + return PredictionDriftSignal._from_rest_object(obj) + if obj.signal_type == "ModelPerformanceSignalBase": + return ModelPerformanceSignal._from_rest_object(obj) + if obj.signal_type == MonitoringSignalType.FEATURE_ATTRIBUTION_DRIFT: + return FeatureAttributionDriftSignal._from_rest_object(obj) + if obj.signal_type == MonitoringSignalType.CUSTOM: + return CustomMonitoringSignal._from_rest_object(obj) + if obj.signal_type == MonitoringSignalType.GENERATION_SAFETY_QUALITY: + return GenerationSafetyQualitySignal._from_rest_object(obj) + if obj.signal_type == MonitoringSignalType.MODEL_PERFORMANCE: + return ModelPerformanceSignal._from_rest_object(obj) + if obj.signal_type == MonitoringSignalType.GENERATION_TOKEN_STATISTICS: + return GenerationTokenStatisticsSignal._from_rest_object(obj) + + return None + + +class DataSignal(MonitoringSignal): + """Base class for data signals. + + This class should not be instantiated directly. Instead, use one of its subclasses. + + :keyword baseline_dataset: The baseline dataset definition for monitor input. + :paramtype baseline_dataset: ~azure.ai.ml.entities.MonitorInputData + :keyword features: The features to include in the signal. + :paramtype features: Union[List[str], ~azure.ai.ml.entities.MonitorFeatureFilter, Literal[ALL_FEATURES]] + :keyword metric_thresholds: The metric thresholds for the signal. + :paramtype metric_thresholds: List[Union[ + ~azure.ai.ml.entities.DataDriftMetricThreshold, + ~azure.ai.ml.entities.DataQualityMetricThreshold, + ~azure.ai.ml.entities.PredictionDriftMetricThreshold, + ~azure.ai.ml.entities.FeatureAttributionDriftMetricThreshold, + ~azure.ai.ml.entities.CustomMonitoringMetricThreshold, + ~azure.ai.ml.entities.GenerationSafetyQualityMonitoringMetricThreshold, + + ]] + :keyword alert_enabled: Whether or not to enable alerts for the signal. Defaults to False. + :paramtype alert_enabled: bool + """ + + def __init__( + self, + *, + production_data: Optional[ProductionData] = None, + reference_data: Optional[ReferenceData] = None, + features: Optional[Union[List[str], MonitorFeatureFilter, Literal["all_features"]]] = None, + feature_type_override: Optional[Dict[str, Union[str, MonitorFeatureDataType]]] = None, + metric_thresholds: Optional[Union[MetricThreshold, List[MetricThreshold]]], + alert_enabled: bool = False, + properties: Optional[Dict[str, str]] = None, + ): + super().__init__( + production_data=production_data, + reference_data=reference_data, + metric_thresholds=metric_thresholds, + alert_enabled=alert_enabled, + properties=properties, + ) + self.features = features + self.feature_type_override = feature_type_override + + +class DataDriftSignal(DataSignal): + """Data drift signal. + + :ivar type: The type of the signal, set to "data_drift" for this class. + :vartype type: str + :param production_data: The data for which drift will be calculated + :paramtype production_data: ~azure.ai.ml.entities.ProductionData + :param reference_data: The data to calculate drift against + :paramtype reference_data: ~azure.ai.ml.entities.ReferenceData + :param metric_thresholds: Metrics to calculate and their associated thresholds + :paramtype metric_thresholds: ~azure.ai.ml.entities.DataDriftMetricThreshold + :param alert_enabled: Whether or not to enable alerts for the signal. Defaults to False. + :paramtype alert_enabled: bool + :param data_segment: The data segment used for scoping on a subset of the data population. + :paramtype data_segment: ~azure.ai.ml.entities.DataSegment + :keyword features: The feature filter identifying which feature(s) to calculate drift over. + :paramtype features: Union[List[str], ~azure.ai.ml.entities.MonitorFeatureFilter, Literal['all_features']] + :param feature_type_override: Dictionary of features and what they should be overridden to. + :paramtype feature_type_override: dict[str, str] + :param properties: Dictionary of additional properties. + :paramtype properties: dict[str, str] + """ + + def __init__( + self, + *, + production_data: Optional[ProductionData] = None, + reference_data: Optional[ReferenceData] = None, + features: Optional[Union[List[str], MonitorFeatureFilter, Literal["all_features"]]] = None, + feature_type_override: Optional[Dict[str, Union[str, MonitorFeatureDataType]]] = None, + metric_thresholds: Optional[Union[DataDriftMetricThreshold, List[MetricThreshold]]] = None, + alert_enabled: bool = False, + data_segment: Optional[DataSegment] = None, + properties: Optional[Dict[str, str]] = None, + ): + super().__init__( + production_data=production_data, + reference_data=reference_data, + metric_thresholds=metric_thresholds, + features=features, + feature_type_override=feature_type_override, + alert_enabled=alert_enabled, + properties=properties, + ) + self.type = MonitorSignalType.DATA_DRIFT + self.data_segment = data_segment + + def _to_rest_object(self, **kwargs: Any) -> RestMonitoringDataDriftSignal: + default_data_window_size = kwargs.get("default_data_window_size") + ref_data_window_size = kwargs.get("ref_data_window_size") + if self.production_data is not None and self.production_data.data_window is None: + self.production_data.data_window = BaselineDataRange(lookback_window_size=default_data_window_size) + rest_features = _to_rest_features(self.features) if self.features else None + return RestMonitoringDataDriftSignal( + production_data=( + self.production_data._to_rest_object(default_data_window_size=default_data_window_size) + if self.production_data is not None + else None + ), + reference_data=( + self.reference_data._to_rest_object( + default_data_window=default_data_window_size, ref_data_window_size=ref_data_window_size + ) + if self.reference_data is not None + else None + ), + features=rest_features, + feature_data_type_override=self.feature_type_override, + metric_thresholds=( + self.metric_thresholds._to_rest_object() + if isinstance(self.metric_thresholds, MetricThreshold) + else None + ), + mode=MonitoringNotificationMode.ENABLED if self.alert_enabled else MonitoringNotificationMode.DISABLED, + data_segment=self.data_segment._to_rest_object() if self.data_segment else None, + properties=self.properties, + ) + + @classmethod + def _from_rest_object(cls, obj: RestMonitoringDataDriftSignal) -> "DataDriftSignal": + return cls( + production_data=ProductionData._from_rest_object(obj.production_data), + reference_data=ReferenceData._from_rest_object(obj.reference_data), + features=_from_rest_features(obj.features), + feature_type_override=obj.feature_data_type_override, + metric_thresholds=DataDriftMetricThreshold._from_rest_object(obj.metric_thresholds), + alert_enabled=( + False + if not obj.mode or (obj.mode and obj.mode == MonitoringNotificationMode.DISABLED) + else MonitoringNotificationMode.ENABLED + ), + data_segment=DataSegment._from_rest_object(obj.data_segment) if obj.data_segment else None, + properties=obj.properties, + ) + + @classmethod + def _get_default_data_drift_signal(cls) -> "DataDriftSignal": + return cls( + features=ALL_FEATURES, # type: ignore[arg-type] + metric_thresholds=DataDriftMetricThreshold._get_default_thresholds(), + ) + + +class PredictionDriftSignal(MonitoringSignal): + """Prediction drift signal. + + :ivar type: The type of the signal, set to "prediction_drift" for this class. + :vartype type: str + :param production_data: The data for which drift will be calculated + :paramtype production_data: ~azure.ai.ml.entities.ProductionData + :param reference_data: The data to calculate drift against + :paramtype reference_data: ~azure.ai.ml.entities.ReferenceData + :param metric_thresholds: Metrics to calculate and their associated thresholds + :paramtype metric_thresholds: ~azure.ai.ml.entities.DataDriftMetricThreshold + :param alert_enabled: Whether or not to enable alerts for the signal. Defaults to False. + :paramtype alert_enabled: bool + :param properties: Dictionary of additional properties. + :paramtype properties: dict[str, str] + """ + + def __init__( + self, + *, + production_data: Optional[ProductionData] = None, + reference_data: Optional[ReferenceData] = None, + metric_thresholds: PredictionDriftMetricThreshold, + alert_enabled: bool = False, + properties: Optional[Dict[str, str]] = None, + ): + super().__init__( + production_data=production_data, + reference_data=reference_data, + metric_thresholds=metric_thresholds, + alert_enabled=alert_enabled, + properties=properties, + ) + self.type = MonitorSignalType.PREDICTION_DRIFT + + def _to_rest_object(self, **kwargs: Any) -> RestPredictionDriftMonitoringSignal: + default_data_window_size = kwargs.get("default_data_window_size") + ref_data_window_size = kwargs.get("ref_data_window_size") + if self.production_data is not None and self.production_data.data_window is None: + self.production_data.data_window = BaselineDataRange(lookback_window_size=default_data_window_size) + return RestPredictionDriftMonitoringSignal( + production_data=( + self.production_data._to_rest_object(default_data_window_size=default_data_window_size) + if self.production_data is not None + else None + ), + reference_data=( + self.reference_data._to_rest_object( + default_data_window=default_data_window_size, ref_data_window_size=ref_data_window_size + ) + if self.reference_data is not None + else None + ), + metric_thresholds=( + self.metric_thresholds._to_rest_object() + if isinstance(self.metric_thresholds, MetricThreshold) + else None + ), + properties=self.properties, + mode=MonitoringNotificationMode.ENABLED if self.alert_enabled else MonitoringNotificationMode.DISABLED, + model_type="classification", + ) + + @classmethod + def _from_rest_object(cls, obj: RestPredictionDriftMonitoringSignal) -> "PredictionDriftSignal": + return cls( + production_data=ProductionData._from_rest_object(obj.production_data), + reference_data=ReferenceData._from_rest_object(obj.reference_data), + metric_thresholds=PredictionDriftMetricThreshold._from_rest_object(obj.metric_thresholds), + alert_enabled=( + False + if not obj.mode or (obj.mode and obj.mode == MonitoringNotificationMode.DISABLED) + else MonitoringNotificationMode.ENABLED + ), + properties=obj.properties, + ) + + @classmethod + def _get_default_prediction_drift_signal(cls) -> "PredictionDriftSignal": + return cls( + metric_thresholds=PredictionDriftMetricThreshold._get_default_thresholds(), + ) + + +class DataQualitySignal(DataSignal): + """Data quality signal + + :ivar type: The type of the signal. Set to "data_quality" for this class. + :vartype type: str + :param production_data: The data for which drift will be calculated + :paramtype production_data: ~azure.ai.ml.entities.ProductionData + :param reference_data: The data to calculate drift against + :paramtype reference_data: ~azure.ai.ml.entities.ReferenceData + :param metric_thresholds: Metrics to calculate and their associated thresholds + :paramtype metric_thresholds: ~azure.ai.ml.entities.DataDriftMetricThreshold + :param alert_enabled: Whether or not to enable alerts for the signal. Defaults to False. + :paramtype alert_enabled: bool + :keyword features: The feature filter identifying which feature(s) to calculate drift over. + :paramtype features: Union[List[str], ~azure.ai.ml.entities.MonitorFeatureFilter, Literal['all_features']] + :param feature_type_override: Dictionary of features and what they should be overridden to. + :paramtype feature_type_override: dict[str, str] + :param properties: Dictionary of additional properties. + :paramtype properties: dict[str, str] + """ + + def __init__( + self, + *, + production_data: Optional[ProductionData] = None, + reference_data: Optional[ReferenceData] = None, + features: Optional[Union[List[str], MonitorFeatureFilter, Literal["all_features"]]] = None, + feature_type_override: Optional[Dict[str, Union[str, MonitorFeatureDataType]]] = None, + metric_thresholds: Optional[Union[MetricThreshold, List[MetricThreshold]]] = None, + alert_enabled: bool = False, + properties: Optional[Dict[str, str]] = None, + ): + super().__init__( + production_data=production_data, + reference_data=reference_data, + metric_thresholds=metric_thresholds, + features=features, + feature_type_override=feature_type_override, + alert_enabled=alert_enabled, + properties=properties, + ) + self.type = MonitorSignalType.DATA_QUALITY + + def _to_rest_object(self, **kwargs: Any) -> RestMonitoringDataQualitySignal: + default_data_window_size = kwargs.get("default_data_window_size") + ref_data_window_size = kwargs.get("ref_data_window_size") + if self.production_data is not None and self.production_data.data_window is None: + self.production_data.data_window = BaselineDataRange( + lookback_window_size=default_data_window_size, + ) + rest_features = _to_rest_features(self.features) if self.features else None + rest_metrics = ( + # TODO: Bug Item number: 2883365 + _to_rest_data_quality_metrics( + self.metric_thresholds.numerical, self.metric_thresholds.categorical # type: ignore + ) + if isinstance(self.metric_thresholds, MetricThreshold) + else None + ) + return RestMonitoringDataQualitySignal( + production_data=( + self.production_data._to_rest_object(default_data_window_size=default_data_window_size) + if self.production_data is not None + else None + ), + reference_data=( + self.reference_data._to_rest_object( + default_data_window=default_data_window_size, ref_data_window_size=ref_data_window_size + ) + if self.reference_data is not None + else None + ), + features=rest_features, + feature_data_type_override=self.feature_type_override, + metric_thresholds=rest_metrics, + mode=MonitoringNotificationMode.ENABLED if self.alert_enabled else MonitoringNotificationMode.DISABLED, + properties=self.properties, + ) + + @classmethod + def _from_rest_object(cls, obj: RestMonitoringDataQualitySignal) -> "DataQualitySignal": + return cls( + production_data=ProductionData._from_rest_object(obj.production_data), + reference_data=ReferenceData._from_rest_object(obj.reference_data), + features=_from_rest_features(obj.features), + feature_type_override=obj.feature_data_type_override, + metric_thresholds=DataQualityMetricThreshold._from_rest_object(obj.metric_thresholds), + alert_enabled=( + False + if not obj.mode or (obj.mode and obj.mode == MonitoringNotificationMode.DISABLED) + else MonitoringNotificationMode.ENABLED + ), + properties=obj.properties, + ) + + @classmethod + def _get_default_data_quality_signal( + cls, + ) -> "DataQualitySignal": + return cls( + features=ALL_FEATURES, # type: ignore[arg-type] + metric_thresholds=DataQualityMetricThreshold._get_default_thresholds(), + ) + + +@experimental +class FADProductionData(RestTranslatableMixin): + """Feature Attribution Production Data + + :keyword input_data: Input data used by the monitor. + :paramtype input_data: ~azure.ai.ml.Input + :keyword data_context: The context of the input dataset. Accepted values are "model_inputs", + "model_outputs", "training", "test", "validation", and "ground_truth". + :paramtype data_context: ~azure.ai.ml.constants._monitoring + :keyword data_column_names: The names of the columns in the input data. + :paramtype data_column_names: Dict[str, str] + :keyword pre_processing_component: The ARM (Azure Resource Manager) resource ID of the component resource used to + preprocess the data. + :paramtype pre_processing_component: string + :param data_window: The number of days or a time frame that a singal monitor looks back over the target. + :type data_window: BaselineDataRange + """ + + def __init__( + self, + *, + input_data: Input, + data_context: Optional[MonitorDatasetContext] = None, + data_column_names: Optional[Dict[str, str]] = None, + pre_processing_component: Optional[str] = None, + data_window: Optional[BaselineDataRange] = None, + ): + self.input_data = input_data + self.data_context = data_context + self.data_column_names = data_column_names + self.pre_processing_component = pre_processing_component + self.data_window = data_window + + def _to_rest_object(self, **kwargs: Any) -> RestMonitoringInputData: + default_data_window_size = kwargs.get("default") + if self.data_window is None: + self.data_window = BaselineDataRange( + lookback_window_size=default_data_window_size, lookback_window_offset="P0D" + ) + if self.data_window.lookback_window_size == "default": + self.data_window.lookback_window_size = default_data_window_size + uri = self.input_data.path + job_type = self.input_data.type + monitoring_input_data = TrailingInputData( + data_context=self.data_context, + target_columns=self.data_column_names, + job_type=job_type, + uri=uri, + pre_processing_component_id=self.pre_processing_component, + window_size=self.data_window.lookback_window_size, + window_offset=( + self.data_window.lookback_window_offset + if self.data_window.lookback_window_offset is not None + else "P0D" + ), + ) + return monitoring_input_data._to_rest_object() + + @classmethod + def _from_rest_object(cls, obj: RestMonitoringInputData) -> "FADProductionData": + data_window = BaselineDataRange( + lookback_window_size=isodate.duration_isoformat(obj.window_size), + lookback_window_offset=isodate.duration_isoformat(obj.window_offset), + ) + return cls( + input_data=Input( + path=obj.uri, + type=obj.job_input_type, + ), + data_context=obj.data_context, + data_column_names=obj.columns, + pre_processing_component=obj.preprocessing_component_id, + data_window=data_window, + ) + + +@experimental +class FeatureAttributionDriftSignal(RestTranslatableMixin): + """Feature attribution drift signal + + :ivar type: The type of the signal. Set to "feature_attribution_drift" for this class. + :vartype type: str + :keyword production_data: The data for which drift will be calculated. + :paratype production_data: ~azure.ai.ml.entities.FADProductionData + :keyword reference_data: The data to calculate drift against. + :paramtype reference_data: ~azure.ai.ml.entities.ReferenceData + :keyword metric_thresholds: Metrics to calculate and their + associated thresholds. + :paramtype metric_thresholds: ~azure.ai.ml.entities.FeatureAttributionDriftMetricThreshold + :keyword alert_enabled: Whether or not to enable alerts for the signal. Defaults to False. + :paramtype alert_enabled: bool + """ + + def __init__( + self, + *, + production_data: Optional[List[FADProductionData]] = None, + reference_data: ReferenceData, + metric_thresholds: FeatureAttributionDriftMetricThreshold, + alert_enabled: bool = False, + properties: Optional[Dict[str, str]] = None, + ): + self.production_data = production_data + self.reference_data = reference_data + self.metric_thresholds = metric_thresholds + self.alert_enabled = alert_enabled + self.properties = properties + self.type = MonitorSignalType.FEATURE_ATTRIBUTION_DRIFT + + def _to_rest_object(self, **kwargs: Any) -> RestFeatureAttributionDriftMonitoringSignal: + default_window_size = kwargs.get("default_data_window_size") + ref_data_window_size = kwargs.get("ref_data_window_size") + return RestFeatureAttributionDriftMonitoringSignal( + production_data=( + [data._to_rest_object(default=default_window_size) for data in self.production_data] + if self.production_data is not None + else None + ), + reference_data=self.reference_data._to_rest_object( + default_data_window=default_window_size, ref_data_window_size=ref_data_window_size + ), + metric_threshold=self.metric_thresholds._to_rest_object(), + mode=MonitoringNotificationMode.ENABLED if self.alert_enabled else MonitoringNotificationMode.DISABLED, + properties=self.properties, + ) + + @classmethod + def _from_rest_object(cls, obj: RestFeatureAttributionDriftMonitoringSignal) -> "FeatureAttributionDriftSignal": + return cls( + production_data=[FADProductionData._from_rest_object(data) for data in obj.production_data], + reference_data=ReferenceData._from_rest_object(obj.reference_data), + metric_thresholds=FeatureAttributionDriftMetricThreshold._from_rest_object(obj.metric_threshold), + alert_enabled=( + False + if not obj.mode or (obj.mode and obj.mode == MonitoringNotificationMode.DISABLED) + else MonitoringNotificationMode.ENABLED + ), + properties=obj.properties, + ) + + +@experimental +class ModelPerformanceSignal(RestTranslatableMixin): + """Model performance signal. + + :keyword baseline_dataset: The data to calculate performance against. + :paramtype baseline_dataset: ~azure.ai.ml.entities.MonitorInputData + :keyword metric_thresholds: A list of metrics to calculate and their + associated thresholds. + :paramtype metric_thresholds: ~azure.ai.ml.entities.ModelPerformanceMetricThreshold + :keyword model_type: The model type. + :paramtype model_type: ~azure.ai.ml.constants.MonitorModelType + :keyword data_segment: The data segment to calculate performance against. + :paramtype data_segment: ~azure.ai.ml.entities.DataSegment + :keyword alert_enabled: Whether or not to enable alerts for the signal. Defaults to False. + :paramtype alert_enabled: bool + """ + + def __init__( + self, + *, + production_data: ProductionData, + reference_data: ReferenceData, + metric_thresholds: ModelPerformanceMetricThreshold, + data_segment: Optional[DataSegment] = None, + alert_enabled: bool = False, + properties: Optional[Dict[str, str]] = None, + ) -> None: + self.production_data = production_data + self.reference_data = reference_data + self.metric_thresholds = metric_thresholds + self.alert_enabled = alert_enabled + self.type = MonitorSignalType.MODEL_PERFORMANCE + self.data_segment = data_segment + self.properties = properties + + def _to_rest_object(self, **kwargs: Any) -> RestModelPerformanceSignal: + default_data_window_size = kwargs.get("default_data_window_size") + ref_data_window_size = kwargs.get("ref_data_window_size") + if self.properties is None: + self.properties = {} + self.properties["azureml.modelmonitor.model_performance_thresholds"] = self.metric_thresholds._to_str_object() + if self.production_data.data_window is None: + self.production_data.data_window = BaselineDataRange( + lookback_window_size=default_data_window_size, + ) + return RestModelPerformanceSignal( + production_data=[self.production_data._to_rest_object(default_data_window_size=default_data_window_size)], + reference_data=self.reference_data._to_rest_object( + default_data_window_size=default_data_window_size, ref_data_window_size=ref_data_window_size + ), + metric_threshold=self.metric_thresholds._to_rest_object(), + data_segment=self.data_segment._to_rest_object() if self.data_segment else None, + mode=MonitoringNotificationMode.ENABLED if self.alert_enabled else MonitoringNotificationMode.DISABLED, + properties=self.properties, + ) + + @classmethod + def _from_rest_object(cls, obj: RestModelPerformanceSignal) -> "ModelPerformanceSignal": + return cls( + production_data=ProductionData._from_rest_object(obj.production_data[0]), + reference_data=ReferenceData._from_rest_object(obj.reference_data), + metric_thresholds=ModelPerformanceMetricThreshold._from_rest_object(obj.metric_threshold), + data_segment=DataSegment._from_rest_object(obj.data_segment) if obj.data_segment else None, + alert_enabled=( + False + if not obj.mode or (obj.mode and obj.mode == MonitoringNotificationMode.DISABLED) + else MonitoringNotificationMode.ENABLED + ), + ) + + +@experimental +class Connection(RestTranslatableMixin): + """Monitoring Connection + + :param environment_variables: A dictionary of environment variables to set for the workspace. + :paramtype environment_variables: Optional[dict[str, str]] + :param secret_config: A dictionary of secrets to set for the workspace. + :paramtype secret_config: Optional[dict[str, str]] + """ + + def __init__( + self, + *, + environment_variables: Optional[Dict[str, str]] = None, + secret_config: Optional[Dict[str, str]] = None, + ): + self.environment_variables = environment_variables + self.secret_config = secret_config + + def _to_rest_object(self) -> RestMonitoringWorkspaceConnection: + return RestMonitoringWorkspaceConnection( + environment_variables=self.environment_variables, + secrets=self.secret_config, + ) + + @classmethod + def _from_rest_object(cls, obj: RestMonitoringWorkspaceConnection) -> "Connection": + return cls( + environment_variables=obj.environment_variables, + secret_config=obj.secrets, + ) + + +@experimental +class CustomMonitoringSignal(RestTranslatableMixin): + """Custom monitoring signal. + + :ivar type: The type of the signal. Set to "custom" for this class. + :vartype type: str + :keyword input_data: A dictionary of input datasets for monitoring. + Each key is the component input port name, and its value is the data asset. + :paramtype input_data: Optional[dict[str, ~azure.ai.ml.entities.ReferenceData]] + :keyword metric_thresholds: A list of metrics to calculate and their + associated thresholds. + :paramtype metric_thresholds: List[~azure.ai.ml.entities.CustomMonitoringMetricThreshold] + :keyword inputs: + :paramtype inputs: Optional[dict[str, ~azure.ai.ml.entities.Input]] + :keyword component_id: The ARM (Azure Resource Manager) ID of the component resource used to + calculate the custom metrics. + :paramtype component_id: str + :keyword connection: Specify connection with environment variables and secret configs. + :paramtype connection: Optional[~azure.ai.ml.entities.WorkspaceConnection] + :keyword alert_enabled: Whether or not to enable alerts for the signal. Defaults to False. + :paramtype alert_enabled: bool + :keyword properties: A dictionary of custom properties for the signal. + :paramtype properties: Optional[dict[str, str]] + """ + + def __init__( + self, + *, + inputs: Optional[Dict[str, Input]] = None, + metric_thresholds: List[CustomMonitoringMetricThreshold], + component_id: str, + connection: Optional[Connection] = None, + input_data: Optional[Dict[str, ReferenceData]] = None, + alert_enabled: bool = False, + properties: Optional[Dict[str, str]] = None, + ): + self.type = MonitorSignalType.CUSTOM + self.inputs = inputs + self.metric_thresholds = metric_thresholds + self.component_id = component_id + self.alert_enabled = alert_enabled + self.input_data = input_data + self.properties = properties + self.connection = connection + + def _to_rest_object(self, **kwargs: Any) -> RestCustomMonitoringSignal: # pylint:disable=unused-argument + if self.connection is None: + self.connection = Connection() + return RestCustomMonitoringSignal( + component_id=self.component_id, + metric_thresholds=[threshold._to_rest_object() for threshold in self.metric_thresholds], + inputs=to_rest_dataset_literal_inputs(self.inputs, job_type=None) if self.inputs else None, + input_assets=( + {asset_name: asset_value._to_rest_object() for asset_name, asset_value in self.input_data.items()} + if self.input_data + else None + ), + workspace_connection=self.connection._to_rest_object(), + mode=MonitoringNotificationMode.ENABLED if self.alert_enabled else MonitoringNotificationMode.DISABLED, + properties=self.properties, + ) + + @classmethod + def _from_rest_object(cls, obj: RestCustomMonitoringSignal) -> "CustomMonitoringSignal": + return cls( + inputs=from_rest_inputs_to_dataset_literal(obj.inputs) if obj.inputs else None, + input_data={key: ReferenceData._from_rest_object(data) for key, data in obj.input_assets.items()}, + metric_thresholds=[ + CustomMonitoringMetricThreshold._from_rest_object(metric) for metric in obj.metric_thresholds + ], + component_id=obj.component_id, + alert_enabled=( + False + if not obj.mode or (obj.mode and obj.mode == MonitoringNotificationMode.DISABLED) + else MonitoringNotificationMode.ENABLED + ), + properties=obj.properties, + connection=Connection._from_rest_object(obj.workspace_connection), + ) + + +@experimental +class LlmData(RestTranslatableMixin): + """LLM Request Response Data + + :param input_data: Input data used by the monitor. + :paramtype input_data: ~azure.ai.ml.entities.Input + :param data_column_names: The names of columns in the input data. + :paramtype data_column_names: Dict[str, str] + :param data_window: The number of days or a time frame that a singal monitor looks back over the target. + :type data_window_size: BaselineDataRange + """ + + def __init__( + self, + *, + input_data: Input, + data_column_names: Optional[Dict[str, str]] = None, + data_window: Optional[BaselineDataRange] = None, + ): + self.input_data = input_data + self.data_column_names = data_column_names + self.data_window = data_window + + def _to_rest_object(self, **kwargs: Any) -> RestMonitoringInputData: + if self.data_window is None: + self.data_window = BaselineDataRange( + lookback_window_size=kwargs.get("default"), + ) + return TrailingInputData( + target_columns=self.data_column_names, + job_type=self.input_data.type, + uri=self.input_data.path, + window_size=self.data_window.lookback_window_size, + window_offset=( + self.data_window.lookback_window_offset + if self.data_window.lookback_window_offset is not None + else "P0D" + ), + )._to_rest_object() + + @classmethod + def _from_rest_object(cls, obj: RestMonitoringInputData) -> "LlmData": + data_window = BaselineDataRange( + lookback_window_size=isodate.duration_isoformat(obj.window_size), + lookback_window_offset=isodate.duration_isoformat(obj.window_offset), + ) + return cls( + input_data=Input( + path=obj.uri, + type=obj.job_input_type, + ), + data_column_names=obj.columns, + data_window=data_window, + ) + + +@experimental +class GenerationSafetyQualitySignal(RestTranslatableMixin): + """Generation Safety Quality monitoring signal. + + :ivar type: The type of the signal. Set to "generationsafetyquality" for this class. + :vartype type: str + :keyword production_data: A list of input datasets for monitoring. + :paramtype input_datasets: Optional[dict[str, ~azure.ai.ml.entities.LlmData]] + :keyword metric_thresholds: Metrics to calculate and their associated thresholds. + :paramtype metric_thresholds: ~azure.ai.ml.entities.GenerationSafetyQualityMonitoringMetricThreshold + :keyword alert_enabled: Whether or not to enable alerts for the signal. Defaults to False. + :paramtype alert_enabled: bool + :keyword connection_id: Gets or sets the connection ID used to connect to the + content generation endpoint. + :paramtype connection_id: str + :keyword properties: The properties of the signal + :paramtype properties: Dict[str, str] + :keyword sampling_rate: The sample rate of the target data, should be greater + than 0 and at most 1. + :paramtype sampling_rate: float + """ + + def __init__( + self, + *, + production_data: Optional[List[LlmData]] = None, + connection_id: Optional[str] = None, + metric_thresholds: GenerationSafetyQualityMonitoringMetricThreshold, + alert_enabled: bool = False, + properties: Optional[Dict[str, str]] = None, + sampling_rate: Optional[float] = None, + ): + self.type = MonitorSignalType.GENERATION_SAFETY_QUALITY + self.production_data = production_data + self.connection_id = connection_id + self.metric_thresholds = metric_thresholds + self.alert_enabled = alert_enabled + self.properties = properties + self.sampling_rate = sampling_rate + + def _to_rest_object(self, **kwargs: Any) -> RestGenerationSafetyQualityMonitoringSignal: + data_window_size = kwargs.get("default_data_window_size") + return RestGenerationSafetyQualityMonitoringSignal( + production_data=( + [data._to_rest_object(default=data_window_size) for data in self.production_data] + if self.production_data is not None + else None + ), + workspace_connection_id=self.connection_id, + metric_thresholds=self.metric_thresholds._to_rest_object(), + mode=MonitoringNotificationMode.ENABLED if self.alert_enabled else MonitoringNotificationMode.DISABLED, + properties=self.properties, + sampling_rate=self.sampling_rate, + ) + + @classmethod + def _from_rest_object(cls, obj: RestGenerationSafetyQualityMonitoringSignal) -> "GenerationSafetyQualitySignal": + return cls( + production_data=[LlmData._from_rest_object(data) for data in obj.production_data], + connection_id=obj.workspace_connection_id, + metric_thresholds=GenerationSafetyQualityMonitoringMetricThreshold._from_rest_object(obj.metric_thresholds), + alert_enabled=( + False + if not obj.mode or (obj.mode and obj.mode == MonitoringNotificationMode.DISABLED) + else MonitoringNotificationMode.ENABLED + ), + properties=obj.properties, + sampling_rate=obj.sampling_rate, + ) + + +@experimental +class GenerationTokenStatisticsSignal(RestTranslatableMixin): + """Generation token statistics signal definition. + + :ivar type: The type of the signal. Set to "generationtokenstatisticssignal" for this class. + :vartype type: str + :keyword production_data: input dataset for monitoring. + :paramtype input_dataset: Optional[~azure.ai.ml.entities.LlmData] + :keyword metric_thresholds: Metrics to calculate and their associated thresholds. Defaults to App Traces + :paramtype metric_thresholds: Optional[~azure.ai.ml.entities.GenerationTokenStatisticsMonitorMetricThreshold] + :keyword alert_enabled: Whether or not to enable alerts for the signal. Defaults to False. + :paramtype alert_enabled: bool + :keyword properties: The properties of the signal + :paramtype properties: Optional[Dict[str, str]] + :keyword sampling_rate: The sample rate of the target data, should be greater + than 0 and at most 1. + :paramtype sampling_rate: float + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_genAI_monitors_configuration.py + :start-after: [START default_monitoring] + :end-before: [END default_monitoring] + :language: python + :dedent: 8 + :caption: Set Token Statistics Monitor. + """ + + def __init__( + self, + *, + production_data: Optional[LlmData] = None, + metric_thresholds: Optional[GenerationTokenStatisticsMonitorMetricThreshold] = None, + alert_enabled: bool = False, + properties: Optional[Dict[str, str]] = None, + sampling_rate: Optional[float] = None, + ): + self.type = MonitorSignalType.GENERATION_TOKEN_STATISTICS + self.production_data = production_data + self.metric_thresholds = metric_thresholds + self.alert_enabled = alert_enabled + self.properties = properties + self.sampling_rate = sampling_rate + + def _to_rest_object(self, **kwargs: Any) -> RestGenerationTokenStatisticsSignal: + data_window_size = kwargs.get("default_data_window_size") + return RestGenerationTokenStatisticsSignal( + production_data=( + self.production_data._to_rest_object(default=data_window_size) + if self.production_data is not None + else None + ), + metric_thresholds=( + self.metric_thresholds._to_rest_object() + if self.metric_thresholds + else GenerationTokenStatisticsMonitorMetricThreshold._get_default_thresholds()._to_rest_object() + ), + mode=MonitoringNotificationMode.ENABLED if self.alert_enabled else MonitoringNotificationMode.DISABLED, + properties=self.properties, + sampling_rate=self.sampling_rate if self.sampling_rate else 0.1, + ) + + @classmethod + def _from_rest_object(cls, obj: RestGenerationTokenStatisticsSignal) -> "GenerationTokenStatisticsSignal": + return cls( + production_data=LlmData._from_rest_object(obj.production_data), + metric_thresholds=GenerationTokenStatisticsMonitorMetricThreshold._from_rest_object(obj.metric_thresholds), + alert_enabled=( + False + if not obj.mode or (obj.mode and obj.mode == MonitoringNotificationMode.DISABLED) + else MonitoringNotificationMode.ENABLED + ), + properties=obj.properties, + sampling_rate=obj.sampling_rate, + ) + + @classmethod + def _get_default_token_statistics_signal(cls) -> "GenerationTokenStatisticsSignal": + return cls( + metric_thresholds=GenerationTokenStatisticsMonitorMetricThreshold._get_default_thresholds(), + sampling_rate=0.1, + ) + + +def _from_rest_features( + obj: RestMonitoringFeatureFilterBase, +) -> Optional[Union[List[str], MonitorFeatureFilter, Literal["all_features"]]]: + if isinstance(obj, RestTopNFeaturesByAttribution): + return MonitorFeatureFilter(top_n_feature_importance=obj.top) + if isinstance(obj, RestFeatureSubset): + _restFeatureSubset: List[str] = obj.features + return _restFeatureSubset + if isinstance(obj, RestAllFeatures): + _restAllFeatures: Literal["all_features"] = ALL_FEATURES # type: ignore[assignment] + return _restAllFeatures + + return None + + +def _to_rest_features( + features: Union[List[str], MonitorFeatureFilter, Literal["all_features"]] +) -> RestMonitoringFeatureFilterBase: + rest_features = None + if isinstance(features, list): + rest_features = RestFeatureSubset(features=features) + elif isinstance(features, MonitorFeatureFilter): + rest_features = features._to_rest_object() + elif isinstance(features, str) and features == ALL_FEATURES: + rest_features = RestAllFeatures() + return rest_features + + +def _to_rest_num_cat_metrics(numerical_metrics: Any, categorical_metrics: Any) -> List: + metrics = [] + if numerical_metrics is not None: + metrics.append(numerical_metrics._to_rest_object()) + + if categorical_metrics is not None: + metrics.append(categorical_metrics._to_rest_object()) + + return metrics + + +def _to_rest_data_quality_metrics(numerical_metrics: Any, categorical_metrics: Any) -> List: + metric_thresholds: List = [] + if numerical_metrics is not None: + metric_thresholds = metric_thresholds + numerical_metrics._to_rest_object() + + if categorical_metrics is not None: + metric_thresholds = metric_thresholds + categorical_metrics._to_rest_object() + + return metric_thresholds diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/target.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/target.py new file mode 100644 index 00000000..73a11895 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/target.py @@ -0,0 +1,55 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Optional, Union + +from azure.ai.ml._restclient.v2023_06_01_preview.models import MonitoringTarget as RestMonitoringTarget +from azure.ai.ml.constants._monitoring import MonitorTargetTasks + + +class MonitoringTarget: + """Monitoring target. + + :keyword ml_task: Type of task. Allowed values: Classification, Regression, and QuestionAnswering + :paramtype ml_task: Optional[Union[str, MonitorTargetTasks]] + :keyword endpoint_deployment_id: The ARM ID of the target deployment. Mutually exclusive with model_id. + :paramtype endpoint_deployment_id: Optional[str] + :keyword model_id: ARM ID of the target model ID. Mutually exclusive with endpoint_deployment_id. + :paramtype model_id: Optional[str] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_spark_configurations.py + :start-after: [START spark_monitor_definition] + :end-before: [END spark_monitor_definition] + :language: python + :dedent: 8 + :caption: Setting a monitoring target using endpoint_deployment_id. + """ + + def __init__( + self, + *, + ml_task: Optional[Union[str, MonitorTargetTasks]] = None, + endpoint_deployment_id: Optional[str] = None, + model_id: Optional[str] = None, + ): + self.endpoint_deployment_id = endpoint_deployment_id + self.model_id = model_id + self.ml_task = ml_task + + def _to_rest_object(self) -> RestMonitoringTarget: + return RestMonitoringTarget( + task_type=self.ml_task if self.ml_task else "classification", + deployment_id=self.endpoint_deployment_id, + model_id=self.model_id, + ) + + @classmethod + def _from_rest_object(cls, obj: RestMonitoringTarget) -> "MonitoringTarget": + return cls( + ml_task=obj.task_type, + endpoint_deployment_id=obj.endpoint_deployment_id, + model_id=obj.model_id, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/thresholds.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/thresholds.py new file mode 100644 index 00000000..3e1c33b5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/thresholds.py @@ -0,0 +1,954 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument, protected-access + +from typing import Any, Dict, List, Optional, Tuple + +from azure.ai.ml._restclient.v2023_06_01_preview.models import ( + CategoricalDataDriftMetricThreshold, + CategoricalDataQualityMetricThreshold, + CategoricalPredictionDriftMetricThreshold, + ClassificationModelPerformanceMetricThreshold, + CustomMetricThreshold, + DataDriftMetricThresholdBase, + DataQualityMetricThresholdBase, + FeatureAttributionMetricThreshold, + GenerationSafetyQualityMetricThreshold, + GenerationTokenStatisticsMetricThreshold, + ModelPerformanceMetricThresholdBase, + MonitoringThreshold, + NumericalDataDriftMetricThreshold, + NumericalDataQualityMetricThreshold, + NumericalPredictionDriftMetricThreshold, + PredictionDriftMetricThresholdBase, +) +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml._utils.utils import camel_to_snake, snake_to_camel +from azure.ai.ml.constants._monitoring import MonitorFeatureType, MonitorMetricName +from azure.ai.ml.entities._mixins import RestTranslatableMixin + + +class MetricThreshold(RestTranslatableMixin): + def __init__(self, *, threshold: Optional[float] = None): + self.data_type: Any = None + self.metric_name: Optional[str] = None + self.threshold = threshold + + +class NumericalDriftMetrics(RestTranslatableMixin): + """Numerical Drift Metrics + + :param jensen_shannon_distance: The Jensen-Shannon distance between the two distributions + :paramtype jensen_shannon_distance: float + :param normalized_wasserstein_distance: The normalized Wasserstein distance between the two distributions + :paramtype normalized_wasserstein_distance: float + :param population_stability_index: The population stability index between the two distributions + :paramtype population_stability_index: float + :param two_sample_kolmogorov_smirnov_test: The two sample Kolmogorov-Smirnov test between the two distributions + :paramtype two_sample_kolmogorov_smirnov_test: float + """ + + def __init__( + self, + *, + jensen_shannon_distance: Optional[float] = None, + normalized_wasserstein_distance: Optional[float] = None, + population_stability_index: Optional[float] = None, + two_sample_kolmogorov_smirnov_test: Optional[float] = None, + metric: Optional[str] = None, + metric_threshold: Optional[float] = None, + ): + self.jensen_shannon_distance = jensen_shannon_distance + self.normalized_wasserstein_distance = normalized_wasserstein_distance + self.population_stability_index = population_stability_index + self.two_sample_kolmogorov_smirnov_test = two_sample_kolmogorov_smirnov_test + self.metric = metric + self.metric_threshold = metric_threshold + + def _find_name_and_threshold(self) -> Tuple: + metric_name = None + threshold = None + if self.jensen_shannon_distance: + metric_name = MonitorMetricName.JENSEN_SHANNON_DISTANCE + threshold = MonitoringThreshold(value=self.jensen_shannon_distance) + elif self.normalized_wasserstein_distance: + metric_name = MonitorMetricName.NORMALIZED_WASSERSTEIN_DISTANCE + threshold = MonitoringThreshold(value=self.normalized_wasserstein_distance) + elif self.population_stability_index: + metric_name = MonitorMetricName.POPULATION_STABILITY_INDEX + threshold = MonitoringThreshold(value=self.population_stability_index) + elif self.two_sample_kolmogorov_smirnov_test: + metric_name = MonitorMetricName.TWO_SAMPLE_KOLMOGOROV_SMIRNOV_TEST + threshold = MonitoringThreshold(value=self.two_sample_kolmogorov_smirnov_test) + + return metric_name, threshold + + @classmethod + # pylint: disable=arguments-differ + def _from_rest_object(cls, metric_name: str, threshold: Optional[float]) -> "NumericalDriftMetrics": # type: ignore + metric_name = camel_to_snake(metric_name) + if metric_name == MonitorMetricName.JENSEN_SHANNON_DISTANCE: + return cls(jensen_shannon_distance=threshold) + if metric_name == MonitorMetricName.NORMALIZED_WASSERSTEIN_DISTANCE: + return cls(normalized_wasserstein_distance=threshold) + if metric_name == MonitorMetricName.POPULATION_STABILITY_INDEX: + return cls(population_stability_index=threshold) + if metric_name == MonitorMetricName.TWO_SAMPLE_KOLMOGOROV_SMIRNOV_TEST: + return cls(two_sample_kolmogorov_smirnov_test=threshold) + return cls() + + @classmethod + def _get_default_thresholds(cls) -> "NumericalDriftMetrics": + return cls( + normalized_wasserstein_distance=0.1, + ) + + +class CategoricalDriftMetrics(RestTranslatableMixin): + """Categorical Drift Metrics + + :param jensen_shannon_distance: The Jensen-Shannon distance between the two distributions + :paramtype jensen_shannon_distance: float + :param population_stability_index: The population stability index between the two distributions + :paramtype population_stability_index: float + :param pearsons_chi_squared_test: The Pearson's Chi-Squared test between the two distributions + :paramtype pearsons_chi_squared_test: float + """ + + def __init__( + self, + *, + jensen_shannon_distance: Optional[float] = None, + population_stability_index: Optional[float] = None, + pearsons_chi_squared_test: Optional[float] = None, + ): + self.jensen_shannon_distance = jensen_shannon_distance + self.population_stability_index = population_stability_index + self.pearsons_chi_squared_test = pearsons_chi_squared_test + + def _find_name_and_threshold(self) -> Tuple: + metric_name = None + threshold = None + if self.jensen_shannon_distance: + metric_name = MonitorMetricName.JENSEN_SHANNON_DISTANCE + threshold = MonitoringThreshold(value=self.jensen_shannon_distance) + if self.population_stability_index and threshold is None: + metric_name = MonitorMetricName.POPULATION_STABILITY_INDEX + threshold = MonitoringThreshold(value=self.population_stability_index) + if self.pearsons_chi_squared_test and threshold is None: + metric_name = MonitorMetricName.PEARSONS_CHI_SQUARED_TEST + threshold = MonitoringThreshold(value=self.pearsons_chi_squared_test) + + return metric_name, threshold + + @classmethod + # pylint: disable=arguments-differ + def _from_rest_object( # type: ignore + cls, metric_name: str, threshold: Optional[float] + ) -> "CategoricalDriftMetrics": + metric_name = camel_to_snake(metric_name) + if metric_name == MonitorMetricName.JENSEN_SHANNON_DISTANCE: + return cls(jensen_shannon_distance=threshold) + if metric_name == MonitorMetricName.POPULATION_STABILITY_INDEX: + return cls(population_stability_index=threshold) + if metric_name == MonitorMetricName.PEARSONS_CHI_SQUARED_TEST: + return cls(pearsons_chi_squared_test=threshold) + return cls() + + @classmethod + def _get_default_thresholds(cls) -> "CategoricalDriftMetrics": + return cls( + jensen_shannon_distance=0.1, + ) + + +class DataDriftMetricThreshold(MetricThreshold): + """Data drift metric threshold + + :param numerical: Numerical drift metrics + :paramtype numerical: ~azure.ai.ml.entities.NumericalDriftMetrics + :param categorical: Categorical drift metrics + :paramtype categorical: ~azure.ai.ml.entities.CategoricalDriftMetrics + """ + + def __init__( + self, + *, + data_type: Optional[MonitorFeatureType] = None, + threshold: Optional[float] = None, + metric: Optional[str] = None, + numerical: Optional[NumericalDriftMetrics] = None, + categorical: Optional[CategoricalDriftMetrics] = None, + ): + super().__init__(threshold=threshold) + self.data_type = data_type + self.metric = metric + self.numerical = numerical + self.categorical = categorical + + def _to_rest_object(self) -> DataDriftMetricThresholdBase: + thresholds = [] + if self.numerical: + num_metric_name, num_threshold = self.numerical._find_name_and_threshold() + thresholds.append( + NumericalDataDriftMetricThreshold( + metric=snake_to_camel(num_metric_name), + threshold=num_threshold, + ) + ) + if self.categorical: + cat_metric_name, cat_threshold = self.categorical._find_name_and_threshold() + thresholds.append( + CategoricalDataDriftMetricThreshold( + metric=snake_to_camel(cat_metric_name), + threshold=cat_threshold, + ) + ) + + return thresholds + + @classmethod + def _from_rest_object(cls, obj: DataDriftMetricThresholdBase) -> "DataDriftMetricThreshold": + num = None + cat = None + for threshold in obj: + if threshold.data_type == "Numerical": + num = NumericalDriftMetrics()._from_rest_object( # pylint: disable=protected-access + threshold.metric, threshold.threshold.value if threshold.threshold else None + ) + elif threshold.data_type == "Categorical": + cat = CategoricalDriftMetrics()._from_rest_object( # pylint: disable=protected-access + threshold.metric, threshold.threshold.value if threshold.threshold else None + ) + + return cls( + numerical=num, + categorical=cat, + ) + + @classmethod + def _get_default_thresholds(cls) -> "DataDriftMetricThreshold": + return cls( + numerical=NumericalDriftMetrics._get_default_thresholds(), + categorical=CategoricalDriftMetrics._get_default_thresholds(), + ) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, DataDriftMetricThreshold): + return NotImplemented + return self.numerical == other.numerical and self.categorical == other.categorical + + +class PredictionDriftMetricThreshold(MetricThreshold): + """Prediction drift metric threshold + + :param numerical: Numerical drift metrics + :paramtype numerical: ~azure.ai.ml.entities.NumericalDriftMetrics + :param categorical: Categorical drift metrics + :paramtype categorical: ~azure.ai.ml.entities.CategoricalDriftMetrics + """ + + def __init__( + self, + *, + data_type: Optional[MonitorFeatureType] = None, + threshold: Optional[float] = None, + numerical: Optional[NumericalDriftMetrics] = None, + categorical: Optional[CategoricalDriftMetrics] = None, + ): + super().__init__(threshold=threshold) + self.data_type = data_type + self.numerical = numerical + self.categorical = categorical + + def _to_rest_object(self) -> PredictionDriftMetricThresholdBase: + thresholds = [] + if self.numerical: + num_metric_name, num_threshold = self.numerical._find_name_and_threshold() + thresholds.append( + NumericalPredictionDriftMetricThreshold( + metric=snake_to_camel(num_metric_name), + threshold=num_threshold, + ) + ) + if self.categorical: + cat_metric_name, cat_threshold = self.categorical._find_name_and_threshold() + thresholds.append( + CategoricalPredictionDriftMetricThreshold( + metric=snake_to_camel(cat_metric_name), + threshold=cat_threshold, + ) + ) + + return thresholds + + @classmethod + def _from_rest_object(cls, obj: PredictionDriftMetricThresholdBase) -> "PredictionDriftMetricThreshold": + num = None + cat = None + for threshold in obj: + if threshold.data_type == "Numerical": + num = NumericalDriftMetrics()._from_rest_object( # pylint: disable=protected-access + threshold.metric, threshold.threshold.value if threshold.threshold else None + ) + elif threshold.data_type == "Categorical": + cat = CategoricalDriftMetrics()._from_rest_object( # pylint: disable=protected-access + threshold.metric, threshold.threshold.value if threshold.threshold else None + ) + + return cls( + numerical=num, + categorical=cat, + ) + + @classmethod + def _get_default_thresholds(cls) -> "PredictionDriftMetricThreshold": + return cls( + numerical=NumericalDriftMetrics._get_default_thresholds(), + categorical=CategoricalDriftMetrics._get_default_thresholds(), + ) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, PredictionDriftMetricThreshold): + return NotImplemented + return ( + self.data_type == other.data_type + and self.metric_name == other.metric_name + and self.threshold == other.threshold + ) + + +class DataQualityMetricsNumerical(RestTranslatableMixin): + """Data Quality Numerical Metrics + + :param null_value_rate: The null value rate + :paramtype null_value_rate: float + :param data_type_error_rate: The data type error rate + :paramtype data_type_error_rate: float + :param out_of_bounds_rate: The out of bounds rate + :paramtype out_of_bounds_rate: float + """ + + def __init__( + self, + *, + null_value_rate: Optional[float] = None, + data_type_error_rate: Optional[float] = None, + out_of_bounds_rate: Optional[float] = None, + ): + self.null_value_rate = null_value_rate + self.data_type_error_rate = data_type_error_rate + self.out_of_bounds_rate = out_of_bounds_rate + + def _to_rest_object(self) -> List[NumericalDataQualityMetricThreshold]: + metric_thresholds = [] + if self.null_value_rate is not None: + metric_name = MonitorMetricName.NULL_VALUE_RATE + threshold = MonitoringThreshold(value=self.null_value_rate) + metric_thresholds.append( + NumericalDataQualityMetricThreshold(metric=snake_to_camel(metric_name), threshold=threshold) + ) + if self.data_type_error_rate is not None: + metric_name = MonitorMetricName.DATA_TYPE_ERROR_RATE + threshold = MonitoringThreshold(value=self.data_type_error_rate) + metric_thresholds.append( + NumericalDataQualityMetricThreshold(metric=snake_to_camel(metric_name), threshold=threshold) + ) + if self.out_of_bounds_rate is not None: + metric_name = MonitorMetricName.OUT_OF_BOUND_RATE + threshold = MonitoringThreshold(value=self.out_of_bounds_rate) + metric_thresholds.append( + NumericalDataQualityMetricThreshold(metric=snake_to_camel(metric_name), threshold=threshold) + ) + + return metric_thresholds + + @classmethod + def _from_rest_object(cls, obj: List) -> "DataQualityMetricsNumerical": + null_value_rate_val = None + data_type_error_rate_val = None + out_of_bounds_rate_val = None + for thresholds in obj: + if thresholds.metric in ("NullValueRate" "nullValueRate"): + null_value_rate_val = thresholds.threshold.value + if thresholds.metric in ("DataTypeErrorRate", "dataTypeErrorRate"): + data_type_error_rate_val = thresholds.threshold.value + if thresholds.metric in ("OutOfBoundsRate", "outOfBoundsRate"): + out_of_bounds_rate_val = thresholds.threshold.value + return cls( + null_value_rate=null_value_rate_val, + data_type_error_rate=data_type_error_rate_val, + out_of_bounds_rate=out_of_bounds_rate_val, + ) + + @classmethod + def _get_default_thresholds(cls) -> "DataQualityMetricsNumerical": + return cls( + null_value_rate=0.0, + data_type_error_rate=0.0, + out_of_bounds_rate=0.0, + ) + + +class DataQualityMetricsCategorical(RestTranslatableMixin): + """Data Quality Categorical Metrics + + :param null_value_rate: The null value rate + :paramtype null_value_rate: float + :param data_type_error_rate: The data type error rate + :paramtype data_type_error_rate: float + :param out_of_bounds_rate: The out of bounds rate + :paramtype out_of_bounds_rate: float + """ + + def __init__( + self, + *, + null_value_rate: Optional[float] = None, + data_type_error_rate: Optional[float] = None, + out_of_bounds_rate: Optional[float] = None, + ): + self.null_value_rate = null_value_rate + self.data_type_error_rate = data_type_error_rate + self.out_of_bounds_rate = out_of_bounds_rate + + def _to_rest_object(self) -> List[CategoricalDataQualityMetricThreshold]: + metric_thresholds = [] + if self.null_value_rate is not None: + metric_name = MonitorMetricName.NULL_VALUE_RATE + threshold = MonitoringThreshold(value=self.null_value_rate) + metric_thresholds.append( + CategoricalDataQualityMetricThreshold(metric=snake_to_camel(metric_name), threshold=threshold) + ) + if self.data_type_error_rate is not None: + metric_name = MonitorMetricName.DATA_TYPE_ERROR_RATE + threshold = MonitoringThreshold(value=self.data_type_error_rate) + metric_thresholds.append( + CategoricalDataQualityMetricThreshold(metric=snake_to_camel(metric_name), threshold=threshold) + ) + if self.out_of_bounds_rate is not None: + metric_name = MonitorMetricName.OUT_OF_BOUND_RATE + threshold = MonitoringThreshold(value=self.out_of_bounds_rate) + metric_thresholds.append( + CategoricalDataQualityMetricThreshold(metric=snake_to_camel(metric_name), threshold=threshold) + ) + + return metric_thresholds + + @classmethod + def _from_rest_object(cls, obj: List) -> "DataQualityMetricsCategorical": + null_value_rate_val = None + data_type_error_rate_val = None + out_of_bounds_rate_val = None + for thresholds in obj: + if thresholds.metric in ("NullValueRate" "nullValueRate"): + null_value_rate_val = thresholds.threshold.value + if thresholds.metric in ("DataTypeErrorRate", "dataTypeErrorRate"): + data_type_error_rate_val = thresholds.threshold.value + if thresholds.metric in ("OutOfBoundsRate", "outOfBoundsRate"): + out_of_bounds_rate_val = thresholds.threshold.value + return cls( + null_value_rate=null_value_rate_val, + data_type_error_rate=data_type_error_rate_val, + out_of_bounds_rate=out_of_bounds_rate_val, + ) + + @classmethod + def _get_default_thresholds(cls) -> "DataQualityMetricsCategorical": + return cls( + null_value_rate=0.0, + data_type_error_rate=0.0, + out_of_bounds_rate=0.0, + ) + + +class DataQualityMetricThreshold(MetricThreshold): + """Data quality metric threshold + + :param numerical: Numerical data quality metrics + :paramtype numerical: ~azure.ai.ml.entities.DataQualityMetricsNumerical + :param categorical: Categorical data quality metrics + :paramtype categorical: ~azure.ai.ml.entities.DataQualityMetricsCategorical + """ + + def __init__( + self, + *, + data_type: Optional[MonitorFeatureType] = None, + threshold: Optional[float] = None, + metric_name: Optional[str] = None, + numerical: Optional[DataQualityMetricsNumerical] = None, + categorical: Optional[DataQualityMetricsCategorical] = None, + ): + super().__init__(threshold=threshold) + self.data_type = data_type + self.metric_name = metric_name + self.numerical = numerical + self.categorical = categorical + + def _to_rest_object(self) -> DataQualityMetricThresholdBase: + thresholds: list = [] + if self.numerical: + thresholds = thresholds + ( + DataQualityMetricsNumerical( # pylint: disable=protected-access + null_value_rate=self.numerical.null_value_rate, + data_type_error_rate=self.numerical.data_type_error_rate, + out_of_bounds_rate=self.numerical.out_of_bounds_rate, + )._to_rest_object() + ) + if self.categorical: + thresholds = ( + thresholds + + ( + DataQualityMetricsCategorical( # pylint: disable=protected-access + null_value_rate=self.numerical.null_value_rate, + data_type_error_rate=self.numerical.data_type_error_rate, + out_of_bounds_rate=self.numerical.out_of_bounds_rate, + )._to_rest_object() + ) + if self.numerical is not None + else thresholds + ) + return thresholds + + @classmethod + def _from_rest_object(cls, obj: DataQualityMetricThresholdBase) -> "DataQualityMetricThreshold": + num = [] + cat = [] + for threshold in obj: + if threshold.data_type == "Numerical": + num.append(threshold) + elif threshold.data_type == "Categorical": + cat.append(threshold) + + num_from_rest = DataQualityMetricsNumerical()._from_rest_object(num) # pylint: disable=protected-access + cat_from_rest = DataQualityMetricsCategorical()._from_rest_object(cat) # pylint: disable=protected-access + return cls( + numerical=num_from_rest, + categorical=cat_from_rest, + ) + + @classmethod + def _get_default_thresholds(cls) -> "DataQualityMetricThreshold": + return cls( + numerical=DataQualityMetricsNumerical()._get_default_thresholds(), # pylint: disable=protected-access + categorical=DataQualityMetricsCategorical()._get_default_thresholds(), # pylint: disable=protected-access + ) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, DataQualityMetricThreshold): + return NotImplemented + return ( + self.data_type == other.data_type + and self.metric_name == other.metric_name + and self.threshold == other.threshold + ) + + +@experimental +class FeatureAttributionDriftMetricThreshold(MetricThreshold): + """Feature attribution drift metric threshold + + :param normalized_discounted_cumulative_gain: The threshold value for metric. + :paramtype normalized_discounted_cumulative_gain: float + """ + + def __init__( + self, *, normalized_discounted_cumulative_gain: Optional[float] = None, threshold: Optional[float] = None + ): + super().__init__(threshold=threshold) + self.data_type = MonitorFeatureType.ALL_FEATURE_TYPES + self.metric_name = MonitorMetricName.NORMALIZED_DISCOUNTED_CUMULATIVE_GAIN + self.normalized_discounted_cumulative_gain = normalized_discounted_cumulative_gain + + def _to_rest_object(self) -> FeatureAttributionMetricThreshold: + return FeatureAttributionMetricThreshold( + metric=snake_to_camel(self.metric_name), + threshold=( + MonitoringThreshold(value=self.normalized_discounted_cumulative_gain) + if self.normalized_discounted_cumulative_gain + else None + ), + ) + + @classmethod + def _from_rest_object(cls, obj: FeatureAttributionMetricThreshold) -> "FeatureAttributionDriftMetricThreshold": + return cls(normalized_discounted_cumulative_gain=obj.threshold.value if obj.threshold else None) + + +@experimental +class ModelPerformanceClassificationThresholds(RestTranslatableMixin): + def __init__( + self, + *, + accuracy: Optional[float] = None, + precision: Optional[float] = None, + recall: Optional[float] = None, + ): + self.accuracy = accuracy + self.precision = precision + self.recall = recall + + def _to_str_object(self, **kwargs): + thresholds = [] + if self.accuracy: + thresholds.append( + '{"modelType":"classification","metric":"Accuracy","threshold":{"value":' + f"{self.accuracy}" + "}}" + ) + if self.precision: + thresholds.append( + '{"modelType":"classification","metric":"Precision","threshold":{"value":' + f"{self.precision}" + "}}" + ) + if self.recall: + thresholds.append( + '{"modelType":"classification","metric":"Recall","threshold":{"value":' + f"{self.recall}" + "}}" + ) + + if not thresholds: + return None + + return ", ".join(thresholds) + + @classmethod + def _from_rest_object(cls, obj) -> "ModelPerformanceClassificationThresholds": + return cls( + accuracy=obj.threshold.value if obj.threshold else None, + ) + + +@experimental +class ModelPerformanceRegressionThresholds(RestTranslatableMixin): + def __init__( + self, + *, + mean_absolute_error: Optional[float] = None, + mean_squared_error: Optional[float] = None, + root_mean_squared_error: Optional[float] = None, + ): + self.mean_absolute_error = mean_absolute_error + self.mean_squared_error = mean_squared_error + self.root_mean_squared_error = root_mean_squared_error + + def _to_str_object(self, **kwargs): + thresholds = [] + if self.mean_absolute_error: + thresholds.append( + '{"modelType":"regression","metric":"MeanAbsoluteError","threshold":{"value":' + + f"{self.mean_absolute_error}" + + "}}" + ) + if self.mean_squared_error: + thresholds.append( + '{"modelType":"regression","metric":"MeanSquaredError","threshold":{"value":' + + f"{self.mean_squared_error}" + + "}}" + ) + if self.root_mean_squared_error: + thresholds.append( + '{"modelType":"regression","metric":"RootMeanSquaredError","threshold":{"value":' + + f"{self.root_mean_squared_error}" + + "}}" + ) + + if not thresholds: + return None + + return ", ".join(thresholds) + + +@experimental +class ModelPerformanceMetricThreshold(RestTranslatableMixin): + def __init__( + self, + *, + classification: Optional[ModelPerformanceClassificationThresholds] = None, + regression: Optional[ModelPerformanceRegressionThresholds] = None, + ): + self.classification = classification + self.regression = regression + + def _to_str_object(self, **kwargs): + thresholds = [] + if self.classification: + thresholds.append(self.classification._to_str_object(**kwargs)) + if self.regression: + thresholds.append(self.regression._to_str_object(**kwargs)) + + if not thresholds: + return None + if len(thresholds) == 2: + result = "[" + ", ".join(thresholds) + "]" + else: + result = "[" + thresholds[0] + "]" + return result + + def _to_rest_object(self, **kwargs) -> ModelPerformanceMetricThresholdBase: + threshold = MonitoringThreshold(value=0.9) + return ClassificationModelPerformanceMetricThreshold( + metric="Accuracy", + threshold=threshold, + ) + + @classmethod + def _from_rest_object(cls, obj: ModelPerformanceMetricThresholdBase) -> "ModelPerformanceMetricThreshold": + return cls( + classification=ModelPerformanceClassificationThresholds._from_rest_object(obj), + regression=None, + ) + + +@experimental +class CustomMonitoringMetricThreshold(MetricThreshold): + """Feature attribution drift metric threshold + + :param metric_name: The metric to calculate + :type metric_name: str + :param threshold: The threshold value. If None, a default value will be set + depending on the selected metric. + :type threshold: float + """ + + def __init__( + self, + *, + metric_name: Optional[str], + threshold: Optional[float] = None, + ): + super().__init__(threshold=threshold) + self.metric_name = metric_name + + def _to_rest_object(self) -> CustomMetricThreshold: + return CustomMetricThreshold( + metric=self.metric_name, + threshold=MonitoringThreshold(value=self.threshold) if self.threshold is not None else None, + ) + + @classmethod + def _from_rest_object(cls, obj: CustomMetricThreshold) -> "CustomMonitoringMetricThreshold": + return cls(metric_name=obj.metric, threshold=obj.threshold.value if obj.threshold else None) + + +@experimental +class GenerationSafetyQualityMonitoringMetricThreshold(RestTranslatableMixin): # pylint: disable=name-too-long + """Generation safety quality metric threshold + + :param groundedness: The groundedness metric threshold + :paramtype groundedness: Dict[str, float] + :param relevance: The relevance metric threshold + :paramtype relevance: Dict[str, float] + :param coherence: The coherence metric threshold + :paramtype coherence: Dict[str, float] + :param fluency: The fluency metric threshold + :paramtype fluency: Dict[str, float] + :param similarity: The similarity metric threshold + :paramtype similarity: Dict[str, float] + """ + + def __init__( + self, + *, + groundedness: Optional[Dict[str, float]] = None, + relevance: Optional[Dict[str, float]] = None, + coherence: Optional[Dict[str, float]] = None, + fluency: Optional[Dict[str, float]] = None, + similarity: Optional[Dict[str, float]] = None, + ): + self.groundedness = groundedness + self.relevance = relevance + self.coherence = coherence + self.fluency = fluency + self.similarity = similarity + + def _to_rest_object(self) -> GenerationSafetyQualityMetricThreshold: + metric_thresholds = [] + if self.groundedness: + if "acceptable_groundedness_score_per_instance" in self.groundedness: + acceptable_threshold = MonitoringThreshold( + value=self.groundedness["acceptable_groundedness_score_per_instance"] + ) + else: + acceptable_threshold = MonitoringThreshold(value=3) + metric_thresholds.append( + GenerationSafetyQualityMetricThreshold( + metric="AcceptableGroundednessScorePerInstance", threshold=acceptable_threshold + ) + ) + aggregated_threshold = MonitoringThreshold(value=self.groundedness["aggregated_groundedness_pass_rate"]) + metric_thresholds.append( + GenerationSafetyQualityMetricThreshold( + metric="AggregatedGroundednessPassRate", threshold=aggregated_threshold + ) + ) + if self.relevance: + if "acceptable_relevance_score_per_instance" in self.relevance: + acceptable_threshold = MonitoringThreshold( + value=self.relevance["acceptable_relevance_score_per_instance"] + ) + else: + acceptable_threshold = MonitoringThreshold(value=3) + metric_thresholds.append( + GenerationSafetyQualityMetricThreshold( + metric="AcceptableRelevanceScorePerInstance", threshold=acceptable_threshold + ) + ) + aggregated_threshold = MonitoringThreshold(value=self.relevance["aggregated_relevance_pass_rate"]) + metric_thresholds.append( + GenerationSafetyQualityMetricThreshold( + metric="AggregatedRelevancePassRate", threshold=aggregated_threshold + ) + ) + if self.coherence: + if "acceptable_coherence_score_per_instance" in self.coherence: + acceptable_threshold = MonitoringThreshold( + value=self.coherence["acceptable_coherence_score_per_instance"] + ) + else: + acceptable_threshold = MonitoringThreshold(value=3) + metric_thresholds.append( + GenerationSafetyQualityMetricThreshold( + metric="AcceptableCoherenceScorePerInstance", threshold=acceptable_threshold + ) + ) + aggregated_threshold = MonitoringThreshold(value=self.coherence["aggregated_coherence_pass_rate"]) + metric_thresholds.append( + GenerationSafetyQualityMetricThreshold( + metric="AggregatedCoherencePassRate", threshold=aggregated_threshold + ) + ) + if self.fluency: + if "acceptable_fluency_score_per_instance" in self.fluency: + acceptable_threshold = MonitoringThreshold(value=self.fluency["acceptable_fluency_score_per_instance"]) + else: + acceptable_threshold = MonitoringThreshold(value=3) + metric_thresholds.append( + GenerationSafetyQualityMetricThreshold( + metric="AcceptableFluencyScorePerInstance", threshold=acceptable_threshold + ) + ) + aggregated_threshold = MonitoringThreshold(value=self.fluency["aggregated_fluency_pass_rate"]) + metric_thresholds.append( + GenerationSafetyQualityMetricThreshold( + metric="AggregatedFluencyPassRate", threshold=aggregated_threshold + ) + ) + if self.similarity: + if "acceptable_similarity_score_per_instance" in self.similarity: + acceptable_threshold = MonitoringThreshold( + value=self.similarity["acceptable_similarity_score_per_instance"] + ) + else: + acceptable_threshold = MonitoringThreshold(value=3) + metric_thresholds.append( + GenerationSafetyQualityMetricThreshold( + metric="AcceptableSimilarityScorePerInstance", threshold=acceptable_threshold + ) + ) + aggregated_threshold = MonitoringThreshold(value=self.similarity["aggregated_similarity_pass_rate"]) + metric_thresholds.append( + GenerationSafetyQualityMetricThreshold( + metric="AggregatedSimilarityPassRate", threshold=aggregated_threshold + ) + ) + return metric_thresholds + + @classmethod + def _from_rest_object( + cls, obj: GenerationSafetyQualityMetricThreshold + ) -> "GenerationSafetyQualityMonitoringMetricThreshold": + groundedness = {} + relevance = {} + coherence = {} + fluency = {} + similarity = {} + + for threshold in obj: + if threshold.metric == "AcceptableGroundednessScorePerInstance": + groundedness["acceptable_groundedness_score_per_instance"] = threshold.threshold.value + if threshold.metric == "AcceptableRelevanceScorePerInstance": + relevance["acceptable_relevance_score_per_instance"] = threshold.threshold.value + if threshold.metric == "AcceptableCoherenceScorePerInstance": + coherence["acceptable_coherence_score_per_instance"] = threshold.threshold.value + if threshold.metric == "AcceptableFluencyScorePerInstance": + fluency["acceptable_fluency_score_per_instance"] = threshold.threshold.value + if threshold.metric == "AcceptableSimilarityScorePerInstance": + similarity["acceptable_similarity_score_per_instance"] = threshold.threshold.value + if threshold.metric == "AggregatedGroundednessPassRate": + groundedness["aggregated_groundedness_pass_rate"] = threshold.threshold.value + if threshold.metric == "AggregatedRelevancePassRate": + relevance["aggregated_relevance_pass_rate"] = threshold.threshold.value + if threshold.metric == "AggregatedCoherencePassRate": + coherence["aggregated_coherence_pass_rate"] = threshold.threshold.value + if threshold.metric == "AggregatedFluencyPassRate": + fluency["aggregated_fluency_pass_rate"] = threshold.threshold.value + if threshold.metric == "AggregatedSimilarityPassRate": + similarity["aggregated_similarity_pass_rate"] = threshold.threshold.value + + return cls( + groundedness=groundedness if groundedness else None, + relevance=relevance if relevance else None, + coherence=coherence if coherence else None, + fluency=fluency if fluency else None, + similarity=similarity if similarity else None, + ) + + +@experimental +class GenerationTokenStatisticsMonitorMetricThreshold(RestTranslatableMixin): # pylint: disable=name-too-long + """Generation token statistics metric threshold definition. + + All required parameters must be populated in order to send to Azure. + + :ivar metric: Required. [Required] Gets or sets the feature attribution metric to calculate. + Possible values include: "TotalTokenCount", "TotalTokenCountPerGroup". + :vartype metric: str or + ~azure.mgmt.machinelearningservices.models.GenerationTokenStatisticsMetric + :ivar threshold: Gets or sets the threshold value. + If null, a default value will be set depending on the selected metric. + :vartype threshold: ~azure.mgmt.machinelearningservices.models.MonitoringThreshold + """ + + def __init__( + self, + *, + totaltoken: Optional[Dict[str, float]] = None, + ): + self.totaltoken = totaltoken + + def _to_rest_object(self) -> GenerationSafetyQualityMetricThreshold: + metric_thresholds = [] + if self.totaltoken: + if "total_token_count" in self.totaltoken: + acceptable_threshold = MonitoringThreshold(value=self.totaltoken["total_token_count"]) + else: + acceptable_threshold = MonitoringThreshold(value=3) + metric_thresholds.append( + GenerationTokenStatisticsMetricThreshold(metric="TotalTokenCount", threshold=acceptable_threshold) + ) + acceptable_threshold_per_group = MonitoringThreshold(value=self.totaltoken["total_token_count_per_group"]) + metric_thresholds.append( + GenerationSafetyQualityMetricThreshold( + metric="TotalTokenCountPerGroup", threshold=acceptable_threshold_per_group + ) + ) + return metric_thresholds + + @classmethod + def _from_rest_object( + cls, obj: GenerationTokenStatisticsMetricThreshold + ) -> "GenerationTokenStatisticsMonitorMetricThreshold": + totaltoken = {} + for threshold in obj: + if threshold.metric == "TotalTokenCount": + totaltoken["total_token_count"] = threshold.threshold.value + if threshold.metric == "TotalTokenCountPerGroup": + totaltoken["total_token_count_per_group"] = threshold.threshold.value + + return cls( + totaltoken=totaltoken if totaltoken else None, + ) + + @classmethod + def _get_default_thresholds(cls) -> "GenerationTokenStatisticsMonitorMetricThreshold": + return cls(totaltoken={"total_token_count": 0, "total_token_count_per_group": 0}) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_notification/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_notification/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_notification/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_notification/notification.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_notification/notification.py new file mode 100644 index 00000000..91380870 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_notification/notification.py @@ -0,0 +1,33 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import List, Optional + +from azure.ai.ml._restclient.v2023_02_01_preview.models import NotificationSetting as RestNotificationSetting +from azure.ai.ml.entities._mixins import RestTranslatableMixin + + +class Notification(RestTranslatableMixin): + """Configuration for notification. + + :param email_on: Send email notification to user on specified notification type. Accepted values are + "JobCompleted", "JobFailed", and "JobCancelled". + :type email_on: Optional[list[str]] + :param: The email recipient list which. Note that this parameter has a character limit of 499 which + includes all of the recipient strings and each comma seperator. + :paramtype emails: Optional[list[str]] + """ + + def __init__(self, *, email_on: Optional[List[str]] = None, emails: Optional[List[str]] = None) -> None: + self.email_on = email_on + self.emails = emails + + def _to_rest_object(self) -> RestNotificationSetting: + return RestNotificationSetting(email_on=self.email_on, emails=self.emails) + + @classmethod + def _from_rest_object(cls, obj: RestNotificationSetting) -> Optional["Notification"]: + if not obj: + return None + return Notification(email_on=obj.email_on, emails=obj.emails) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_registry/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_registry/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_registry/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_registry/registry.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_registry/registry.py new file mode 100644 index 00000000..a01e70d3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_registry/registry.py @@ -0,0 +1,231 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from os import PathLike +from pathlib import Path +from typing import IO, Any, AnyStr, Dict, List, Optional, Union + +from azure.ai.ml._restclient.v2022_10_01_preview.models import ManagedServiceIdentity as RestManagedServiceIdentity +from azure.ai.ml._restclient.v2022_10_01_preview.models import ( + ManagedServiceIdentityType as RestManagedServiceIdentityType, +) +from azure.ai.ml._restclient.v2022_10_01_preview.models import Registry as RestRegistry +from azure.ai.ml._restclient.v2022_10_01_preview.models import RegistryProperties +from azure.ai.ml._utils.utils import dump_yaml_to_file +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY +from azure.ai.ml.entities._assets.intellectual_property import IntellectualProperty +from azure.ai.ml.entities._credentials import IdentityConfiguration +from azure.ai.ml.entities._resource import Resource +from azure.ai.ml.entities._util import load_from_dict + +from .registry_support_classes import RegistryRegionDetails + +CONTAINER_REGISTRY = "container_registry" +REPLICATION_LOCATIONS = "replication_locations" +INTELLECTUAL_PROPERTY = "intellectual_property" + + +class Registry(Resource): + def __init__( + self, + *, + name: str, + location: str, + identity: Optional[IdentityConfiguration] = None, + tags: Optional[Dict[str, str]] = None, + public_network_access: Optional[str] = None, + discovery_url: Optional[str] = None, + intellectual_property: Optional[IntellectualProperty] = None, + managed_resource_group: Optional[str] = None, + mlflow_registry_uri: Optional[str] = None, + replication_locations: Optional[List[RegistryRegionDetails]], + **kwargs: Any, + ): + """Azure ML registry. + + :param name: Name of the registry. Must be globally unique and is immutable. + :type name: str + :param location: The location this registry resource is located in. + :type location: str + :param identity: registry's System Managed Identity + :type identity: ManagedServiceIdentity + :param tags: Tags of the registry. + :type tags: dict + :param public_network_access: Whether to allow public endpoint connectivity. + :type public_network_access: str + :param discovery_url: Backend service base url for the registry. + :type discovery_url: str + :param intellectual_property: **Experimental** Intellectual property publisher. + :type intellectual_property: ~azure.ai.ml.entities.IntellectualProperty + :param managed_resource_group: Managed resource group created for the registry. + :type managed_resource_group: str + :param mlflow_registry_uri: Ml flow tracking uri for the registry. + :type mlflow_registry_uri: str + :param region_details: Details of each region the registry is in. + :type region_details: List[RegistryRegionDetails] + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + """ + + super().__init__(name=name, tags=tags, **kwargs) + + # self.display_name = name # Do we need a top-level visible name value? + self.location = location + self.identity = identity + self.replication_locations = replication_locations + self.public_network_access = public_network_access + self.intellectual_property = intellectual_property + self.managed_resource_group = managed_resource_group + self.discovery_url = discovery_url + self.mlflow_registry_uri = mlflow_registry_uri + self.container_registry = None + + def dump( + self, + dest: Union[str, PathLike, IO[AnyStr]], + **kwargs: Any, + ) -> None: + """Dump the registry spec into a file in yaml format. + + :param dest: Path to a local file as the target, new file will be created, raises exception if the file exists. + :type dest: str + """ + yaml_serialized = self._to_dict() + dump_yaml_to_file(dest, yaml_serialized, default_flow_style=False) + + # The internal structure of the registry object is closer to how it's + # represented by the registry API, which differs from how registries + # are represented in YAML. This function converts those differences. + def _to_dict(self) -> Dict: + # JIT import to avoid experimental warnings on unrelated calls + from azure.ai.ml._schema.registry.registry import RegistrySchema + + schema = RegistrySchema(context={BASE_PATH_CONTEXT_KEY: "./"}) + + # Grab the first acr account of the first region and set that + # as the system-wide container registry. + # Although support for multiple ACRs per region, as well as + # different ACRs per region technically exist according to the + # API schema, we do not want to surface that as an option, + # since the use cases for variable/multiple ACRs are extremely + # limited, and would probably just confuse most users. + if self.replication_locations and len(self.replication_locations) > 0: + if self.replication_locations[0].acr_config and len(self.replication_locations[0].acr_config) > 0: + self.container_registry = self.replication_locations[0].acr_config[0] # type: ignore[assignment] + + res: dict = schema.dump(self) + return res + + @classmethod + def _load( + cls, + data: Optional[Dict] = None, + yaml_path: Optional[Union[PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "Registry": + data = data or {} + params_override = params_override or [] + context = { + BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"), + PARAMS_OVERRIDE_KEY: params_override, + } + # JIT import to avoid experimental warnings on unrelated calls + from azure.ai.ml._schema.registry.registry import RegistrySchema + + loaded_schema = load_from_dict(RegistrySchema, data, context, **kwargs) + cls._convert_yaml_dict_to_entity_input(loaded_schema) + return Registry(**loaded_schema) + + @classmethod + def _from_rest_object(cls, rest_obj: RestRegistry) -> Optional["Registry"]: + if not rest_obj: + return None + real_registry = rest_obj.properties + + # Convert from api name region_details to user-shown name "replication locations" + replication_locations = [] + if real_registry and real_registry.region_details: + replication_locations = [ + RegistryRegionDetails._from_rest_object(details) for details in real_registry.region_details + ] + identity = None + if rest_obj.identity and isinstance(rest_obj.identity, RestManagedServiceIdentity): + identity = IdentityConfiguration._from_rest_object(rest_obj.identity) + return Registry( + name=rest_obj.name, + identity=identity, + id=rest_obj.id, + tags=rest_obj.tags, + location=rest_obj.location, + public_network_access=real_registry.public_network_access, + discovery_url=real_registry.discovery_url, + intellectual_property=( + IntellectualProperty(publisher=real_registry.intellectual_property_publisher) + if real_registry.intellectual_property_publisher + else None + ), + managed_resource_group=real_registry.managed_resource_group, + mlflow_registry_uri=real_registry.ml_flow_registry_uri, + replication_locations=replication_locations, # type: ignore[arg-type] + ) + + # There are differences between what our registry validation schema + # accepts, and how we actually represent things internally. + # This is mostly due to the compromise required to balance + # the actual shape of registries as they're defined by + # autorest with how the spec wanted users to be able to + # configure them. This function should eventually be + @classmethod + def _convert_yaml_dict_to_entity_input( + cls, + input: Dict, # pylint: disable=redefined-builtin + ) -> None: + # pop container_registry value. + global_acr_exists = False + if CONTAINER_REGISTRY in input: + acr_input = input.pop(CONTAINER_REGISTRY) + global_acr_exists = True + for region_detail in input[REPLICATION_LOCATIONS]: + # Apply container_registry as acr_config of each region detail + if global_acr_exists: + if not hasattr(region_detail, "acr_details") or len(region_detail.acr_details) == 0: + region_detail.acr_config = [acr_input] # pylint: disable=(possibly-used-before-assignment + + def _to_rest_object(self) -> RestRegistry: + """Build current parameterized schedule instance to a registry object before submission. + + :return: Rest registry. + :rtype: RestRegistry + """ + identity = RestManagedServiceIdentity(type=RestManagedServiceIdentityType.SYSTEM_ASSIGNED) + replication_locations = [] + if self.replication_locations: + replication_locations = [details._to_rest_object() for details in self.replication_locations] + # Notes about this construction. + # RestRegistry.properties.tags: this property exists due to swagger inheritance + # issues, don't actually use it, use top level RestRegistry.tags instead + # RestRegistry.properties.managed_resource_group_tags: Registries create a + # managed resource group to manage their internal sub-resources. + # We always want the tags on this MRG to match those of the registry itself + # to keep janitor policies aligned. + return RestRegistry( + name=self.name, + location=self.location, + identity=identity, + tags=self.tags, + properties=RegistryProperties( + public_network_access=self.public_network_access, + discovery_url=self.discovery_url, + intellectual_property_publisher=( + (self.intellectual_property.publisher) if self.intellectual_property else None + ), + managed_resource_group=self.managed_resource_group, + ml_flow_registry_uri=self.mlflow_registry_uri, + region_details=replication_locations, + managed_resource_group_tags=self.tags, + ), + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_registry/registry_support_classes.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_registry/registry_support_classes.py new file mode 100644 index 00000000..810c5df5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_registry/registry_support_classes.py @@ -0,0 +1,273 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint:disable=protected-access,no-else-return + +from copy import deepcopy +from functools import reduce +from typing import List, Optional, Union + +from azure.ai.ml._exception_helper import log_and_raise_error +from azure.ai.ml._restclient.v2022_10_01_preview.models import AcrDetails as RestAcrDetails +from azure.ai.ml._restclient.v2022_10_01_preview.models import ArmResourceId as RestArmResourceId +from azure.ai.ml._restclient.v2022_10_01_preview.models import RegistryRegionArmDetails as RestRegistryRegionArmDetails +from azure.ai.ml._restclient.v2022_10_01_preview.models import StorageAccountDetails as RestStorageAccountDetails +from azure.ai.ml._restclient.v2022_10_01_preview.models import SystemCreatedAcrAccount as RestSystemCreatedAcrAccount +from azure.ai.ml._restclient.v2022_10_01_preview.models import ( + SystemCreatedStorageAccount as RestSystemCreatedStorageAccount, +) +from azure.ai.ml._restclient.v2022_10_01_preview.models import UserCreatedAcrAccount as RestUserCreatedAcrAccount +from azure.ai.ml.constants._registry import StorageAccountType +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + +from .util import _make_rest_user_storage_from_id + + +# This exists despite not being used by the schema validator because this entire +# class is an output only value from the API. +class SystemCreatedAcrAccount: + def __init__( + self, + *, + acr_account_sku: str, + arm_resource_id: Optional[str] = None, + ): + """Azure ML ACR account. + + :param acr_account_sku: The storage account service tier. Currently + only Premium is a valid option for registries. + :type acr_account_sku: str + :param arm_resource_id: Resource ID of the ACR account. + :type arm_resource_id: str. Default value is None. + """ + self.acr_account_sku = acr_account_sku + self.arm_resource_id = arm_resource_id + + # acr should technically be a union between str and SystemCreatedAcrAccount, + # but python doesn't accept self class references apparently. + # Class method instead of normal function to accept possible + # string input. + @classmethod + def _to_rest_object(cls, acr: Union[str, "SystemCreatedAcrAccount"]) -> RestAcrDetails: + if hasattr(acr, "acr_account_sku") and acr.acr_account_sku is not None: + # SKU enum requires input to be a capitalized word, + # so we format the input to be acceptable as long as spelling is + # correct. + acr_account_sku = acr.acr_account_sku.capitalize() + # We DO NOT want to set the arm_resource_id. The backend provides very + # unhelpful errors if you provide an empty/null/invalid resource ID, + # and ignores the value otherwise. It's better to avoid setting it in + # the conversion in this direction at all. + return RestAcrDetails( + system_created_acr_account=RestSystemCreatedAcrAccount( + acr_account_sku=acr_account_sku, + ) + ) + else: + return RestAcrDetails( + user_created_acr_account=RestUserCreatedAcrAccount(arm_resource_id=RestArmResourceId(resource_id=acr)) + ) + + @classmethod + def _from_rest_object(cls, rest_obj: RestAcrDetails) -> Optional["Union[str, SystemCreatedAcrAccount]"]: + if not rest_obj: + return None + if hasattr(rest_obj, "system_created_acr_account") and rest_obj.system_created_acr_account is not None: + resource_id = None + if rest_obj.system_created_acr_account.arm_resource_id: + resource_id = rest_obj.system_created_acr_account.arm_resource_id.resource_id + return SystemCreatedAcrAccount( + acr_account_sku=rest_obj.system_created_acr_account.acr_account_sku, + arm_resource_id=resource_id, + ) + elif hasattr(rest_obj, "user_created_acr_account") and rest_obj.user_created_acr_account is not None: + res: Optional[str] = rest_obj.user_created_acr_account.arm_resource_id.resource_id + return res + else: + return None + + +class SystemCreatedStorageAccount: + def __init__( + self, + *, + storage_account_hns: bool, + storage_account_type: Optional[StorageAccountType], + arm_resource_id: Optional[str] = None, + replicated_ids: Optional[List[str]] = None, + replication_count: int = 1, + ): + """ + :param arm_resource_id: Resource ID of the storage account. + :type arm_resource_id: str + :param storage_account_hns: Whether or not this storage account + has hierarchical namespaces enabled. + :type storage_account_hns: bool + :param storage_account_type: Allowed values: "Standard_LRS", + "Standard_GRS, "Standard_RAGRS", "Standard_ZRS", "Standard_GZRS", + "Standard_RAGZRS", "Premium_LRS", "Premium_ZRS" + :type storage_account_type: StorageAccountType + :param replication_count: The number of replicas of this storage account + that should be created. Defaults to 1. Values less than 1 are invalid. + :type replication_count: int + :param replicated_ids: If this storage was replicated, then this is a + list of all storage IDs with these settings for this registry. + Defaults to none for un-replicated storage accounts. + :type replicated_ids: List[str] + """ + self.arm_resource_id = arm_resource_id + self.storage_account_hns = storage_account_hns + self.storage_account_type = storage_account_type + self.replication_count = replication_count + self.replicated_ids = replicated_ids + + +# Per-region information for registries. +class RegistryRegionDetails: + def __init__( + self, + *, + acr_config: Optional[List[Union[str, SystemCreatedAcrAccount]]] = None, + location: Optional[str] = None, + storage_config: Optional[Union[List[str], SystemCreatedStorageAccount]] = None, + ): + """Details for each region a registry is in. + + :param acr_details: List of ACR account details. Each value can either be a + single string representing the arm_resource_id of a user-created + acr_details object, or a entire SystemCreatedAcrAccount object. + :type acr_details: List[Union[str, SystemCreatedAcrAccount]] + :param location: The location where the registry exists. + :type location: str + :param storage_account_details: List of storage accounts. Each value + can either be a single string representing the arm_resource_id of + a user-created storage account, or an entire + SystemCreatedStorageAccount object. + :type storage_account_details: Union[List[str], SystemCreatedStorageAccount] + """ + self.acr_config = acr_config + self.location = location + self.storage_config = storage_config + + @classmethod + def _from_rest_object(cls, rest_obj: RestRegistryRegionArmDetails) -> Optional["RegistryRegionDetails"]: + if not rest_obj: + return None + converted_acr_details = [] + if rest_obj.acr_details: + converted_acr_details = [SystemCreatedAcrAccount._from_rest_object(acr) for acr in rest_obj.acr_details] + storages: Optional[Union[List[str], SystemCreatedStorageAccount]] = [] + if rest_obj.storage_account_details: + storages = cls._storage_config_from_rest_object(rest_obj.storage_account_details) + + return RegistryRegionDetails( + acr_config=converted_acr_details, # type: ignore[arg-type] + location=rest_obj.location, + storage_config=storages, + ) + + def _to_rest_object(self) -> RestRegistryRegionArmDetails: + converted_acr_details = [] + if self.acr_config: + converted_acr_details = [SystemCreatedAcrAccount._to_rest_object(acr) for acr in self.acr_config] + storages = [] + if self.storage_config: + storages = self._storage_config_to_rest_object() + return RestRegistryRegionArmDetails( + acr_details=converted_acr_details, + location=self.location, + storage_account_details=storages, + ) + + def _storage_config_to_rest_object(self) -> List[RestStorageAccountDetails]: + storage = self.storage_config + # storage_config can either be a single system-created storage account, + # or list of user-inputted id's. + if ( + storage is not None + and not isinstance(storage, list) + and hasattr(storage, "storage_account_type") + and storage.storage_account_type is not None + ): + # We DO NOT want to set the arm_resource_id. The backend provides very + # unhelpful errors if you provide an empty/null/invalid resource ID, + # and ignores the value otherwise. It's better to avoid setting it in + # the conversion in this direction at all. + # We don't bother processing storage_account_type because the + # rest version is case insensitive. + account = RestStorageAccountDetails( + system_created_storage_account=RestSystemCreatedStorageAccount( + storage_account_hns_enabled=storage.storage_account_hns, + storage_account_type=storage.storage_account_type, + ) + ) + # duplicate this value based on the replication_count + count = storage.replication_count + if count < 1: + raise ValueError(f"Replication count cannot be less than 1. Value was: {count}.") + return [deepcopy(account) for _ in range(0, count)] + elif storage is not None and not isinstance(storage, SystemCreatedStorageAccount) and len(storage) > 0: + return [_make_rest_user_storage_from_id(user_id=user_id) for user_id in storage] + else: + return [] + + @classmethod + def _storage_config_from_rest_object( + cls, rest_configs: Optional[List] + ) -> Optional[Union[List[str], SystemCreatedStorageAccount]]: + if not rest_configs: + return None + num_configs = len(rest_configs) + if num_configs == 0: + return None + system_created_count = reduce( + # TODO: Bug Item number: 2883323 + lambda x, y: int(x) + int(y), # type: ignore + [ + hasattr(config, "system_created_storage_account") and config.system_created_storage_account is not None + for config in rest_configs + ], + ) + # configs should be mono-typed. Either they're all system created + # or all user created. + if system_created_count == num_configs: + # System created case - assume all elements are duplicates + # of a single storage configuration. + # Convert back into a single local representation by + # combining id's into a list, and using the first element's + # account type and hns. + first_config = rest_configs[0].system_created_storage_account + resource_id = None + if first_config.arm_resource_id: + resource_id = first_config.arm_resource_id.resource_id + # account for ids of duplicated if they exist + replicated_ids = None + if num_configs > 1: + replicated_ids = [ + config.system_created_storage_account.arm_resource_id.resource_id for config in rest_configs + ] + return SystemCreatedStorageAccount( + storage_account_hns=first_config.storage_account_hns_enabled, + storage_account_type=( + (StorageAccountType(first_config.storage_account_type.lower())) + if first_config.storage_account_type + else None + ), + arm_resource_id=resource_id, + replication_count=num_configs, + replicated_ids=replicated_ids, + ) + elif system_created_count == 0: + return [config.user_created_storage_account.arm_resource_id.resource_id for config in rest_configs] + else: + msg = f"""tried reading in a registry whose storage accounts were not + mono-managed or user-created. {system_created_count} out of {num_configs} were managed.""" + err = ValidationException( + message=msg, + target=ErrorTarget.REGISTRY, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + log_and_raise_error(err) + return None diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_registry/util.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_registry/util.py new file mode 100644 index 00000000..18f56169 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_registry/util.py @@ -0,0 +1,17 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from azure.ai.ml._restclient.v2022_10_01_preview.models import ArmResourceId as RestArmResourceId +from azure.ai.ml._restclient.v2022_10_01_preview.models import StorageAccountDetails as RestStorageAccountDetails +from azure.ai.ml._restclient.v2022_10_01_preview.models import ( + UserCreatedStorageAccount as RestUserCreatedStorageAccount, +) + + +def _make_rest_user_storage_from_id(*, user_id: str) -> RestStorageAccountDetails: + return RestStorageAccountDetails( + user_created_storage_account=RestUserCreatedStorageAccount( + arm_resource_id=RestArmResourceId(resource_id=user_id) + ) + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_resource.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_resource.py new file mode 100644 index 00000000..d20eaeff --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_resource.py @@ -0,0 +1,194 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +import abc +import os +from os import PathLike +from pathlib import Path +from typing import IO, Any, AnyStr, Dict, List, Optional, Tuple, Union, cast + +from msrest import Serializer + +from azure.ai.ml._restclient.v2022_10_01 import models +from azure.ai.ml._telemetry.logging_handler import in_jupyter_notebook +from azure.ai.ml._utils.utils import dump_yaml + +from ..constants._common import BASE_PATH_CONTEXT_KEY +from ._system_data import SystemData + + +class Resource(abc.ABC): + """Base class for entity classes. + + Resource is an abstract object that serves as a base for creating resources. It contains common properties and + methods for all resources. + + This class should not be instantiated directly. Instead, use one of its subclasses. + + :param name: The name of the resource. + :type name: str + :param description: The description of the resource. + :type description: Optional[str] + :param tags: Tags can be added, removed, and updated. + :type tags: Optional[dict] + :param properties: The resource's property dictionary. + :type properties: Optional[dict] + :keyword print_as_yaml: Specifies if the the resource should print out as a YAML-formatted object. If False, + the resource will print out in a more-compact style. By default, the YAML output is only used in Jupyter + notebooks. Be aware that some bookkeeping values are shown only in the non-YAML output. + :paramtype print_as_yaml: bool + """ + + def __init__( + self, + name: Optional[str], + description: Optional[str] = None, + tags: Optional[Dict] = None, + properties: Optional[Dict] = None, + **kwargs: Any, + ) -> None: + self.name = name + self.description = description + self.tags: Optional[Dict] = dict(tags) if tags else {} + self.properties = dict(properties) if properties else {} + # Conditional assignment to prevent entity bloat when unused. + self._print_as_yaml = kwargs.pop("print_as_yaml", False) + + # Hide read only properties in kwargs + self._id = kwargs.pop("id", None) + self.__source_path: Union[str, PathLike] = kwargs.pop("source_path", "") + self._base_path = kwargs.pop(BASE_PATH_CONTEXT_KEY, None) or os.getcwd() # base path should never be None + self._creation_context: Optional[SystemData] = kwargs.pop("creation_context", None) + client_models = {k: v for k, v in models.__dict__.items() if isinstance(v, type)} + self._serialize = Serializer(client_models) + self._serialize.client_side_validation = False + super().__init__(**kwargs) + + @property + def _source_path(self) -> Union[str, PathLike]: + # source path is added to display file location for validation error messages + # usually, base_path = Path(source_path).parent if source_path else os.getcwd() + return self.__source_path + + @_source_path.setter + def _source_path(self, value: Union[str, PathLike]) -> None: + self.__source_path = Path(value).as_posix() + + @property + def id(self) -> Optional[str]: + """The resource ID. + + :return: The global ID of the resource, an Azure Resource Manager (ARM) ID. + :rtype: Optional[str] + """ + if self._id is None: + return None + return str(self._id) + + @property + def creation_context(self) -> Optional[SystemData]: + """The creation context of the resource. + + :return: The creation metadata for the resource. + :rtype: Optional[~azure.ai.ml.entities.SystemData] + """ + return cast(Optional[SystemData], self._creation_context) + + @property + def base_path(self) -> str: + """The base path of the resource. + + :return: The base path of the resource. + :rtype: str + """ + return self._base_path + + @abc.abstractmethod + def dump(self, dest: Union[str, PathLike, IO[AnyStr]], **kwargs: Any) -> Any: + """Dump the object content into a file. + + :param dest: The local path or file stream to write the YAML content to. + If dest is a file path, a new file will be created. + If dest is an open file, the file will be written to directly. + :type dest: Union[PathLike, str, IO[AnyStr]] + """ + + @classmethod + # pylint: disable=unused-argument + def _resolve_cls_and_type(cls, data: Dict, params_override: Optional[List[Dict]] = None) -> Tuple: + """Resolve the class to use for deserializing the data. Return current class if no override is provided. + + :param data: Data to deserialize. + :type data: dict + :param params_override: Parameters to override, defaults to None + :type params_override: typing.Optional[list] + :return: Class to use for deserializing the data & its "type". Type will be None if no override is provided. + :rtype: tuple[class, typing.Optional[str]] + """ + return cls, None + + @classmethod + @abc.abstractmethod + def _load( + cls, + data: Optional[Dict] = None, + yaml_path: Optional[Union[PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "Resource": + """Construct a resource object from a file. @classmethod. + + :param cls: Indicates that this is a class method. + :type cls: class + :param data: Path to a local file as the source, defaults to None + :type data: typing.Optional[typing.Dict] + :param yaml_path: Path to a yaml file as the source, defaults to None + :type yaml_path: typing.Optional[typing.Union[typing.PathLike, str]] + :param params_override: Parameters to override, defaults to None + :type params_override: typing.Optional[list] + :return: Resource + :rtype: Resource + """ + + # pylint: disable:unused-argument + def _get_arm_resource( + self, + # pylint: disable=unused-argument + **kwargs: Any, + ) -> Dict: + """Get arm resource. + + :return: Resource + :rtype: dict + """ + from azure.ai.ml._arm_deployments.arm_helper import get_template + + # pylint: disable=no-member + template = get_template(resource_type=self._arm_type) # type: ignore + # pylint: disable=no-member + template["copy"]["name"] = f"{self._arm_type}Deployment" # type: ignore + return dict(template) + + def _get_arm_resource_and_params(self, **kwargs: Any) -> List: + """Get arm resource and parameters. + + :return: Resource and parameters + :rtype: dict + """ + resource = self._get_arm_resource(**kwargs) + # pylint: disable=no-member + param = self._to_arm_resource_param(**kwargs) # type: ignore + return [(resource, param)] + + def __repr__(self) -> str: + var_dict = {k.strip("_"): v for (k, v) in vars(self).items()} + return f"{self.__class__.__name__}({var_dict})" + + def __str__(self) -> str: + if self._print_as_yaml or in_jupyter_notebook(): + # pylint: disable=no-member + yaml_serialized = self._to_dict() # type: ignore + return str(dump_yaml(yaml_serialized, default_flow_style=False)) + return self.__repr__() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_schedule/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_schedule/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_schedule/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_schedule/schedule.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_schedule/schedule.py new file mode 100644 index 00000000..93867a9e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_schedule/schedule.py @@ -0,0 +1,513 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint: disable=protected-access +import logging +import typing +from os import PathLike +from pathlib import Path +from typing import IO, Any, AnyStr, Dict, List, Optional, Tuple, Union + +from typing_extensions import Literal + +from azure.ai.ml._restclient.v2023_06_01_preview.models import JobBase as RestJobBase +from azure.ai.ml._restclient.v2023_06_01_preview.models import JobScheduleAction +from azure.ai.ml._restclient.v2023_06_01_preview.models import PipelineJob as RestPipelineJob +from azure.ai.ml._restclient.v2023_06_01_preview.models import Schedule as RestSchedule +from azure.ai.ml._restclient.v2023_06_01_preview.models import ScheduleActionType as RestScheduleActionType +from azure.ai.ml._restclient.v2023_06_01_preview.models import ScheduleProperties +from azure.ai.ml._restclient.v2024_01_01_preview.models import TriggerRunSubmissionDto as RestTriggerRunSubmissionDto +from azure.ai.ml._schema.schedule.schedule import JobScheduleSchema +from azure.ai.ml._utils.utils import camel_to_snake, dump_yaml_to_file, is_private_preview_enabled +from azure.ai.ml.constants import JobType +from azure.ai.ml.constants._common import ARM_ID_PREFIX, BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY, ScheduleType +from azure.ai.ml.entities._job.command_job import CommandJob +from azure.ai.ml.entities._job.job import Job +from azure.ai.ml.entities._job.pipeline.pipeline_job import PipelineJob +from azure.ai.ml.entities._job.spark_job import SparkJob +from azure.ai.ml.entities._mixins import RestTranslatableMixin, TelemetryMixin, YamlTranslatableMixin +from azure.ai.ml.entities._resource import Resource +from azure.ai.ml.entities._system_data import SystemData +from azure.ai.ml.entities._util import load_from_dict +from azure.ai.ml.entities._validation import MutableValidationResult, PathAwareSchemaValidatableMixin + +from ...exceptions import ErrorCategory, ErrorTarget, ScheduleException, ValidationException +from .._builders import BaseNode +from .trigger import CronTrigger, RecurrenceTrigger, TriggerBase + +module_logger = logging.getLogger(__name__) + + +class Schedule(YamlTranslatableMixin, PathAwareSchemaValidatableMixin, Resource): + """Schedule object used to create and manage schedules. + + This class should not be instantiated directly. Instead, please use the subclasses. + + :keyword name: The name of the schedule. + :paramtype name: str + :keyword trigger: The schedule trigger configuration. + :paramtype trigger: Union[~azure.ai.ml.entities.CronTrigger, ~azure.ai.ml.entities.RecurrenceTrigger] + :keyword display_name: The display name of the schedule. + :paramtype display_name: Optional[str] + :keyword description: The description of the schedule. + :paramtype description: Optional[str] + :keyword tags: Tag dictionary. Tags can be added, removed, and updated. + :paramtype tags: Optional[dict]] + :keyword properties: A dictionary of properties to associate with the schedule. + :paramtype properties: Optional[dict[str, str]] + :keyword kwargs: Additional keyword arguments passed to the Resource constructor. + :paramtype kwargs: dict + """ + + def __init__( + self, + *, + name: str, + trigger: Optional[Union[CronTrigger, RecurrenceTrigger]], + display_name: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[Dict] = None, + properties: Optional[Dict] = None, + **kwargs: Any, + ) -> None: + is_enabled = kwargs.pop("is_enabled", None) + provisioning_state = kwargs.pop("provisioning_state", None) + super().__init__(name=name, description=description, tags=tags, properties=properties, **kwargs) + self.trigger = trigger + self.display_name = display_name + self._is_enabled: bool = is_enabled + self._provisioning_state: str = provisioning_state + self._type: Any = None + + def dump(self, dest: Union[str, PathLike, IO[AnyStr]], **kwargs: Any) -> None: + """Dump the schedule content into a file in YAML format. + + :param dest: The local path or file stream to write the YAML content to. + If dest is a file path, a new file will be created. + If dest is an open file, the file will be written to directly. + :type dest: Union[PathLike, str, IO[AnyStr]] + :raises FileExistsError: Raised if dest is a file path and the file already exists. + :raises IOError: Raised if dest is an open file and the file is not writable. + """ + path = kwargs.pop("path", None) + yaml_serialized = self._to_dict() + dump_yaml_to_file(dest, yaml_serialized, default_flow_style=False, path=path, **kwargs) + + @classmethod + def _create_validation_error(cls, message: str, no_personal_data_message: str) -> ValidationException: + return ValidationException( + message=message, + no_personal_data_message=no_personal_data_message, + target=ErrorTarget.SCHEDULE, + ) + + @classmethod + def _resolve_cls_and_type(cls, data: Dict, params_override: Optional[List[Dict]] = None) -> Tuple: + from azure.ai.ml.entities._data_import.schedule import ImportDataSchedule + from azure.ai.ml.entities._monitoring.schedule import MonitorSchedule + + if "create_monitor" in data: + return MonitorSchedule, None + if "import_data" in data: + return ImportDataSchedule, None + return JobSchedule, None + + @property + def create_job(self) -> Any: # pylint: disable=useless-return + """The create_job entity associated with the schedule if exists.""" + module_logger.warning("create_job is not a valid property of %s", str(type(self))) + # return None here just to be explicit + return None + + @create_job.setter + def create_job(self, value: Any) -> None: # pylint: disable=unused-argument + """Set the create_job entity associated with the schedule if exists. + + :param value: The create_job entity associated with the schedule if exists. + :type value: Any + """ + module_logger.warning("create_job is not a valid property of %s", str(type(self))) + + @property + def is_enabled(self) -> bool: + """Specifies if the schedule is enabled or not. + + :return: True if the schedule is enabled, False otherwise. + :rtype: bool + """ + return self._is_enabled + + @property + def provisioning_state(self) -> str: + """Returns the schedule's provisioning state. The possible values include + "Creating", "Updating", "Deleting", "Succeeded", "Failed", "Canceled". + + :return: The schedule's provisioning state. + :rtype: str + """ + return self._provisioning_state + + @property + def type(self) -> Optional[str]: + """The schedule type. Accepted values are 'job' and 'monitor'. + + :return: The schedule type. + :rtype: str + """ + return self._type + + def _to_dict(self) -> Dict: + res: dict = self._dump_for_validation() + return res + + @classmethod + def _from_rest_object(cls, obj: RestSchedule) -> "Schedule": + from azure.ai.ml.entities._data_import.schedule import ImportDataSchedule + from azure.ai.ml.entities._monitoring.schedule import MonitorSchedule + + if obj.properties.action.action_type == RestScheduleActionType.CREATE_JOB: + return JobSchedule._from_rest_object(obj) + if obj.properties.action.action_type == RestScheduleActionType.CREATE_MONITOR: + res_monitor_schedule: Schedule = MonitorSchedule._from_rest_object(obj) + return res_monitor_schedule + if obj.properties.action.action_type == RestScheduleActionType.IMPORT_DATA: + res_data_schedule: Schedule = ImportDataSchedule._from_rest_object(obj) + return res_data_schedule + msg = f"Unsupported schedule type {obj.properties.action.action_type}" + raise ScheduleException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.SCHEDULE, + error_category=ErrorCategory.SYSTEM_ERROR, + ) + + +class JobSchedule(RestTranslatableMixin, Schedule, TelemetryMixin): + """Class for managing job schedules. + + :keyword name: The name of the schedule. + :paramtype name: str + :keyword trigger: The trigger configuration for the schedule. + :paramtype trigger: Union[~azure.ai.ml.entities.CronTrigger, ~azure.ai.ml.entities.RecurrenceTrigger] + :keyword create_job: The job definition or an existing job name. + :paramtype create_job: Union[~azure.ai.ml.entities.Job, str] + :keyword display_name: The display name of the schedule. + :paramtype display_name: Optional[str] + :keyword description: The description of the schedule. + :paramtype description: Optional[str] + :keyword tags: Tag dictionary. Tags can be added, removed, and updated. + :paramtype tags: Optional[dict[str, str]] + :keyword properties: A dictionary of properties to associate with the schedule. + :paramtype properties: Optional[dict[str, str]] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START job_schedule_configuration] + :end-before: [END job_schedule_configuration] + :language: python + :dedent: 8 + :caption: Configuring a JobSchedule. + """ + + def __init__( + self, + *, + name: str, + trigger: Optional[Union[CronTrigger, RecurrenceTrigger]], + create_job: Union[Job, str], + display_name: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[Dict] = None, + properties: Optional[Dict] = None, + **kwargs: Any, + ) -> None: + super().__init__( + name=name, + trigger=trigger, + display_name=display_name, + description=description, + tags=tags, + properties=properties, + **kwargs, + ) + self._create_job = create_job + self._type = ScheduleType.JOB + + @property + def create_job(self) -> Union[Job, str]: + """Return the job associated with the schedule. + + :return: The job definition or an existing job name. + :rtype: Union[~azure.ai.ml.entities.Job, str] + """ + return self._create_job + + @create_job.setter + def create_job(self, value: Union[Job, str]) -> None: + """Sets the job that will be run when the schedule is triggered. + + :param value: The job definition or an existing job name. + :type value: Union[~azure.ai.ml.entities.Job, str] + """ + self._create_job = value + + @classmethod + def _load( + cls, + data: Optional[Dict] = None, + yaml_path: Optional[Union[PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "JobSchedule": + data = data or {} + params_override = params_override or [] + context = { + BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"), + PARAMS_OVERRIDE_KEY: params_override, + } + return JobSchedule( + base_path=context[BASE_PATH_CONTEXT_KEY], + **load_from_dict(JobScheduleSchema, data, context, **kwargs), + ) + + @classmethod + def _load_from_rest_dict( + cls, + data: Optional[Dict] = None, + yaml_path: Optional[Union[PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "JobSchedule": + """ + Load job schedule from rest object dict. + + This function is added because the user-faced schema is different from the rest one. + + For example: + + user yaml create_job is a file reference with updates(not a job definition): + + .. code-block:: yaml + + create_job: + job: ./job.yaml + inputs: + input: 10 + + while what we get from rest will be a complete job definition: + + .. code-block:: yaml + + create_job: + name: xx + jobs: + node1: ... + inputs: + input: .. + + :param data: The REST object to convert + :type data: Optional[Dict] + :param yaml_path: The yaml path + :type yaml_path: Optional[Union[PathLike str]] + :param params_override: A list of parameter overrides + :type params_override: Optional[list] + :return: The job schedule + :rtype: JobSchedule + """ + data = data or {} + params_override = params_override or [] + context = { + BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"), + PARAMS_OVERRIDE_KEY: params_override, + } + create_job_key = "create_job" + if create_job_key not in data: + msg = "Job definition for schedule '{}' can not be None." + raise ScheduleException( + message=msg.format(data["name"]), + no_personal_data_message=msg.format("[name]"), + target=ErrorTarget.JOB, + error_category=ErrorCategory.SYSTEM_ERROR, + ) + # Load the job definition separately + create_job_data = data.pop(create_job_key) + # Save the id for remote job reference before load job, as data dict will be changed + job_id = create_job_data.get("id") + if isinstance(job_id, str) and job_id.startswith(ARM_ID_PREFIX): + job_id = job_id[len(ARM_ID_PREFIX) :] + create_job = Job._load( + data=create_job_data, + **kwargs, + ) + # Set id manually as it is a dump only field in schema + create_job._id = job_id + schedule = JobSchedule( + base_path=context[BASE_PATH_CONTEXT_KEY], + **load_from_dict(JobScheduleSchema, data, context, **kwargs), + ) + schedule.create_job = create_job + return schedule + + @classmethod + def _create_schema_for_validation(cls, context: Any) -> JobScheduleSchema: + return JobScheduleSchema(context=context) + + def _customized_validate(self) -> MutableValidationResult: + """Validate the resource with customized logic. + + :return: The validation result + :rtype: MutableValidationResult + """ + if isinstance(self.create_job, PipelineJob): + return self.create_job._validate() + return self._create_empty_validation_result() + + @classmethod + def _get_skip_fields_in_schema_validation(cls) -> typing.List[str]: + """Get the fields that should be skipped in schema validation. + + Override this method to add customized validation logic. + + :return: The list of fields to skip in schema validation + :rtype: typing.List[str] + """ + return ["create_job"] + + @classmethod + def _from_rest_object(cls, obj: RestSchedule) -> "JobSchedule": + properties = obj.properties + action: JobScheduleAction = properties.action + if action.job_definition is None: + msg = "Job definition for schedule '{}' can not be None." + raise ScheduleException( + message=msg.format(obj.name), + no_personal_data_message=msg.format("[name]"), + target=ErrorTarget.JOB, + error_category=ErrorCategory.SYSTEM_ERROR, + ) + if camel_to_snake(action.job_definition.job_type) not in [JobType.PIPELINE, JobType.COMMAND, JobType.SPARK]: + msg = f"Unsupported job type {action.job_definition.job_type} for schedule '{{}}'." + raise ScheduleException( + message=msg.format(obj.name), + no_personal_data_message=msg.format("[name]"), + target=ErrorTarget.JOB, + # Classified as user_error as we may support other type afterwards. + error_category=ErrorCategory.USER_ERROR, + ) + # Wrap job definition with JobBase for Job._from_rest_object call. + create_job = RestJobBase(properties=action.job_definition) + # id is a readonly field so set it after init. + # TODO: Add this support after source job id move to JobBaseProperties + if hasattr(action.job_definition, "source_job_id"): + create_job.id = action.job_definition.source_job_id + create_job = Job._from_rest_object(create_job) + return cls( + trigger=TriggerBase._from_rest_object(properties.trigger), + create_job=create_job, + name=obj.name, + display_name=properties.display_name, + description=properties.description, + tags=properties.tags, + properties=properties.properties, + provisioning_state=properties.provisioning_state, + is_enabled=properties.is_enabled, + creation_context=SystemData._from_rest_object(obj.system_data), + ) + + def _to_rest_object(self) -> RestSchedule: + """Build current parameterized schedule instance to a schedule object before submission. + + :return: Rest schedule. + :rtype: RestSchedule + """ + if isinstance(self.create_job, BaseNode): + self.create_job = self.create_job._to_job() + private_enabled = is_private_preview_enabled() + if isinstance(self.create_job, PipelineJob): + job_definition = self.create_job._to_rest_object().properties + # Set the source job id, as it is used only for schedule scenario. + job_definition.source_job_id = self.create_job.id + elif private_enabled and isinstance(self.create_job, (CommandJob, SparkJob)): + job_definition = self.create_job._to_rest_object().properties + # TODO: Merge this branch with PipelineJob after source job id move to JobBaseProperties + # job_definition.source_job_id = self.create_job.id + elif isinstance(self.create_job, str): # arm id reference + # TODO: Update this after source job id move to JobBaseProperties + # Rest pipeline job will hold a 'Default' as experiment_name, + # MFE will add default if None, so pass an empty string here. + job_definition = RestPipelineJob(source_job_id=self.create_job, experiment_name="") + else: + msg = "Unsupported job type '{}' in schedule {}." + raise ValidationException( + message=msg.format(type(self.create_job).__name__, self.name), + no_personal_data_message=msg.format("[type]", "[name]"), + target=ErrorTarget.SCHEDULE, + error_category=ErrorCategory.USER_ERROR, + ) + return RestSchedule( + properties=ScheduleProperties( + description=self.description, + properties=self.properties, + tags=self.tags, + action=JobScheduleAction(job_definition=job_definition), + display_name=self.display_name, + is_enabled=self._is_enabled, + trigger=self.trigger._to_rest_object() if self.trigger is not None else None, + ) + ) + + def __str__(self) -> str: + try: + res_yaml: str = self._to_yaml() + return res_yaml + except BaseException: # pylint: disable=W0718 + res_jobSchedule: str = super(JobSchedule, self).__str__() + return res_jobSchedule + + # pylint: disable-next=docstring-missing-param + def _get_telemetry_values(self, *args: Any, **kwargs: Any) -> Dict[Literal["trigger_type"], str]: + """Return the telemetry values of schedule. + + :return: A dictionary with telemetry values + :rtype: Dict[Literal["trigger_type"], str] + """ + return {"trigger_type": type(self.trigger).__name__} + + +class ScheduleTriggerResult: + """Schedule trigger result returned by trigger an enabled schedule once. + + This class shouldn't be instantiated directly. Instead, it is used as the return type of schedule trigger. + + :ivar str job_name: + :ivar str schedule_action_type: + """ + + def __init__(self, **kwargs): + self.job_name = kwargs.get("job_name", None) + self.schedule_action_type = kwargs.get("schedule_action_type", None) + + @classmethod + def _from_rest_object(cls, obj: RestTriggerRunSubmissionDto) -> "ScheduleTriggerResult": + """Construct a ScheduleJob from a rest object. + + :param obj: The rest object to construct from. + :type obj: ~azure.ai.ml._restclient.v2024_01_01_preview.models.TriggerRunSubmissionDto + :return: The constructed ScheduleJob. + :rtype: ScheduleTriggerResult + """ + return cls( + schedule_action_type=obj.schedule_action_type, + job_name=obj.submission_id, + ) + + def _to_dict(self) -> dict: + """Convert the object to a dictionary. + :return: The dictionary representation of the object. + :rtype: dict + """ + return { + "job_name": self.job_name, + "schedule_action_type": self.schedule_action_type, + } diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_schedule/trigger.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_schedule/trigger.py new file mode 100644 index 00000000..855aac9e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_schedule/trigger.py @@ -0,0 +1,290 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint: disable=protected-access +import logging +from abc import ABC +from datetime import datetime +from typing import List, Optional, Union + +from azure.ai.ml._restclient.v2023_04_01_preview.models import CronTrigger as RestCronTrigger +from azure.ai.ml._restclient.v2023_04_01_preview.models import RecurrenceSchedule as RestRecurrencePattern +from azure.ai.ml._restclient.v2023_04_01_preview.models import RecurrenceTrigger as RestRecurrenceTrigger +from azure.ai.ml._restclient.v2023_04_01_preview.models import TriggerBase as RestTriggerBase +from azure.ai.ml._restclient.v2023_04_01_preview.models import TriggerType as RestTriggerType +from azure.ai.ml._utils.utils import camel_to_snake, snake_to_camel +from azure.ai.ml.constants import TimeZone +from azure.ai.ml.entities._mixins import RestTranslatableMixin + +module_logger = logging.getLogger(__name__) + + +class TriggerBase(RestTranslatableMixin, ABC): + """Base class of Trigger. + + This class should not be instantiated directly. Instead, use one of its subclasses. + + :keyword type: The type of trigger. + :paramtype type: str + :keyword start_time: Specifies the start time of the schedule in ISO 8601 format. + :paramtype start_time: Optional[Union[str, datetime]] + :keyword end_time: Specifies the end time of the schedule in ISO 8601 format. + Note that end_time is not supported for compute schedules. + :paramtype end_time: Optional[Union[str, datetime]] + :keyword time_zone: The time zone where the schedule will run. Defaults to UTC(+00:00). + Note that this applies to the start_time and end_time. + :paramtype time_zone: ~azure.ai.ml.constants.TimeZone + """ + + def __init__( + self, + *, + type: str, # pylint: disable=redefined-builtin + start_time: Optional[Union[str, datetime]] = None, + end_time: Optional[Union[str, datetime]] = None, + time_zone: Union[str, TimeZone] = TimeZone.UTC, + ) -> None: + super().__init__() + self.type = type + self.start_time = start_time + self.end_time = end_time + self.time_zone = time_zone + + @classmethod + def _from_rest_object(cls, obj: RestTriggerBase) -> Optional[Union["CronTrigger", "RecurrenceTrigger"]]: + if obj.trigger_type == RestTriggerType.RECURRENCE: + return RecurrenceTrigger._from_rest_object(obj) + if obj.trigger_type == RestTriggerType.CRON: + return CronTrigger._from_rest_object(obj) + + return None + + +class RecurrencePattern(RestTranslatableMixin): + """Recurrence pattern for a job schedule. + + :keyword hours: The number of hours for the recurrence schedule pattern. + :paramtype hours: Union[int, List[int]] + :keyword minutes: The number of minutes for the recurrence schedule pattern. + :paramtype minutes: Union[int, List[int]] + :keyword week_days: A list of days of the week for the recurrence schedule pattern. + Acceptable values include: "monday", "tuesday", "wednesday", "thursday", "friday", "saturday", "sunday" + :type week_days: Optional[Union[str, List[str]]] + :keyword month_days: A list of days of the month for the recurrence schedule pattern. + :paramtype month_days: Optional[Union[int, List[int]]] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START job_schedule_configuration] + :end-before: [END job_schedule_configuration] + :language: python + :dedent: 8 + :caption: Configuring a JobSchedule to use a RecurrencePattern. + """ + + def __init__( + self, + *, + hours: Union[int, List[int]], + minutes: Union[int, List[int]], + week_days: Optional[Union[str, List[str]]] = None, + month_days: Optional[Union[int, List[int]]] = None, + ) -> None: + self.hours = hours + self.minutes = minutes + self.week_days = week_days + self.month_days = month_days + + def _to_rest_object(self) -> RestRecurrencePattern: + return RestRecurrencePattern( + hours=[self.hours] if not isinstance(self.hours, list) else self.hours, + minutes=[self.minutes] if not isinstance(self.minutes, list) else self.minutes, + week_days=[self.week_days] if self.week_days and not isinstance(self.week_days, list) else self.week_days, + month_days=( + [self.month_days] if self.month_days and not isinstance(self.month_days, list) else self.month_days + ), + ) + + def _to_rest_compute_pattern_object(self) -> RestRecurrencePattern: + # This function is added because we can't make compute trigger to use same class + # with schedule from service side. + if self.month_days: + module_logger.warning("'month_days' is ignored for not supported on compute recurrence schedule.") + return RestRecurrencePattern( + hours=[self.hours] if not isinstance(self.hours, list) else self.hours, + minutes=[self.minutes] if not isinstance(self.minutes, list) else self.minutes, + week_days=[self.week_days] if self.week_days and not isinstance(self.week_days, list) else self.week_days, + ) + + @classmethod + def _from_rest_object(cls, obj: RestRecurrencePattern) -> "RecurrencePattern": + return cls( + hours=obj.hours, + minutes=obj.minutes, + week_days=obj.week_days, + month_days=obj.month_days if hasattr(obj, "month_days") else None, + ) + + +class CronTrigger(TriggerBase): + """Cron Trigger for a job schedule. + + :keyword expression: The cron expression of schedule, following NCronTab format. + :paramtype expression: str + :keyword start_time: The start time for the trigger. If using a datetime object, leave the tzinfo as None and use + the ``time_zone`` parameter to specify a time zone if needed. If using a string, use the format + YYYY-MM-DDThh:mm:ss. Defaults to running the first workload instantly and continuing future workloads + based on the schedule. If the start time is in the past, the first workload is run at the next calculated run + time. + :paramtype start_time: Optional[Union[str, datetime]] + :keyword end_time: The start time for the trigger. If using a datetime object, leave the tzinfo as None and use + the ``time_zone`` parameter to specify a time zone if needed. If using a string, use the format + YYYY-MM-DDThh:mm:ss. Note that end_time is not supported for compute schedules. + :paramtype end_time: Optional[Union[str, datetime]] + :keyword time_zone: The time zone where the schedule will run. Defaults to UTC(+00:00). + Note that this applies to the start_time and end_time. + :paramtype time_zone: Union[str, ~azure.ai.ml.constants.TimeZone] + :raises Exception: Raised if end_time is in the past. + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START cron_trigger_configuration] + :end-before: [END cron_trigger_configuration] + :language: python + :dedent: 8 + :caption: Configuring a CronTrigger. + """ + + def __init__( + self, + *, + expression: str, + start_time: Optional[Union[str, datetime]] = None, + end_time: Optional[Union[str, datetime]] = None, + time_zone: Union[str, TimeZone] = TimeZone.UTC, + ) -> None: + super().__init__( + type=RestTriggerType.CRON, + start_time=start_time, + end_time=end_time, + time_zone=time_zone, + ) + self.expression = expression + + def _to_rest_object(self) -> RestCronTrigger: # v2022_12_01.models.CronTrigger + return RestCronTrigger( + trigger_type=self.type, + expression=self.expression, + start_time=self.start_time, + end_time=self.end_time, + time_zone=self.time_zone, + ) + + def _to_rest_compute_cron_object(self) -> RestCronTrigger: # v2022_12_01_preview.models.CronTrigger + # This function is added because we can't make compute trigger to use same class + # with schedule from service side. + if self.end_time: + module_logger.warning("'end_time' is ignored for not supported on compute schedule.") + return RestCronTrigger( + expression=self.expression, + start_time=self.start_time, + time_zone=self.time_zone, + ) + + @classmethod + def _from_rest_object(cls, obj: RestCronTrigger) -> "CronTrigger": + return cls( + expression=obj.expression, + start_time=obj.start_time, + end_time=obj.end_time, + time_zone=obj.time_zone, + ) + + +class RecurrenceTrigger(TriggerBase): + """Recurrence trigger for a job schedule. + + :keyword start_time: Specifies the start time of the schedule in ISO 8601 format. + :paramtype start_time: Optional[Union[str, datetime]] + :keyword end_time: Specifies the end time of the schedule in ISO 8601 format. + Note that end_time is not supported for compute schedules. + :paramtype end_time: Optional[Union[str, datetime]] + :keyword time_zone: The time zone where the schedule will run. Defaults to UTC(+00:00). + Note that this applies to the start_time and end_time. + :paramtype time_zone: Union[str, ~azure.ai.ml.constants.TimeZone] + :keyword frequency: Specifies the frequency that the schedule should be triggered with. + Possible values include: "minute", "hour", "day", "week", "month". + :type frequency: str + :keyword interval: Specifies the interval in conjunction with the frequency that the schedule should be triggered + with. + :paramtype interval: int + :keyword schedule: Specifies the recurrence pattern. + :paramtype schedule: Optional[~azure.ai.ml.entities.RecurrencePattern] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START job_schedule_configuration] + :end-before: [END job_schedule_configuration] + :language: python + :dedent: 8 + :caption: Configuring a JobSchedule to trigger recurrence every 4 weeks. + """ + + def __init__( + self, + *, + frequency: str, + interval: int, + schedule: Optional[RecurrencePattern] = None, + start_time: Optional[Union[str, datetime]] = None, + end_time: Optional[Union[str, datetime]] = None, + time_zone: Union[str, TimeZone] = TimeZone.UTC, + ) -> None: + super().__init__( + type=RestTriggerType.RECURRENCE, + start_time=start_time, + end_time=end_time, + time_zone=time_zone, + ) + # Create empty pattern as schedule is required in rest model + self.schedule = schedule if schedule else RecurrencePattern(hours=[], minutes=[]) + self.frequency = frequency + self.interval = interval + + def _to_rest_object(self) -> RestRecurrenceTrigger: # v2022_12_01.models.RecurrenceTrigger + return RestRecurrenceTrigger( + frequency=snake_to_camel(self.frequency), + interval=self.interval, + schedule=self.schedule._to_rest_object(), + start_time=self.start_time, + end_time=self.end_time, + time_zone=self.time_zone, + ) + + def _to_rest_compute_recurrence_object(self) -> RestRecurrenceTrigger: + # v2022_12_01_preview.models.RecurrenceTrigger + # This function is added because we can't make compute trigger to use same class + # with schedule from service side. + if self.end_time: + module_logger.warning("'end_time' is ignored for not supported on compute schedule.") + return RestRecurrenceTrigger( + frequency=snake_to_camel(self.frequency), + interval=self.interval, + schedule=self.schedule._to_rest_compute_pattern_object(), + start_time=self.start_time, + time_zone=self.time_zone, + ) + + @classmethod + def _from_rest_object(cls, obj: RestRecurrenceTrigger) -> "RecurrenceTrigger": + return cls( + frequency=camel_to_snake(obj.frequency), + interval=obj.interval, + schedule=RecurrencePattern._from_rest_object(obj.schedule) if obj.schedule else None, + start_time=obj.start_time, + end_time=obj.end_time, + time_zone=obj.time_zone, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_system_data.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_system_data.py new file mode 100644 index 00000000..05020da2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_system_data.py @@ -0,0 +1,77 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +from typing import Any + +from azure.ai.ml._restclient.v2022_10_01.models import SystemData as RestSystemData +from azure.ai.ml.entities._mixins import RestTranslatableMixin + + +class SystemData(RestTranslatableMixin): + """Metadata related to the creation and most recent modification of a resource. + + :ivar created_by: The identity that created the resource. + :vartype created_by: str + :ivar created_by_type: The type of identity that created the resource. Possible values include: + "User", "Application", "ManagedIdentity", "Key". + :vartype created_by_type: str or ~azure.ai.ml.entities.CreatedByType + :ivar created_at: The timestamp of resource creation (UTC). + :vartype created_at: ~datetime.datetime + :ivar last_modified_by: The identity that last modified the resource. + :vartype last_modified_by: str + :ivar last_modified_by_type: The type of identity that last modified the resource. Possible + values include: "User", "Application", "ManagedIdentity", "Key". + :vartype last_modified_by_type: str or ~azure.ai.ml.entities.CreatedByType + :ivar last_modified_at: The timestamp of resource last modification (UTC). + :vartype last_modified_at: ~datetime.datetime + :keyword created_by: The identity that created the resource. + :paramtype created_by: str + :keyword created_by_type: The type of identity that created the resource. Accepted values are + "User", "Application", "ManagedIdentity", "Key". + :paramtype created_by_type: Union[str, ~azure.ai.ml.entities.CreatedByType] + :keyword created_at: The timestamp of resource creation (UTC). + :paramtype created_at: datetime + :keyword last_modified_by: The identity that last modified the resource. + :paramtype last_modified_by: str + :keyword last_modified_by_type: The type of identity that last modified the resource. Accepted values are + "User", "Application", "ManagedIdentity", "Key". + :paramtype last_modified_by_type: Union[str, ~azure.ai.ml.entities.CreatedByType] + :keyword last_modified_at: The timestamp of resource last modification in UTC. + :paramtype last_modified_at: datetime + """ + + def __init__(self, **kwargs: Any) -> None: + self.created_by = kwargs.get("created_by", None) + self.created_by_type = kwargs.get("created_by_type", None) + self.created_at = kwargs.get("created_at", None) + self.last_modified_by = kwargs.get("last_modified_by", None) + self.last_modified_by_type = kwargs.get("last_modified_by_type", None) + self.last_modified_at = kwargs.get("last_modified_at", None) + + @classmethod + def _from_rest_object(cls, obj: RestSystemData) -> "SystemData": + return cls( + created_by=obj.created_by, + created_at=obj.created_at, + created_by_type=obj.created_by_type, + last_modified_by=obj.last_modified_by, + last_modified_by_type=obj.last_modified_by_type, + last_modified_at=obj.last_modified_at, + ) + + def _to_rest_object(self) -> RestSystemData: + return RestSystemData( + created_by=self.created_by, + created_at=self.created_at, + created_by_type=self.created_by_type, + last_modified_by=self.last_modified_by, + last_modified_by_type=self.last_modified_by_type, + last_modified_at=self.last_modified_at, + ) + + def _to_dict(self) -> dict: + from azure.ai.ml._schema.job.creation_context import CreationContextSchema + + return CreationContextSchema().dump(self) # pylint: disable=no-member diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_util.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_util.py new file mode 100644 index 00000000..c487be6e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_util.py @@ -0,0 +1,645 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import copy +import hashlib +import json +import os +import shutil +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union, cast, overload +from unittest import mock + +import msrest +from marshmallow.exceptions import ValidationError + +from .._restclient.v2022_02_01_preview.models import JobInputType as JobInputType02 +from .._restclient.v2023_04_01_preview.models import JobInput as RestJobInput +from .._restclient.v2023_04_01_preview.models import JobInputType as JobInputType10 +from .._restclient.v2023_04_01_preview.models import JobOutput as RestJobOutput +from .._schema._datastore import AzureBlobSchema, AzureDataLakeGen1Schema, AzureDataLakeGen2Schema, AzureFileSchema +from .._schema._deployment.batch.batch_deployment import BatchDeploymentSchema +from .._schema._deployment.online.online_deployment import ( + KubernetesOnlineDeploymentSchema, + ManagedOnlineDeploymentSchema, +) +from .._schema._endpoint.batch.batch_endpoint import BatchEndpointSchema +from .._schema._endpoint.online.online_endpoint import KubernetesOnlineEndpointSchema, ManagedOnlineEndpointSchema +from .._schema._sweep import SweepJobSchema +from .._schema.assets.data import DataSchema +from .._schema.assets.environment import EnvironmentSchema +from .._schema.assets.model import ModelSchema +from .._schema.component.command_component import CommandComponentSchema +from .._schema.component.parallel_component import ParallelComponentSchema +from .._schema.compute.aml_compute import AmlComputeSchema +from .._schema.compute.compute_instance import ComputeInstanceSchema +from .._schema.compute.virtual_machine_compute import VirtualMachineComputeSchema +from .._schema.job import CommandJobSchema, ParallelJobSchema +from .._schema.pipeline.pipeline_job import PipelineJobSchema +from .._schema.schedule.schedule import JobScheduleSchema +from .._schema.workspace import WorkspaceSchema +from .._utils.utils import is_internal_component_data, try_enable_internal_components +from ..constants._common import ( + REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT, + CommonYamlFields, + YAMLRefDocLinks, + YAMLRefDocSchemaNames, +) +from ..constants._component import NodeType +from ..constants._endpoint import EndpointYamlFields +from ..entities._mixins import RestTranslatableMixin +from ..exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + +# avoid circular import error +if TYPE_CHECKING: + from azure.ai.ml.entities._inputs_outputs import Output + from azure.ai.ml.entities._job.pipeline._io import NodeOutput + +# Maps schema class name to formatted error message pointing to Microsoft docs reference page for a schema's YAML +REF_DOC_ERROR_MESSAGE_MAP = { + DataSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format(YAMLRefDocSchemaNames.DATA, YAMLRefDocLinks.DATA), + EnvironmentSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format( + YAMLRefDocSchemaNames.ENVIRONMENT, YAMLRefDocLinks.ENVIRONMENT + ), + ModelSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format(YAMLRefDocSchemaNames.MODEL, YAMLRefDocLinks.MODEL), + CommandComponentSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format( + YAMLRefDocSchemaNames.COMMAND_COMPONENT, YAMLRefDocLinks.COMMAND_COMPONENT + ), + ParallelComponentSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format( + YAMLRefDocSchemaNames.PARALLEL_COMPONENT, YAMLRefDocLinks.PARALLEL_COMPONENT + ), + AmlComputeSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format( + YAMLRefDocSchemaNames.AML_COMPUTE, YAMLRefDocLinks.AML_COMPUTE + ), + ComputeInstanceSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format( + YAMLRefDocSchemaNames.COMPUTE_INSTANCE, YAMLRefDocLinks.COMPUTE_INSTANCE + ), + VirtualMachineComputeSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format( + YAMLRefDocSchemaNames.VIRTUAL_MACHINE_COMPUTE, + YAMLRefDocLinks.VIRTUAL_MACHINE_COMPUTE, + ), + AzureDataLakeGen1Schema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format( + YAMLRefDocSchemaNames.DATASTORE_DATA_LAKE_GEN_1, + YAMLRefDocLinks.DATASTORE_DATA_LAKE_GEN_1, + ), + AzureBlobSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format( + YAMLRefDocSchemaNames.DATASTORE_BLOB, YAMLRefDocLinks.DATASTORE_BLOB + ), + AzureFileSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format( + YAMLRefDocSchemaNames.DATASTORE_FILE, YAMLRefDocLinks.DATASTORE_FILE + ), + AzureDataLakeGen2Schema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format( + YAMLRefDocSchemaNames.DATASTORE_DATA_LAKE_GEN_2, + YAMLRefDocLinks.DATASTORE_DATA_LAKE_GEN_2, + ), + BatchEndpointSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format( + YAMLRefDocSchemaNames.BATCH_ENDPOINT, YAMLRefDocLinks.BATCH_ENDPOINT + ), + KubernetesOnlineEndpointSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format( + YAMLRefDocSchemaNames.ONLINE_ENDPOINT, YAMLRefDocLinks.ONLINE_ENDPOINT + ), + ManagedOnlineEndpointSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format( + YAMLRefDocSchemaNames.ONLINE_ENDPOINT, YAMLRefDocLinks.ONLINE_ENDPOINT + ), + BatchDeploymentSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format( + YAMLRefDocSchemaNames.BATCH_DEPLOYMENT, YAMLRefDocLinks.BATCH_DEPLOYMENT + ), + ManagedOnlineDeploymentSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format( + YAMLRefDocSchemaNames.MANAGED_ONLINE_DEPLOYMENT, + YAMLRefDocLinks.MANAGED_ONLINE_DEPLOYMENT, + ), + KubernetesOnlineDeploymentSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format( + YAMLRefDocSchemaNames.KUBERNETES_ONLINE_DEPLOYMENT, + YAMLRefDocLinks.KUBERNETES_ONLINE_DEPLOYMENT, + ), + PipelineJobSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format( + YAMLRefDocSchemaNames.PIPELINE_JOB, YAMLRefDocLinks.PIPELINE_JOB + ), + JobScheduleSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format( + YAMLRefDocSchemaNames.JOB_SCHEDULE, YAMLRefDocLinks.JOB_SCHEDULE + ), + SweepJobSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format( + YAMLRefDocSchemaNames.SWEEP_JOB, YAMLRefDocLinks.SWEEP_JOB + ), + CommandJobSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format( + YAMLRefDocSchemaNames.COMMAND_JOB, YAMLRefDocLinks.COMMAND_JOB + ), + ParallelJobSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format( + YAMLRefDocSchemaNames.PARALLEL_JOB, YAMLRefDocLinks.PARALLEL_JOB + ), + WorkspaceSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format( + YAMLRefDocSchemaNames.WORKSPACE, YAMLRefDocLinks.WORKSPACE + ), +} + + +def find_field_in_override(field: str, params_override: Optional[list] = None) -> Optional[str]: + """Find specific field in params override. + + :param field: The name of the field to find + :type field: str + :param params_override: The params override + :type params_override: Optional[list] + :return: The type + :rtype: Optional[str] + """ + params_override = params_override or [] + for override in params_override: + if field in override: + res: Optional[str] = override[field] + return res + return None + + +def find_type_in_override(params_override: Optional[list] = None) -> Optional[str]: + """Find type in params override. + + :param params_override: The params override + :type params_override: Optional[list] + :return: The type + :rtype: Optional[str] + """ + return find_field_in_override(CommonYamlFields.TYPE, params_override) + + +def is_compute_in_override(params_override: Optional[list] = None) -> bool: + """Check if compute is in params override. + + :param params_override: The params override + :type params_override: Optional[list] + :return: True if compute is in params override + :rtype: bool + """ + if params_override is not None: + return any(EndpointYamlFields.COMPUTE in param for param in params_override) + return False + + +def load_from_dict(schema: Any, data: Dict, context: Dict, additional_message: str = "", **kwargs: Any) -> Any: + """Load data from dict. + + :param schema: The schema to load data with. + :type schema: Any + :param data: The data to load. + :type data: Dict + :param context: The context of the data. + :type context: Dict + :param additional_message: The additional message to add to the error message. + :type additional_message: str + :return: The loaded data. + :rtype: Any + """ + try: + return schema(context=context).load(data, **kwargs) + except ValidationError as e: + pretty_error = json.dumps(e.normalized_messages(), indent=2) + raise ValidationError(decorate_validation_error(schema, pretty_error, additional_message)) from e + + +def decorate_validation_error(schema: Any, pretty_error: str, additional_message: str = "") -> str: + """Decorate validation error with additional message. + + :param schema: The schema that failed validation. + :type schema: Any + :param pretty_error: The pretty error message. + :type pretty_error: str + :param additional_message: The additional message to add. + :type additional_message: str + :return: The decorated error message. + :rtype: str + """ + ref_doc_link_error_msg = REF_DOC_ERROR_MESSAGE_MAP.get(schema, "") + if ref_doc_link_error_msg: + additional_message += f"\n{ref_doc_link_error_msg}" + additional_message += ( + "\nThe easiest way to author a specification file is using IntelliSense and auto-completion Azure ML VS " + "code extension provides: https://code.visualstudio.com/docs/datascience/azure-machine-learning. " + "To set up: https://learn.microsoft.com/azure/machine-learning/how-to-setup-vs-code" + ) + return f"Validation for {schema.__name__} failed:\n\n {pretty_error} \n\n {additional_message}" + + +def get_md5_string(text: Optional[str]) -> str: + """Get md5 string for a given text. + + :param text: The text to get md5 string for. + :type text: str + :return: The md5 string. + :rtype: str + """ + try: + if text is not None: + return hashlib.md5(text.encode("utf8")).hexdigest() # nosec + return "" + except Exception as ex: + raise ex + + +def validate_attribute_type(attrs_to_check: Dict[str, Any], attr_type_map: Dict[str, Type]) -> None: + """Validate if attributes of object are set with valid types, raise error + if don't. + + :param attrs_to_check: Mapping from attributes name to actual value. + :type attrs_to_check: Dict[str, Any] + :param attr_type_map: Mapping from attributes name to tuple of expecting type + :type attr_type_map: Dict[str, Type] + """ + # + kwargs = attrs_to_check.get("kwargs", {}) + attrs_to_check.update(kwargs) + for attr, expecting_type in attr_type_map.items(): + attr_val = attrs_to_check.get(attr, None) + if attr_val is not None and not isinstance(attr_val, expecting_type): + msg = "Expecting {} for {}, got {} instead." + raise ValidationException( + message=msg.format(expecting_type, attr, type(attr_val)), + no_personal_data_message=msg.format(expecting_type, "[attr]", type(attr_val)), + target=ErrorTarget.GENERAL, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + +def is_empty_target(obj: Optional[Dict]) -> bool: + """Determines if it's empty target + + :param obj: The object to check + :type obj: Optional[Dict] + :return: True if obj is None or an empty Dict + :rtype: bool + """ + return ( + obj is None + # some objs have overloaded "==" and will cause error. e.g CommandComponent obj + or (isinstance(obj, dict) and len(obj) == 0) + ) + + +def convert_ordered_dict_to_dict(target_object: Union[Dict, List], remove_empty: bool = True) -> Union[Dict, List]: + """Convert ordered dict to dict. Remove keys with None value. + This is a workaround for rest request must be in dict instead of + ordered dict. + + :param target_object: The object to convert + :type target_object: Union[Dict, List] + :param remove_empty: Whether to omit values that are None or empty dictionaries. Defaults to True. + :type remove_empty: bool + :return: Converted ordered dict with removed None values + :rtype: Union[Dict, List] + """ + # OrderedDict can appear nested in a list + if isinstance(target_object, list): + new_list = [] + for item in target_object: + item = convert_ordered_dict_to_dict(item) + if not is_empty_target(item) or not remove_empty: + new_list.append(item) + return new_list + if isinstance(target_object, dict): + new_dict = {} + for key, value in target_object.items(): + value = convert_ordered_dict_to_dict(value) + if not is_empty_target(value) or not remove_empty: + new_dict[key] = value + return new_dict + return target_object + + +def _general_copy(src: Union[str, os.PathLike], dst: Union[str, os.PathLike], make_dirs: bool = True) -> None: + """Wrapped `shutil.copy2` function for possible "Function not implemented" exception raised by it. + + Background: `shutil.copy2` will throw OSError when dealing with Azure File. + See https://stackoverflow.com/questions/51616058 for more information. + + :param src: The source path to copy from + :type src: Union[str, os.PathLike] + :param dst: The destination path to copy to + :type dst: Union[str, os.PathLike] + :param make_dirs: Whether to ensure the destination path exists. Defaults to True. + :type make_dirs: bool + """ + if make_dirs: + os.makedirs(os.path.dirname(dst), exist_ok=True) + if hasattr(os, "listxattr"): + with mock.patch("shutil._copyxattr", return_value=[]): + shutil.copy2(src, dst) + else: + shutil.copy2(src, dst) + + +def _dump_data_binding_expression_in_fields(obj: Any) -> Any: + for key, value in obj.__dict__.items(): + # PipelineInput is subclass of NodeInput + from ._job.pipeline._io import NodeInput + + if isinstance(value, NodeInput): + obj.__dict__[key] = str(value) + elif isinstance(value, RestTranslatableMixin): + _dump_data_binding_expression_in_fields(value) + return obj + + +T = TypeVar("T") + + +def get_rest_dict_for_node_attrs( + target_obj: Union[T, str], clear_empty_value: bool = False +) -> Union[T, Dict, List, str, int, float, bool]: + """Convert object to dict and convert OrderedDict to dict. + Allow data binding expression as value, disregarding of the type defined in rest object. + + :param target_obj: The object to convert + :type target_obj: T + :param clear_empty_value: Whether to clear empty values. Defaults to False. + :type clear_empty_value: bool + :return: The translated dict, or the the original object + :rtype: Union[T, Dict] + """ + # pylint: disable=too-many-return-statements + from azure.ai.ml.entities._job.pipeline._io import PipelineInput + + if target_obj is None: + return None + if isinstance(target_obj, dict): + result_dict: dict = {} + for key, value in target_obj.items(): + if value is None: + continue + if key in ["additional_properties"]: + continue + result_dict[key] = get_rest_dict_for_node_attrs(value, clear_empty_value) + return result_dict + if isinstance(target_obj, list): + result_list: list = [] + for item in target_obj: + result_list.append(get_rest_dict_for_node_attrs(item, clear_empty_value)) + return result_list + if isinstance(target_obj, RestTranslatableMixin): + # note that the rest object may be invalid as data binding expression may not fit + # rest object structure + # pylint: disable=protected-access + _target_obj = _dump_data_binding_expression_in_fields(copy.deepcopy(target_obj)) + + from azure.ai.ml.entities._credentials import _BaseIdentityConfiguration + + if isinstance(_target_obj, _BaseIdentityConfiguration): + # TODO: Bug Item number: 2883348 + return get_rest_dict_for_node_attrs( + _target_obj._to_job_rest_object(), clear_empty_value=clear_empty_value # type: ignore + ) + return get_rest_dict_for_node_attrs(_target_obj._to_rest_object(), clear_empty_value=clear_empty_value) + + if isinstance(target_obj, msrest.serialization.Model): + # can't use result.as_dict() as data binding expression may not fit rest object structure + return get_rest_dict_for_node_attrs(target_obj.__dict__, clear_empty_value=clear_empty_value) + + if isinstance(target_obj, PipelineInput): + return get_rest_dict_for_node_attrs(str(target_obj), clear_empty_value=clear_empty_value) + + if not isinstance(target_obj, (str, int, float, bool)): + raise ValueError("Unexpected type {}".format(type(target_obj))) + + return target_obj + + +class _DummyRestModelFromDict(msrest.serialization.Model): + """A dummy rest model that can be initialized from dict, return base_dict[attr_name] + for getattr(self, attr_name) when attr_name is a public attrs; return None when trying to get + a non-existent public attribute. + """ + + def __init__(self, rest_dict: Optional[dict]): + self._rest_dict = rest_dict or {} + super().__init__() + + def __getattribute__(self, item: str) -> Any: + if not item.startswith("_"): + return self._rest_dict.get(item, None) + return super().__getattribute__(item) + + +def from_rest_dict_to_dummy_rest_object(rest_dict: Optional[Dict]) -> _DummyRestModelFromDict: + """Create a dummy rest object based on a rest dict, which is a primitive dict containing + attributes in a rest object. + For example, for a rest object class like: + class A(msrest.serialization.Model): + def __init__(self, a, b): + self.a = a + self.b = b + rest_object = A(1, None) + rest_dict = {"a": 1} + regenerated_rest_object = from_rest_dict_to_fake_rest_object(rest_dict) + assert regenerated_rest_object.a == 1 + assert regenerated_rest_object.b is None + + :param rest_dict: The rest dict + :type rest_dict: Optional[Dict] + :return: A dummy rest object + :rtype: _DummyRestModelFromDict + """ + if rest_dict is None or isinstance(rest_dict, dict): + return _DummyRestModelFromDict(rest_dict) + raise ValueError("Unexpected type {}".format(type(rest_dict))) + + +def extract_label(input_str: str) -> Union[Tuple, List]: + """Extract label from input string. + + :param input_str: The input string + :type input_str: str + :return: The rest of the string and the label + :rtype: Tuple[str, Optional[str]] + """ + if not isinstance(input_str, str): + return None, None + if "@" in input_str: + return input_str.rsplit("@", 1) + return input_str, None + + +@overload +def resolve_pipeline_parameters(pipeline_parameters: None, remove_empty: bool = False) -> None: ... + + +@overload +def resolve_pipeline_parameters( + pipeline_parameters: Dict[str, T], remove_empty: bool = False +) -> Dict[str, Union[T, str, "NodeOutput"]]: ... + + +def resolve_pipeline_parameters(pipeline_parameters: Optional[Dict], remove_empty: bool = False) -> Optional[Dict]: + """Resolve pipeline parameters. + + 1. Resolve BaseNode and OutputsAttrDict type to NodeOutput. + 2. Remove empty value (optional). + + :param pipeline_parameters: The pipeline parameters + :type pipeline_parameters: Optional[Dict[str, T]] + :param remove_empty: Whether to remove None values. Defaults to False. + :type remove_empty: bool + :return: + * None if pipeline_parameters is None + * The resolved dict of pipeline parameters + :rtype: Optional[Dict[str, Union[T, str, "NodeOutput"]]] + """ + + if pipeline_parameters is None: + return None + if not isinstance(pipeline_parameters, dict): + raise ValidationException( + message="pipeline_parameters must in dict {parameter: value} format.", + no_personal_data_message="pipeline_parameters must in dict {parameter: value} format.", + target=ErrorTarget.PIPELINE, + ) + + updated_parameters = {} + for k, v in pipeline_parameters.items(): + v = resolve_pipeline_parameter(v) + if v is None and remove_empty: + continue + updated_parameters[k] = v + pipeline_parameters = updated_parameters + return pipeline_parameters + + +def resolve_pipeline_parameter(data: Any) -> Union[T, str, "NodeOutput"]: + """Resolve pipeline parameter. + 1. Resolve BaseNode and OutputsAttrDict type to NodeOutput. + 2. Remove empty value (optional). + :param data: The pipeline parameter + :type data: T + :return: + * None if data is None + * The resolved pipeline parameter + :rtype: Union[T, str, "NodeOutput"] + """ + from azure.ai.ml.entities._builders.base_node import BaseNode + from azure.ai.ml.entities._builders.pipeline import Pipeline + from azure.ai.ml.entities._job.pipeline._io import NodeOutput, OutputsAttrDict + from azure.ai.ml.entities._job.pipeline._pipeline_expression import PipelineExpression + + if isinstance(data, PipelineExpression): + data = cast(Union[str, BaseNode], data.resolve()) + if isinstance(data, (BaseNode, Pipeline)): + # For the case use a node/pipeline node as the input, we use its only one output as the real input. + # Here we set node = node.outputs, then the following logic will get the output object. + data = cast(OutputsAttrDict, data.outputs) + if isinstance(data, OutputsAttrDict): + # For the case that use the outputs of another component as the input, + # we use the only one output as the real input, + # if multiple outputs are provided, an exception is raised. + output_len = len(data) + if output_len != 1: + raise ValidationException( + message="Setting input failed: Exactly 1 output is required, got %d. (%s)" % (output_len, data), + no_personal_data_message="multiple output(s) found of specified outputs, exactly 1 output required.", + target=ErrorTarget.PIPELINE, + ) + data = cast(NodeOutput, list(data.values())[0]) + return cast(Union[T, str, "NodeOutput"], data) + + +def normalize_job_input_output_type(input_output_value: Union[RestJobOutput, RestJobInput, Dict]) -> None: + """Normalizes the `job_input_type`, `job_output_type`, and `type` keys for REST job output and input objects. + + :param input_output_value: Either a REST input or REST output of a job + :type input_output_value: Union[RestJobOutput, RestJobInput, Dict] + + .. note:: + + We have changed the api starting v2022_06_01_preview version and there are some api interface changes, + which will result in pipeline submitted by v2022_02_01_preview can't be parsed correctly. And this will block + az ml job list/show. So we convert the input/output type of camel to snake to be compatible with the Jun/Oct + api. + + """ + + FEB_JUN_JOB_INPUT_OUTPUT_TYPE_MAPPING = { + JobInputType02.CUSTOM_MODEL: JobInputType10.CUSTOM_MODEL, + JobInputType02.LITERAL: JobInputType10.LITERAL, + JobInputType02.ML_FLOW_MODEL: JobInputType10.MLFLOW_MODEL, + JobInputType02.ML_TABLE: JobInputType10.MLTABLE, + JobInputType02.TRITON_MODEL: JobInputType10.TRITON_MODEL, + JobInputType02.URI_FILE: JobInputType10.URI_FILE, + JobInputType02.URI_FOLDER: JobInputType10.URI_FOLDER, + } + if ( + hasattr(input_output_value, "job_input_type") + and input_output_value.job_input_type in FEB_JUN_JOB_INPUT_OUTPUT_TYPE_MAPPING + ): + input_output_value.job_input_type = FEB_JUN_JOB_INPUT_OUTPUT_TYPE_MAPPING[input_output_value.job_input_type] + elif ( + hasattr(input_output_value, "job_output_type") + and input_output_value.job_output_type in FEB_JUN_JOB_INPUT_OUTPUT_TYPE_MAPPING + ): + input_output_value.job_output_type = FEB_JUN_JOB_INPUT_OUTPUT_TYPE_MAPPING[input_output_value.job_output_type] + elif isinstance(input_output_value, dict): + job_output_type = input_output_value.get("job_output_type", None) + job_input_type = input_output_value.get("job_input_type", None) + job_type = input_output_value.get("type", None) + + if job_output_type and job_output_type in FEB_JUN_JOB_INPUT_OUTPUT_TYPE_MAPPING: + input_output_value["job_output_type"] = FEB_JUN_JOB_INPUT_OUTPUT_TYPE_MAPPING[job_output_type] + if job_input_type and job_input_type in FEB_JUN_JOB_INPUT_OUTPUT_TYPE_MAPPING: + input_output_value["job_input_type"] = FEB_JUN_JOB_INPUT_OUTPUT_TYPE_MAPPING[job_input_type] + if job_type and job_type in FEB_JUN_JOB_INPUT_OUTPUT_TYPE_MAPPING: + input_output_value["type"] = FEB_JUN_JOB_INPUT_OUTPUT_TYPE_MAPPING[job_type] + + +def get_type_from_spec(data: dict, *, valid_keys: Iterable[str]) -> str: + """Get the type of the node or component from the yaml spec. + + Yaml spec must have a key named "type" and exception will be raised if it's not once of valid_keys. + + If internal components are enabled, related factory and schema will be updated. + + :param data: The data + :type data: dict + :keyword valid_keys: An iterable of valid types + :paramtype valid_keys: Iterable[str] + :return: The type of the node or component + :rtype: str + """ + _type, _ = extract_label(data.get(CommonYamlFields.TYPE, None)) + + # we should keep at least 1 place outside _internal to enable internal components + # and this is the only place + try_enable_internal_components() + # todo: refine Hard code for now to support different task type for DataTransfer component + if _type == NodeType.DATA_TRANSFER: + _type = "_".join([NodeType.DATA_TRANSFER, data.get("task", " ")]) + if _type not in valid_keys: + is_internal_component_data(data, raise_if_not_enabled=True) + + raise ValidationException( + message="Unsupported component type: %s." % _type, + target=ErrorTarget.COMPONENT, + no_personal_data_message="Unsupported component type", + error_category=ErrorCategory.USER_ERROR, + ) + res: str = _type + return res + + +def copy_output_setting(source: Union["Output", "NodeOutput"], target: "NodeOutput") -> None: + """Copy node output setting from source to target. + + Currently only path, name, version will be copied. + + :param source: The Output to copy from + :type source: Union[Output, NodeOutput] + :param target: The Output to copy to + :type target: NodeOutput + """ + # pylint: disable=protected-access + from azure.ai.ml.entities._job.pipeline._io import NodeOutput, PipelineOutput + + if not isinstance(source, NodeOutput): + # Only copy when source is an output builder + return + source_data = source._data + if isinstance(source_data, PipelineOutput): + source_data = source_data._data + if source_data: + target._data = copy.deepcopy(source_data) + # copy pipeline component output's node output to subgraph builder + if source._binding_output is not None: + target._binding_output = source._binding_output diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validate_funcs.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validate_funcs.py new file mode 100644 index 00000000..8a082cb5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validate_funcs.py @@ -0,0 +1,94 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access +from os import PathLike +from typing import IO, Any, AnyStr, Callable, Dict, List, Optional, Union, cast + +from marshmallow import ValidationError + +from azure.ai.ml import MLClient + +from ..exceptions import ValidationException +from . import Component, Job +from ._load_functions import _load_common_raising_marshmallow_error, _try_load_yaml_dict +from ._validation import PathAwareSchemaValidatableMixin, ValidationResult, ValidationResultBuilder + + +def validate_common( + cls: Any, + path: Union[str, PathLike, IO[AnyStr]], + validate_func: Optional[Callable], + params_override: Optional[List[Dict]] = None, +) -> ValidationResult: + params_override = params_override or [] + yaml_dict = _try_load_yaml_dict(path) + + try: + cls, _ = cls._resolve_cls_and_type(data=yaml_dict, params_override=params_override) + + entity = _load_common_raising_marshmallow_error( + cls=cls, yaml_dict=yaml_dict, relative_origin=path, params_override=params_override + ) + + if validate_func is not None: + res = cast(ValidationResult, validate_func(entity)) + return res + if isinstance(entity, PathAwareSchemaValidatableMixin): + return entity._validate() + return ValidationResultBuilder.success() + except ValidationException as err: + return ValidationResultBuilder.from_single_message(err.message) + except ValidationError as err: + return ValidationResultBuilder.from_validation_error(err, source_path=path) + + +def validate_component( + path: Union[str, PathLike, IO[AnyStr]], + ml_client: Optional[MLClient] = None, + params_override: Optional[List[Dict]] = None, +) -> ValidationResult: + """Validate a component defined in a local file. + + :param path: The path to the component definition file. + :type path: Union[str, PathLike, IO[AnyStr]] + :param ml_client: The client to use for validation. Will skip remote validation if None. + :type ml_client: azure.ai.ml.core.AzureMLComputeClient + :param params_override: Fields to overwrite on top of the yaml file. + Format is [{"field1": "value1"}, {"field2": "value2"}] + :type params_override: List[Dict] + :return: The validation result. + :rtype: ValidationResult + """ + return validate_common( + cls=Component, + path=path, + validate_func=ml_client.components.validate if ml_client is not None else None, + params_override=params_override, + ) + + +def validate_job( + path: Union[str, PathLike, IO[AnyStr]], + ml_client: Optional[MLClient] = None, + params_override: Optional[List[Dict]] = None, +) -> ValidationResult: + """Validate a job defined in a local file. + + :param path: The path to the job definition file. + :type path: str + :param ml_client: The client to use for validation. Will skip remote validation if None. + :type ml_client: azure.ai.ml.core.AzureMLComputeClient + :param params_override: Fields to overwrite on top of the yaml file. + Format is [{"field1": "value1"}, {"field2": "value2"}] + :type params_override: List[Dict] + :return: The validation result. + :rtype: ValidationResult + """ + return validate_common( + cls=Job, + path=path, + validate_func=ml_client.jobs.validate if ml_client is not None else None, + params_override=params_override, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/__init__.py new file mode 100644 index 00000000..29ba05c5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/__init__.py @@ -0,0 +1,18 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +from .core import MutableValidationResult, ValidationResult, ValidationResultBuilder +from .path_aware_schema import PathAwareSchemaValidatableMixin +from .remote import RemoteValidatableMixin +from .schema import SchemaValidatableMixin + +__all__ = [ + "SchemaValidatableMixin", + "PathAwareSchemaValidatableMixin", + "RemoteValidatableMixin", + "MutableValidationResult", + "ValidationResult", + "ValidationResultBuilder", +] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/core.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/core.py new file mode 100644 index 00000000..a7516c1d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/core.py @@ -0,0 +1,531 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import copy +import json +import logging +import os.path +import typing +from os import PathLike +from pathlib import Path +from typing import IO, Any, AnyStr, Dict, List, Optional, Tuple, Union, cast + +import pydash +import strictyaml +from marshmallow import ValidationError + +module_logger = logging.getLogger(__name__) + + +class _ValidationStatus: + """Validation status class. + + Validation status is used to indicate the status of an validation result. It can be one of the following values: + Succeeded, Failed. + """ + + SUCCEEDED = "Succeeded" + """Succeeded.""" + FAILED = "Failed" + """Failed.""" + + +class Diagnostic(object): + """Represents a diagnostic of an asset validation error with the location info.""" + + def __init__(self, yaml_path: str, message: Optional[str], error_code: Optional[str]) -> None: + """Init Diagnostic. + + :keyword yaml_path: A dash path from root to the target element of the diagnostic. jobs.job_a.inputs.input_str + :paramtype yaml_path: str + :keyword message: Error message of diagnostic. + :paramtype message: str + :keyword error_code: Error code of diagnostic. + :paramtype error_code: str + """ + self.yaml_path = yaml_path + self.message = message + self.error_code = error_code + self.local_path, self.value = None, None + + def __repr__(self) -> str: + """The asset friendly name and error message. + + :return: The formatted diagnostic + :rtype: str + """ + return "{}: {}".format(self.yaml_path, self.message) + + @classmethod + def create_instance( + cls, + yaml_path: str, + message: Optional[str] = None, + error_code: Optional[str] = None, + ) -> "Diagnostic": + """Create a diagnostic instance. + + :param yaml_path: A dash path from root to the target element of the diagnostic. jobs.job_a.inputs.input_str + :type yaml_path: str + :param message: Error message of diagnostic. + :type message: str + :param error_code: Error code of diagnostic. + :type error_code: str + :return: The created instance + :rtype: Diagnostic + """ + return cls( + yaml_path=yaml_path, + message=message, + error_code=error_code, + ) + + +class ValidationResult(object): + """Represents the result of job/asset validation. + + This class is used to organize and parse diagnostics from both client & server side before expose them. The result + is immutable. + """ + + def __init__(self) -> None: + self._target_obj: Optional[Dict] = None + self._errors: List = [] + self._warnings: List = [] + + @property + def error_messages(self) -> Dict: + """ + Return all messages of errors in the validation result. + + :return: A dictionary of error messages. The key is the yaml path of the error, and the value is the error + message. + :rtype: dict + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START validation_result] + :end-before: [END validation_result] + :language: markdown + :dedent: 8 + """ + messages = {} + for diagnostic in self._errors: + if diagnostic.yaml_path not in messages: + messages[diagnostic.yaml_path] = diagnostic.message + else: + messages[diagnostic.yaml_path] += "; " + diagnostic.message + return messages + + @property + def passed(self) -> bool: + """Returns boolean indicating whether any errors were found. + + :return: True if the validation passed, False otherwise. + :rtype: bool + """ + return not self._errors + + def _to_dict(self) -> typing.Dict[str, typing.Any]: + result: Dict = { + "result": _ValidationStatus.SUCCEEDED if self.passed else _ValidationStatus.FAILED, + } + for diagnostic_type, diagnostics in [ + ("errors", self._errors), + ("warnings", self._warnings), + ]: + messages = [] + for diagnostic in diagnostics: + message = { + "message": diagnostic.message, + "path": diagnostic.yaml_path, + "value": pydash.get(self._target_obj, diagnostic.yaml_path, diagnostic.value), + } + if diagnostic.local_path: + message["location"] = str(diagnostic.local_path) + messages.append(message) + if messages: + result[diagnostic_type] = messages + return result + + def __repr__(self) -> str: + """Get the string representation of the validation result. + + :return: The string representation + :rtype: str + """ + return json.dumps(self._to_dict(), indent=2) + + +class MutableValidationResult(ValidationResult): + """Used by the client side to construct a validation result. + + The result is mutable and should not be exposed to the user. + """ + + def __init__(self, target_obj: Optional[Dict] = None): + super().__init__() + self._target_obj = target_obj + + def merge_with( + self, + target: ValidationResult, + field_name: Optional[str] = None, + condition_skip: Optional[typing.Callable] = None, + overwrite: bool = False, + ) -> "MutableValidationResult": + """Merge errors & warnings in another validation results into current one. + + Will update current validation result. + If field_name is not None, then yaml_path in the other validation result will be updated accordingly. + * => field_name, jobs.job_a => field_name.jobs.job_a e.g.. If None, then no update. + + :param target: Validation result to merge. + :type target: ValidationResult + :param field_name: The base field name for the target to merge. + :type field_name: str + :param condition_skip: A function to determine whether to skip the merge of a diagnostic in the target. + :type condition_skip: typing.Callable + :param overwrite: Whether to overwrite the current validation result. If False, all diagnostics will be kept; + if True, current diagnostics with the same yaml_path will be dropped. + :type overwrite: bool + :return: The current validation result. + :rtype: MutableValidationResult + """ + for source_diagnostics, target_diagnostics in [ + (target._errors, self._errors), + (target._warnings, self._warnings), + ]: + if overwrite: + keys_to_remove = set(map(lambda x: x.yaml_path, source_diagnostics)) + target_diagnostics[:] = [ + diagnostic for diagnostic in target_diagnostics if diagnostic.yaml_path not in keys_to_remove + ] + for diagnostic in source_diagnostics: + if condition_skip and condition_skip(diagnostic): + continue + new_diagnostic = copy.deepcopy(diagnostic) + if field_name: + if new_diagnostic.yaml_path == "*": + new_diagnostic.yaml_path = field_name + else: + new_diagnostic.yaml_path = field_name + "." + new_diagnostic.yaml_path + target_diagnostics.append(new_diagnostic) + return self + + def try_raise( + self, + raise_error: Optional[bool] = True, + *, + error_func: Optional[typing.Callable[[str, str], Exception]] = None, + ) -> "MutableValidationResult": + """Try to raise an error from the validation result. + + If the validation is passed or raise_error is False, this method + will return the validation result. + + :param raise_error: Whether to raise the error. + :type raise_error: bool + :keyword error_func: A function to create the error. If None, a marshmallow.ValidationError will be created. + The first parameter of the function is the string representation of the validation result, + and the second parameter is the error message without personal data. + :type error_func: typing.Callable[[str, str], Exception] + :return: The current validation result. + :rtype: MutableValidationResult + """ + # pylint: disable=logging-not-lazy + if raise_error is False: + return self + + if self._warnings: + module_logger.warning("Warnings: %s" % str(self._warnings)) + + if not self.passed: + if error_func is None: + + def error_func(msg: Union[str, list, dict], _: Any) -> ValidationError: + return ValidationError(message=msg) + + raise error_func( + self.__repr__(), + "validation failed on the following fields: " + ", ".join(self.error_messages), + ) + return self + + def append_error( + self, + yaml_path: str = "*", + message: Optional[str] = None, + error_code: Optional[str] = None, + ) -> "MutableValidationResult": + """Append an error to the validation result. + + :param yaml_path: The yaml path of the error. + :type yaml_path: str + :param message: The message of the error. + :type message: str + :param error_code: The error code of the error. + :type error_code: str + :return: The current validation result. + :rtype: MutableValidationResult + """ + self._errors.append( + Diagnostic.create_instance( + yaml_path=yaml_path, + message=message, + error_code=error_code, + ) + ) + return self + + def resolve_location_for_diagnostics(self, source_path: str, resolve_value: bool = False) -> None: + """Resolve location/value for diagnostics based on the source path where the validatable object is loaded. + + Location includes local path of the exact file (can be different from the source path) & line number of the + invalid field. Value of a diagnostic is resolved from the validatable object in transfering to a dict by + default; however, when the validatable object is not available for the validation result, validation result is + created from marshmallow.ValidationError.messages e.g., it can be resolved from the source path. + + :param source_path: The path of the source file. + :type source_path: str + :param resolve_value: Whether to resolve the value of the invalid field from source file. + :type resolve_value: bool + """ + resolver = _YamlLocationResolver(source_path) + for diagnostic in self._errors + self._warnings: + res = resolver.resolve(diagnostic.yaml_path) + if res is not None: + diagnostic.local_path, value = res + if value is not None and resolve_value: + diagnostic.value = value + + def append_warning( + self, + yaml_path: str = "*", + message: Optional[str] = None, + error_code: Optional[str] = None, + ) -> "MutableValidationResult": + """Append a warning to the validation result. + + :param yaml_path: The yaml path of the warning. + :type yaml_path: str + :param message: The message of the warning. + :type message: str + :param error_code: The error code of the warning. + :type error_code: str + :return: The current validation result. + :rtype: MutableValidationResult + """ + self._warnings.append( + Diagnostic.create_instance( + yaml_path=yaml_path, + message=message, + error_code=error_code, + ) + ) + return self + + +class ValidationResultBuilder: + """A helper class to create a validation result.""" + + UNKNOWN_MESSAGE = "Unknown field." + + def __init__(self) -> None: + pass + + @classmethod + def success(cls) -> MutableValidationResult: + """Create a validation result with success status. + + :return: A validation result + :rtype: MutableValidationResult + """ + return MutableValidationResult() + + @classmethod + def from_single_message( + cls, singular_error_message: Optional[str] = None, yaml_path: str = "*", data: Optional[dict] = None + ) -> MutableValidationResult: + """Create a validation result with only 1 diagnostic. + + :param singular_error_message: diagnostic.message. + :type singular_error_message: Optional[str] + :param yaml_path: diagnostic.yaml_path. + :type yaml_path: str + :param data: serializedvalidation target. + :type data: Optional[Dict] + :return: The validation result + :rtype: MutableValidationResult + """ + obj = MutableValidationResult(target_obj=data) + if singular_error_message: + obj.append_error(message=singular_error_message, yaml_path=yaml_path) + return obj + + @classmethod + def from_validation_error( + cls, + error: ValidationError, + *, + source_path: Optional[Union[str, PathLike, IO[AnyStr]]] = None, + error_on_unknown_field: bool = False, + ) -> MutableValidationResult: + """Create a validation result from a ValidationError, which will be raised in marshmallow.Schema.load. Please + use this function only for exception in loading file. + + :param error: ValidationError raised by marshmallow.Schema.load. + :type error: ValidationError + :keyword source_path: The path to the source file. + :paramtype source_path: Optional[Union[str, PathLike, IO[AnyStr]]] + :keyword error_on_unknown_field: whether to raise error if there are unknown field diagnostics. + :paramtype error_on_unknown_field: bool + :return: The validation result + :rtype: MutableValidationResult + """ + obj = cls.from_validation_messages( + error.messages, data=error.data, error_on_unknown_field=error_on_unknown_field + ) + if source_path: + obj.resolve_location_for_diagnostics(cast(str, source_path), resolve_value=True) + return obj + + @classmethod + def from_validation_messages( + cls, errors: typing.Dict, data: typing.Dict, *, error_on_unknown_field: bool = False + ) -> MutableValidationResult: + """Create a validation result from error messages, which will be returned by marshmallow.Schema.validate. + + :param errors: error message returned by marshmallow.Schema.validate. + :type errors: dict + :param data: serialized data to validate + :type data: dict + :keyword error_on_unknown_field: whether to raise error if there are unknown field diagnostics. + :paramtype error_on_unknown_field: bool + :return: The validation result + :rtype: MutableValidationResult + """ + instance = MutableValidationResult(target_obj=data) + errors = copy.deepcopy(errors) + cls._from_validation_messages_recursively(errors, [], instance, error_on_unknown_field=error_on_unknown_field) + return instance + + @classmethod + def _from_validation_messages_recursively( + cls, + errors: typing.Union[typing.Dict, typing.List, str], + path_stack: typing.List[str], + instance: MutableValidationResult, + error_on_unknown_field: bool, + ) -> None: + cur_path = ".".join(path_stack) if path_stack else "*" + # single error message + if isinstance(errors, dict) and "_schema" in errors: + instance.append_error( + message=";".join(errors["_schema"]), + yaml_path=cur_path, + ) + # errors on attributes + elif isinstance(errors, dict): + for field, msgs in errors.items(): + # fields.Dict + if field in ["key", "value"]: + cls._from_validation_messages_recursively(msgs, path_stack, instance, error_on_unknown_field) + else: + # Todo: Add hack logic here to deal with error message in nested TypeSensitiveUnionField in + # DataTransfer: will be a nested dict with None field as dictionary key. + # open a item to track: https://msdata.visualstudio.com/Vienna/_workitems/edit/2244262/ + if field is None: + cls._from_validation_messages_recursively(msgs, path_stack, instance, error_on_unknown_field) + else: + path_stack.append(field) + cls._from_validation_messages_recursively(msgs, path_stack, instance, error_on_unknown_field) + path_stack.pop() + + # detailed error message + elif isinstance(errors, list) and all(isinstance(msg, str) for msg in errors): + if cls.UNKNOWN_MESSAGE in errors and not error_on_unknown_field: + # Unknown field is not a real error, so we should remove it and append a warning. + errors.remove(cls.UNKNOWN_MESSAGE) + instance.append_warning(message=cls.UNKNOWN_MESSAGE, yaml_path=cur_path) + if errors: + instance.append_error(message=";".join(errors), yaml_path=cur_path) + # union field + elif isinstance(errors, list): + + def msg2str(msg: Any) -> Any: + if isinstance(msg, str): + return msg + if isinstance(msg, dict) and len(msg) == 1 and "_schema" in msg and len(msg["_schema"]) == 1: + return str(msg["_schema"][0]) + + return str(msg) + + instance.append_error(message="; ".join([msg2str(x) for x in errors]), yaml_path=cur_path) + # unknown error + else: + instance.append_error(message=str(errors), yaml_path=cur_path) + + +class _YamlLocationResolver: + def __init__(self, source_path: str): + self._source_path = source_path + + def resolve(self, yaml_path: str, source_path: Optional[str] = None) -> Optional[Tuple]: + """Resolve the location & value of a yaml path starting from source_path. + + :param yaml_path: yaml path. + :type yaml_path: str + :param source_path: source path. + :type source_path: str + :return: the location & value of the yaml path based on source_path. + :rtype: Tuple[str, str] + """ + source_path = source_path or self._source_path + if source_path is None or not os.path.isfile(source_path): + return None, None + if yaml_path is None or yaml_path == "*": + return source_path, None + + attrs = yaml_path.split(".") + attrs.reverse() + + res: Optional[Tuple] = self._resolve_recursively(attrs, Path(source_path)) + return res + + def _resolve_recursively(self, attrs: List[str], source_path: Path) -> Optional[Tuple]: + with open(source_path, encoding="utf-8") as f: + try: + loaded_yaml = strictyaml.load(f.read()) + except Exception as e: # pylint: disable=W0718 + msg = "Can't load source file %s as a strict yaml:\n%s" % (source_path, str(e)) + module_logger.debug(msg) + return None, None + + while attrs: + attr = attrs[-1] + if loaded_yaml.is_mapping() and attr in loaded_yaml: + loaded_yaml = loaded_yaml.get(attr) + attrs.pop() + else: + try: + # if current object is a path of a valid yaml file, try to resolve location in new source file + next_path = Path(loaded_yaml.value) + if not next_path.is_absolute(): + next_path = source_path.parent / next_path + if next_path.is_file(): + return self._resolve_recursively(attrs, source_path=next_path) + except OSError: + pass + except TypeError: + pass + # if not, return current section + break + return ( + f"{source_path.resolve().absolute()}#line {loaded_yaml.start_line}", + None if attrs else loaded_yaml.value, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/path_aware_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/path_aware_schema.py new file mode 100644 index 00000000..959de310 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/path_aware_schema.py @@ -0,0 +1,53 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import typing +from os import PathLike +from pathlib import Path + +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY + +from ..._schema import PathAwareSchema +from .._job.pipeline._attr_dict import try_get_non_arbitrary_attr +from .._util import convert_ordered_dict_to_dict +from .schema import SchemaValidatableMixin + + +class PathAwareSchemaValidatableMixin(SchemaValidatableMixin): + """The mixin class for schema validation. Entity classes inheriting from this class should have a base path + and a schema of PathAwareSchema. + """ + + @property + def __base_path_for_validation(self) -> typing.Union[str, PathLike]: + """Get the base path of the resource. + + It will try to return self.base_path, then self._base_path, then Path.cwd() if above attrs are non-existent or + `None. + + :return: The base path of the resource + :rtype: typing.Union[str, os.PathLike] + """ + return ( + try_get_non_arbitrary_attr(self, BASE_PATH_CONTEXT_KEY) + or try_get_non_arbitrary_attr(self, f"_{BASE_PATH_CONTEXT_KEY}") + or Path.cwd() + ) + + def _default_context(self) -> dict: + # Note that, although context can be passed, nested.schema will be initialized only once + # base_path works well because it's fixed after loaded + return {BASE_PATH_CONTEXT_KEY: self.__base_path_for_validation} + + @classmethod + def _create_schema_for_validation(cls, context: typing.Any) -> PathAwareSchema: + raise NotImplementedError() + + @classmethod + def _create_validation_error(cls, message: str, no_personal_data_message: str) -> Exception: + raise NotImplementedError() + + def _dump_for_validation(self) -> typing.Dict: + # this is not a necessary step but to keep the same behavior as before + # empty items will be removed when converting to dict + return typing.cast(dict, convert_ordered_dict_to_dict(super()._dump_for_validation())) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/remote.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/remote.py new file mode 100644 index 00000000..06f022a0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/remote.py @@ -0,0 +1,162 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import logging +import typing + +import msrest + +from azure.ai.ml._vendor.azure_resources.models import ( + Deployment, + DeploymentProperties, + DeploymentValidateResult, + ErrorResponse, +) +from azure.ai.ml.entities._mixins import RestTranslatableMixin + +from .core import MutableValidationResult, ValidationResultBuilder + +module_logger = logging.getLogger(__name__) + + +class PreflightResource(msrest.serialization.Model): + """Specified resource. + + Variables are only populated by the server, and will be ignored when sending a request. + + :ivar id: Resource ID. + :vartype id: str + :ivar name: Resource name. + :vartype name: str + :ivar type: Resource type. + :vartype type: str + :param location: Resource location. + :type location: str + :param tags: A set of tags. Resource tags. + :type tags: dict[str, str] + """ + + _attribute_map = { + "type": {"key": "type", "type": "str"}, + "name": {"key": "name", "type": "str"}, + "location": {"key": "location", "type": "str"}, + "api_version": {"key": "apiversion", "type": "str"}, + "properties": {"key": "properties", "type": "object"}, + } + + def __init__(self, **kwargs: typing.Any): + super(PreflightResource, self).__init__(**kwargs) + self.name = kwargs.get("name", None) + self.type = kwargs.get("type", None) + self.location = kwargs.get("location", None) + self.properties = kwargs.get("properties", None) + self.api_version = kwargs.get("api_version", None) + + +class ValidationTemplateRequest(msrest.serialization.Model): + """Export resource group template request parameters. + + :param resources: The rest objects to be validated. + :type resources: list[_models.Resource] + :param options: The export template options. A CSV-formatted list containing zero or more of + the following: 'IncludeParameterDefaultValue', 'IncludeComments', + 'SkipResourceNameParameterization', 'SkipAllParameterization'. + :type options: str + """ + + _attribute_map = { + "resources": {"key": "resources", "type": "[PreflightResource]"}, + "content_version": {"key": "contentVersion", "type": "str"}, + "parameters": {"key": "parameters", "type": "object"}, + "_schema": { + "key": "$schema", + "type": "str", + "default": "https://schema.management.azure.com/schemas/2019-04-01/deploymentTemplate.json#", + }, + } + + def __init__(self, **kwargs: typing.Any): + super(ValidationTemplateRequest, self).__init__(**kwargs) + self._schema = kwargs.get("_schema", None) + self.content_version = kwargs.get("content_version", None) + self.parameters = kwargs.get("parameters", None) + self.resources = kwargs.get("resources", None) + + +class RemoteValidatableMixin(RestTranslatableMixin): + @classmethod + def _get_resource_type(cls) -> str: + """Return resource type to be used in remote validation. + + Should be overridden by subclass. + + :return: The resource type + :rtype: str + """ + raise NotImplementedError() + + def _get_resource_name_version(self) -> typing.Tuple: + """Return resource name and version to be used in remote validation. + + Should be overridden by subclass. + + :return: The name and version + :rtype: typing.Tuple[str, str] + """ + raise NotImplementedError() + + def _to_preflight_resource(self, location: str, workspace_name: str) -> PreflightResource: + """Return the preflight resource to be used in remote validation. + + :param location: The location of the resource. + :type location: str + :param workspace_name: The workspace name + :type workspace_name: str + :return: The preflight resource + :rtype: PreflightResource + """ + name, version = self._get_resource_name_version() + return PreflightResource( + type=self._get_resource_type(), + name=f"{workspace_name}/{name}/{version}", + location=location, + properties=self._to_rest_object().properties, + api_version="2023-03-01-preview", + ) + + def _build_rest_object_for_remote_validation(self, location: str, workspace_name: str) -> Deployment: + return Deployment( + properties=DeploymentProperties( + mode="Incremental", + template=ValidationTemplateRequest( + _schema="https://schema.management.azure.com/schemas/2019-04-01/deploymentTemplate.json#", + content_version="1.0.0.0", + parameters={}, + resources=[self._to_preflight_resource(location=location, workspace_name=workspace_name)], + ), + ) + ) + + @classmethod + def _build_validation_result_from_rest_object(cls, rest_obj: DeploymentValidateResult) -> MutableValidationResult: + """Create a validation result from a rest object. Note that the created validation result does not have + target_obj so should only be used for merging. + + :param rest_obj: The Deployment Validate REST obj + :type rest_obj: DeploymentValidateResult + :return: The validation result created from rest_obj + :rtype: MutableValidationResult + """ + if not rest_obj.error or not rest_obj.error.details: + return ValidationResultBuilder.success() + result = MutableValidationResult(target_obj=None) + details: typing.List[ErrorResponse] = rest_obj.error.details + for detail in details: + result.append_error( + message=detail.message, + yaml_path=detail.target.replace("/", "."), + error_code=detail.code, + # will always be UserError for now, not sure if innerError can be passed back + ) + return result diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/schema.py new file mode 100644 index 00000000..9e34173d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/schema.py @@ -0,0 +1,156 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import json +import logging +import typing + +from marshmallow import Schema, ValidationError + +from .core import MutableValidationResult, ValidationResultBuilder + +module_logger = logging.getLogger(__name__) + + +class SchemaValidatableMixin: + """The mixin class for schema validation.""" + + @classmethod + def _create_empty_validation_result(cls) -> MutableValidationResult: + """Simply create an empty validation result + + To reduce _ValidationResultBuilder importing, which is a private class. + + :return: An empty validation result + :rtype: MutableValidationResult + """ + return ValidationResultBuilder.success() + + @classmethod + def _load_with_schema( + cls, data: typing.Any, *, context: typing.Any, raise_original_exception: bool = False, **kwargs: typing.Any + ) -> typing.Any: + schema = cls._create_schema_for_validation(context=context) + + try: + return schema.load(data, **kwargs) + except ValidationError as e: + if raise_original_exception: + raise e + msg = "Trying to load data with schema failed. Data:\n%s\nError: %s" % ( + json.dumps(data, indent=4) if isinstance(data, dict) else data, + json.dumps(e.messages, indent=4), + ) + raise cls._create_validation_error( + message=msg, + no_personal_data_message=str(e), + ) from e + + @classmethod + # pylint: disable-next=docstring-missing-param + def _create_schema_for_validation(cls, context: typing.Any) -> Schema: + """Create a schema of the resource with specific context. Should be overridden by subclass. + + :return: The schema of the resource. + :rtype: Schema. + """ + raise NotImplementedError() + + def _default_context(self) -> dict: + """Get the default context for schema validation. Should be overridden by subclass. + + :return: The default context for schema validation + :rtype: dict + """ + raise NotImplementedError() + + @property + def _schema_for_validation(self) -> Schema: + """Return the schema of this Resource with default context. Do not override this method. + Override _create_schema_for_validation instead. + + :return: The schema of the resource. + :rtype: Schema. + """ + return self._create_schema_for_validation(context=self._default_context()) + + def _dump_for_validation(self) -> typing.Dict: + """Convert the resource to a dictionary. + + :return: Converted dictionary + :rtype: typing.Dict + """ + res: dict = self._schema_for_validation.dump(self) + return res + + @classmethod + def _create_validation_error(cls, message: str, no_personal_data_message: str) -> Exception: + """The function to create the validation exception to raise in _try_raise and _validate when + raise_error is True. + + Should be overridden by subclass. + + :param message: The error message containing detailed information + :type message: str + :param no_personal_data_message: The error message without personal data + :type no_personal_data_message: str + :return: The validation exception to raise + :rtype: Exception + """ + raise NotImplementedError() + + @classmethod + def _try_raise( + cls, validation_result: MutableValidationResult, *, raise_error: typing.Optional[bool] = True + ) -> MutableValidationResult: + return validation_result.try_raise(raise_error=raise_error, error_func=cls._create_validation_error) + + def _validate(self, raise_error: typing.Optional[bool] = False) -> MutableValidationResult: + """Validate the resource. If raise_error is True, raise ValidationError if validation fails and log warnings if + applicable; Else, return the validation result. + + :param raise_error: Whether to raise ValidationError if validation fails. + :type raise_error: bool + :return: The validation result + :rtype: MutableValidationResult + """ + result = self.__schema_validate() + result.merge_with(self._customized_validate()) + return self._try_raise(result, raise_error=raise_error) + + def _customized_validate(self) -> MutableValidationResult: + """Validate the resource with customized logic. + + Override this method to add customized validation logic. + + :return: The customized validation result + :rtype: MutableValidationResult + """ + return self._create_empty_validation_result() + + @classmethod + def _get_skip_fields_in_schema_validation( + cls, + ) -> typing.List[str]: + """Get the fields that should be skipped in schema validation. + + Override this method to add customized validation logic. + + :return: The fields to skip in schema validation + :rtype: typing.List[str] + """ + return [] + + def __schema_validate(self) -> MutableValidationResult: + """Validate the resource with the schema. + + :return: The validation result + :rtype: MutableValidationResult + """ + data = self._dump_for_validation() + messages = self._schema_for_validation.validate(data) + for skip_field in self._get_skip_fields_in_schema_validation(): + if skip_field in messages: + del messages[skip_field] + return ValidationResultBuilder.from_validation_messages(messages, data=data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/_ai_workspaces/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/_ai_workspaces/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/_ai_workspaces/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/_ai_workspaces/_constants.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/_ai_workspaces/_constants.py new file mode 100644 index 00000000..1e75a1c2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/_ai_workspaces/_constants.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +ENDPOINT_AI_SERVICE_KIND = "AIServices" diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/_ai_workspaces/capability_host.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/_ai_workspaces/capability_host.py new file mode 100644 index 00000000..f86ea8ed --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/_ai_workspaces/capability_host.py @@ -0,0 +1,187 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import os +from os import PathLike +from typing import ( + List, + Optional, + Union, + IO, + Any, + AnyStr, + Dict, +) +from pathlib import Path +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.entities._resource import Resource +from azure.ai.ml.constants._workspace import CapabilityHostKind +from azure.ai.ml.constants._common import ( + BASE_PATH_CONTEXT_KEY, + PARAMS_OVERRIDE_KEY, +) + +from azure.ai.ml._schema.workspace.ai_workspaces.capability_host import ( + CapabilityHostSchema, +) +from azure.ai.ml._utils.utils import dump_yaml_to_file +from azure.ai.ml.entities._util import load_from_dict +from azure.ai.ml._restclient.v2024_10_01_preview.models._models_py3 import ( + CapabilityHost as RestCapabilityHost, +) +from azure.ai.ml._restclient.v2024_10_01_preview.models._models_py3 import ( + CapabilityHostProperties as RestCapabilityHostProperties, +) + + +@experimental +class CapabilityHost(Resource): + """Initialize a CapabilityHost instance. + Capabilityhost management is controlled by MLClient's capabilityhosts operations. + + :param name: The name of the capability host. + :type name: str + :param description: The description of the capability host. + :type description: Optional[str] + :param vector_store_connections: A list of vector store (AI Search) connections. + :type vector_store_connections: Optional[List[str]] + :param ai_services_connections: A list of OpenAI service connection. + :type ai_services_connections: Optional[List[str]] + :param storage_connections: A list of storage connections. Default storage connection value is + projectname/workspaceblobstore for project workspace. + :type storage_connections: Optional[List[str]] + :param capability_host_kind: The kind of capability host, either as a string or CapabilityHostKind enum. + Default is AGENTS. + :type capability_host_kind: Union[str, CapabilityHostKind] + :param kwargs: Additional keyword arguments. + :type kwargs: Any + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_capability_host.py + :start-after: [START capability_host_object_create] + :end-before: [END capability_host_object_create] + :language: python + :dedent: 8 + :caption: Create a CapabilityHost object. + """ + + def __init__( + self, + *, + name: str, + description: Optional[str] = None, + vector_store_connections: Optional[List[str]] = None, + ai_services_connections: Optional[List[str]] = None, + storage_connections: Optional[List[str]] = None, + capability_host_kind: Union[str, CapabilityHostKind] = CapabilityHostKind.AGENTS, + **kwargs: Any, + ): + super().__init__(name=name, description=description, **kwargs) + self.capability_host_kind = capability_host_kind + self.ai_services_connections = ai_services_connections + self.storage_connections = storage_connections + self.vector_store_connections = vector_store_connections + + def dump( + self, + dest: Optional[Union[str, PathLike, IO[AnyStr]]], + **kwargs: Any, + ) -> None: + """Dump the CapabilityHost content into a file in yaml format. + + :param dest: The destination to receive this CapabilityHost's content. + Must be either a path to a local file, or an already-open file stream. + If dest is a file path, a new file will be created, + and an exception is raised if the file exists. + If dest is an open file, the file will be written to directly, + and an exception will be raised if the file is not writable. + :type dest: Union[PathLike, str, IO[AnyStr]] + """ + path = kwargs.pop("path", None) + yaml_serialized = self._to_dict() + dump_yaml_to_file(dest, yaml_serialized, default_flow_style=False, path=path, **kwargs) + + def _to_dict(self) -> Dict: + """Dump the object into a dictionary. + + :return: Dictionary representation of the object. + :rtype: Dict + """ + + return CapabilityHostSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + + @classmethod + def _load( + cls, + data: Optional[dict] = None, + yaml_path: Optional[Union[os.PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "CapabilityHost": + """Load a capabilityhost object from a yaml file. + + :param cls: Indicates that this is a class method. + :type cls: class + :param data: Data Dictionary, defaults to None + :type data: Dict + :param yaml_path: YAML Path, defaults to None + :type yaml_path: Union[PathLike, str] + :param params_override: Fields to overwrite on top of the yaml file. + Format is [{"field1": "value1"}, {"field2": "value2"}], defaults to None + :type params_override: List[Dict] + :raises Exception: An exception + :return: Loaded CapabilityHost object. + :rtype: ~azure.ai.ml.entities._workspace._ai_workspaces.capability_host.CapabilityHost + """ + params_override = params_override or [] + data = data or {} + context = { + BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"), + PARAMS_OVERRIDE_KEY: params_override, + } + return cls(**load_from_dict(CapabilityHostSchema, data, context, **kwargs)) + + @classmethod + def _from_rest_object(cls, rest_obj: RestCapabilityHost) -> "CapabilityHost": + """Convert a REST object into a CapabilityHost object. + + :param cls: Indicates that this is a class method. + :type cls: class + :param rest_obj: The REST object to convert. + :type rest_obj: ~azure.ai.ml._restclient.v2024_10_01_preview.models._models_py3.CapabilityHost + :return: CapabilityHost object. + :rtype: ~azure.ai.ml.entities._workspace._ai_workspaces.capability_host.CapabilityHost + """ + capability_host = cls( + name=str(rest_obj.name), + description=(rest_obj.properties.description if rest_obj.properties else None), + ai_services_connections=(rest_obj.properties.ai_services_connections if rest_obj.properties else None), + storage_connections=(rest_obj.properties.storage_connections if rest_obj.properties else None), + vector_store_connections=(rest_obj.properties.vector_store_connections if rest_obj.properties else None), + capability_host_kind=( + rest_obj.properties.capability_host_kind if rest_obj.properties else CapabilityHostKind.AGENTS + ), + ) + return capability_host + + def _to_rest_object(self) -> RestCapabilityHost: + """ + Convert the CapabilityHost instance to a RestCapabilityHost object. + + :return: A RestCapabilityHost object representing the capability host for a Hub or Project workspace. + :rtype: azure.ai.ml._restclient.v2024_10_01_preview.models._models_py3.CapabilityHost + """ + + properties = RestCapabilityHostProperties( + ai_services_connections=self.ai_services_connections, + storage_connections=self.storage_connections, + vector_store_connections=self.vector_store_connections, + description=self.description, + capability_host_kind=self.capability_host_kind, + ) + resource = RestCapabilityHost( + properties=properties, + ) + return resource diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/_ai_workspaces/hub.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/_ai_workspaces/hub.py new file mode 100644 index 00000000..4caac057 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/_ai_workspaces/hub.py @@ -0,0 +1,220 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Any, Dict, List, Optional + +from azure.ai.ml._restclient.v2024_10_01_preview.models import Workspace as RestWorkspace +from azure.ai.ml._restclient.v2024_10_01_preview.models import WorkspaceHubConfig as RestWorkspaceHubConfig +from azure.ai.ml._schema.workspace import HubSchema +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.constants._common import WorkspaceKind +from azure.ai.ml.entities._credentials import IdentityConfiguration +from azure.ai.ml.entities._workspace.customer_managed_key import CustomerManagedKey +from azure.ai.ml.entities._workspace.network_acls import NetworkAcls +from azure.ai.ml.entities._workspace.networking import ManagedNetwork +from azure.ai.ml.entities._workspace.workspace import Workspace + + +@experimental +class Hub(Workspace): + """A Hub is a special type of workspace that acts as a parent and resource container for lightweight child + workspaces called projects. Resources like the hub's storage account, key vault, + and container registry are shared by all child projects. + + As a type of workspace, hub management is controlled by an MLClient's workspace operations. + + :param name: Name of the hub. + :type name: str + :param description: Description of the hub. + :type description: str + :param tags: Tags of the hub. + :type tags: dict + :param display_name: Display name for the hub. This is non-unique within the resource group. + :type display_name: str + :param location: The location to create the hub in. + If not specified, the same location as the resource group will be used. + :type location: str + :param resource_group: Name of resource group to create the hub in. + :type resource_group: str + :param managed_network: Hub's Managed Network configuration + :type managed_network: ~azure.ai.ml.entities.ManagedNetwork + :param storage_account: The resource ID of an existing storage account to use instead of creating a new one. + :type storage_account: str + :param key_vault: The resource ID of an existing key vault to use instead of creating a new one. + :type key_vault: str + :param container_registry: The resource ID of an existing container registry + to use instead of creating a new one. + :type container_registry: str + :param customer_managed_key: Key vault details for encrypting data with customer-managed keys. + If not specified, Microsoft-managed keys will be used by default. + :type customer_managed_key: ~azure.ai.ml.entities.CustomerManagedKey + :param image_build_compute: The name of the compute target to use for building environment. + Docker images with the container registry is behind a VNet. + :type image_build_compute: str + :param public_network_access: Whether to allow public endpoint connectivity. + when a workspace is private link enabled. + :type public_network_access: str + :param network_acls: The network access control list (ACL) settings of the workspace. + :type network_acls: ~azure.ai.ml.entities.NetworkAcls + :param identity: The hub's Managed Identity (user assigned, or system assigned). + :type identity: ~azure.ai.ml.entities.IdentityConfiguration + :param primary_user_assigned_identity: The hub's primary user assigned identity. + :type primary_user_assigned_identity: str + :param enable_data_isolation: A flag to determine if workspace has data isolation enabled. + The flag can only be set at the creation phase, it can't be updated. + :type enable_data_isolation: bool + :param default_resource_group: The resource group that will be used by projects + created under this hub if no resource group is specified. + :type default_resource_group: str + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + + .. literalinclude:: ../samples/ml_samples_workspace.py + :start-after: [START workspace_hub] + :end-before: [END workspace_hub] + :language: python + :dedent: 8 + :caption: Creating a Hub object. + """ + + # The field 'additional_workspace_storage_accounts' exists in the API but is currently unused. + + def __init__( + self, + *, + name: str, + description: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + display_name: Optional[str] = None, + location: Optional[str] = None, + resource_group: Optional[str] = None, + managed_network: Optional[ManagedNetwork] = None, + storage_account: Optional[str] = None, + key_vault: Optional[str] = None, + container_registry: Optional[str] = None, + customer_managed_key: Optional[CustomerManagedKey] = None, + public_network_access: Optional[str] = None, + network_acls: Optional[NetworkAcls] = None, + identity: Optional[IdentityConfiguration] = None, + primary_user_assigned_identity: Optional[str] = None, + enable_data_isolation: bool = False, + default_resource_group: Optional[str] = None, + associated_workspaces: Optional[List[str]] = None, # hidden input for rest->client conversions. + **kwargs: Any, + ): + self._workspace_id = kwargs.pop("workspace_id", "") + # Ensure user can't overwrite/double input kind. + kwargs.pop("kind", None) + super().__init__( + name=name, + description=description, + tags=tags, + kind=WorkspaceKind.HUB, + display_name=display_name, + location=location, + storage_account=storage_account, + key_vault=key_vault, + container_registry=container_registry, + resource_group=resource_group, + customer_managed_key=customer_managed_key, + public_network_access=public_network_access, + network_acls=network_acls, + identity=identity, + primary_user_assigned_identity=primary_user_assigned_identity, + managed_network=managed_network, + enable_data_isolation=enable_data_isolation, + **kwargs, + ) + self._default_resource_group = default_resource_group + self._associated_workspaces = associated_workspaces + + @classmethod + def _get_schema_class(cls): + return HubSchema + + @classmethod + def _from_rest_object(cls, rest_obj: RestWorkspace, v2_service_context: Optional[object] = None) -> Optional["Hub"]: + if not rest_obj: + return None + + workspace_object = Workspace._from_rest_object(rest_obj, v2_service_context) + + default_resource_group = None + + if hasattr(rest_obj, "workspace_hub_config"): + if rest_obj.workspace_hub_config and isinstance(rest_obj.workspace_hub_config, RestWorkspaceHubConfig): + default_resource_group = rest_obj.workspace_hub_config.default_workspace_resource_group + + if workspace_object is not None: + return Hub( + name=workspace_object.name if workspace_object.name is not None else "", + description=workspace_object.description, + tags=workspace_object.tags, + display_name=workspace_object.display_name, + location=workspace_object.location, + resource_group=workspace_object.resource_group, + managed_network=workspace_object.managed_network, + customer_managed_key=workspace_object.customer_managed_key, + public_network_access=workspace_object.public_network_access, + network_acls=workspace_object.network_acls, + identity=workspace_object.identity, + primary_user_assigned_identity=workspace_object.primary_user_assigned_identity, + storage_account=rest_obj.storage_account, + key_vault=rest_obj.key_vault, + container_registry=rest_obj.container_registry, + workspace_id=rest_obj.workspace_id, + enable_data_isolation=rest_obj.enable_data_isolation, + default_resource_group=default_resource_group, + associated_workspaces=rest_obj.associated_workspaces if rest_obj.associated_workspaces else [], + id=rest_obj.id, + ) + return None + + # Helper function to deal with sub-rest object conversion. + def _hub_values_to_rest_object(self) -> RestWorkspaceHubConfig: + additional_workspace_storage_accounts = None + default_resource_group = None + if hasattr(self, "additional_workspace_storage_accounts"): + additional_workspace_storage_accounts = None + if hasattr(self, "default_resource_group"): + default_resource_group = None + return RestWorkspaceHubConfig( + additional_workspace_storage_accounts=additional_workspace_storage_accounts, + default_workspace_resource_group=default_resource_group, + ) + + def _to_rest_object(self) -> RestWorkspace: + restWorkspace = super()._to_rest_object() + restWorkspace.workspace_hub_config = self._hub_values_to_rest_object() + return restWorkspace + + @property + def default_resource_group(self) -> Optional[str]: + """The default resource group for this hub and its children. + + :return: The resource group. + :rtype: Optional[str] + """ + return self._default_resource_group + + @default_resource_group.setter + def default_resource_group(self, value: str): + """Set the default resource group for child projects of this hub. + + :param value: The new resource group. + :type value: str + """ + if not value: + return + self._default_resource_group = value + + # No setter, read-only + @property + def associated_workspaces(self) -> Optional[List[str]]: + """The workspaces associated with the hub. + + :return: The resource group. + :rtype: Optional[List[str]] + """ + return self._associated_workspaces diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/_ai_workspaces/project.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/_ai_workspaces/project.py new file mode 100644 index 00000000..ffad4922 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/_ai_workspaces/project.py @@ -0,0 +1,89 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +from typing import Any, Dict, Optional + +from azure.ai.ml._schema.workspace import ProjectSchema +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.constants._common import WorkspaceKind +from azure.ai.ml.entities._workspace.workspace import Workspace + + +# Effectively a lightweight wrapper around a v2 SDK workspace +@experimental +class Project(Workspace): + """A Project is a lightweight object for orchestrating AI applications, and is parented by a hub. + Unlike a standard workspace, a project does not have a variety of sub-resources directly associated with it. + Instead, its parent hub managed these resources, which are then used by the project and its siblings. + + As a type of workspace, project management is controlled by an MLClient's workspace operations. + + :param name: The name of the project. + :type name: str + :param hub_id: The hub parent of the project, as a resource ID. + :type hub_id: str + :param description: The description of the project. + :type description: Optional[str] + :param tags: Tags associated with the project. + :type tags: Optional[Dict[str, str]] + :param display_name: The display name of the project. + :type display_name: Optional[str] + :param location: The location of the project. Must match that of the parent hub + and is automatically assigned to match the parent hub's location during creation. + :type location: Optional[str] + :param resource_group: The project's resource group name. + :type resource_group: Optional[str] + """ + + def __init__( + self, + *, + name: str, + hub_id: str, + description: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + display_name: Optional[str] = None, + location: Optional[str] = None, + resource_group: Optional[str] = None, + **kwargs, + ) -> None: + # Ensure user can't overwrite/double input kind. + kwargs.pop("kind", None) + super().__init__( + name=name, + description=description, + tags=tags, + kind=WorkspaceKind.PROJECT, + display_name=display_name, + location=location, + resource_group=resource_group, + hub_id=hub_id, + **kwargs, + ) + + @classmethod + def _get_schema_class(cls) -> Any: + return ProjectSchema + + @property + def hub_id(self) -> str: + """The UID of the hub parent of the project. + + :return: Resource ID of the parent hub. + :rtype: str + """ + return self._hub_id if self._hub_id else "" + + @hub_id.setter + def hub_id(self, value: str): + """Set the parent hub id of the project. + + :param value: The hub id to assign to the project. + Note: cannot be reassigned after creation. + :type value: str + """ + if not value: + return + self._hub_id = value diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/compute_runtime.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/compute_runtime.py new file mode 100644 index 00000000..bc7ee127 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/compute_runtime.py @@ -0,0 +1,41 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Optional + +from azure.ai.ml._restclient.v2023_06_01_preview.models import ComputeRuntimeDto as RestComputeRuntimeDto +from azure.ai.ml.entities._mixins import RestTranslatableMixin + + +class ComputeRuntime(RestTranslatableMixin): + """Spark compute runtime configuration. + + :keyword spark_runtime_version: Spark runtime version. + :paramtype spark_runtime_version: Optional[str] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_compute.py + :start-after: [START compute_runtime] + :end-before: [END compute_runtime] + :language: python + :dedent: 8 + :caption: Creating a ComputeRuntime object. + """ + + def __init__( + self, + *, + spark_runtime_version: Optional[str] = None, + ) -> None: + self.spark_runtime_version = spark_runtime_version + + def _to_rest_object(self) -> RestComputeRuntimeDto: + return RestComputeRuntimeDto(spark_runtime_version=self.spark_runtime_version) + + @classmethod + def _from_rest_object(cls, obj: RestComputeRuntimeDto) -> Optional["ComputeRuntime"]: + if not obj: + return None + return ComputeRuntime(spark_runtime_version=obj.spark_runtime_version) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/connections/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/connections/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/connections/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/connections/connection_subtypes.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/connections/connection_subtypes.py new file mode 100644 index 00000000..d97e513e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/connections/connection_subtypes.py @@ -0,0 +1,748 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import re +from typing import Any, Dict, List, Optional, Type, Union + +from azure.ai.ml._restclient.v2024_04_01_preview.models import ConnectionCategory +from azure.ai.ml._schema.workspace.connections.connection_subtypes import ( + APIKeyConnectionSchema, + AzureAISearchConnectionSchema, + AzureAIServicesConnectionSchema, + AzureBlobStoreConnectionSchema, + AzureContentSafetyConnectionSchema, + AzureOpenAIConnectionSchema, + AzureSpeechServicesConnectionSchema, + MicrosoftOneLakeConnectionSchema, + OpenAIConnectionSchema, + SerpConnectionSchema, + ServerlessConnectionSchema, +) +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.constants._common import ( + CONNECTION_ACCOUNT_NAME_KEY, + CONNECTION_API_TYPE_KEY, + CONNECTION_API_VERSION_KEY, + CONNECTION_CONTAINER_NAME_KEY, + CONNECTION_KIND_KEY, + CONNECTION_RESOURCE_ID_KEY, + CognitiveServiceKinds, + ConnectionTypes, +) +from azure.ai.ml.entities._credentials import AadCredentialConfiguration, ApiKeyConfiguration + +from .one_lake_artifacts import OneLakeConnectionArtifact +from .workspace_connection import WorkspaceConnection + +# Dev notes: Any new classes require modifying the elif chains in the following functions in the +# WorkspaceConnection parent class: _from_rest_object, _get_entity_class_from_type, _get_schema_class_from_type + + +class AzureBlobStoreConnection(WorkspaceConnection): + """A connection to an Azure Blob Store. + + :param name: Name of the connection. + :type name: str + :param url: The URL or ARM resource ID of the external resource. + :type url: str + :param container_name: The name of the container. + :type container_name: str + :param account_name: The name of the account. + :type account_name: str + :param credentials: The credentials for authenticating to the blob store. This type of + connection accepts 3 types of credentials: account key and SAS token credentials, + or NoneCredentialConfiguration for credential-less connections. + :type credentials: Union[ + ~azure.ai.ml.entities.AccountKeyConfiguration, + ~azure.ai.ml.entities.SasTokenConfiguration, + ~azure.ai.ml.entities.AadCredentialConfiguration, + ] + :param metadata: Metadata dictionary. + :type metadata: Optional[dict[str,str]] + """ + + def __init__( + self, + *, + url: str, + container_name: str, + account_name: str, + metadata: Optional[Dict[Any, Any]] = None, + **kwargs, + ): + kwargs.pop("type", None) # make sure we never somehow use wrong type + # Blob store connections returned from the API generally have no credentials, but we still don't want + # to silently run over user inputted connections if they want to play with them locally, so double-check + # kwargs for them. + if metadata is None: + metadata = {} + metadata[CONNECTION_CONTAINER_NAME_KEY] = container_name + metadata[CONNECTION_ACCOUNT_NAME_KEY] = account_name + + super().__init__( + url=url, + type=camel_to_snake(ConnectionCategory.AZURE_BLOB), + from_child=True, + metadata=metadata, + **kwargs, + ) + + @classmethod + def _get_required_metadata_fields(cls) -> List[str]: + return [CONNECTION_CONTAINER_NAME_KEY, CONNECTION_ACCOUNT_NAME_KEY] + + @classmethod + def _get_schema_class(cls) -> Type: + return AzureBlobStoreConnectionSchema + + @property + def container_name(self) -> Optional[str]: + """The name of the connection's container. + + :return: The name of the container. + :rtype: Optional[str] + """ + if self.metadata is not None: + return self.metadata.get(CONNECTION_CONTAINER_NAME_KEY, None) + return None + + @container_name.setter + def container_name(self, value: str) -> None: + """Set the container name of the connection. + + :param value: The new container name to set. + :type value: str + """ + if self.metadata is None: + self.metadata = {} + self.metadata[CONNECTION_CONTAINER_NAME_KEY] = value + + @property + def account_name(self) -> Optional[str]: + """The name of the connection's account + + :return: The name of the account. + :rtype: Optional[str] + """ + if self.metadata is not None: + return self.metadata.get(CONNECTION_ACCOUNT_NAME_KEY, None) + return None + + @account_name.setter + def account_name(self, value: str) -> None: + """Set the account name of the connection. + + :param value: The new account name to set. + :type value: str + """ + if self.metadata is None: + self.metadata = {} + self.metadata[CONNECTION_ACCOUNT_NAME_KEY] = value + + +# Dev note: One lake connections are unfortunately unique in that it's extremely +# difficult for customers to find out what the target for their system ought to be. +# Due to this, we construct the target internally by composing more inputs +# that are more user-accessible. +class MicrosoftOneLakeConnection(WorkspaceConnection): + """A connection to a Microsoft One Lake. Connections of this type + are further specified by their artifact class type, although + the number of artifact classes is currently limited. + + :param name: Name of the connection. + :type name: str + :param endpoint: The endpoint of the connection. + :type endpoint: str + :param artifact: The artifact class used to further specify the connection. + :type artifact: Optional[~azure.ai.ml.entities.OneLakeArtifact] + :param one_lake_workspace_name: The name, not ID, of the workspace where the One Lake + resource lives. + :type one_lake_workspace_name: Optional[str] + :param credentials: The credentials for authenticating to the blob store. This type of + connection accepts 3 types of credentials: account key and SAS token credentials, + or NoneCredentialConfiguration for credential-less connections. + :type credentials: Union[ + ~azure.ai.ml.entities.AccessKeyConfiguration, + ~azure.ai.ml.entities.SasTokenConfiguration, + ~azure.ai.ml.entities.AadCredentialConfiguration, + ] + :param metadata: Metadata dictionary. + :type metadata: Optional[dict[str,str]] + """ + + def __init__( + self, + *, + endpoint: str, + artifact: Optional[OneLakeConnectionArtifact] = None, + one_lake_workspace_name: Optional[str] = None, + metadata: Optional[Dict[Any, Any]] = None, + **kwargs, + ): + kwargs.pop("type", None) # make sure we never somehow use wrong type + + # Allow target to be inputted for from-rest conversions where we don't + # need to worry about data-availability nonsense. + target = kwargs.pop("target", None) + if target is None: + if artifact is None: + raise ValueError("If target is unset, then artifact must be set") + if endpoint is None: + raise ValueError("If target is unset, then endpoint must be set") + if one_lake_workspace_name is None: + raise ValueError("If target is unset, then one_lake_workspace_name must be set") + target = MicrosoftOneLakeConnection._construct_target(endpoint, one_lake_workspace_name, artifact) + super().__init__( + target=target, + type=camel_to_snake(ConnectionCategory.AZURE_ONE_LAKE), + from_child=True, + metadata=metadata, + **kwargs, + ) + + @classmethod + def _get_schema_class(cls) -> Type: + return MicrosoftOneLakeConnectionSchema + + # Target is constructed from user inputs, because it's apparently very difficult for users to + # directly access a One Lake's target URL. + @classmethod + def _construct_target(cls, endpoint: str, workspace: str, artifact: OneLakeConnectionArtifact) -> str: + artifact_name = artifact.name + # If an id is supplied, the format is different + if re.match(".{7}-.{4}-.{4}-.{4}.{12}", artifact_name): + return f"https://{endpoint}/{workspace}/{artifact_name}" + return f"https://{endpoint}/{workspace}/{artifact_name}.Lakehouse" + + +# There are enough types of connections that their only accept an api key credential, +# or just an api key credential or no credentials, that it merits a parent class for +# all of them. One that's slightly more specific than the base Connection. +# This file contains that parent class, as well as all of its children. +# Not experimental since users should never see this, +# No need to add an extra warning. +class ApiOrAadConnection(WorkspaceConnection): + """Internal parent class for all connections that accept either an api key or + entra ID as credentials. Entra ID credentials are implicitly assumed if no api key is provided. + + :param name: Name of the connection. + :type name: str + :param target: The URL or ARM resource ID of the external resource. + :type target: str + :param api_key: The api key to connect to the azure endpoint. + If unset, tries to use the user's Entra ID as credentials instead. + :type api_key: Optional[str] + :param api_version: The api version that this connection was created for. + :type api_version: Optional[str] + :param type: The type of the connection. + :type type: str + :param allow_entra: Whether or not this connection allows initialization without + an API key via Aad. Defaults to True. + :type allow_entra: bool + """ + + def __init__( + self, + *, + api_key: Optional[str] = None, + allow_entra: bool = True, + type: str, + metadata: Optional[Dict[Any, Any]] = None, + **kwargs: Any, + ): + # See if credentials directly inputted via kwargs + credentials: Union[AadCredentialConfiguration, ApiKeyConfiguration] = kwargs.pop( + "credentials", AadCredentialConfiguration() + ) + # Replace anything that isn't an API credential with an AAD credential. + # Importantly, this replaced the None credential default from the parent YAML schema. + if not isinstance(credentials, ApiKeyConfiguration): + credentials = AadCredentialConfiguration() + # Further replace that if a key is provided + if api_key: + credentials = ApiKeyConfiguration(key=api_key) + elif not allow_entra and isinstance(credentials, AadCredentialConfiguration): + # If no creds are provided in any capacity when needed. complain. + raise ValueError("This connection type must set the api_key value.") + + super().__init__( + type=type, + credentials=credentials, + metadata=metadata, + **kwargs, + ) + + @property + def api_key(self) -> Optional[str]: + """The API key of the connection. + + :return: The API key of the connection. + :rtype: Optional[str] + """ + if isinstance(self._credentials, ApiKeyConfiguration): + return self._credentials.key + return None + + @api_key.setter + def api_key(self, value: str) -> None: + """Set the API key of the connection. Setting this to None will + cause the connection to use the user's Entra ID as credentials. + + :param value: The new API key to set. + :type value: str + """ + if value is None: + self._credentials = AadCredentialConfiguration() + else: + self._credentials = ApiKeyConfiguration(key=value) + + +@experimental +class AzureOpenAIConnection(ApiOrAadConnection): + """A Connection that is specifically designed for handling connections + to Azure Open AI. + + :param name: Name of the connection. + :type name: str + :param azure_endpoint: The URL or ARM resource ID of the Azure Open AI Resource. + :type azure_endpoint: str + :param api_key: The api key to connect to the azure endpoint. + If unset, tries to use the user's Entra ID as credentials instead. + :type api_key: Optional[str] + :param open_ai_resource_id: The fully qualified ID of the Azure Open AI resource to connect to. + :type open_ai_resource_id: Optional[str] + :param api_version: The api version that this connection was created for. + :type api_version: Optional[str] + :param metadata: Metadata dictionary. + :type metadata: Optional[dict[str,str]] + """ + + def __init__( + self, + *, + azure_endpoint: str, + api_key: Optional[str] = None, + api_version: Optional[str] = None, + api_type: str = "Azure", # Required API input, hidden to allow for rare overrides + open_ai_resource_id: Optional[str] = None, + metadata: Optional[Dict[Any, Any]] = None, + **kwargs: Any, + ): + kwargs.pop("type", None) # make sure we never somehow use wrong type + # Sneak in resource ID as it's inputted from rest conversions as a kwarg. + from_rest_resource_id = kwargs.pop("resource_id", None) + if open_ai_resource_id is None and from_rest_resource_id is not None: + open_ai_resource_id = from_rest_resource_id + + if metadata is None: + metadata = {} + metadata[CONNECTION_API_VERSION_KEY] = api_version + metadata[CONNECTION_API_TYPE_KEY] = api_type + metadata[CONNECTION_RESOURCE_ID_KEY] = open_ai_resource_id + + super().__init__( + azure_endpoint=azure_endpoint, + api_key=api_key, + type=camel_to_snake(ConnectionCategory.AZURE_OPEN_AI), + from_child=True, + metadata=metadata, + **kwargs, + ) + + @classmethod + def _get_required_metadata_fields(cls) -> List[str]: + return [CONNECTION_API_VERSION_KEY, CONNECTION_API_TYPE_KEY, CONNECTION_RESOURCE_ID_KEY] + + @classmethod + def _get_schema_class(cls) -> Type: + return AzureOpenAIConnectionSchema + + @property + def api_version(self) -> Optional[str]: + """The API version of the connection. + + :return: The API version of the connection. + :rtype: Optional[str] + """ + if self.metadata is not None and CONNECTION_API_VERSION_KEY in self.metadata: + res: str = self.metadata[CONNECTION_API_VERSION_KEY] + return res + return None + + @api_version.setter + def api_version(self, value: str) -> None: + """Set the API version of the connection. + + :param value: The new api version to set. + :type value: str + """ + if not hasattr(self, "metadata") or self.metadata is None: + self.metadata = {} + self.metadata[CONNECTION_API_VERSION_KEY] = value + + @property + def open_ai_resource_id(self) -> Optional[str]: + """The fully qualified ID of the Azure Open AI resource this connects to. + + :return: The fully qualified ID of the Azure Open AI resource this connects to. + :rtype: Optional[str] + """ + if self.metadata is not None and CONNECTION_RESOURCE_ID_KEY in self.metadata: + res: str = self.metadata[CONNECTION_RESOURCE_ID_KEY] + return res + return None + + @open_ai_resource_id.setter + def open_ai_resource_id(self, value: Optional[str]) -> None: + """Set the fully qualified ID of the Azure Open AI resource to connect to. + + :param value: The new resource id to set. + :type value: Optional[str] + """ + if not hasattr(self, "metadata") or self.metadata is None: + self.metadata = {} + if value is None: + self.metadata.pop(CONNECTION_RESOURCE_ID_KEY, None) + return + self.metadata[CONNECTION_RESOURCE_ID_KEY] = value + + +@experimental +class AzureAIServicesConnection(ApiOrAadConnection): + """A Connection geared towards Azure AI services. + + :param name: Name of the connection. + :type name: str + :param endpoint: The URL or ARM resource ID of the external resource. + :type endpoint: str + :param api_key: The api key to connect to the azure endpoint. + If unset, tries to use the user's Entra ID as credentials instead. + :type api_key: Optional[str] + :param ai_services_resource_id: The fully qualified ID of the Azure AI service resource to connect to. + :type ai_services_resource_id: str + :param metadata: Metadata dictionary. + :type metadata: Optional[dict[str,str]] + """ + + def __init__( + self, + *, + endpoint: str, + api_key: Optional[str] = None, + ai_services_resource_id: str, + metadata: Optional[Dict[Any, Any]] = None, + **kwargs: Any, + ): + kwargs.pop("type", None) # make sure we never somehow use wrong type + if metadata is None: + metadata = {} + metadata[CONNECTION_RESOURCE_ID_KEY] = ai_services_resource_id + super().__init__( + endpoint=endpoint, + api_key=api_key, + type=ConnectionTypes.AZURE_AI_SERVICES, + from_child=True, + metadata=metadata, + **kwargs, + ) + + @classmethod + def _get_schema_class(cls) -> Type: + return AzureAIServicesConnectionSchema + + @classmethod + def _get_required_metadata_fields(cls) -> List[str]: + return [CONNECTION_RESOURCE_ID_KEY] + + @property + def ai_services_resource_id(self) -> Optional[str]: + """The resource id of the ai service being connected to. + + :return: The resource id of the ai service being connected to. + :rtype: Optional[str] + """ + if self.metadata is not None and CONNECTION_RESOURCE_ID_KEY in self.metadata: + res: str = self.metadata[CONNECTION_RESOURCE_ID_KEY] + return res + return None + + @ai_services_resource_id.setter + def ai_services_resource_id(self, value: str) -> None: + """Set the ai service resource id of the connection. + + :param value: The new ai service resource id to set. + :type value: str + """ + if not hasattr(self, "metadata") or self.metadata is None: + self.metadata = {} + self.metadata[CONNECTION_RESOURCE_ID_KEY] = value + + +class AzureAISearchConnection(ApiOrAadConnection): + """A Connection that is specifically designed for handling connections to + Azure AI Search. + + :param name: Name of the connection. + :type name: str + :param endpoint: The URL or ARM resource ID of the Azure AI Search Service + :type endpoint: str + :param api_key: The API key needed to connect to the Azure AI Search Service. + :type api_key: Optional[str] + :param metadata: Metadata dictionary. + :type metadata: Optional[dict[str,str]] + """ + + def __init__( + self, + *, + endpoint: str, + api_key: Optional[str] = None, + metadata: Optional[Dict[Any, Any]] = None, + **kwargs: Any, + ): + kwargs.pop("type", None) # make sure we never somehow use wrong type + + super().__init__( + endpoint=endpoint, + api_key=api_key, + type=ConnectionTypes.AZURE_SEARCH, + from_child=True, + metadata=metadata, + **kwargs, + ) + + @classmethod + def _get_schema_class(cls) -> Type: + return AzureAISearchConnectionSchema + + +class AzureContentSafetyConnection(ApiOrAadConnection): + """A Connection geared towards a Azure Content Safety service. + + :param name: Name of the connection. + :type name: str + :param endpoint: The URL or ARM resource ID of the external resource. + :type endpoint: str + :param api_key: The api key to connect to the azure endpoint. + If unset, tries to use the user's Entra ID as credentials instead. + :type api_key: Optional[str] + :param metadata: Metadata dictionary. + :type metadata: Optional[dict[str,str]] + """ + + def __init__( + self, + *, + endpoint: str, + api_key: Optional[str] = None, + metadata: Optional[Dict[Any, Any]] = None, + **kwargs: Any, + ): + kwargs.pop("type", None) # make sure we never somehow use wrong type + + if metadata is None: + metadata = {} + metadata[CONNECTION_KIND_KEY] = CognitiveServiceKinds.CONTENT_SAFETY + + super().__init__( + endpoint=endpoint, + api_key=api_key, + type=ConnectionTypes.AZURE_CONTENT_SAFETY, + from_child=True, + metadata=metadata, + **kwargs, + ) + + @classmethod + def _get_schema_class(cls) -> Type: + return AzureContentSafetyConnectionSchema + + +class AzureSpeechServicesConnection(ApiOrAadConnection): + """A Connection geared towards an Azure Speech service. + + :param name: Name of the connection. + :type name: str + :param endpoint: The URL or ARM resource ID of the external resource. + :type endpoint: str + :param api_key: The api key to connect to the azure endpoint. + If unset, tries to use the user's Entra ID as credentials instead. + :type api_key: Optional[str] + :param metadata: Metadata dictionary. + :type metadata: Optional[dict[str,str]] + """ + + # kinds AzureOpenAI", "ContentSafety", and "Speech" + + def __init__( + self, + *, + endpoint: str, + api_key: Optional[str] = None, + metadata: Optional[Dict[Any, Any]] = None, + **kwargs: Any, + ): + kwargs.pop("type", None) # make sure we never somehow use wrong type + + if metadata is None: + metadata = {} + metadata[CONNECTION_KIND_KEY] = CognitiveServiceKinds.SPEECH + super().__init__( + endpoint=endpoint, + api_key=api_key, + type=ConnectionTypes.AZURE_SPEECH_SERVICES, + from_child=True, + metadata=metadata, + **kwargs, + ) + + @classmethod + def _get_schema_class(cls) -> Type: + return AzureSpeechServicesConnectionSchema + + +@experimental +class APIKeyConnection(ApiOrAadConnection): + """A generic connection for any API key-based service. + + :param name: Name of the connection. + :type name: str + :param api_base: The URL to target with this connection. + :type api_base: str + :param api_key: The API key needed to connect to the api_base. + :type api_key: Optional[str] + :param metadata: Metadata dictionary. + :type metadata: Optional[dict[str,str]] + """ + + def __init__( + self, + *, + api_base: str, + api_key: Optional[str] = None, + metadata: Optional[Dict[Any, Any]] = None, + **kwargs, + ): + kwargs.pop("type", None) # make sure we never somehow use wrong type + super().__init__( + api_base=api_base, + api_key=api_key, + type=camel_to_snake(ConnectionCategory.API_KEY), + allow_entra=False, + from_child=True, + metadata=metadata, + **kwargs, + ) + + @classmethod + def _get_schema_class(cls) -> Type: + return APIKeyConnectionSchema + + +@experimental +class OpenAIConnection(ApiOrAadConnection): + """A connection geared towards direct connections to Open AI. + Not to be confused with the AzureOpenAIWorkspaceConnection, which is for Azure's Open AI services. + + :param name: Name of the connection. + :type name: str + :param api_key: The API key needed to connect to the Open AI. + :type api_key: Optional[str] + :param metadata: Metadata dictionary. + :type metadata: Optional[dict[str,str]] + """ + + def __init__( + self, + *, + api_key: Optional[str] = None, + metadata: Optional[Dict[Any, Any]] = None, + **kwargs, + ): + kwargs.pop("type", None) # make sure we never somehow use wrong type + super().__init__( + type=ConnectionCategory.Open_AI, + api_key=api_key, + allow_entra=False, + from_child=True, + metadata=metadata, + **kwargs, + ) + + @classmethod + def _get_schema_class(cls) -> Type: + return OpenAIConnectionSchema + + +@experimental +class SerpConnection(ApiOrAadConnection): + """A connection geared towards a Serp service (Open source search API Service) + + :param name: Name of the connection. + :type name: str + :param api_key: The API key needed to connect to the Open AI. + :type api_key: Optional[str] + :param metadata: Metadata dictionary. + :type metadata: Optional[dict[str,str]] + """ + + def __init__( + self, + *, + api_key: Optional[str] = None, + metadata: Optional[Dict[Any, Any]] = None, + **kwargs, + ): + kwargs.pop("type", None) # make sure we never somehow use wrong type + super().__init__( + type=ConnectionCategory.SERP, + api_key=api_key, + allow_entra=False, + from_child=True, + metadata=metadata, + **kwargs, + ) + + @classmethod + def _get_schema_class(cls) -> Type: + return SerpConnectionSchema + + +@experimental +class ServerlessConnection(ApiOrAadConnection): + """A connection geared towards a MaaS endpoint (Serverless). + + :param name: Name of the connection. + :type name: str + :param endpoint: The serverless endpoint. + :type endpoint: str + :param api_key: The API key needed to connect to the endpoint. + :type api_key: Optional[str] + :param metadata: Metadata dictionary. + :type metadata: Optional[dict[str,str]] + """ + + def __init__( + self, + *, + endpoint: str, + api_key: Optional[str] = None, + metadata: Optional[Dict[Any, Any]] = None, + **kwargs, + ): + kwargs.pop("type", None) # make sure we never somehow use wrong type + super().__init__( + type=ConnectionCategory.SERVERLESS, + endpoint=endpoint, + api_key=api_key, + allow_entra=False, + from_child=True, + metadata=metadata, + **kwargs, + ) + + @classmethod + def _get_schema_class(cls) -> Type: + return ServerlessConnectionSchema diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/connections/one_lake_artifacts.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/connections/one_lake_artifacts.py new file mode 100644 index 00000000..ea81602f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/connections/one_lake_artifacts.py @@ -0,0 +1,25 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Any +from azure.ai.ml._utils._experimental import experimental + +# Dev note: Supposedly there's going to be more artifact subclasses at some point. +# If/when that comes to pass, we can worry about adding polymorphism to these classes. +# For now, this is a one-off that's needed to help match the object structure that PF uses. + + +# Why is this not called a "LakeHouseArtifact"? Because despite the under-the-hood type, +# users expect this variety to be called "OneLake". +@experimental +class OneLakeConnectionArtifact: + """Artifact class used by the Connection subclass known + as a MicrosoftOneLakeConnection. Supplying this class further + specifies the connection as a Lake House connection. + """ + + # Note: Kwargs exist just to silently absorb type from schema. + def __init__(self, *, name: str, **kwargs: Any): # pylint: disable=unused-argument + self.name = name + self.type = "lake_house" diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/connections/workspace_connection.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/connections/workspace_connection.py new file mode 100644 index 00000000..ab1ee9f8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/connections/workspace_connection.py @@ -0,0 +1,677 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import warnings +from os import PathLike +from pathlib import Path +from typing import IO, Any, AnyStr, Dict, List, Optional, Type, Union, cast + + +from azure.ai.ml._restclient.v2024_04_01_preview.models import ( + WorkspaceConnectionPropertiesV2BasicResource as RestWorkspaceConnection, +) +from azure.ai.ml._restclient.v2024_04_01_preview.models import ( + ConnectionCategory, + NoneAuthTypeWorkspaceConnectionProperties, + AADAuthTypeWorkspaceConnectionProperties, +) + +from azure.ai.ml._schema.workspace.connections.workspace_connection import WorkspaceConnectionSchema +from azure.ai.ml._utils.utils import _snake_to_camel, camel_to_snake, dump_yaml_to_file +from azure.ai.ml.constants._common import ( + BASE_PATH_CONTEXT_KEY, + PARAMS_OVERRIDE_KEY, + ConnectionTypes, + CognitiveServiceKinds, + CONNECTION_KIND_KEY, + CONNECTION_RESOURCE_ID_KEY, +) +from azure.ai.ml.entities._credentials import ( + AccessKeyConfiguration, + ApiKeyConfiguration, + ManagedIdentityConfiguration, + NoneCredentialConfiguration, + PatTokenConfiguration, + SasTokenConfiguration, + ServicePrincipalConfiguration, + UsernamePasswordConfiguration, + _BaseIdentityConfiguration, + AccountKeyConfiguration, + AadCredentialConfiguration, +) +from azure.ai.ml.entities._resource import Resource +from azure.ai.ml.entities._system_data import SystemData +from azure.ai.ml.entities._util import load_from_dict + + +CONNECTION_CATEGORY_TO_CREDENTIAL_MAP = { + ConnectionCategory.AZURE_BLOB: [AccessKeyConfiguration, SasTokenConfiguration, AadCredentialConfiguration], + ConnectionTypes.AZURE_DATA_LAKE_GEN_2: [ + ServicePrincipalConfiguration, + AadCredentialConfiguration, + ManagedIdentityConfiguration, + ], + ConnectionCategory.GIT: [PatTokenConfiguration, NoneCredentialConfiguration, UsernamePasswordConfiguration], + ConnectionCategory.PYTHON_FEED: [UsernamePasswordConfiguration, PatTokenConfiguration, NoneCredentialConfiguration], + ConnectionCategory.CONTAINER_REGISTRY: [ManagedIdentityConfiguration, UsernamePasswordConfiguration], +} + +DATASTORE_CONNECTIONS = { + ConnectionCategory.AZURE_BLOB, + ConnectionTypes.AZURE_DATA_LAKE_GEN_2, + ConnectionCategory.AZURE_ONE_LAKE, +} + +CONNECTION_ALTERNATE_TARGET_NAMES = ["target", "api_base", "url", "azure_endpoint", "endpoint"] + + +# Dev note: The acceptable strings for the type field are all snake_cased versions of the string constants defined +# In the rest client enum defined at _azure_machine_learning_services_enums.ConnectionCategory. +# We avoid directly referencing it in the docs to avoid restclient references. +class WorkspaceConnection(Resource): + """Azure ML connection provides a secure way to store authentication and configuration information needed + to connect and interact with the external resources. + + Note: For connections to OpenAI, Cognitive Search, and Cognitive Services, use the respective subclasses + (ex: ~azure.ai.ml.entities.OpenAIConnection) instead of instantiating this class directly. + + :param name: Name of the connection. + :type name: str + :param target: The URL or ARM resource ID of the external resource. + :type target: str + :param metadata: Metadata dictionary. + :type metadata: Optional[Dict[str, Any]] + :param type: The category of external resource for this connection. + :type type: The type of connection, possible values are: "git", "python_feed", "container_registry", + "feature_store", "s3", "snowflake", "azure_sql_db", "azure_synapse_analytics", "azure_my_sql_db", + "azure_postgres_db", "adls_gen_2", "azure_one_lake", "custom". + :param credentials: The credentials for authenticating to the external resource. Note that certain connection + types (as defined by the type input) only accept certain types of credentials. + :type credentials: Union[ + ~azure.ai.ml.entities.PatTokenConfiguration, + ~azure.ai.ml.entities.SasTokenConfiguration, + ~azure.ai.ml.entities.UsernamePasswordConfiguration, + ~azure.ai.ml.entities.ManagedIdentityConfiguration + ~azure.ai.ml.entities.ServicePrincipalConfiguration, + ~azure.ai.ml.entities.AccessKeyConfiguration, + ~azure.ai.ml.entities.ApiKeyConfiguration, + ~azure.ai.ml.entities.NoneCredentialConfiguration + ~azure.ai.ml.entities.AccountKeyConfiguration, + ~azure.ai.ml.entities.AadCredentialConfiguration, + None + ] + :param is_shared: For connections in project, this controls whether or not this connection + is shared amongst other projects that are shared by the parent hub. Defaults to true. + :type is_shared: bool + """ + + def __init__( + self, + *, + # TODO : Check if this is okay since it shadows builtin-type type + type: str, # pylint: disable=redefined-builtin + credentials: Union[ + PatTokenConfiguration, + SasTokenConfiguration, + UsernamePasswordConfiguration, + ManagedIdentityConfiguration, + ServicePrincipalConfiguration, + AccessKeyConfiguration, + ApiKeyConfiguration, + NoneCredentialConfiguration, + AccountKeyConfiguration, + AadCredentialConfiguration, + ], + is_shared: bool = True, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ): + + # Dev note: This initializer has an undocumented kwarg "from_child" to determine if this initialization + # is from a child class. + # This kwarg is required to allow instantiation of types that are associated with subtypes without a + # warning printout. + # The additional undocumented kwarg "strict_typing" turns the warning into a value error. + from_child = kwargs.pop("from_child", False) + strict_typing = kwargs.pop("strict_typing", False) + correct_class = WorkspaceConnection._get_entity_class_from_type(type) + if not from_child and correct_class != WorkspaceConnection: + if strict_typing: + raise ValueError( + f"Cannot instantiate a base Connection with a type of {type}. " + f"Please use the appropriate subclass {correct_class.__name__} instead." + ) + warnings.warn( + f"The connection of {type} has additional fields and should not be instantiated directly " + f"from the Connection class. Please use its subclass {correct_class.__name__} instead.", + ) + # This disgusting code allows for a variety of inputs names to technically all + # act like the target field, while still maintaining the aggregate field as required. + target = None + for target_name in CONNECTION_ALTERNATE_TARGET_NAMES: + target = kwargs.pop(target_name, target) + if target is None and type not in {ConnectionCategory.SERP, ConnectionCategory.Open_AI}: + raise ValueError("target is a required field for Connection.") + + tags = kwargs.pop("tags", None) + if tags is not None: + if metadata is not None: + # Update tags updated with metadata to make sure metadata values are preserved in case of conflicts. + tags.update(metadata) + metadata = tags + warnings.warn( + "Tags are a deprecated field for connections, use metadata instead. Since both " + + "metadata and tags are assigned, metadata values will take precedence in the event of conflicts." + ) + else: + metadata = tags + warnings.warn("Tags are a deprecated field for connections, use metadata instead.") + + super().__init__(**kwargs) + + self.type = type + self._target = target + self._credentials = credentials + self._is_shared = is_shared + self._metadata = metadata + self._validate_cred_for_conn_cat() + + def _validate_cred_for_conn_cat(self) -> None: + """Given a connection type, ensure that the given credentials are valid for that connection type. + Does not validate the actual data of the inputted credential, just that they are of the right class + type. + + """ + # Convert none credentials to AAD credentials for datastore connection types. + # The backend stores datastore aad creds as none, unlike other connection types with aad, + # which actually list them as aad. This IS distinct from regular none credentials, or so I've been told, + # so I will endeavor to smooth over that inconsistency here. + converted_type = _snake_to_camel(self.type).lower() + if self._credentials == NoneCredentialConfiguration() and any( + converted_type == _snake_to_camel(item).lower() for item in DATASTORE_CONNECTIONS + ): + self._credentials = AadCredentialConfiguration() + + if self.type in CONNECTION_CATEGORY_TO_CREDENTIAL_MAP: + allowed_credentials = CONNECTION_CATEGORY_TO_CREDENTIAL_MAP[self.type] + if self.credentials is None and NoneCredentialConfiguration not in allowed_credentials: + raise ValueError( + f"Cannot instantiate a Connection with a type of {self.type} and no credentials." + f"Please supply credentials from one of the following types: {allowed_credentials}." + ) + cred_type = type(self.credentials) + if cred_type not in allowed_credentials: + raise ValueError( + f"Cannot instantiate a Connection with a type of {self.type} and credentials of type" + f" {cred_type}. Please supply credentials from one of the following types: {allowed_credentials}." + ) + # For unknown types, just let the user do whatever they want. + + @property + def type(self) -> Optional[str]: + """Type of the connection, supported are 'git', 'python_feed' and 'container_registry'. + + :return: Type of the job. + :rtype: str + """ + return self._type + + @type.setter + def type(self, value: str) -> None: + """Set the type of the connection, supported are 'git', 'python_feed' and 'container_registry'. + + :param value: value for the type of connection. + :type: str + """ + if not value: + return + self._type: Optional[str] = camel_to_snake(value) + + @property + def target(self) -> Optional[str]: + """Target url for the connection. + + :return: Target of the connection. + :rtype: Optional[str] + """ + return self._target + + @property + def endpoint(self) -> Optional[str]: + """Alternate name for the target of the connection, + which is used by some connection subclasses. + + :return: The target of the connection. + :rtype: str + """ + return self.target + + @property + def azure_endpoint(self) -> Optional[str]: + """Alternate name for the target of the connection, + which is used by some connection subclasses. + + :return: The target of the connection. + :rtype: str + """ + return self.target + + @property + def url(self) -> Optional[str]: + """Alternate name for the target of the connection, + which is used by some connection subclasses. + + :return: The target of the connection. + :rtype: str + """ + return self.target + + @property + def api_base(self) -> Optional[str]: + """Alternate name for the target of the connection, + which is used by some connection subclasses. + + :return: The target of the connection. + :rtype: str + """ + return self.target + + @property + def credentials( + self, + ) -> Union[ + PatTokenConfiguration, + SasTokenConfiguration, + UsernamePasswordConfiguration, + ManagedIdentityConfiguration, + ServicePrincipalConfiguration, + AccessKeyConfiguration, + ApiKeyConfiguration, + NoneCredentialConfiguration, + AccountKeyConfiguration, + AadCredentialConfiguration, + ]: + """Credentials for connection. + + :return: Credentials for connection. + :rtype: Union[ + ~azure.ai.ml.entities.PatTokenConfiguration, + ~azure.ai.ml.entities.SasTokenConfiguration, + ~azure.ai.ml.entities.UsernamePasswordConfiguration, + ~azure.ai.ml.entities.ManagedIdentityConfiguration + ~azure.ai.ml.entities.ServicePrincipalConfiguration, + ~azure.ai.ml.entities.AccessKeyConfiguration, + ~azure.ai.ml.entities.ApiKeyConfiguration + ~azure.ai.ml.entities.NoneCredentialConfiguration, + ~azure.ai.ml.entities.AccountKeyConfiguration, + ~azure.ai.ml.entities.AadCredentialConfiguration, + ] + """ + return self._credentials + + @property + def metadata(self) -> Optional[Dict[str, Any]]: + """The connection's metadata dictionary. + :return: This connection's metadata. + :rtype: Optional[Dict[str, Any]] + """ + return self._metadata if self._metadata is not None else {} + + @metadata.setter + def metadata(self, value: Optional[Dict[str, Any]]) -> None: + """Set the metadata for the connection. Be warned that setting this will override + ALL metadata values, including those implicitly set by certain connection types to manage their + extra data. Usually, you should probably access the metadata dictionary, then add or remove values + individually as needed. + :param value: The new metadata for connection. + This completely overwrites the existing metadata dictionary. + :type value: Optional[Dict[str, Any]] + """ + if not value: + return + self._metadata = value + + @property + def tags(self) -> Optional[Dict[str, Any]]: + """Deprecated. Use metadata instead. + :return: This connection's metadata. + :rtype: Optional[Dict[str, Any]] + """ + return self._metadata if self._metadata is not None else {} + + @tags.setter + def tags(self, value: Optional[Dict[str, Any]]) -> None: + """Deprecated use metadata instead + :param value: The new metadata for connection. + This completely overwrites the existing metadata dictionary. + :type value: Optional[Dict[str, Any]] + """ + if not value: + return + self._metadata = value + + @property + def is_shared(self) -> bool: + """Get the Boolean describing if this connection is shared amongst its cohort within a hub. + Only applicable for connections created within a project. + + :rtype: bool + """ + return self._is_shared + + @is_shared.setter + def is_shared(self, value: bool) -> None: + """Assign the is_shared property of the connection, determining if it is shared amongst other projects + within its parent hub. Only applicable for connections created within a project. + + :param value: The new is_shared value. + :type value: bool + """ + if not value: + return + self._is_shared = value + + def dump(self, dest: Union[str, PathLike, IO[AnyStr]], **kwargs: Any) -> None: + """Dump the connection spec into a file in yaml format. + + :param dest: The destination to receive this connection's spec. + Must be either a path to a local file, or an already-open file stream. + If dest is a file path, a new file will be created, + and an exception is raised if the file exists. + If dest is an open file, the file will be written to directly, + and an exception will be raised if the file is not writable. + :type dest: Union[PathLike, str, IO[AnyStr]] + """ + path = kwargs.pop("path", None) + yaml_serialized = self._to_dict() + dump_yaml_to_file(dest, yaml_serialized, default_flow_style=False, path=path, **kwargs) + + @classmethod + def _load( + cls, + data: Optional[Dict] = None, + yaml_path: Optional[Union[PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "WorkspaceConnection": + data = data or {} + params_override = params_override or [] + context = { + BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"), + PARAMS_OVERRIDE_KEY: params_override, + } + return cls._load_from_dict(data=data, context=context, **kwargs) + + @classmethod + def _load_from_dict(cls, data: Dict, context: Dict, **kwargs: Any) -> "WorkspaceConnection": + conn_type = data["type"] if "type" in data else None + schema_class = cls._get_schema_class_from_type(conn_type) + loaded_data: WorkspaceConnection = load_from_dict(schema_class, data, context, **kwargs) + return loaded_data + + def _to_dict(self) -> Dict: + schema_class = WorkspaceConnection._get_schema_class_from_type(self.type) + # Not sure what this pylint complaint was about, probably due to the polymorphic + # tricks at play. Disabling since testing indicates no issue. + res: dict = schema_class(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + @classmethod + def _from_rest_object(cls, rest_obj: RestWorkspaceConnection) -> "WorkspaceConnection": + conn_class = cls._get_entity_class_from_rest_obj(rest_obj) + + popped_metadata = conn_class._get_required_metadata_fields() + + rest_kwargs = cls._extract_kwargs_from_rest_obj(rest_obj=rest_obj, popped_metadata=popped_metadata) + # Check for alternative name for custom connection type (added for client clarity). + if rest_kwargs["type"].lower() == camel_to_snake(ConnectionCategory.CUSTOM_KEYS).lower(): + rest_kwargs["type"] = ConnectionTypes.CUSTOM + if rest_kwargs["type"].lower() == camel_to_snake(ConnectionCategory.ADLS_GEN2).lower(): + rest_kwargs["type"] = ConnectionTypes.AZURE_DATA_LAKE_GEN_2 + target = rest_kwargs.get("target", "") + # This dumb code accomplishes 2 things. + # It ensures that sub-classes properly input their target, regardless of which + # arbitrary name they replace it with, while also still allowing our official + # client specs to list those inputs as 'required' + for target_name in CONNECTION_ALTERNATE_TARGET_NAMES: + rest_kwargs[target_name] = target + if rest_obj.properties.category == ConnectionCategory.AZURE_ONE_LAKE: + # The microsoft one lake connection uniquely has client-only inputs + # that aren't just an alternate name for the target. + # This sets those inputs, that way the initializer can still + # required those fields for users. + rest_kwargs["artifact"] = "" + rest_kwargs["one_lake_workspace_name"] = "" + if rest_obj.properties.category == ConnectionTypes.AI_SERVICES_REST_PLACEHOLDER: + # AI Services renames it's metadata field when surfaced to users and inputted + # into it's initializer for clarity. ResourceId doesn't really tell much on its own. + # No default in pop, this should fail if we somehow don't get a resource ID + rest_kwargs["ai_services_resource_id"] = rest_kwargs.pop(camel_to_snake(CONNECTION_RESOURCE_ID_KEY)) + connection = conn_class(**rest_kwargs) + return cast(WorkspaceConnection, connection) + + def _validate(self) -> str: + return str(self.name) + + def _to_rest_object(self) -> RestWorkspaceConnection: + connection_properties_class: Any = NoneAuthTypeWorkspaceConnectionProperties + if self._credentials: + connection_properties_class = self._credentials._get_rest_properties_class() + # Convert from human readable type to corresponding api enum if needed. + conn_type = self.type + if conn_type == ConnectionTypes.CUSTOM: + conn_type = ConnectionCategory.CUSTOM_KEYS + elif conn_type == ConnectionTypes.AZURE_DATA_LAKE_GEN_2: + conn_type = ConnectionCategory.ADLS_GEN2 + elif conn_type in { + ConnectionTypes.AZURE_CONTENT_SAFETY, + ConnectionTypes.AZURE_SPEECH_SERVICES, + }: + conn_type = ConnectionCategory.COGNITIVE_SERVICE + elif conn_type == ConnectionTypes.AZURE_SEARCH: + conn_type = ConnectionCategory.COGNITIVE_SEARCH + elif conn_type == ConnectionTypes.AZURE_AI_SERVICES: + # ConnectionCategory.AI_SERVICES category accidentally unpublished + conn_type = ConnectionTypes.AI_SERVICES_REST_PLACEHOLDER + # Some credential property bags have no credential input. + if connection_properties_class in { + NoneAuthTypeWorkspaceConnectionProperties, + AADAuthTypeWorkspaceConnectionProperties, + }: + properties = connection_properties_class( + target=self.target, + metadata=self.metadata, + category=_snake_to_camel(conn_type), + is_shared_to_all=self.is_shared, + ) + else: + properties = connection_properties_class( + target=self.target, + credentials=self.credentials._to_workspace_connection_rest_object() if self._credentials else None, + metadata=self.metadata, + category=_snake_to_camel(conn_type), + is_shared_to_all=self.is_shared, + ) + + return RestWorkspaceConnection(properties=properties) + + @classmethod + def _extract_kwargs_from_rest_obj( + cls, rest_obj: RestWorkspaceConnection, popped_metadata: List[str] + ) -> Dict[str, str]: + """Internal helper function with extracts all the fields needed to initialize a connection object + from its associated restful object. Pulls extra fields based on the supplied `popped_metadata` input. + Returns all the fields as a dictionary, which is expected to then be supplied to a + connection initializer as kwargs. + + :param rest_obj: The rest object representation of a connection + :type rest_obj: RestWorkspaceConnection + :param popped_metadata: Key names that should be pulled from the rest object's metadata and + injected as top-level fields into the client connection's initializer. + This is needed for subclasses that require extra inputs compared to the base Connection class. + :type popped_metadata: List[str] + + :return: A dictionary containing all kwargs needed to construct a connection. + :rtype: Dict[str, str] + """ + properties = rest_obj.properties + credentials: Any = NoneCredentialConfiguration() + + credentials_class = _BaseIdentityConfiguration._get_credential_class_from_rest_type(properties.auth_type) + # None and AAD auth types have a property bag class, but no credentials inside that. + # Thankfully they both have no inputs. + + if credentials_class is AadCredentialConfiguration: + credentials = AadCredentialConfiguration() + elif credentials_class is not NoneCredentialConfiguration: + credentials = credentials_class._from_workspace_connection_rest_object(properties.credentials) + + metadata = properties.metadata if hasattr(properties, "metadata") else {} + rest_kwargs = { + "id": rest_obj.id, + "name": rest_obj.name, + "target": properties.target, + "creation_context": SystemData._from_rest_object(rest_obj.system_data) if rest_obj.system_data else None, + "type": camel_to_snake(properties.category), + "credentials": credentials, + "metadata": metadata, + "is_shared": properties.is_shared_to_all if hasattr(properties, "is_shared_to_all") else True, + } + + for name in popped_metadata: + if name in metadata: + rest_kwargs[camel_to_snake(name)] = metadata[name] + return rest_kwargs + + @classmethod + def _get_entity_class_from_type(cls, type: str) -> Type: + """Helper function that derives the correct connection class given the client or server type. + Differs slightly from the rest object version in that it doesn't need to account for + rest object metadata. + + This reason there are two functions at all is due to certain API connection types that + are obfuscated with different names when presented to the client. These types are + accounted for in the ConnectionTypes class in the constants file. + + :param type: The type string describing the connection. + :type type: str + + :return: Theconnection class the conn_type corresponds to. + :rtype: Type + """ + from .connection_subtypes import ( + AzureBlobStoreConnection, + MicrosoftOneLakeConnection, + AzureOpenAIConnection, + AzureAIServicesConnection, + AzureAISearchConnection, + AzureContentSafetyConnection, + AzureSpeechServicesConnection, + APIKeyConnection, + OpenAIConnection, + SerpConnection, + ServerlessConnection, + ) + + conn_type = _snake_to_camel(type).lower() + if conn_type is None: + return WorkspaceConnection + + # Connection categories don't perfectly follow perfect camel casing, so lower + # case everything to avoid problems. + CONNECTION_CATEGORY_TO_SUBCLASS_MAP = { + ConnectionCategory.AZURE_OPEN_AI.lower(): AzureOpenAIConnection, + ConnectionCategory.AZURE_BLOB.lower(): AzureBlobStoreConnection, + ConnectionCategory.AZURE_ONE_LAKE.lower(): MicrosoftOneLakeConnection, + ConnectionCategory.API_KEY.lower(): APIKeyConnection, + ConnectionCategory.OPEN_AI.lower(): OpenAIConnection, + ConnectionCategory.SERP.lower(): SerpConnection, + ConnectionCategory.SERVERLESS.lower(): ServerlessConnection, + _snake_to_camel(ConnectionTypes.AZURE_CONTENT_SAFETY).lower(): AzureContentSafetyConnection, + _snake_to_camel(ConnectionTypes.AZURE_SPEECH_SERVICES).lower(): AzureSpeechServicesConnection, + ConnectionCategory.COGNITIVE_SEARCH.lower(): AzureAISearchConnection, + _snake_to_camel(ConnectionTypes.AZURE_SEARCH).lower(): AzureAISearchConnection, + _snake_to_camel(ConnectionTypes.AZURE_AI_SERVICES).lower(): AzureAIServicesConnection, + ConnectionTypes.AI_SERVICES_REST_PLACEHOLDER.lower(): AzureAIServicesConnection, + } + return CONNECTION_CATEGORY_TO_SUBCLASS_MAP.get(conn_type, WorkspaceConnection) + + @classmethod + def _get_entity_class_from_rest_obj(cls, rest_obj: RestWorkspaceConnection) -> Type: + """Helper function that converts a restful connection into the associated + connection class or subclass. Accounts for potential snake/camel case and + capitalization differences in the type, and sub-typing derived from metadata. + + :param rest_obj: The rest object representation of the connection to derive a class from. + :type rest_obj: RestWorkspaceConnection + + :return: The connection class the conn_type corresponds to. + :rtype: Type + """ + conn_type = rest_obj.properties.category + conn_type = _snake_to_camel(conn_type).lower() + if conn_type is None: + return WorkspaceConnection + + # Imports are done here to avoid circular imports on load. + from .connection_subtypes import ( + AzureContentSafetyConnection, + AzureSpeechServicesConnection, + ) + + # Cognitive search connections have further subdivisions based on the kind of service. + if ( + conn_type == ConnectionCategory.COGNITIVE_SERVICE.lower() + and hasattr(rest_obj.properties, "metadata") + and rest_obj.properties.metadata is not None + ): + kind = rest_obj.properties.metadata.get(CONNECTION_KIND_KEY, "").lower() + if kind == CognitiveServiceKinds.CONTENT_SAFETY.lower(): + return AzureContentSafetyConnection + if kind == CognitiveServiceKinds.SPEECH.lower(): + return AzureSpeechServicesConnection + return WorkspaceConnection + + return cls._get_entity_class_from_type(type=conn_type) + + @classmethod + def _get_schema_class_from_type(cls, conn_type: Optional[str]) -> Type: + """Helper function that converts a rest client connection category into the associated + connection schema class or subclass. Accounts for potential snake/camel case and + capitalization differences. + + :param conn_type: The connection type. + :type conn_type: str + + :return: The connection schema class the conn_type corresponds to. + :rtype: Type + """ + if conn_type is None: + return WorkspaceConnectionSchema + entity_class = cls._get_entity_class_from_type(conn_type) + return entity_class._get_schema_class() + + @classmethod + def _get_required_metadata_fields(cls) -> List[str]: + """Helper function that returns the required metadata fields for specific + connection type. This parent function returns nothing, but needs to be overwritten by child + classes, which are created under the expectation that they have extra fields that need to be + accounted for. + + :return: A list of the required metadata fields for the specific connection type. + :rtype: List[str] + """ + return [] + + @classmethod + def _get_schema_class(cls) -> Type: + """Helper function that maps this class to its associated schema class. Needs to be overridden by + child classes to allow the base class to be polymorphic in its schema reading. + + :return: The appropriate schema class to use with this entity class. + :rtype: Type + """ + return WorkspaceConnectionSchema diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/customer_managed_key.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/customer_managed_key.py new file mode 100644 index 00000000..88474dab --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/customer_managed_key.py @@ -0,0 +1,48 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +from typing import Optional + + +class CustomerManagedKey: + """Key vault details for encrypting data with customer-managed keys. + + :param key_vault: Key vault that is holding the customer-managed key. + :type key_vault: str + :param key_uri: URI for the customer-managed key. + :type key_uri: str + :param cosmosdb_id: ARM id of bring-your-own cosmosdb account that customer brings + to store customer's data with encryption. + :type cosmosdb_id: str + :param storage_id: ARM id of bring-your-own storage account that customer brings + to store customer's data with encryption. + :type storage_id: str + :param search_id: ARM id of bring-your-own search account that customer brings + to store customer's data with encryption. + :type search_id: str + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_workspace.py + :start-after: [START customermanagedkey] + :end-before: [END customermanagedkey] + :language: python + :dedent: 8 + :caption: Creating a CustomerManagedKey object. + """ + + def __init__( + self, + key_vault: Optional[str] = None, + key_uri: Optional[str] = None, + cosmosdb_id: Optional[str] = None, + storage_id: Optional[str] = None, + search_id: Optional[str] = None, + ): + self.key_vault = key_vault + self.key_uri = key_uri + self.cosmosdb_id = cosmosdb_id or "" + self.storage_id = storage_id or "" + self.search_id = search_id or "" diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/diagnose.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/diagnose.py new file mode 100644 index 00000000..fa923dc4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/diagnose.py @@ -0,0 +1,214 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import json +from typing import Any, Dict, List, Optional + +from azure.ai.ml._restclient.v2024_10_01_preview.models import ( + DiagnoseRequestProperties as RestDiagnoseRequestProperties, +) +from azure.ai.ml._restclient.v2024_10_01_preview.models import DiagnoseResponseResult as RestDiagnoseResponseResult +from azure.ai.ml._restclient.v2024_10_01_preview.models import ( + DiagnoseResponseResultValue as RestDiagnoseResponseResultValue, +) +from azure.ai.ml._restclient.v2024_10_01_preview.models import DiagnoseResult as RestDiagnoseResult +from azure.ai.ml._restclient.v2024_10_01_preview.models import ( + DiagnoseWorkspaceParameters as RestDiagnoseWorkspaceParameters, +) + + +class DiagnoseRequestProperties: + """DiagnoseRequestProperties.""" + + def __init__( + self, + *, + udr: Optional[Dict[str, Any]] = None, + nsg: Optional[Dict[str, Any]] = None, + resource_lock: Optional[Dict[str, Any]] = None, + dns_resolution: Optional[Dict[str, Any]] = None, + storage_account: Optional[Dict[str, Any]] = None, + key_vault: Optional[Dict[str, Any]] = None, + container_registry: Optional[Dict[str, Any]] = None, + application_insights: Optional[Dict[str, Any]] = None, + others: Optional[Dict[str, Any]] = None, + ): + self.udr = udr + self.nsg = nsg + self.resource_lock = resource_lock + self.dns_resolution = dns_resolution + self.storage_account = storage_account + self.key_vault = key_vault + self.container_registry = container_registry + self.application_insights = application_insights + self.others = others + + @classmethod + def _from_rest_object(cls, rest_obj: RestDiagnoseRequestProperties) -> "DiagnoseRequestProperties": + return cls( + udr=rest_obj.udr, + nsg=rest_obj.nsg, + resource_lock=rest_obj.resource_lock, + dns_resolution=rest_obj.dns_resolution, + storage_account=rest_obj.storage_account, + key_vault=rest_obj.key_vault, + container_registry=rest_obj.container_registry, + application_insights=rest_obj.application_insights, + others=rest_obj.others, + ) + + def _to_rest_object(self) -> RestDiagnoseRequestProperties: + return RestDiagnoseRequestProperties( + udr=self.udr, + nsg=self.nsg, + resource_lock=self.resource_lock, + dns_resolution=self.dns_resolution, + storage_account=self.storage_account, + key_vault=self.key_vault, + container_registry=self.container_registry, + application_insights=self.application_insights, + others=self.others, + ) + + +class DiagnoseResponseResult: + """DiagnoseResponseResult.""" + + def __init__( + self, + *, + value: Optional["DiagnoseResponseResultValue"] = None, + ): + self.value = value + + @classmethod + def _from_rest_object(cls, rest_obj: RestDiagnoseResponseResult) -> "DiagnoseResponseResult": + val = None + if rest_obj and rest_obj.value and isinstance(rest_obj.value, RestDiagnoseResponseResultValue): + # pylint: disable=protected-access + val = DiagnoseResponseResultValue._from_rest_object(rest_obj.value) + return cls(value=val) + + def _to_rest_object(self) -> RestDiagnoseResponseResult: + return RestDiagnoseResponseResult(value=self.value) + + +class DiagnoseResponseResultValue: + """DiagnoseResponseResultValue.""" + + def __init__( + self, + *, + user_defined_route_results: Optional[List["DiagnoseResult"]] = None, + network_security_rule_results: Optional[List["DiagnoseResult"]] = None, + resource_lock_results: Optional[List["DiagnoseResult"]] = None, + dns_resolution_results: Optional[List["DiagnoseResult"]] = None, + storage_account_results: Optional[List["DiagnoseResult"]] = None, + key_vault_results: Optional[List["DiagnoseResult"]] = None, + container_registry_results: Optional[List["DiagnoseResult"]] = None, + application_insights_results: Optional[List["DiagnoseResult"]] = None, + other_results: Optional[List["DiagnoseResult"]] = None, + ): + self.user_defined_route_results = user_defined_route_results + self.network_security_rule_results = network_security_rule_results + self.resource_lock_results = resource_lock_results + self.dns_resolution_results = dns_resolution_results + self.storage_account_results = storage_account_results + self.key_vault_results = key_vault_results + self.container_registry_results = container_registry_results + self.application_insights_results = application_insights_results + self.other_results = other_results + + @classmethod + def _from_rest_object(cls, rest_obj: RestDiagnoseResponseResultValue) -> "DiagnoseResponseResultValue": + return cls( + user_defined_route_results=rest_obj.user_defined_route_results, + network_security_rule_results=rest_obj.network_security_rule_results, + resource_lock_results=rest_obj.resource_lock_results, + dns_resolution_results=rest_obj.dns_resolution_results, + storage_account_results=rest_obj.storage_account_results, + key_vault_results=rest_obj.key_vault_results, + container_registry_results=rest_obj.container_registry_results, + application_insights_results=rest_obj.application_insights_results, + other_results=rest_obj.other_results, + ) + + def _to_rest_object(self) -> RestDiagnoseResponseResultValue: + return RestDiagnoseResponseResultValue( + user_defined_route_results=self.user_defined_route_results, + network_security_rule_results=self.network_security_rule_results, + resource_lock_results=self.resource_lock_results, + dns_resolution_results=self.dns_resolution_results, + storage_account_results=self.storage_account_results, + key_vault_results=self.key_vault_results, + container_registry_results=self.container_registry_results, + application_insights_results=self.application_insights_results, + other_results=self.other_results, + ) + + def __json__(self): + results = self.__dict__.copy() + for k, v in results.items(): + results[k] = [item.__dict__ for item in v] + return results + + def __str__(self) -> str: + return json.dumps(self, default=lambda o: o.__json__(), indent=2) + + +class DiagnoseResult: + """Result of Diagnose.""" + + def __init__( + self, + *, + code: Optional[str] = None, + level: Optional[str] = None, + message: Optional[str] = None, + ): + self.code = code + self.level = level + self.message = message + + @classmethod + def _from_rest_object(cls, rest_obj: RestDiagnoseResult) -> "DiagnoseResult": + return cls( + code=rest_obj.code, + level=rest_obj.level, + message=rest_obj.message, + ) + + def _to_rest_object(self) -> RestDiagnoseResult: + return RestDiagnoseResult( + code=self.code, + level=self.level, + message=self.message, + ) + + +class DiagnoseWorkspaceParameters: + """Parameters to diagnose a workspace.""" + + def __init__( + self, + *, + value: Optional["DiagnoseRequestProperties"] = None, + ): + self.value = value + + @classmethod + def _from_rest_object(cls, rest_obj: RestDiagnoseWorkspaceParameters) -> "DiagnoseWorkspaceParameters": + val = None + if rest_obj.value and isinstance(rest_obj.value, DiagnoseRequestProperties): + # TODO: Bug Item number: 2883283 + # pylint: disable=protected-access + val = rest_obj.value._from_rest_object() # type: ignore + return cls(value=val) + + def _to_rest_object(self) -> RestDiagnoseWorkspaceParameters: + val = None + if self.value and isinstance(self.value, DiagnoseRequestProperties): + # pylint: disable=protected-access + val = self.value._to_rest_object() + return RestDiagnoseWorkspaceParameters(value=val) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/feature_store_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/feature_store_settings.py new file mode 100644 index 00000000..8c264db0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/feature_store_settings.py @@ -0,0 +1,61 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from typing import Optional + +from azure.ai.ml._restclient.v2024_10_01_preview.models import FeatureStoreSettings as RestFeatureStoreSettings +from azure.ai.ml.entities._mixins import RestTranslatableMixin + +from .compute_runtime import ComputeRuntime + + +class FeatureStoreSettings(RestTranslatableMixin): + """Feature Store Settings + + :param compute_runtime: The spark compute runtime settings. defaults to None. + :type compute_runtime: Optional[~compute_runtime.ComputeRuntime] + :param offline_store_connection_name: The offline store connection name. Defaults to None. + :type offline_store_connection_name: Optional[str] + :param online_store_connection_name: The online store connection name. Defaults to None. + :type online_store_connection_name: Optional[str] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_featurestore.py + :start-after: [START configure_feature_store_settings] + :end-before: [END configure_feature_store_settings] + :language: python + :dedent: 8 + :caption: Instantiating FeatureStoreSettings + """ + + def __init__( + self, + *, + compute_runtime: Optional[ComputeRuntime] = None, + offline_store_connection_name: Optional[str] = None, + online_store_connection_name: Optional[str] = None, + ) -> None: + self.compute_runtime = compute_runtime if compute_runtime else ComputeRuntime(spark_runtime_version="3.4.0") + self.offline_store_connection_name = offline_store_connection_name + self.online_store_connection_name = online_store_connection_name + + def _to_rest_object(self) -> RestFeatureStoreSettings: + return RestFeatureStoreSettings( + compute_runtime=ComputeRuntime._to_rest_object(self.compute_runtime), + offline_store_connection_name=self.offline_store_connection_name, + online_store_connection_name=self.online_store_connection_name, + ) + + @classmethod + def _from_rest_object(cls, obj: RestFeatureStoreSettings) -> Optional["FeatureStoreSettings"]: + if not obj: + return None + return FeatureStoreSettings( + compute_runtime=ComputeRuntime._from_rest_object(obj.compute_runtime), + offline_store_connection_name=obj.offline_store_connection_name, + online_store_connection_name=obj.online_store_connection_name, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/network_acls.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/network_acls.py new file mode 100644 index 00000000..fbb3b9ef --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/network_acls.py @@ -0,0 +1,90 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import List, Optional + +from azure.ai.ml._restclient.v2024_10_01_preview.models import IPRule as RestIPRule +from azure.ai.ml._restclient.v2024_10_01_preview.models import NetworkAcls as RestNetworkAcls +from azure.ai.ml.entities._mixins import RestTranslatableMixin + + +class IPRule(RestTranslatableMixin): + """Represents an IP rule with a value. + + :param value: An IPv4 address or range in CIDR notation. + :type value: str + """ + + def __init__(self, value: Optional[str]): + self.value = value + + def __repr__(self): + return f"IPRule(value={self.value})" + + def _to_rest_object(self) -> RestIPRule: + return RestIPRule(value=self.value) + + @classmethod + def _from_rest_object(cls, obj: RestIPRule) -> "IPRule": + return cls(value=obj.value) + + +class DefaultActionType: + """Specifies the default action when no IP rules are matched.""" + + DENY = "Deny" + ALLOW = "Allow" + + +class NetworkAcls(RestTranslatableMixin): + """Network Access Setting for Workspace + + :param default_action: Specifies the default action when no IP rules are matched. + :type default_action: str + :param ip_rules: Rules governing the accessibility of a resource from a specific IP address or IP range. + :type ip_rules: Optional[List[IPRule]] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_workspace.py + :start-after: [START workspace_network_access_settings] + :end-before: [END workspace_network_access_settings] + :language: python + :dedent: 8 + :caption: Configuring one of the three public network access settings. + """ + + def __init__( + self, + *, + default_action: str = DefaultActionType.ALLOW, + ip_rules: Optional[List[IPRule]] = None, + ): + self.default_action = default_action + self.ip_rules = ip_rules if ip_rules is not None else [] + + def __repr__(self): + ip_rules_repr = ", ".join(repr(ip_rule) for ip_rule in self.ip_rules) + return f"NetworkAcls(default_action={self.default_action}, ip_rules=[{ip_rules_repr}])" + + def _to_rest_object(self) -> RestNetworkAcls: + return RestNetworkAcls( + default_action=self.default_action, + ip_rules=( + [ip_rule._to_rest_object() for ip_rule in self.ip_rules] # pylint: disable=protected-access + if self.ip_rules + else None + ), + ) + + @classmethod + def _from_rest_object(cls, obj: RestNetworkAcls) -> "NetworkAcls": + return cls( + default_action=obj.default_action, + ip_rules=( + [IPRule._from_rest_object(ip_rule) for ip_rule in obj.ip_rules] # pylint: disable=protected-access + if obj.ip_rules + else [] + ), + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/networking.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/networking.py new file mode 100644 index 00000000..4576eac9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/networking.py @@ -0,0 +1,348 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from abc import ABC +from typing import Any, Dict, List, Optional + +from azure.ai.ml._restclient.v2024_10_01_preview.models import FqdnOutboundRule as RestFqdnOutboundRule +from azure.ai.ml._restclient.v2024_10_01_preview.models import ( + ManagedNetworkProvisionStatus as RestManagedNetworkProvisionStatus, +) +from azure.ai.ml._restclient.v2024_10_01_preview.models import ManagedNetworkSettings as RestManagedNetwork +from azure.ai.ml._restclient.v2024_10_01_preview.models import ( + PrivateEndpointDestination as RestPrivateEndpointOutboundRuleDestination, +) +from azure.ai.ml._restclient.v2024_10_01_preview.models import ( + PrivateEndpointOutboundRule as RestPrivateEndpointOutboundRule, +) +from azure.ai.ml._restclient.v2024_10_01_preview.models import ( + ServiceTagDestination as RestServiceTagOutboundRuleDestination, +) +from azure.ai.ml._restclient.v2024_10_01_preview.models import ServiceTagOutboundRule as RestServiceTagOutboundRule +from azure.ai.ml.constants._workspace import IsolationMode, OutboundRuleCategory, OutboundRuleType + + +class OutboundRule(ABC): + """Base class for Outbound Rules, cannot be instantiated directly. Please see FqdnDestination, + PrivateEndpointDestination, and ServiceTagDestination objects to create outbound rules. + + :param name: Name of the outbound rule. + :type name: str + :param type: Type of the outbound rule. Supported types are "FQDN", "PrivateEndpoint", "ServiceTag" + :type type: str + :ivar type: Type of the outbound rule. Supported types are "FQDN", "PrivateEndpoint", "ServiceTag" + :vartype type: str + """ + + def __init__( + self, + *, + name: Optional[str] = None, + **kwargs: Any, + ) -> None: + self.name = name + self.parent_rule_names = kwargs.pop("parent_rule_names", None) + self.type = kwargs.pop("type", None) + self.category = kwargs.pop("category", OutboundRuleCategory.USER_DEFINED) + self.status = kwargs.pop("status", None) + + @classmethod + def _from_rest_object(cls, rest_obj: Any, name: str) -> Optional["OutboundRule"]: + if isinstance(rest_obj, RestFqdnOutboundRule): + rule_fqdnDestination = FqdnDestination(destination=rest_obj.destination, name=name) + rule_fqdnDestination.category = rest_obj.category + rule_fqdnDestination.status = rest_obj.status + return rule_fqdnDestination + if isinstance(rest_obj, RestPrivateEndpointOutboundRule): + rule_privateEndpointDestination = PrivateEndpointDestination( + service_resource_id=rest_obj.destination.service_resource_id, + subresource_target=rest_obj.destination.subresource_target, + spark_enabled=rest_obj.destination.spark_enabled, + fqdns=rest_obj.fqdns, + name=name, + ) + rule_privateEndpointDestination.category = rest_obj.category + rule_privateEndpointDestination.status = rest_obj.status + return rule_privateEndpointDestination + if isinstance(rest_obj, RestServiceTagOutboundRule): + rule = ServiceTagDestination( + service_tag=rest_obj.destination.service_tag, + protocol=rest_obj.destination.protocol, + port_ranges=rest_obj.destination.port_ranges, + address_prefixes=rest_obj.destination.address_prefixes, + name=name, + ) + rule.category = rest_obj.category + rule.status = rest_obj.status + return rule + + return None + + +class FqdnDestination(OutboundRule): + """Class representing a FQDN outbound rule. + + :param name: Name of the outbound rule. + :type name: str + :param destination: Fully qualified domain name to which outbound connections are allowed. + For example: “xxxxxx.contoso.com”. + :type destination: str + :ivar type: Type of the outbound rule. Set to "FQDN" for this class. + :vartype type: str + + .. literalinclude:: ../samples/ml_samples_workspace.py + :start-after: [START fqdn_outboundrule] + :end-before: [END fqdn_outboundrule] + :language: python + :dedent: 8 + :caption: Creating a FqdnDestination outbound rule object. + """ + + def __init__(self, *, name: str, destination: str, **kwargs: Any) -> None: + self.destination = destination + OutboundRule.__init__(self, type=OutboundRuleType.FQDN, name=name, **kwargs) + + def _to_rest_object(self) -> RestFqdnOutboundRule: + return RestFqdnOutboundRule(type=self.type, category=self.category, destination=self.destination) + + def _to_dict(self) -> Dict: + return { + "name": self.name, + "type": OutboundRuleType.FQDN, + "category": self.category, + "destination": self.destination, + "status": self.status, + } + + +class PrivateEndpointDestination(OutboundRule): + """Class representing a Private Endpoint outbound rule. + + :param name: Name of the outbound rule. + :type name: str + :param service_resource_id: The resource URI of the root service that supports creation of the private link. + :type service_resource_id: str + :param subresource_target: The target endpoint of the subresource of the service. + :type subresource_target: str + :param spark_enabled: Indicates if the private endpoint can be used for Spark jobs, default is “false”. + :type spark_enabled: bool + :param fqdns: String list of FQDNs particular to the Private Endpoint resource creation. For application + gateway Private Endpoints, this is the FQDN which will resolve to the private IP of the application + gateway PE inside the workspace's managed network. + :type fqdns: List[str] + :ivar type: Type of the outbound rule. Set to "PrivateEndpoint" for this class. + :vartype type: str + + .. literalinclude:: ../samples/ml_samples_workspace.py + :start-after: [START private_endpoint_outboundrule] + :end-before: [END private_endpoint_outboundrule] + :language: python + :dedent: 8 + :caption: Creating a PrivateEndpointDestination outbound rule object. + """ + + def __init__( + self, + *, + name: str, + service_resource_id: str, + subresource_target: str, + spark_enabled: bool = False, + fqdns: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + self.service_resource_id = service_resource_id + self.subresource_target = subresource_target + self.spark_enabled = spark_enabled + self.fqdns = fqdns + OutboundRule.__init__(self, type=OutboundRuleType.PRIVATE_ENDPOINT, name=name, **kwargs) + + def _to_rest_object(self) -> RestPrivateEndpointOutboundRule: + return RestPrivateEndpointOutboundRule( + type=self.type, + category=self.category, + destination=RestPrivateEndpointOutboundRuleDestination( + service_resource_id=self.service_resource_id, + subresource_target=self.subresource_target, + spark_enabled=self.spark_enabled, + ), + fqdns=self.fqdns, + ) + + def _to_dict(self) -> Dict: + return { + "name": self.name, + "type": OutboundRuleType.PRIVATE_ENDPOINT, + "category": self.category, + "destination": { + "service_resource_id": self.service_resource_id, + "subresource_target": self.subresource_target, + "spark_enabled": self.spark_enabled, + }, + "fqdns": self.fqdns, + "status": self.status, + } + + +class ServiceTagDestination(OutboundRule): + """Class representing a Service Tag outbound rule. + + :param name: Name of the outbound rule. + :type name: str + :param service_tag: Service Tag of an Azure service, maps to predefined IP addresses for its service endpoints. + :type service_tag: str + :param protocol: Allowed transport protocol, can be "TCP", "UDP", "ICMP" or "*" for all supported protocols. + :type protocol: str + :param port_ranges: A comma-separated list of single ports and/or range of ports, such as "80,1024-65535". + Traffics should be allowed to these port ranges. + :type port_ranges: str + :param address_prefixes: Optional list of CIDR prefixes or IP ranges, when provided, service_tag argument will + be ignored and address_prefixes will be used instead. + :type address_prefixes: List[str] + :ivar type: Type of the outbound rule. Set to "ServiceTag" for this class. + :vartype type: str + + .. literalinclude:: ../samples/ml_samples_workspace.py + :start-after: [START service_tag_outboundrule] + :end-before: [END service_tag_outboundrule] + :language: python + :dedent: 8 + :caption: Creating a ServiceTagDestination outbound rule object. + """ + + def __init__( + self, + *, + name: str, + protocol: str, + port_ranges: str, + service_tag: Optional[str] = None, + address_prefixes: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + self.service_tag = service_tag + self.protocol = protocol + self.port_ranges = port_ranges + self.address_prefixes = address_prefixes + OutboundRule.__init__(self, type=OutboundRuleType.SERVICE_TAG, name=name, **kwargs) + + def _to_rest_object(self) -> RestServiceTagOutboundRule: + return RestServiceTagOutboundRule( + type=self.type, + category=self.category, + destination=RestServiceTagOutboundRuleDestination( + service_tag=self.service_tag, + protocol=self.protocol, + port_ranges=self.port_ranges, + address_prefixes=self.address_prefixes, + ), + ) + + def _to_dict(self) -> Dict: + return { + "name": self.name, + "type": OutboundRuleType.SERVICE_TAG, + "category": self.category, + "destination": { + "service_tag": self.service_tag, + "protocol": self.protocol, + "port_ranges": self.port_ranges, + "address_prefixes": self.address_prefixes, + }, + "status": self.status, + } + + +class ManagedNetwork: + """Managed Network settings for a workspace. + + :param isolation_mode: Isolation of the managed network, defaults to Disabled. + :type isolation_mode: str + :param firewall_sku: Firewall Sku for FQDN rules in AllowOnlyApprovedOutbound.. + :type firewall_sku: str + :param outbound_rules: List of outbound rules for the managed network. + :type outbound_rules: List[~azure.ai.ml.entities.OutboundRule] + :param network_id: Network id for the managed network, not meant to be set by user. + :type network_id: str + + .. literalinclude:: ../samples/ml_samples_workspace.py + :start-after: [START workspace_managed_network] + :end-before: [END workspace_managed_network] + :language: python + :dedent: 8 + :caption: Creating a ManagedNetwork object with one of each rule type. + """ + + def __init__( + self, + *, + isolation_mode: str = IsolationMode.DISABLED, + outbound_rules: Optional[List[OutboundRule]] = None, + firewall_sku: Optional[str] = None, + network_id: Optional[str] = None, + **kwargs: Any, + ) -> None: + self.isolation_mode = isolation_mode + self.firewall_sku = firewall_sku + self.network_id = network_id + self.outbound_rules = outbound_rules + self.status = kwargs.pop("status", None) + + def _to_rest_object(self) -> RestManagedNetwork: + rest_outbound_rules = ( + { + # pylint: disable=protected-access + outbound_rule.name: outbound_rule._to_rest_object() # type: ignore[attr-defined] + for outbound_rule in self.outbound_rules + } + if self.outbound_rules + else {} + ) + return RestManagedNetwork( + isolation_mode=self.isolation_mode, outbound_rules=rest_outbound_rules, firewall_sku=self.firewall_sku + ) + + @classmethod + def _from_rest_object(cls, obj: RestManagedNetwork) -> "ManagedNetwork": + from_rest_outbound_rules = ( + [ + OutboundRule._from_rest_object(obj.outbound_rules[name], name=name) # pylint: disable=protected-access + for name in obj.outbound_rules + ] + if obj.outbound_rules + else {} + ) + return ManagedNetwork( + isolation_mode=obj.isolation_mode, + outbound_rules=from_rest_outbound_rules, # type: ignore[arg-type] + network_id=obj.network_id, + status=obj.status, + firewall_sku=obj.firewall_sku, + ) + + +class ManagedNetworkProvisionStatus: + """ManagedNetworkProvisionStatus. + + :param status: Status for managed network provision. + :type status: str + :param spark_ready: Bool value indicating if managed network is spark ready + :type spark_ready: bool + """ + + def __init__( + self, + *, + status: Optional[str] = None, + spark_ready: Optional[bool] = None, + ): + self.status = status + self.spark_ready = spark_ready + + @classmethod + def _from_rest_object(cls, rest_obj: RestManagedNetworkProvisionStatus) -> "ManagedNetworkProvisionStatus": + return cls( + status=rest_obj.status, + spark_ready=rest_obj.spark_ready, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/private_endpoint.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/private_endpoint.py new file mode 100644 index 00000000..c9e8882e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/private_endpoint.py @@ -0,0 +1,53 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Dict, Optional + + +class EndpointConnection: + """Private Endpoint Connection related to a workspace private endpoint. + + :param subscription_id: Subscription id of the connection. + :type subscription_id: str + :param resource_group: Resource group of the connection. + :type resource_group: str + :param vnet_name: Name of the virtual network of the connection. + :type vnet_name: str + :param subnet_name: Name of the subnet of the connection. + :type subnet_name: str + :param location: Location of the connection. + :type location: str + """ + + def __init__( + self, + subscription_id: str, + resource_group: str, + vnet_name: str, + subnet_name: str, + location: Optional[str] = None, + ): + self.subscription_id = subscription_id + self.resource_group = resource_group + self.location = location + self.vnet_name = vnet_name + self.subnet_name = subnet_name + + +class PrivateEndpoint: + """Private Endpoint of a workspace. + + :param approval_type: Approval type of the private endpoint. + :type approval_type: str + :param connections: List of private endpoint connections. + :type connections: List[~azure.ai.ml.entities.EndpointConnection] + """ + + def __init__( + self, + approval_type: Optional[str] = None, + connections: Optional[Dict[str, EndpointConnection]] = None, + ): + self.approval_type = approval_type + self.connections = connections diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/serverless_compute.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/serverless_compute.py new file mode 100644 index 00000000..b78ede06 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/serverless_compute.py @@ -0,0 +1,52 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from typing import Optional, Union + +from marshmallow.exceptions import ValidationError + +from azure.ai.ml._restclient.v2024_10_01_preview.models import ( + ServerlessComputeSettings as RestServerlessComputeSettings, +) +from azure.ai.ml._schema._utils.utils import ArmId + + +class ServerlessComputeSettings: + custom_subnet: Optional[ArmId] + no_public_ip: bool = False + + def __init__(self, *, custom_subnet: Optional[Union[str, ArmId]] = None, no_public_ip: bool = False) -> None: + """Settings regarding serverless compute(s) in an Azure ML workspace. + + :keyword custom_subnet: The ARM ID of the subnet to use for serverless compute(s). + :paramtype custom_subnet: Optional[Union[str, ArmId]] + :keyword no_public_ip: Whether or not to disable public IP addresses for serverless compute(s). + Defaults to False. + :paramtype no_public_ip: bool + :raises ValidationError: If the custom_subnet is not formatted as an ARM ID. + """ + if isinstance(custom_subnet, str): + self.custom_subnet = ArmId(custom_subnet) + elif isinstance(custom_subnet, ArmId) or custom_subnet is None: + self.custom_subnet = custom_subnet + else: + raise ValidationError("custom_subnet must be a string, ArmId, or None.") + self.no_public_ip = no_public_ip + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ServerlessComputeSettings): + return NotImplemented + return self.custom_subnet == other.custom_subnet and self.no_public_ip == other.no_public_ip + + def _to_rest_object(self) -> RestServerlessComputeSettings: + return RestServerlessComputeSettings( + serverless_compute_custom_subnet=self.custom_subnet, + serverless_compute_no_public_ip=self.no_public_ip, + ) + + @classmethod + def _from_rest_object(cls, obj: RestServerlessComputeSettings) -> "ServerlessComputeSettings": + return cls( + custom_subnet=obj.serverless_compute_custom_subnet, + no_public_ip=obj.serverless_compute_no_public_ip, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/workspace.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/workspace.py new file mode 100644 index 00000000..495e00b0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/workspace.py @@ -0,0 +1,491 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=too-many-instance-attributes + +from os import PathLike +from pathlib import Path +from typing import IO, Any, AnyStr, Dict, List, Optional, Tuple, Type, Union + +from azure.ai.ml._restclient.v2024_10_01_preview.models import FeatureStoreSettings as RestFeatureStoreSettings +from azure.ai.ml._restclient.v2024_10_01_preview.models import ManagedNetworkSettings as RestManagedNetwork +from azure.ai.ml._restclient.v2024_10_01_preview.models import ManagedServiceIdentity as RestManagedServiceIdentity +from azure.ai.ml._restclient.v2024_10_01_preview.models import NetworkAcls as RestNetworkAcls +from azure.ai.ml._restclient.v2024_10_01_preview.models import ( + ServerlessComputeSettings as RestServerlessComputeSettings, +) +from azure.ai.ml._restclient.v2024_10_01_preview.models import Workspace as RestWorkspace +from azure.ai.ml._schema.workspace.workspace import WorkspaceSchema +from azure.ai.ml._utils.utils import dump_yaml_to_file +from azure.ai.ml.constants._common import ( + BASE_PATH_CONTEXT_KEY, + PARAMS_OVERRIDE_KEY, + CommonYamlFields, + WorkspaceKind, + WorkspaceResourceConstants, +) +from azure.ai.ml.entities._credentials import IdentityConfiguration +from azure.ai.ml.entities._resource import Resource +from azure.ai.ml.entities._util import find_field_in_override, load_from_dict +from azure.ai.ml.entities._workspace.serverless_compute import ServerlessComputeSettings +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + +from .customer_managed_key import CustomerManagedKey +from .feature_store_settings import FeatureStoreSettings +from .network_acls import NetworkAcls +from .networking import ManagedNetwork + + +class Workspace(Resource): + """Azure ML workspace. + + :param name: Name of the workspace. + :type name: str + :param description: Description of the workspace. + :type description: str + :param tags: Tags of the workspace. + :type tags: dict + :param display_name: Display name for the workspace. This is non-unique within the resource group. + :type display_name: str + :param location: The location to create the workspace in. + If not specified, the same location as the resource group will be used. + :type location: str + :param resource_group: Name of resource group to create the workspace in. + :type resource_group: str + :param hbi_workspace: Whether the customer data is of high business impact (HBI), + containing sensitive business information. + For more information, see + https://learn.microsoft.com/azure/machine-learning/concept-data-encryption#encryption-at-rest. + :type hbi_workspace: bool + :param storage_account: The resource ID of an existing storage account to use instead of creating a new one. + :type storage_account: str + :param container_registry: The resource ID of an existing container registry + to use instead of creating a new one. + :type container_registry: str + :param key_vault: The resource ID of an existing key vault to use instead of creating a new one. + :type key_vault: str + :param application_insights: The resource ID of an existing application insights + to use instead of creating a new one. + :type application_insights: str + :param customer_managed_key: Key vault details for encrypting data with customer-managed keys. + If not specified, Microsoft-managed keys will be used by default. + :type customer_managed_key: ~azure.ai.ml.entities.CustomerManagedKey + :param image_build_compute: The name of the compute target to use for building environment + Docker images with the container registry is behind a VNet. + :type image_build_compute: str + :param public_network_access: Whether to allow public endpoint connectivity + when a workspace is private link enabled. + :type public_network_access: str + :param network_acls: The network access control list (ACL) settings of the workspace. + :type network_acls: ~azure.ai.ml.entities.NetworkAcls + :param identity: workspace's Managed Identity (user assigned, or system assigned) + :type identity: ~azure.ai.ml.entities.IdentityConfiguration + :param primary_user_assigned_identity: The workspace's primary user assigned identity + :type primary_user_assigned_identity: str + :param managed_network: workspace's Managed Network configuration + :type managed_network: ~azure.ai.ml.entities.ManagedNetwork + :param provision_network_now: Set to trigger the provisioning of the managed vnet with the default options when + creating a workspace with the managed vnet enable, or else it does nothing + :type provision_network_now: Optional[bool] + :param system_datastores_auth_mode: The authentication mode for system datastores. + :type system_datastores_auth_mode: str + :param enable_data_isolation: A flag to determine if workspace has data isolation enabled. + The flag can only be set at the creation phase, it can't be updated. + :type enable_data_isolation: bool + :param allow_roleassignment_on_rg: Determine whether allow workspace role assignment on resource group level. + :type allow_roleassignment_on_rg: Optional[bool] + :param serverless_compute: The serverless compute settings for the workspace. + :type: ~azure.ai.ml.entities.ServerlessComputeSettings + :param workspace_hub: Deprecated resource ID of an existing workspace hub to help create project workspace. + Use the Project class instead now. + :type workspace_hub: Optional[str] + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + + .. literalinclude:: ../samples/ml_samples_workspace.py + :start-after: [START workspace] + :end-before: [END workspace] + :language: python + :dedent: 8 + :caption: Creating a Workspace object. + """ + + def __init__( + self, + *, + name: str, + description: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + display_name: Optional[str] = None, + location: Optional[str] = None, + resource_group: Optional[str] = None, + hbi_workspace: bool = False, + storage_account: Optional[str] = None, + container_registry: Optional[str] = None, + key_vault: Optional[str] = None, + application_insights: Optional[str] = None, + customer_managed_key: Optional[CustomerManagedKey] = None, + image_build_compute: Optional[str] = None, + public_network_access: Optional[str] = None, + network_acls: Optional[NetworkAcls] = None, + identity: Optional[IdentityConfiguration] = None, + primary_user_assigned_identity: Optional[str] = None, + managed_network: Optional[ManagedNetwork] = None, + provision_network_now: Optional[bool] = None, + system_datastores_auth_mode: Optional[str] = None, + enable_data_isolation: bool = False, + allow_roleassignment_on_rg: Optional[bool] = None, + hub_id: Optional[str] = None, # Hidden input, surfaced by Project + workspace_hub: Optional[str] = None, # Deprecated input maintained for backwards compat. + serverless_compute: Optional[ServerlessComputeSettings] = None, + **kwargs: Any, + ): + # Workspaces have subclasses that are differentiated by the 'kind' field in the REST API. + # Now that this value is occasionally surfaced (for sub-class YAML specifications) + # We've switched to using 'type' in the SDK for consistency's sake with other polymorphic classes. + # That said, the code below but quietly supports 'kind' as an input + # to maintain backwards compatibility with internal systems that I suspect still use 'kind' somewhere. + # 'type' takes precedence over 'kind' if they're both set, and this defaults to a normal workspace's type + # if nothing is set. + # pylint: disable=too-many-locals + self._kind = kwargs.pop("kind", None) + if self._kind is None: + self._kind = WorkspaceKind.DEFAULT + + self.print_as_yaml = True + self._discovery_url: Optional[str] = kwargs.pop("discovery_url", None) + self._mlflow_tracking_uri: Optional[str] = kwargs.pop("mlflow_tracking_uri", None) + self._workspace_id = kwargs.pop("workspace_id", None) + self._feature_store_settings: Optional[FeatureStoreSettings] = kwargs.pop("feature_store_settings", None) + super().__init__(name=name, description=description, tags=tags, **kwargs) + + self.display_name = display_name + self.location = location + self.resource_group = resource_group + self.hbi_workspace = hbi_workspace + self.storage_account = storage_account + self.container_registry = container_registry + self.key_vault = key_vault + self.application_insights = application_insights + self.customer_managed_key = customer_managed_key + self.image_build_compute = image_build_compute + self.public_network_access = public_network_access + self.identity = identity + self.primary_user_assigned_identity = primary_user_assigned_identity + self.managed_network = managed_network + self.provision_network_now = provision_network_now + self.system_datastores_auth_mode = system_datastores_auth_mode + self.enable_data_isolation = enable_data_isolation + self.allow_roleassignment_on_rg = allow_roleassignment_on_rg + if workspace_hub and not hub_id: + hub_id = workspace_hub + self.__hub_id = hub_id + # Overwrite kind if hub_id is provided. Technically not needed anymore, + # but kept for backwards if people try to just use a normal workspace like + # a project. + if hub_id: + self._kind = WorkspaceKind.PROJECT + self.serverless_compute: Optional[ServerlessComputeSettings] = serverless_compute + self.network_acls: Optional[NetworkAcls] = network_acls + + @property + def discovery_url(self) -> Optional[str]: + """Backend service base URLs for the workspace. + + :return: Backend service URLs of the workspace + :rtype: str + """ + return self._discovery_url + + # Exists to appease tox's mypy rules. + @property + def _hub_id(self) -> Optional[str]: + """The UID of the hub parent of the project. This is an internal property + that's surfaced by the Project sub-class, but exists here for backwards-compatibility + reasons. + + :return: Resource ID of the parent hub. + :rtype: str + """ + return self.__hub_id + + # Exists to appease tox's mypy rules. + @_hub_id.setter + def _hub_id(self, value: str): + """Set the hub of the project. This is an internal property + that's surfaced by the Project sub-class, but exists here for backwards-compatibility + reasons. + + + :param value: The hub id to assign to the project. + Note: cannot be reassigned after creation. + :type value: str + """ + if not value: + return + self.__hub_id = value + + @property + def mlflow_tracking_uri(self) -> Optional[str]: + """MLflow tracking uri for the workspace. + + :return: Returns mlflow tracking uri of the workspace. + :rtype: str + """ + return self._mlflow_tracking_uri + + def dump(self, dest: Union[str, PathLike, IO[AnyStr]], **kwargs: Any) -> None: + """Dump the workspace spec into a file in yaml format. + + :param dest: The destination to receive this workspace's spec. + Must be either a path to a local file, or an already-open file stream. + If dest is a file path, a new file will be created, + and an exception is raised if the file exists. + If dest is an open file, the file will be written to directly, + and an exception will be raised if the file is not writable. + :type dest: Union[PathLike, str, IO[AnyStr]] + """ + path = kwargs.pop("path", None) + yaml_serialized = self._to_dict() + dump_yaml_to_file(dest, yaml_serialized, default_flow_style=False, path=path, **kwargs) + + def _to_dict(self) -> Dict: + res: dict = self._get_schema_class()(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + @classmethod + def _resolve_sub_cls_and_kind( + cls, data: Dict, params_override: Optional[List[Dict]] = None + ) -> Tuple[Type["Workspace"], str]: + """Given a workspace data dictionary, determine the appropriate workspace class and type string. + Allows for easier polymorphism between the workspace class and its children. + Adapted from similar code in the Job class. + + :param data: A dictionary of values describing the workspace. + :type data: Dict + :param params_override: Override values from alternative sources (ex: CLI input). + :type params_override: Optional[List[Dict]] + :return: A tuple containing the workspace class and type string. + :rtype: Tuple[Type["Workspace"], str] + """ + from azure.ai.ml.entities import Hub, Project + + workspace_class: Optional[Type["Workspace"]] = None + type_in_override = find_field_in_override(CommonYamlFields.KIND, params_override) + type_str = type_in_override or data.get(CommonYamlFields.KIND, WorkspaceKind.DEFAULT) + if type_str is not None: + type_str = type_str.lower() + if type_str == WorkspaceKind.HUB: + workspace_class = Hub + elif type_str == WorkspaceKind.PROJECT: + workspace_class = Project + elif type_str == WorkspaceKind.DEFAULT: + workspace_class = Workspace + else: + msg = f"Unsupported workspace type: {type_str}." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.WORKSPACE, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + return workspace_class, type_str + + @classmethod + def _load( + cls, + data: Optional[Dict] = None, + yaml_path: Optional[Union[PathLike, str]] = None, + params_override: Optional[list] = None, + **kwargs: Any, + ) -> "Workspace": + # This _load function is polymorphic and can return child classes. + # It was adapted from the Job class's similar function. + data = data or {} + params_override = params_override or [] + context = { + BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"), + PARAMS_OVERRIDE_KEY: params_override, + } + workspace_class, type_str = cls._resolve_sub_cls_and_kind(data, params_override) + schema_type = workspace_class._get_schema_class() # pylint: disable=protected-access + loaded_schema = load_from_dict( + schema_type, + data=data, + context=context, + additional_message=f"If you are trying to configure a workspace that is not of type {type_str}," + f" please specify the correct job type in the 'type' property.", + **kwargs, + ) + result = workspace_class(**loaded_schema) + if yaml_path: + result._source_path = yaml_path # pylint: disable=protected-access + return result + + @classmethod + def _from_rest_object( + cls, rest_obj: RestWorkspace, v2_service_context: Optional[object] = None + ) -> Optional["Workspace"]: + + if not rest_obj: + return None + customer_managed_key = ( + CustomerManagedKey( + key_vault=rest_obj.encryption.key_vault_properties.key_vault_arm_id, + key_uri=rest_obj.encryption.key_vault_properties.key_identifier, + ) + if rest_obj.encryption + and rest_obj.encryption.status == WorkspaceResourceConstants.ENCRYPTION_STATUS_ENABLED + else None + ) + + # TODO: Remove attribute check once Oct API version is out + mlflow_tracking_uri = None + + if hasattr(rest_obj, "ml_flow_tracking_uri"): + try: + if v2_service_context: + # v2_service_context is required (not None) in get_mlflow_tracking_uri_v2 + from azureml.mlflow import get_mlflow_tracking_uri_v2 + + mlflow_tracking_uri = get_mlflow_tracking_uri_v2(rest_obj, v2_service_context) + else: + mlflow_tracking_uri = rest_obj.ml_flow_tracking_uri + except ImportError: + mlflow_tracking_uri = rest_obj.ml_flow_tracking_uri + + # TODO: Remove once Online Endpoints updates API version to at least 2023-08-01 + allow_roleassignment_on_rg = None + if hasattr(rest_obj, "allow_role_assignment_on_rg"): + allow_roleassignment_on_rg = rest_obj.allow_role_assignment_on_rg + system_datastores_auth_mode = None + if hasattr(rest_obj, "system_datastores_auth_mode"): + system_datastores_auth_mode = rest_obj.system_datastores_auth_mode + + # TODO: remove this once it is included in API response + managed_network = None + if hasattr(rest_obj, "managed_network"): + if rest_obj.managed_network and isinstance(rest_obj.managed_network, RestManagedNetwork): + managed_network = ManagedNetwork._from_rest_object( # pylint: disable=protected-access + rest_obj.managed_network + ) + + # TODO: Remove once it's included in response + provision_network_now = None + if hasattr(rest_obj, "provision_network_now"): + provision_network_now = rest_obj.provision_network_now + + armid_parts = str(rest_obj.id).split("/") + group = None if len(armid_parts) < 4 else armid_parts[4] + identity = None + if rest_obj.identity and isinstance(rest_obj.identity, RestManagedServiceIdentity): + identity = IdentityConfiguration._from_workspace_rest_object( # pylint: disable=protected-access + rest_obj.identity + ) + feature_store_settings = None + if rest_obj.feature_store_settings and isinstance(rest_obj.feature_store_settings, RestFeatureStoreSettings): + feature_store_settings = FeatureStoreSettings._from_rest_object( # pylint: disable=protected-access + rest_obj.feature_store_settings + ) + serverless_compute = None + # TODO: Remove attribute check once serverless_compute_settings is in API response contract + if hasattr(rest_obj, "serverless_compute_settings"): + if rest_obj.serverless_compute_settings and isinstance( + rest_obj.serverless_compute_settings, RestServerlessComputeSettings + ): + serverless_compute = ServerlessComputeSettings._from_rest_object( # pylint: disable=protected-access + rest_obj.serverless_compute_settings + ) + network_acls = None + if hasattr(rest_obj, "network_acls"): + if rest_obj.network_acls and isinstance(rest_obj.network_acls, RestNetworkAcls): + network_acls = NetworkAcls._from_rest_object(rest_obj.network_acls) # pylint: disable=protected-access + + return cls( + name=rest_obj.name, + id=rest_obj.id, + description=rest_obj.description, + kind=rest_obj.kind.lower() if rest_obj.kind else None, + tags=rest_obj.tags, + location=rest_obj.location, + resource_group=group, + display_name=rest_obj.friendly_name, + discovery_url=rest_obj.discovery_url, + hbi_workspace=rest_obj.hbi_workspace, + storage_account=rest_obj.storage_account, + container_registry=rest_obj.container_registry, + key_vault=rest_obj.key_vault, + application_insights=rest_obj.application_insights, + customer_managed_key=customer_managed_key, + image_build_compute=rest_obj.image_build_compute, + public_network_access=rest_obj.public_network_access, + network_acls=network_acls, + mlflow_tracking_uri=mlflow_tracking_uri, + identity=identity, + primary_user_assigned_identity=rest_obj.primary_user_assigned_identity, + managed_network=managed_network, + provision_network_now=provision_network_now, + system_datastores_auth_mode=system_datastores_auth_mode, + feature_store_settings=feature_store_settings, + enable_data_isolation=rest_obj.enable_data_isolation, + allow_roleassignment_on_rg=allow_roleassignment_on_rg, + hub_id=rest_obj.hub_resource_id, + workspace_id=rest_obj.workspace_id, + serverless_compute=serverless_compute, + ) + + def _to_rest_object(self) -> RestWorkspace: + """Note: Unlike most entities, the create operation for workspaces does NOTE use this function, + and instead relies on its own internal conversion process to produce a valid ARM template. + + :return: The REST API object-equivalent of this workspace. + :rtype: RestWorkspace + """ + feature_store_settings = None + if self._feature_store_settings: + feature_store_settings = self._feature_store_settings._to_rest_object() # pylint: disable=protected-access + + serverless_compute_settings = None + if self.serverless_compute: + serverless_compute_settings = self.serverless_compute._to_rest_object() # pylint: disable=protected-access + + return RestWorkspace( + name=self.name, + identity=( + self.identity._to_workspace_rest_object() if self.identity else None # pylint: disable=protected-access + ), + location=self.location, + tags=self.tags, + description=self.description, + kind=self._kind, + friendly_name=self.display_name, + key_vault=self.key_vault, + application_insights=self.application_insights, + container_registry=self.container_registry, + storage_account=self.storage_account, + discovery_url=self.discovery_url, + hbi_workspace=self.hbi_workspace, + image_build_compute=self.image_build_compute, + public_network_access=self.public_network_access, + primary_user_assigned_identity=self.primary_user_assigned_identity, + managed_network=( + self.managed_network._to_rest_object() # pylint: disable=protected-access + if self.managed_network + else None + ), + provision_network_now=self.provision_network_now, + system_datastores_auth_mode=self.system_datastores_auth_mode, + feature_store_settings=feature_store_settings, + enable_data_isolation=self.enable_data_isolation, + allow_role_assignment_on_rg=self.allow_roleassignment_on_rg, # diff due to swagger restclient casing diff + hub_resource_id=self._hub_id, + serverless_compute_settings=serverless_compute_settings, + ) + + # Helper for sub-class polymorphism. Needs to be overwritten by child classes + # If they don't want to redefine things like _to_dict. + @classmethod + def _get_schema_class(cls) -> Type[WorkspaceSchema]: + return WorkspaceSchema diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/workspace_keys.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/workspace_keys.py new file mode 100644 index 00000000..4213b419 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/workspace_keys.py @@ -0,0 +1,100 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import List, Optional + +from azure.ai.ml._restclient.v2024_10_01_preview.models import ListWorkspaceKeysResult + + +class ContainerRegistryCredential: + """Key for ACR associated with given workspace. + + :param location: Location of the ACR + :type location: str + :param username: Username of the ACR + :type username: str + :param passwords: Passwords to access the ACR + :type passwords: List[str] + """ + + def __init__( + self, *, location: Optional[str] = None, username: Optional[str] = None, passwords: Optional[List[str]] = None + ): + self.location = location + self.username = username + self.passwords = passwords + + +class NotebookAccessKeys: + """Key for notebook resource associated with given workspace. + + :param primary_access_key: Primary access key of notebook resource + :type primary_access_key: str + :param secondary_access_key: Secondary access key of notebook resource + :type secondary_access_key: str + """ + + def __init__(self, *, primary_access_key: Optional[str] = None, secondary_access_key: Optional[str] = None): + self.primary_access_key = primary_access_key + self.secondary_access_key = secondary_access_key + + +class WorkspaceKeys: + """Workspace Keys. + + :param user_storage_key: Key for storage account associated with given workspace + :type user_storage_key: str + :param user_storage_resource_id: Resource id of storage account associated with given workspace + :type user_storage_resource_id: str + :param app_insights_instrumentation_key: Key for app insights associated with given workspace + :type app_insights_instrumentation_key: str + :param container_registry_credentials: Key for ACR associated with given workspace + :type container_registry_credentials: ContainerRegistryCredential + :param notebook_access_keys: Key for notebook resource associated with given workspace + :type notebook_access_keys: NotebookAccessKeys + """ + + def __init__( + self, + *, + user_storage_key: Optional[str] = None, + user_storage_resource_id: Optional[str] = None, + app_insights_instrumentation_key: Optional[str] = None, + container_registry_credentials: Optional[ContainerRegistryCredential] = None, + notebook_access_keys: Optional[NotebookAccessKeys] = None + ): + self.user_storage_key = user_storage_key + self.user_storage_resource_id = user_storage_resource_id + self.app_insights_instrumentation_key = app_insights_instrumentation_key + self.container_registry_credentials = container_registry_credentials + self.notebook_access_keys = notebook_access_keys + + @classmethod + def _from_rest_object(cls, rest_obj: ListWorkspaceKeysResult) -> Optional["WorkspaceKeys"]: + if not rest_obj: + return None + + container_registry_credentials = None + notebook_access_keys = None + + if hasattr(rest_obj, "container_registry_credentials") and rest_obj.container_registry_credentials is not None: + container_registry_credentials = ContainerRegistryCredential( + location=rest_obj.container_registry_credentials.location, + username=rest_obj.container_registry_credentials.username, + passwords=rest_obj.container_registry_credentials.passwords, + ) + + if hasattr(rest_obj, "notebook_access_keys") and rest_obj.notebook_access_keys is not None: + notebook_access_keys = NotebookAccessKeys( + primary_access_key=rest_obj.notebook_access_keys.primary_access_key, + secondary_access_key=rest_obj.notebook_access_keys.secondary_access_key, + ) + + return WorkspaceKeys( + user_storage_key=rest_obj.user_storage_key, + user_storage_resource_id=rest_obj.user_storage_arm_id, + app_insights_instrumentation_key=rest_obj.app_insights_instrumentation_key, + container_registry_credentials=container_registry_credentials, + notebook_access_keys=notebook_access_keys, + ) |