diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/_schema | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema')
257 files changed, 14301 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/__init__.py new file mode 100644 index 00000000..115a65bb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/__init__.py @@ -0,0 +1,60 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) + +from ._data_import import DataImportSchema +from ._sweep import SweepJobSchema +from .assets.code_asset import AnonymousCodeAssetSchema, CodeAssetSchema +from .assets.data import DataSchema +from .assets.environment import AnonymousEnvironmentSchema, EnvironmentSchema +from .assets.index import IndexAssetSchema +from .assets.model import ModelSchema +from .assets.workspace_asset_reference import WorkspaceAssetReferenceSchema +from .component import CommandComponentSchema +from .core.fields import ( + ArmStr, + ArmVersionedStr, + ExperimentalField, + NestedField, + RegistryStr, + StringTransformedEnum, + UnionField, +) +from .core.schema import PathAwareSchema, YamlFileSchema +from .core.schema_meta import PatchedSchemaMeta +from .job import CommandJobSchema, ParallelJobSchema, SparkJobSchema + +# TODO: enable in PuP +# from .job import ImportJobSchema +# from .component import ImportComponentSchema + +__all__ = [ + # "ImportJobSchema", + # "ImportComponentSchema", + "ArmStr", + "ArmVersionedStr", + "DataSchema", + "StringTransformedEnum", + "CodeAssetSchema", + "CommandJobSchema", + "SparkJobSchema", + "ParallelJobSchema", + "EnvironmentSchema", + "AnonymousEnvironmentSchema", + "NestedField", + "PatchedSchemaMeta", + "PathAwareSchema", + "ModelSchema", + "SweepJobSchema", + "UnionField", + "YamlFileSchema", + "CommandComponentSchema", + "AnonymousCodeAssetSchema", + "ExperimentalField", + "RegistryStr", + "WorkspaceAssetReferenceSchema", + "DataImportSchema", + "IndexAssetSchema", +] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data/__init__.py new file mode 100644 index 00000000..29a4fcd3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data/mltable_metadata_path_schemas.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data/mltable_metadata_path_schemas.py new file mode 100644 index 00000000..0156743e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data/mltable_metadata_path_schemas.py @@ -0,0 +1,30 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields + +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + + +class MLTableMetadataPathFileSchema(metaclass=PatchedSchemaMeta): + file = fields.Str( + metadata={"description": "This specifies path of file containing data."}, + required=True, + ) + + +class MLTableMetadataPathFolderSchema(metaclass=PatchedSchemaMeta): + folder = fields.Str( + metadata={"description": "This specifies path of folder containing data."}, + required=True, + ) + + +class MLTableMetadataPathPatternSchema(metaclass=PatchedSchemaMeta): + pattern = fields.Str( + metadata={ + "description": "This specifies a search pattern to allow globbing of files and folders containing data." + }, + required=True, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data/mltable_metadata_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data/mltable_metadata_schema.py new file mode 100644 index 00000000..99861bc3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data/mltable_metadata_schema.py @@ -0,0 +1,40 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from typing import Dict + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.fields import NestedField, UnionField +from azure.ai.ml._schema.core.schema import YamlFileSchema +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY + +from .mltable_metadata_path_schemas import ( + MLTableMetadataPathFileSchema, + MLTableMetadataPathFolderSchema, + MLTableMetadataPathPatternSchema, +) + + +class MLTableMetadataSchema(YamlFileSchema): + paths = fields.List( + UnionField( + [ + NestedField(MLTableMetadataPathFileSchema()), + NestedField(MLTableMetadataPathFolderSchema()), + NestedField(MLTableMetadataPathPatternSchema()), + ] + ), + required=True, + ) + transformations = fields.List(fields.Raw(), required=False) + + @post_load + def make(self, data: Dict, **kwargs): + from azure.ai.ml.entities._data.mltable_metadata import MLTableMetadata, MLTableMetadataPath + + paths = [MLTableMetadataPath(pathDict=pathDict) for pathDict in data.pop("paths")] + return MLTableMetadata(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data, paths=paths) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data_import/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data_import/__init__.py new file mode 100644 index 00000000..28719d1f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data_import/__init__.py @@ -0,0 +1,9 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) + +from .data_import import DataImportSchema + +__all__ = ["DataImportSchema"] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data_import/data_import.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data_import/data_import.py new file mode 100644 index 00000000..a731e1da --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data_import/data_import.py @@ -0,0 +1,22 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import post_load + +from azure.ai.ml._schema.core.fields import NestedField +from azure.ai.ml._schema.job.input_output_entry import DatabaseSchema, FileSystemSchema +from azure.ai.ml._utils._experimental import experimental +from ..core.fields import UnionField +from ..assets.data import DataSchema + + +@experimental +class DataImportSchema(DataSchema): + source = UnionField([NestedField(DatabaseSchema), NestedField(FileSystemSchema)], required=True, allow_none=False) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._data_import.data_import import DataImport + + return DataImport(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data_import/schedule.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data_import/schedule.py new file mode 100644 index 00000000..20a7e3d2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data_import/schedule.py @@ -0,0 +1,39 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import yaml + +from azure.ai.ml._schema.core.fields import NestedField, FileRefField +from azure.ai.ml._schema.schedule.schedule import ScheduleSchema +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY +from ..core.fields import UnionField +from .data_import import DataImportSchema + + +class ImportDataFileRefField(FileRefField): + def _deserialize(self, value, attr, data, **kwargs) -> "DataImport": + # Get component info from component yaml file. + data = super()._deserialize(value, attr, data, **kwargs) + data_import_dict = yaml.safe_load(data) + + from azure.ai.ml.entities._data_import.data_import import DataImport + + return DataImport._load( + data=data_import_dict, + yaml_path=self.context[BASE_PATH_CONTEXT_KEY] / value, + **kwargs, + ) + + +@experimental +class ImportDataScheduleSchema(ScheduleSchema): + import_data = UnionField( + [ + ImportDataFileRefField, + NestedField(DataImportSchema), + ] + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/__init__.py new file mode 100644 index 00000000..18774380 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/__init__.py @@ -0,0 +1,30 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore + +from .adls_gen1 import AzureDataLakeGen1Schema +from .azure_storage import AzureBlobSchema, AzureDataLakeGen2Schema, AzureFileSchema, AzureStorageSchema +from .credentials import ( + AccountKeySchema, + BaseTenantCredentialSchema, + CertificateSchema, + NoneCredentialsSchema, + SasTokenSchema, + ServicePrincipalSchema, +) + +__all__ = [ + "AccountKeySchema", + "AzureBlobSchema", + "AzureDataLakeGen1Schema", + "AzureDataLakeGen2Schema", + "AzureFileSchema", + "AzureStorageSchema", + "BaseTenantCredentialSchema", + "CertificateSchema", + "NoneCredentialsSchema", + "SasTokenSchema", + "ServicePrincipalSchema", +] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/_on_prem.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/_on_prem.py new file mode 100644 index 00000000..1f0a9710 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/_on_prem.py @@ -0,0 +1,40 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from typing import Any, Dict + +from marshmallow import fields, post_load + +from azure.ai.ml._restclient.v2022_10_01_preview.models import DatastoreType +from azure.ai.ml._schema.core.fields import NestedField, PathAwareSchema, StringTransformedEnum, UnionField +from azure.ai.ml._utils.utils import camel_to_snake + +from ._on_prem_credentials import KerberosKeytabSchema, KerberosPasswordSchema + + +class HdfsSchema(PathAwareSchema): + name = fields.Str(required=True) + id = fields.Str(dump_only=True) + type = StringTransformedEnum( + allowed_values=DatastoreType.HDFS, + casing_transform=camel_to_snake, + required=True, + ) + hdfs_server_certificate = fields.Str() + name_node_address = fields.Str(required=True) + protocol = fields.Str() + credentials = UnionField( + [NestedField(KerberosPasswordSchema), NestedField(KerberosKeytabSchema)], + required=True, + ) + description = fields.Str() + tags = fields.Dict(keys=fields.Str(), values=fields.Dict()) + + @post_load + def make(self, data: Dict[str, Any], **kwargs) -> "HdfsDatastore": + from azure.ai.ml.entities._datastore._on_prem import HdfsDatastore + + return HdfsDatastore(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/_on_prem_credentials.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/_on_prem_credentials.py new file mode 100644 index 00000000..ada92afc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/_on_prem_credentials.py @@ -0,0 +1,53 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from typing import Dict + +from marshmallow import ValidationError, fields, post_load, pre_dump + +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + + +class BaseKerberosCredentials(metaclass=PatchedSchemaMeta): + kerberos_realm = fields.Str(required=True) + kerberos_kdc_address = fields.Str(required=True) + kerberos_principal = fields.Str(required=True) + + +class KerberosPasswordSchema(BaseKerberosCredentials): + kerberos_password = fields.Str(required=True) + + @post_load + def make(self, data: Dict[str, str], **kwargs) -> "KerberosPasswordCredentials": + from azure.ai.ml.entities._datastore._on_prem_credentials import KerberosPasswordCredentials + + return KerberosPasswordCredentials(**data) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.entities._datastore._on_prem_credentials import KerberosPasswordCredentials + + if not isinstance(data, KerberosPasswordCredentials): + raise ValidationError("Cannot dump non-KerberosPasswordCredentials object into KerberosPasswordCredentials") + return data + + +class KerberosKeytabSchema(BaseKerberosCredentials): + kerberos_keytab = fields.Str(required=True) + + @post_load + def make(self, data: Dict[str, str], **kwargs) -> "KerberosKeytabCredentials": + from azure.ai.ml.entities._datastore._on_prem_credentials import KerberosKeytabCredentials + + return KerberosKeytabCredentials(**data) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.entities._datastore._on_prem_credentials import KerberosKeytabCredentials + + if not isinstance(data, KerberosKeytabCredentials): + raise ValidationError("Cannot dump non-KerberosKeytabCredentials object into KerberosKeytabCredentials") + return data diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/adls_gen1.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/adls_gen1.py new file mode 100644 index 00000000..7a575fc6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/adls_gen1.py @@ -0,0 +1,41 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from typing import Any, Dict + +from marshmallow import fields, post_load + +from azure.ai.ml._restclient.v2022_10_01.models import DatastoreType +from azure.ai.ml._schema.core.fields import NestedField, PathAwareSchema, StringTransformedEnum, UnionField +from azure.ai.ml._utils.utils import camel_to_snake + +from .credentials import CertificateSchema, NoneCredentialsSchema, ServicePrincipalSchema + + +class AzureDataLakeGen1Schema(PathAwareSchema): + name = fields.Str(required=True) + id = fields.Str(dump_only=True) + type = StringTransformedEnum( + allowed_values=DatastoreType.AZURE_DATA_LAKE_GEN1, + casing_transform=camel_to_snake, + required=True, + ) + store_name = fields.Str(required=True) + credentials = UnionField( + [ + NestedField(ServicePrincipalSchema), + NestedField(CertificateSchema), + NestedField(NoneCredentialsSchema), + ] + ) + description = fields.Str() + tags = fields.Dict(keys=fields.Str(), values=fields.Dict()) + + @post_load + def make(self, data: Dict[str, Any], **kwargs) -> "AzureDataLakeGen1Datastore": + from azure.ai.ml.entities import AzureDataLakeGen1Datastore + + return AzureDataLakeGen1Datastore(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/azure_storage.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/azure_storage.py new file mode 100644 index 00000000..ffe8c61c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/azure_storage.py @@ -0,0 +1,97 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from typing import Any, Dict + +from marshmallow import fields, post_load + +from azure.ai.ml._restclient.v2022_10_01.models import DatastoreType +from azure.ai.ml._schema.core.fields import NestedField, PathAwareSchema, StringTransformedEnum, UnionField +from azure.ai.ml._utils.utils import camel_to_snake + +from .credentials import ( + AccountKeySchema, + CertificateSchema, + NoneCredentialsSchema, + SasTokenSchema, + ServicePrincipalSchema, +) + + +class AzureStorageSchema(PathAwareSchema): + name = fields.Str(required=True) + id = fields.Str(dump_only=True) + account_name = fields.Str(required=True) + endpoint = fields.Str() + protocol = fields.Str() + description = fields.Str() + tags = fields.Dict(keys=fields.Str(), values=fields.Str()) + + +class AzureFileSchema(AzureStorageSchema): + type = StringTransformedEnum( + allowed_values=DatastoreType.AZURE_FILE, + casing_transform=camel_to_snake, + required=True, + ) + file_share_name = fields.Str(required=True) + credentials = UnionField( + [ + NestedField(AccountKeySchema), + NestedField(SasTokenSchema), + NestedField(NoneCredentialsSchema), + ] + ) + + @post_load + def make(self, data: Dict[str, Any], **kwargs) -> "AzureFileDatastore": # type: ignore[name-defined] + from azure.ai.ml.entities import AzureFileDatastore + + return AzureFileDatastore(**data) + + +class AzureBlobSchema(AzureStorageSchema): + type = StringTransformedEnum( + allowed_values=DatastoreType.AZURE_BLOB, + casing_transform=camel_to_snake, + required=True, + ) + container_name = fields.Str(required=True) + credentials = UnionField( + [ + NestedField(AccountKeySchema), + NestedField(SasTokenSchema), + NestedField(NoneCredentialsSchema), + ], + ) + + @post_load + def make(self, data: Dict[str, Any], **kwargs) -> "AzureBlobDatastore": # type: ignore[name-defined] + from azure.ai.ml.entities import AzureBlobDatastore + + return AzureBlobDatastore(**data) + + +class AzureDataLakeGen2Schema(AzureStorageSchema): + type = StringTransformedEnum( + allowed_values=DatastoreType.AZURE_DATA_LAKE_GEN2, + casing_transform=camel_to_snake, + required=True, + ) + filesystem = fields.Str(required=True) + credentials = UnionField( + [ + NestedField(ServicePrincipalSchema), + NestedField(CertificateSchema), + NestedField(NoneCredentialsSchema), + ] + ) + + @post_load + def make(self, data: Dict[str, Any], **kwargs) -> "AzureDataLakeGen2Datastore": + from azure.ai.ml.entities import AzureDataLakeGen2Datastore + + return AzureDataLakeGen2Datastore(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/credentials.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/credentials.py new file mode 100644 index 00000000..a4b46aa0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/credentials.py @@ -0,0 +1,99 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from typing import Any, Dict + +from marshmallow import ValidationError, fields, post_load, pre_dump, pre_load + +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml.entities._credentials import ( + AccountKeyConfiguration, + CertificateConfiguration, + NoneCredentialConfiguration, + SasTokenConfiguration, + ServicePrincipalConfiguration, +) + + +class NoneCredentialsSchema(metaclass=PatchedSchemaMeta): + @post_load + def make(self, data: Dict[str, str], **kwargs) -> NoneCredentialConfiguration: + return NoneCredentialConfiguration(**data) + + +class AccountKeySchema(metaclass=PatchedSchemaMeta): + account_key = fields.Str(required=True) + + @post_load + def make(self, data: Dict[str, str], **kwargs) -> AccountKeyConfiguration: + return AccountKeyConfiguration(**data) + + @pre_dump + def predump(self, data, **kwargs): + if not isinstance(data, AccountKeyConfiguration): + raise ValidationError("Cannot dump non-AccountKeyCredentials object into AccountKeyCredentials") + return data + + +class SasTokenSchema(metaclass=PatchedSchemaMeta): + sas_token = fields.Str(required=True) + + @post_load + def make(self, data: Dict[str, str], **kwargs) -> SasTokenConfiguration: + return SasTokenConfiguration(**data) + + @pre_dump + def predump(self, data, **kwargs): + if not isinstance(data, SasTokenConfiguration): + raise ValidationError("Cannot dump non-SasTokenCredentials object into SasTokenCredentials") + return data + + +class BaseTenantCredentialSchema(metaclass=PatchedSchemaMeta): + authority_url = fields.Str() + resource_url = fields.Str() + tenant_id = fields.Str(required=True) + client_id = fields.Str(required=True) + + @pre_load + def accept_backward_compatible_keys(self, data, **kwargs): + acceptable_keys = [key for key in data.keys() if key in ("authority_url", "authority_uri")] + if len(acceptable_keys) > 1: + raise ValidationError( + "Cannot specify both 'authority_url' and 'authority_uri'. Please use 'authority_url'." + ) + if acceptable_keys: + data["authority_url"] = data.pop(acceptable_keys[0]) + return data + + +class ServicePrincipalSchema(BaseTenantCredentialSchema): + client_secret = fields.Str(required=True) + + @post_load + def make(self, data: Dict[str, str], **kwargs) -> ServicePrincipalConfiguration: + return ServicePrincipalConfiguration(**data) + + @pre_dump + def predump(self, data, **kwargs): + if not isinstance(data, ServicePrincipalConfiguration): + raise ValidationError("Cannot dump non-ServicePrincipalCredentials object into ServicePrincipalCredentials") + return data + + +class CertificateSchema(BaseTenantCredentialSchema): + certificate = fields.Str() + thumbprint = fields.Str(required=True) + + @post_load + def make(self, data: Dict[str, Any], **kwargs) -> CertificateConfiguration: + return CertificateConfiguration(**data) + + @pre_dump + def predump(self, data, **kwargs): + if not isinstance(data, CertificateConfiguration): + raise ValidationError("Cannot dump non-CertificateCredentials object into CertificateCredentials") + return data diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/one_lake.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/one_lake.py new file mode 100644 index 00000000..4b5e7b66 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/one_lake.py @@ -0,0 +1,49 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from typing import Any, Dict + +from marshmallow import Schema, fields, post_load + +from azure.ai.ml._restclient.v2023_04_01_preview.models import DatastoreType, OneLakeArtifactType +from azure.ai.ml._schema.core.fields import NestedField, PathAwareSchema, StringTransformedEnum, UnionField +from azure.ai.ml._utils.utils import camel_to_snake + +from .credentials import NoneCredentialsSchema, ServicePrincipalSchema + + +class OneLakeArtifactSchema(Schema): + name = fields.Str(required=True) + type = StringTransformedEnum(allowed_values=OneLakeArtifactType.LAKE_HOUSE, casing_transform=camel_to_snake) + + +class OneLakeSchema(PathAwareSchema): + name = fields.Str(required=True) + id = fields.Str(dump_only=True) + type = StringTransformedEnum( + allowed_values=DatastoreType.ONE_LAKE, + casing_transform=camel_to_snake, + required=True, + ) + # required fields for OneLake + one_lake_workspace_name = fields.Str(required=True) + endpoint = fields.Str(required=True) + artifact = NestedField(OneLakeArtifactSchema) + # ServicePrincipal and UserIdentity are the two supported credential types + credentials = UnionField( + [ + NestedField(ServicePrincipalSchema), + NestedField(NoneCredentialsSchema), + ] + ) + description = fields.Str() + tags = fields.Dict(keys=fields.Str(), values=fields.Str()) + + @post_load + def make(self, data: Dict[str, Any], **kwargs) -> "OneLakeDatastore": + from azure.ai.ml.entities import OneLakeDatastore + + return OneLakeDatastore(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_deployment.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_deployment.py new file mode 100644 index 00000000..7a69176b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_deployment.py @@ -0,0 +1,92 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument,no-else-return + +import logging +from typing import Any + +from marshmallow import fields, post_load +from marshmallow.exceptions import ValidationError +from azure.ai.ml._schema import ( + UnionField, + ArmVersionedStr, + ArmStr, + RegistryStr, +) +from azure.ai.ml._schema._deployment.deployment import DeploymentSchema +from azure.ai.ml._schema.core.fields import ComputeField, NestedField, StringTransformedEnum +from azure.ai.ml._schema.job.creation_context import CreationContextSchema +from azure.ai.ml._schema.pipeline.pipeline_component import PipelineComponentFileRefField +from azure.ai.ml.constants._common import AzureMLResourceType +from azure.ai.ml._schema.job_resource_configuration import JobResourceConfigurationSchema +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY +from azure.ai.ml.constants._deployment import BatchDeploymentOutputAction, BatchDeploymentType + +from .batch_deployment_settings import BatchRetrySettingsSchema + +module_logger = logging.getLogger(__name__) + + +class BatchDeploymentSchema(DeploymentSchema): + compute = ComputeField(required=False) + error_threshold = fields.Int( + metadata={ + "description": """Error threshold, if the error count for the entire input goes above this value,\r\n + the batch inference will be aborted. Range is [-1, int.MaxValue].\r\n + For FileDataset, this value is the count of file failures.\r\n + For TabularDataset, this value is the count of record failures.\r\n + If set to -1 (the lower bound), all failures during batch inference will be ignored.""" + } + ) + retry_settings = NestedField(BatchRetrySettingsSchema) + mini_batch_size = fields.Int() + logging_level = fields.Str( + metadata={ + "description": """A string of the logging level name, which is defined in 'logging'. + Possible values are 'warning', 'info', and 'debug'.""" + } + ) + output_action = StringTransformedEnum( + allowed_values=[ + BatchDeploymentOutputAction.APPEND_ROW, + BatchDeploymentOutputAction.SUMMARY_ONLY, + ], + metadata={"description": "Indicates how batch inferencing will handle output."}, + dump_default=BatchDeploymentOutputAction.APPEND_ROW, + ) + output_file_name = fields.Str(metadata={"description": "Customized output file name for append_row output action."}) + max_concurrency_per_instance = fields.Int( + metadata={"description": "Indicates maximum number of parallelism per instance."} + ) + resources = NestedField(JobResourceConfigurationSchema) + type = StringTransformedEnum( + allowed_values=[BatchDeploymentType.PIPELINE, BatchDeploymentType.MODEL], required=False + ) + + job_definition = ArmStr(azureml_type=AzureMLResourceType.JOB) + component = UnionField( + [ + RegistryStr(azureml_type=AzureMLResourceType.COMPONENT), + ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True), + PipelineComponentFileRefField(), + ] + ) + creation_context = NestedField(CreationContextSchema, dump_only=True) + provisioning_state = fields.Str(dump_only=True) + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: + from azure.ai.ml.entities import BatchDeployment, ModelBatchDeployment, PipelineComponentBatchDeployment + + if "type" not in data: + return BatchDeployment(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data) + elif data["type"] == BatchDeploymentType.PIPELINE: + return PipelineComponentBatchDeployment(**data) + elif data["type"] == BatchDeploymentType.MODEL: + return ModelBatchDeployment(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data) + else: + raise ValidationError( + "Deployment type must be of type " + f"{BatchDeploymentType.PIPELINE} or {BatchDeploymentType.MODEL}." + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_deployment_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_deployment_settings.py new file mode 100644 index 00000000..2a36352c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_deployment_settings.py @@ -0,0 +1,26 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging +from typing import Any + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml.entities._deployment.deployment_settings import BatchRetrySettings + +module_logger = logging.getLogger(__name__) + + +class BatchRetrySettingsSchema(metaclass=PatchedSchemaMeta): + max_retries = fields.Int( + metadata={"description": "The number of maximum tries for a failed or timeout mini batch."}, + ) + timeout = fields.Int(metadata={"description": "The timeout for a mini batch."}) + + @post_load + def make(self, data: Any, **kwargs: Any) -> BatchRetrySettings: + return BatchRetrySettings(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_job.py new file mode 100644 index 00000000..a1496f1e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_job.py @@ -0,0 +1,132 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument,protected-access + +from typing import Any + +from marshmallow import fields +from marshmallow.decorators import post_load + +from azure.ai.ml._restclient.v2020_09_01_dataplanepreview.models import ( + BatchJob, + CustomModelJobInput, + CustomModelJobOutput, + DataVersion, + LiteralJobInput, + MLFlowModelJobInput, + MLFlowModelJobOutput, + MLTableJobInput, + MLTableJobOutput, + TritonModelJobInput, + TritonModelJobOutput, + UriFileJobInput, + UriFileJobOutput, + UriFolderJobInput, + UriFolderJobOutput, +) +from azure.ai.ml._schema.core.fields import ArmStr, NestedField +from azure.ai.ml._schema.core.schema import PathAwareSchema +from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta +from azure.ai.ml.constants import AssetTypes +from azure.ai.ml.constants._common import AzureMLResourceType, InputTypes +from azure.ai.ml.constants._endpoint import EndpointYamlFields +from azure.ai.ml.entities import ComputeConfiguration +from azure.ai.ml.entities._inputs_outputs import Input, Output + +from .batch_deployment_settings import BatchRetrySettingsSchema +from .compute_binding import ComputeBindingSchema + + +class OutputDataSchema(metaclass=PatchedSchemaMeta): + datastore_id = ArmStr(azureml_type=AzureMLResourceType.DATASTORE) + path = fields.Str() + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: + return DataVersion(**data) + + +class BatchJobSchema(PathAwareSchema): + compute = NestedField(ComputeBindingSchema) + dataset = fields.Str() + error_threshold = fields.Int() + input_data = fields.Dict() + mini_batch_size = fields.Int() + name = fields.Str(data_key="job_name") + output_data = fields.Dict() + output_dataset = NestedField(OutputDataSchema) + output_file_name = fields.Str() + retry_settings = NestedField(BatchRetrySettingsSchema) + properties = fields.Dict(data_key="properties") + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: # pylint: disable=too-many-branches + if data.get(EndpointYamlFields.BATCH_JOB_INPUT_DATA, None): + for key, input_data in data[EndpointYamlFields.BATCH_JOB_INPUT_DATA].items(): + if isinstance(input_data, Input): + if input_data.type == AssetTypes.URI_FILE: + data[EndpointYamlFields.BATCH_JOB_INPUT_DATA][key] = UriFileJobInput(uri=input_data.path) + if input_data.type == AssetTypes.URI_FOLDER: + data[EndpointYamlFields.BATCH_JOB_INPUT_DATA][key] = UriFolderJobInput(uri=input_data.path) + if input_data.type == AssetTypes.TRITON_MODEL: + data[EndpointYamlFields.BATCH_JOB_INPUT_DATA][key] = TritonModelJobInput( + mode=input_data.mode, uri=input_data.path + ) + if input_data.type == AssetTypes.MLFLOW_MODEL: + data[EndpointYamlFields.BATCH_JOB_INPUT_DATA][key] = MLFlowModelJobInput( + mode=input_data.mode, uri=input_data.path + ) + if input_data.type == AssetTypes.MLTABLE: + data[EndpointYamlFields.BATCH_JOB_INPUT_DATA][key] = MLTableJobInput( + mode=input_data.mode, uri=input_data.path + ) + if input_data.type == AssetTypes.CUSTOM_MODEL: + data[EndpointYamlFields.BATCH_JOB_INPUT_DATA][key] = CustomModelJobInput( + mode=input_data.mode, uri=input_data.path + ) + if input_data.type in { + InputTypes.INTEGER, + InputTypes.NUMBER, + InputTypes.STRING, + InputTypes.BOOLEAN, + }: + data[EndpointYamlFields.BATCH_JOB_INPUT_DATA][key] = LiteralJobInput(value=input_data.default) + if data.get(EndpointYamlFields.BATCH_JOB_OUTPUT_DATA, None): + for key, output_data in data[EndpointYamlFields.BATCH_JOB_OUTPUT_DATA].items(): + if isinstance(output_data, Output): + if output_data.type == AssetTypes.URI_FILE: + data[EndpointYamlFields.BATCH_JOB_OUTPUT_DATA][key] = UriFileJobOutput( + mode=output_data.mode, uri=output_data.path + ) + if output_data.type == AssetTypes.URI_FOLDER: + data[EndpointYamlFields.BATCH_JOB_OUTPUT_DATA][key] = UriFolderJobOutput( + mode=output_data.mode, uri=output_data.path + ) + if output_data.type == AssetTypes.TRITON_MODEL: + data[EndpointYamlFields.BATCH_JOB_OUTPUT_DATA][key] = TritonModelJobOutput( + mode=output_data.mode, uri=output_data.path + ) + if output_data.type == AssetTypes.MLFLOW_MODEL: + data[EndpointYamlFields.BATCH_JOB_OUTPUT_DATA][key] = MLFlowModelJobOutput( + mode=output_data.mode, uri=output_data.path + ) + if output_data.type == AssetTypes.MLTABLE: + data[EndpointYamlFields.BATCH_JOB_OUTPUT_DATA][key] = MLTableJobOutput( + mode=output_data.mode, uri=output_data.path + ) + if output_data.type == AssetTypes.CUSTOM_MODEL: + data[EndpointYamlFields.BATCH_JOB_OUTPUT_DATA][key] = CustomModelJobOutput( + mode=output_data.mode, uri=output_data.path + ) + + if data.get(EndpointYamlFields.COMPUTE, None): + data[EndpointYamlFields.COMPUTE] = ComputeConfiguration( + **data[EndpointYamlFields.COMPUTE] + )._to_rest_object() + + if data.get(EndpointYamlFields.RETRY_SETTINGS, None): + data[EndpointYamlFields.RETRY_SETTINGS] = data[EndpointYamlFields.RETRY_SETTINGS]._to_rest_object() + + return BatchJob(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_pipeline_component_deployment_configurations_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_pipeline_component_deployment_configurations_schema.py new file mode 100644 index 00000000..f0b22fd7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_pipeline_component_deployment_configurations_schema.py @@ -0,0 +1,52 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +import logging +from typing import Any + +from marshmallow import fields, post_load + +from azure.ai.ml._schema import ( + ArmVersionedStr, + PatchedSchemaMeta, + StringTransformedEnum, + UnionField, + ArmStr, + RegistryStr, +) +from azure.ai.ml._schema.pipeline.pipeline_component import PipelineComponentFileRefField +from azure.ai.ml.constants._common import AzureMLResourceType +from azure.ai.ml.constants._job.job import JobType + +module_logger = logging.getLogger(__name__) + + +# pylint: disable-next=name-too-long +class BatchPipelineComponentDeploymentConfiguarationsSchema(metaclass=PatchedSchemaMeta): + component_id = fields.Str() + job = UnionField( + [ + ArmStr(azureml_type=AzureMLResourceType.JOB), + PipelineComponentFileRefField(), + ] + ) + component = UnionField( + [ + RegistryStr(azureml_type=AzureMLResourceType.COMPONENT), + ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True), + PipelineComponentFileRefField(), + ] + ) + type = StringTransformedEnum(required=True, allowed_values=[JobType.PIPELINE]) + settings = fields.Dict() + name = fields.Str() + description = fields.Str() + tags = fields.Dict() + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: # pylint: disable=unused-argument + from azure.ai.ml.entities._deployment.job_definition import JobDefinition + + return JobDefinition(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/compute_binding.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/compute_binding.py new file mode 100644 index 00000000..2e4b0348 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/compute_binding.py @@ -0,0 +1,36 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging +from typing import Any + +from marshmallow import ValidationError, fields, validates_schema + +from azure.ai.ml._schema.core.fields import ArmStr, StringTransformedEnum, UnionField +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml.constants._common import LOCAL_COMPUTE_TARGET, AzureMLResourceType + +module_logger = logging.getLogger(__name__) + + +class ComputeBindingSchema(metaclass=PatchedSchemaMeta): + target = UnionField( + [ + StringTransformedEnum(allowed_values=[LOCAL_COMPUTE_TARGET]), + ArmStr(azureml_type=AzureMLResourceType.COMPUTE), + # Case for virtual clusters + ArmStr(azureml_type=AzureMLResourceType.VIRTUALCLUSTER), + ] + ) + instance_count = fields.Integer() + instance_type = fields.Str(metadata={"description": "The instance type to make available to this job."}) + location = fields.Str(metadata={"description": "The locations where this job may run."}) + properties = fields.Dict(keys=fields.Str()) + + @validates_schema + def validate(self, data: Any, **kwargs): + if data.get("target") == LOCAL_COMPUTE_TARGET and data.get("instance_count", 1) != 1: + raise ValidationError("Local runs must have node count of 1.") diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/job_definition_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/job_definition_schema.py new file mode 100644 index 00000000..269f1da7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/job_definition_schema.py @@ -0,0 +1,51 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +import logging +from typing import Any + +from marshmallow import fields, post_load + +from azure.ai.ml._schema import ( + ArmVersionedStr, + PatchedSchemaMeta, + StringTransformedEnum, + UnionField, + ArmStr, + RegistryStr, +) +from azure.ai.ml._schema.pipeline.pipeline_component import PipelineComponentFileRefField +from azure.ai.ml.constants._common import AzureMLResourceType +from azure.ai.ml.constants._job.job import JobType + +module_logger = logging.getLogger(__name__) + + +class JobDefinitionSchema(metaclass=PatchedSchemaMeta): + component_id = fields.Str() + job = UnionField( + [ + ArmStr(azureml_type=AzureMLResourceType.JOB), + PipelineComponentFileRefField(), + ] + ) + component = UnionField( + [ + RegistryStr(azureml_type=AzureMLResourceType.COMPONENT), + ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True), + PipelineComponentFileRefField(), + ] + ) + type = StringTransformedEnum(required=True, allowed_values=[JobType.PIPELINE]) + settings = fields.Dict() + name = fields.Str() + description = fields.Str() + tags = fields.Dict() + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: # pylint: disable=unused-argument + from azure.ai.ml.entities._deployment.job_definition import JobDefinition + + return JobDefinition(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/model_batch_deployment.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/model_batch_deployment.py new file mode 100644 index 00000000..0dbd8463 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/model_batch_deployment.py @@ -0,0 +1,46 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging +from typing import Any + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.fields import ComputeField, NestedField, StringTransformedEnum +from azure.ai.ml._schema.job_resource_configuration import JobResourceConfigurationSchema +from azure.ai.ml._schema._deployment.deployment import DeploymentSchema +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY +from azure.ai.ml.constants._deployment import BatchDeploymentType +from azure.ai.ml._schema import ExperimentalField +from .model_batch_deployment_settings import ModelBatchDeploymentSettingsSchema + + +module_logger = logging.getLogger(__name__) + + +class ModelBatchDeploymentSchema(DeploymentSchema): + compute = ComputeField(required=True) + error_threshold = fields.Int( + metadata={ + "description": """Error threshold, if the error count for the entire input goes above this value,\r\n + the batch inference will be aborted. Range is [-1, int.MaxValue].\r\n + For FileDataset, this value is the count of file failures.\r\n + For TabularDataset, this value is the count of record failures.\r\n + If set to -1 (the lower bound), all failures during batch inference will be ignored.""" + } + ) + resources = NestedField(JobResourceConfigurationSchema) + type = StringTransformedEnum( + allowed_values=[BatchDeploymentType.PIPELINE, BatchDeploymentType.MODEL], required=False + ) + + settings = ExperimentalField(NestedField(ModelBatchDeploymentSettingsSchema)) + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: + from azure.ai.ml.entities import ModelBatchDeployment + + return ModelBatchDeployment(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/model_batch_deployment_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/model_batch_deployment_settings.py new file mode 100644 index 00000000..e1945751 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/model_batch_deployment_settings.py @@ -0,0 +1,56 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import logging +from typing import Any + +from marshmallow import fields, post_load + +from azure.ai.ml._schema import PatchedSchemaMeta +from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum +from azure.ai.ml.constants._deployment import BatchDeploymentOutputAction + +from .batch_deployment_settings import BatchRetrySettingsSchema + +module_logger = logging.getLogger(__name__) + + +class ModelBatchDeploymentSettingsSchema(metaclass=PatchedSchemaMeta): + error_threshold = fields.Int( + metadata={ + "description": """Error threshold, if the error count for the entire input goes above this value,\r\n + the batch inference will be aborted. Range is [-1, int.MaxValue].\r\n + For FileDataset, this value is the count of file failures.\r\n + For TabularDataset, this value is the count of record failures.\r\n + If set to -1 (the lower bound), all failures during batch inference will be ignored.""" + } + ) + instance_count = fields.Int() + retry_settings = NestedField(BatchRetrySettingsSchema) + mini_batch_size = fields.Int() + logging_level = fields.Str( + metadata={ + "description": """A string of the logging level name, which is defined in 'logging'. + Possible values are 'warning', 'info', and 'debug'.""" + } + ) + output_action = StringTransformedEnum( + allowed_values=[ + BatchDeploymentOutputAction.APPEND_ROW, + BatchDeploymentOutputAction.SUMMARY_ONLY, + ], + metadata={"description": "Indicates how batch inferencing will handle output."}, + dump_default=BatchDeploymentOutputAction.APPEND_ROW, + ) + output_file_name = fields.Str(metadata={"description": "Customized output file name for append_row output action."}) + max_concurrency_per_instance = fields.Int( + metadata={"description": "Indicates maximum number of parallelism per instance."} + ) + environment_variables = fields.Dict() + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: # pylint: disable=unused-argument + from azure.ai.ml.entities import ModelBatchDeploymentSettings + + return ModelBatchDeploymentSettings(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/pipeline_component_batch_deployment_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/pipeline_component_batch_deployment_schema.py new file mode 100644 index 00000000..4bc884b0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/pipeline_component_batch_deployment_schema.py @@ -0,0 +1,70 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +import logging +from typing import Any + +from marshmallow import INCLUDE, fields, post_load + +from azure.ai.ml._schema import ( + ArmVersionedStr, + ArmStr, + UnionField, + RegistryStr, + NestedField, +) +from azure.ai.ml._schema.core.fields import PipelineNodeNameStr, TypeSensitiveUnionField, PathAwareSchema +from azure.ai.ml._schema.pipeline.pipeline_component import PipelineComponentFileRefField +from azure.ai.ml.constants._common import AzureMLResourceType +from azure.ai.ml.constants._component import NodeType + +module_logger = logging.getLogger(__name__) + + +class PipelineComponentBatchDeploymentSchema(PathAwareSchema): + name = fields.Str() + endpoint_name = fields.Str() + component = UnionField( + [ + RegistryStr(azureml_type=AzureMLResourceType.COMPONENT), + ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True), + PipelineComponentFileRefField(), + ] + ) + settings = fields.Dict() + name = fields.Str() + type = fields.Str() + job_definition = UnionField( + [ + ArmStr(azureml_type=AzureMLResourceType.JOB), + NestedField("PipelineSchema", unknown=INCLUDE), + ] + ) + tags = fields.Dict() + description = fields.Str(metadata={"description": "Description of the endpoint deployment."}) + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: # pylint: disable=unused-argument + from azure.ai.ml.entities._deployment.pipeline_component_batch_deployment import ( + PipelineComponentBatchDeployment, + ) + + return PipelineComponentBatchDeployment(**data) + + +class NodeNameStr(PipelineNodeNameStr): + def _get_field_name(self) -> str: + return "Pipeline node" + + +def PipelineJobsField(): + pipeline_enable_job_type = {NodeType.PIPELINE: [NestedField("PipelineSchema", unknown=INCLUDE)]} + + pipeline_job_field = fields.Dict( + keys=NodeNameStr(), + values=TypeSensitiveUnionField(pipeline_enable_job_type), + ) + + return pipeline_job_field diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/run_settings_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/run_settings_schema.py new file mode 100644 index 00000000..54661ada --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/run_settings_schema.py @@ -0,0 +1,28 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +import logging +from typing import Any + +from marshmallow import fields, post_load + +from azure.ai.ml._schema import PatchedSchemaMeta + +module_logger = logging.getLogger(__name__) + + +class RunSettingsSchema(metaclass=PatchedSchemaMeta): + name = fields.Str() + display_name = fields.Str() + experiment_name = fields.Str() + description = fields.Str() + tags = fields.Dict() + settings = fields.Dict() + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: # pylint: disable=unused-argument + from azure.ai.ml.entities._deployment.run_settings import RunSettings + + return RunSettings(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/code_configuration_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/code_configuration_schema.py new file mode 100644 index 00000000..e9b3eac4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/code_configuration_schema.py @@ -0,0 +1,25 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging +from typing import Any + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.schema import PathAwareSchema + +module_logger = logging.getLogger(__name__) + + +class CodeConfigurationSchema(PathAwareSchema): + code = fields.Str() + scoring_script = fields.Str() + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: + from azure.ai.ml.entities import CodeConfiguration + + return CodeConfiguration(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/deployment.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/deployment.py new file mode 100644 index 00000000..669a96ad --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/deployment.py @@ -0,0 +1,48 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import logging + +from marshmallow import fields + +from azure.ai.ml._schema.assets.environment import AnonymousEnvironmentSchema, EnvironmentSchema +from azure.ai.ml._schema.assets.model import AnonymousModelSchema +from azure.ai.ml._schema.core.fields import ArmVersionedStr, NestedField, PathAwareSchema, RegistryStr, UnionField +from azure.ai.ml.constants._common import AzureMLResourceType + +from .code_configuration_schema import CodeConfigurationSchema + +module_logger = logging.getLogger(__name__) + + +class DeploymentSchema(PathAwareSchema): + name = fields.Str(required=True) + endpoint_name = fields.Str(required=True) + description = fields.Str(metadata={"description": "Description of the endpoint deployment."}) + id = fields.Str() + tags = fields.Dict() + properties = fields.Dict() + model = UnionField( + [ + RegistryStr(azureml_type=AzureMLResourceType.MODEL), + ArmVersionedStr(azureml_type=AzureMLResourceType.MODEL, allow_default_version=True), + NestedField(AnonymousModelSchema), + ], + metadata={"description": "Reference to the model asset for the endpoint deployment."}, + ) + code_configuration = NestedField( + CodeConfigurationSchema, + metadata={"description": "Code configuration for the endpoint deployment."}, + ) + environment = UnionField( + [ + RegistryStr(azureml_type=AzureMLResourceType.ENVIRONMENT), + ArmVersionedStr(azureml_type=AzureMLResourceType.ENVIRONMENT, allow_default_version=True), + NestedField(EnvironmentSchema), + NestedField(AnonymousEnvironmentSchema), + ] + ) + environment_variables = fields.Dict( + metadata={"description": "Environment variables configuration for the deployment."} + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/data_asset_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/data_asset_schema.py new file mode 100644 index 00000000..84bd37e3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/data_asset_schema.py @@ -0,0 +1,26 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +import logging +from typing import Any + +from marshmallow import fields, post_load + +from azure.ai.ml._schema import PatchedSchemaMeta + +module_logger = logging.getLogger(__name__) + + +class DataAssetSchema(metaclass=PatchedSchemaMeta): + name = fields.Str() + path = fields.Str() + version = fields.Str() + data_id = fields.Str() + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: # pylint: disable=unused-argument + from azure.ai.ml.entities._deployment.data_asset import DataAsset + + return DataAsset(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/data_collector_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/data_collector_schema.py new file mode 100644 index 00000000..633f96fc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/data_collector_schema.py @@ -0,0 +1,39 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +import logging +from typing import Any + +from marshmallow import fields, post_load, validates, ValidationError + +from azure.ai.ml._schema import NestedField, PatchedSchemaMeta, StringTransformedEnum +from azure.ai.ml._schema._deployment.online.request_logging_schema import RequestLoggingSchema +from azure.ai.ml._schema._deployment.online.deployment_collection_schema import DeploymentCollectionSchema + +from azure.ai.ml.constants._common import RollingRate + +module_logger = logging.getLogger(__name__) + + +class DataCollectorSchema(metaclass=PatchedSchemaMeta): + collections = fields.Dict(keys=fields.Str, values=NestedField(DeploymentCollectionSchema)) + rolling_rate = StringTransformedEnum( + required=False, + allowed_values=[RollingRate.MINUTE, RollingRate.DAY, RollingRate.HOUR], + ) + sampling_rate = fields.Float() # Should be copied to each of the collections + request_logging = NestedField(RequestLoggingSchema) + + # pylint: disable=unused-argument + @validates("sampling_rate") + def validate_sampling_rate(self, value, **kwargs): + if value > 1.0 or value < 0.0: + raise ValidationError("Sampling rate must be an number in range (0.0-1.0)") + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: # pylint: disable=unused-argument + from azure.ai.ml.entities._deployment.data_collector import DataCollector + + return DataCollector(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/deployment_collection_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/deployment_collection_schema.py new file mode 100644 index 00000000..4be4a9cc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/deployment_collection_schema.py @@ -0,0 +1,32 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import logging +from typing import Any + +from marshmallow import post_load, fields + +from azure.ai.ml._schema import PatchedSchemaMeta, StringTransformedEnum, NestedField, UnionField +from azure.ai.ml._schema._deployment.online.data_asset_schema import DataAssetSchema +from azure.ai.ml.constants._common import Boolean + +module_logger = logging.getLogger(__name__) + + +class DeploymentCollectionSchema(metaclass=PatchedSchemaMeta): + enabled = StringTransformedEnum(required=True, allowed_values=[Boolean.TRUE, Boolean.FALSE]) + data = UnionField( + [ + NestedField(DataAssetSchema), + fields.Str(), + ] + ) + client_id = fields.Str() + + # pylint: disable=unused-argument + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: + from azure.ai.ml.entities._deployment.deployment_collection import DeploymentCollection + + return DeploymentCollection(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/event_hub_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/event_hub_schema.py new file mode 100644 index 00000000..27b603de --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/event_hub_schema.py @@ -0,0 +1,31 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging +from typing import Any + +from marshmallow import ValidationError, fields, post_load, validates + +from azure.ai.ml._schema import NestedField, PatchedSchemaMeta +from azure.ai.ml._schema._deployment.online.oversize_data_config_schema import OversizeDataConfigSchema + +module_logger = logging.getLogger(__name__) + + +class EventHubSchema(metaclass=PatchedSchemaMeta): + namespace = fields.Str() + oversize_data_config = NestedField(OversizeDataConfigSchema) + + @validates("namespace") + def validate_namespace(self, value, **kwargs): + if len(value.split(".")) != 2: + raise ValidationError("Namespace must follow format of {namespace}.{name}") + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: + from azure.ai.ml.entities._deployment.event_hub import EventHub + + return EventHub(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/liveness_probe.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/liveness_probe.py new file mode 100644 index 00000000..d1008b8b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/liveness_probe.py @@ -0,0 +1,28 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging +from typing import Any + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + +module_logger = logging.getLogger(__name__) + + +class LivenessProbeSchema(metaclass=PatchedSchemaMeta): + period = fields.Int() + initial_delay = fields.Int() + timeout = fields.Int() + success_threshold = fields.Int() + failure_threshold = fields.Int() + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: + from azure.ai.ml.entities import ProbeSettings + + return ProbeSettings(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/online_deployment.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/online_deployment.py new file mode 100644 index 00000000..7f0760fe --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/online_deployment.py @@ -0,0 +1,79 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging +from typing import Any + +from marshmallow import fields, post_load + +from azure.ai.ml._restclient.v2022_02_01_preview.models import EndpointComputeType +from azure.ai.ml._schema._deployment.deployment import DeploymentSchema +from azure.ai.ml._schema._utils.utils import exit_if_registry_assets +from azure.ai.ml._schema.core.fields import ExperimentalField, NestedField, StringTransformedEnum, UnionField +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PublicNetworkAccess +from azure.ai.ml._schema.job.creation_context import CreationContextSchema + +from .data_collector_schema import DataCollectorSchema +from .liveness_probe import LivenessProbeSchema +from .request_settings_schema import RequestSettingsSchema +from .resource_requirements_schema import ResourceRequirementsSchema +from .scale_settings_schema import DefaultScaleSettingsSchema, TargetUtilizationScaleSettingsSchema + +module_logger = logging.getLogger(__name__) + + +class OnlineDeploymentSchema(DeploymentSchema): + app_insights_enabled = fields.Bool() + scale_settings = UnionField( + [ + NestedField(DefaultScaleSettingsSchema), + NestedField(TargetUtilizationScaleSettingsSchema), + ] + ) + request_settings = NestedField(RequestSettingsSchema) + liveness_probe = NestedField(LivenessProbeSchema) + readiness_probe = NestedField(LivenessProbeSchema) + provisioning_state = fields.Str() + instance_count = fields.Int() + type = StringTransformedEnum( + required=False, + allowed_values=[ + EndpointComputeType.MANAGED.value, + EndpointComputeType.KUBERNETES.value, + ], + casing_transform=camel_to_snake, + ) + model_mount_path = fields.Str() + instance_type = fields.Str() + data_collector = ExperimentalField(NestedField(DataCollectorSchema)) + + +class KubernetesOnlineDeploymentSchema(OnlineDeploymentSchema): + resources = NestedField(ResourceRequirementsSchema) + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: + from azure.ai.ml.entities import KubernetesOnlineDeployment + + exit_if_registry_assets(data=data, caller="K8SDeployment") + return KubernetesOnlineDeployment(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data) + + +class ManagedOnlineDeploymentSchema(OnlineDeploymentSchema): + instance_type = fields.Str(required=True) + egress_public_network_access = StringTransformedEnum( + allowed_values=[PublicNetworkAccess.ENABLED, PublicNetworkAccess.DISABLED] + ) + private_network_connection = ExperimentalField(fields.Bool()) + data_collector = NestedField(DataCollectorSchema) + creation_context = NestedField(CreationContextSchema, dump_only=True) + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: + from azure.ai.ml.entities import ManagedOnlineDeployment + + return ManagedOnlineDeployment(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/oversize_data_config_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/oversize_data_config_schema.py new file mode 100644 index 00000000..8103681a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/oversize_data_config_schema.py @@ -0,0 +1,31 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import logging +from typing import Any + +from marshmallow import ValidationError, fields, post_load, validates + +from azure.ai.ml._schema import PatchedSchemaMeta +from azure.ai.ml._utils._storage_utils import AzureMLDatastorePathUri + +module_logger = logging.getLogger(__name__) + + +class OversizeDataConfigSchema(metaclass=PatchedSchemaMeta): + path = fields.Str() + + # pylint: disable=unused-argument + @validates("path") + def validate_path(self, value, **kwargs): + datastore_path = AzureMLDatastorePathUri(value) + if datastore_path.uri_type != "Datastore": + raise ValidationError(f"Path '{value}' is not a properly formatted datastore path.") + + # pylint: disable=unused-argument + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: + from azure.ai.ml.entities._deployment.oversize_data_config import OversizeDataConfig + + return OversizeDataConfig(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/payload_response_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/payload_response_schema.py new file mode 100644 index 00000000..172af4f1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/payload_response_schema.py @@ -0,0 +1,24 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import logging +from typing import Any + +from marshmallow import post_load + +from azure.ai.ml._schema import PatchedSchemaMeta, StringTransformedEnum +from azure.ai.ml.constants._common import Boolean + +module_logger = logging.getLogger(__name__) + + +class PayloadResponseSchema(metaclass=PatchedSchemaMeta): + enabled = StringTransformedEnum(required=True, allowed_values=[Boolean.TRUE, Boolean.FALSE]) + + # pylint: disable=unused-argument + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: + from azure.ai.ml.entities._deployment.payload_response import PayloadResponse + + return PayloadResponse(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/request_logging_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/request_logging_schema.py new file mode 100644 index 00000000..4ac0b466 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/request_logging_schema.py @@ -0,0 +1,23 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import logging +from typing import Any + +from marshmallow import fields, post_load + +from azure.ai.ml._schema import PatchedSchemaMeta + +module_logger = logging.getLogger(__name__) + + +class RequestLoggingSchema(metaclass=PatchedSchemaMeta): + capture_headers = fields.List(fields.Str()) + + # pylint: disable=unused-argument + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: + from azure.ai.ml.entities._deployment.request_logging import RequestLogging + + return RequestLogging(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/request_settings_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/request_settings_schema.py new file mode 100644 index 00000000..887a71c5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/request_settings_schema.py @@ -0,0 +1,26 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging +from typing import Any + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + +module_logger = logging.getLogger(__name__) + + +class RequestSettingsSchema(metaclass=PatchedSchemaMeta): + request_timeout_ms = fields.Int(required=False) + max_concurrent_requests_per_instance = fields.Int(required=False) + max_queue_wait_ms = fields.Int(required=False) + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: + from azure.ai.ml.entities import OnlineRequestSettings + + return OnlineRequestSettings(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/resource_requirements_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/resource_requirements_schema.py new file mode 100644 index 00000000..7f43d91f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/resource_requirements_schema.py @@ -0,0 +1,28 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging +from typing import Any + +from marshmallow import post_load + +from azure.ai.ml._schema.core.fields import NestedField +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + +from .resource_settings_schema import ResourceSettingsSchema + +module_logger = logging.getLogger(__name__) + + +class ResourceRequirementsSchema(metaclass=PatchedSchemaMeta): + requests = NestedField(ResourceSettingsSchema) + limits = NestedField(ResourceSettingsSchema) + + @post_load + def make(self, data: Any, **kwargs: Any) -> "ResourceRequirementsSettings": + from azure.ai.ml.entities import ResourceRequirementsSettings + + return ResourceRequirementsSettings(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/resource_settings_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/resource_settings_schema.py new file mode 100644 index 00000000..21a229ad --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/resource_settings_schema.py @@ -0,0 +1,32 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging +from typing import Any + +from marshmallow import fields, post_load, pre_load + +from azure.ai.ml._schema._utils.utils import replace_key_in_odict +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + +module_logger = logging.getLogger(__name__) + + +class ResourceSettingsSchema(metaclass=PatchedSchemaMeta): + cpu = fields.String() + memory = fields.String() + gpu = fields.String() + + @pre_load + def conversion(self, data: Any, **kwargs) -> Any: + data = replace_key_in_odict(data, "nvidia.com/gpu", "gpu") + return data + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: + from azure.ai.ml.entities import ResourceSettings + + return ResourceSettings(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/scale_settings_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/scale_settings_schema.py new file mode 100644 index 00000000..6c5c5283 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/scale_settings_schema.py @@ -0,0 +1,51 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging +from typing import Any + +from marshmallow import fields, post_load + +from azure.ai.ml._restclient.v2022_10_01.models import ScaleType +from azure.ai.ml._schema.core.fields import StringTransformedEnum +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml._utils.utils import camel_to_snake + +module_logger = logging.getLogger(__name__) + + +class DefaultScaleSettingsSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum( + required=True, + allowed_values=ScaleType.DEFAULT, + casing_transform=camel_to_snake, + data_key="type", + ) + + @post_load + def make(self, data: Any, **kwargs: Any) -> "DefaultScaleSettings": + from azure.ai.ml.entities import DefaultScaleSettings + + return DefaultScaleSettings(**data) + + +class TargetUtilizationScaleSettingsSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum( + required=True, + allowed_values=ScaleType.TARGET_UTILIZATION, + casing_transform=camel_to_snake, + data_key="type", + ) + polling_interval = fields.Int() + target_utilization_percentage = fields.Int() + min_instances = fields.Int() + max_instances = fields.Int() + + @post_load + def make(self, data: Any, **kwargs: Any) -> "TargetUtilizationScaleSettings": + from azure.ai.ml.entities import TargetUtilizationScaleSettings + + return TargetUtilizationScaleSettings(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/__init__.py new file mode 100644 index 00000000..437d8743 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/__init__.py @@ -0,0 +1,17 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) + +from .distillation_job import DistillationJobSchema +from .endpoint_request_settings import EndpointRequestSettingsSchema +from .prompt_settings import PromptSettingsSchema +from .teacher_model_settings import TeacherModelSettingsSchema + +__all__ = [ + "DistillationJobSchema", + "PromptSettingsSchema", + "EndpointRequestSettingsSchema", + "TeacherModelSettingsSchema", +] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/distillation_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/distillation_job.py new file mode 100644 index 00000000..d72f2457 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/distillation_job.py @@ -0,0 +1,84 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields + +from azure.ai.ml._schema._distillation.prompt_settings import PromptSettingsSchema +from azure.ai.ml._schema._distillation.teacher_model_settings import TeacherModelSettingsSchema +from azure.ai.ml._schema.core.fields import ( + ArmVersionedStr, + LocalPathField, + NestedField, + RegistryStr, + StringTransformedEnum, + UnionField, +) +from azure.ai.ml._schema.job import BaseJobSchema +from azure.ai.ml._schema.job.input_output_entry import DataInputSchema, ModelInputSchema +from azure.ai.ml._schema.job.input_output_fields_provider import OutputsField +from azure.ai.ml._schema.job_resource_configuration import ResourceConfigurationSchema +from azure.ai.ml._schema.workspace.connections import ServerlessConnectionSchema, WorkspaceConnectionSchema +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.constants import DataGenerationTaskType, DataGenerationType, JobType +from azure.ai.ml.constants._common import AzureMLResourceType + + +@experimental +class DistillationJobSchema(BaseJobSchema): + type = StringTransformedEnum(required=True, allowed_values=JobType.DISTILLATION) + data_generation_type = StringTransformedEnum( + allowed_values=[DataGenerationType.LABEL_GENERATION, DataGenerationType.DATA_GENERATION], + required=True, + ) + data_generation_task_type = StringTransformedEnum( + allowed_values=[ + DataGenerationTaskType.NLI, + DataGenerationTaskType.NLU_QA, + DataGenerationTaskType.CONVERSATION, + DataGenerationTaskType.MATH, + DataGenerationTaskType.SUMMARIZATION, + ], + casing_transform=str.upper, + required=True, + ) + teacher_model_endpoint_connection = UnionField( + [NestedField(WorkspaceConnectionSchema), NestedField(ServerlessConnectionSchema)], required=True + ) + student_model = UnionField( + [ + NestedField(ModelInputSchema), + RegistryStr(azureml_type=AzureMLResourceType.MODEL), + ArmVersionedStr(azureml_type=AzureMLResourceType.MODEL, allow_default_version=True), + ], + required=True, + ) + training_data = UnionField( + [ + NestedField(DataInputSchema), + ArmVersionedStr(azureml_type=AzureMLResourceType.DATA), + fields.Str(metadata={"pattern": r"^(http(s)?):.*"}), + fields.Str(metadata={"pattern": r"^(wasb(s)?):.*"}), + LocalPathField(pattern=r"^file:.*"), + LocalPathField( + pattern=r"^(?!(azureml|http(s)?|wasb(s)?|file):).*", + ), + ] + ) + validation_data = UnionField( + [ + NestedField(DataInputSchema), + ArmVersionedStr(azureml_type=AzureMLResourceType.DATA), + fields.Str(metadata={"pattern": r"^(http(s)?):.*"}), + fields.Str(metadata={"pattern": r"^(wasb(s)?):.*"}), + LocalPathField(pattern=r"^file:.*"), + LocalPathField( + pattern=r"^(?!(azureml|http(s)?|wasb(s)?|file):).*", + ), + ] + ) + teacher_model_settings = NestedField(TeacherModelSettingsSchema) + prompt_settings = NestedField(PromptSettingsSchema) + hyperparameters = fields.Dict(keys=fields.Str(), values=fields.Str(allow_none=True)) + resources = NestedField(ResourceConfigurationSchema) + outputs = OutputsField() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/endpoint_request_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/endpoint_request_settings.py new file mode 100644 index 00000000..960e7d2a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/endpoint_request_settings.py @@ -0,0 +1,27 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml._utils._experimental import experimental + + +@experimental +class EndpointRequestSettingsSchema(metaclass=PatchedSchemaMeta): + request_batch_size = fields.Int() + min_endpoint_success_ratio = fields.Number() + + @post_load + def make(self, data, **kwargs): # pylint: disable=unused-argument + """Post-load processing of the schema data + + :param data: Dictionary of parsed values from the yaml. + :type data: typing.Dict + :return: EndpointRequestSettings made from the yaml + :rtype: EndpointRequestSettings + """ + from azure.ai.ml.entities._job.distillation.endpoint_request_settings import EndpointRequestSettings + + return EndpointRequestSettings(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/prompt_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/prompt_settings.py new file mode 100644 index 00000000..3b21908a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/prompt_settings.py @@ -0,0 +1,29 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml._utils._experimental import experimental + + +@experimental +class PromptSettingsSchema(metaclass=PatchedSchemaMeta): + enable_chain_of_thought = fields.Bool() + enable_chain_of_density = fields.Bool() + max_len_summary = fields.Int() + # custom_prompt = fields.Str() + + @post_load + def make(self, data, **kwargs): # pylint: disable=unused-argument + """Post-load processing of the schema data + + :param data: Dictionary of parsed values from the yaml. + :type data: typing.Dict + :return: PromptSettings made from the yaml + :rtype: PromptSettings + """ + from azure.ai.ml.entities._job.distillation.prompt_settings import PromptSettings + + return PromptSettings(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/teacher_model_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/teacher_model_settings.py new file mode 100644 index 00000000..ecf32047 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/teacher_model_settings.py @@ -0,0 +1,29 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields, post_load + +from azure.ai.ml._schema._distillation.endpoint_request_settings import EndpointRequestSettingsSchema +from azure.ai.ml._schema.core.fields import NestedField +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml._utils._experimental import experimental + + +@experimental +class TeacherModelSettingsSchema(metaclass=PatchedSchemaMeta): + inference_parameters = fields.Dict(keys=fields.Str(), values=fields.Raw()) + endpoint_request_settings = NestedField(EndpointRequestSettingsSchema) + + @post_load + def make(self, data, **kwargs): # pylint: disable=unused-argument + """Post-load processing of the schema data + + :param data: Dictionary of parsed values from the yaml. + :type data: typing.Dict + :return: TeacherModelSettings made from the yaml + :rtype: TeacherModelSettings + """ + from azure.ai.ml.entities._job.distillation.teacher_model_settings import TeacherModelSettings + + return TeacherModelSettings(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/__init__.py new file mode 100644 index 00000000..e9538cbb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/__init__.py @@ -0,0 +1,15 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) + + +from .batch.batch_endpoint import BatchEndpointSchema +from .online.online_endpoint import KubernetesOnlineEndpointSchema, ManagedOnlineEndpointSchema + +__all__ = [ + "BatchEndpointSchema", + "KubernetesOnlineEndpointSchema", + "ManagedOnlineEndpointSchema", +] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/batch/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/batch/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/batch/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/batch/batch_endpoint.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/batch/batch_endpoint.py new file mode 100644 index 00000000..0bee2493 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/batch/batch_endpoint.py @@ -0,0 +1,27 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging +from typing import Any + +from marshmallow import post_load + +from azure.ai.ml._schema._endpoint.batch.batch_endpoint_defaults import BatchEndpointsDefaultsSchema +from azure.ai.ml._schema._endpoint.endpoint import EndpointSchema +from azure.ai.ml._schema.core.fields import NestedField +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY + +module_logger = logging.getLogger(__name__) + + +class BatchEndpointSchema(EndpointSchema): + defaults = NestedField(BatchEndpointsDefaultsSchema) + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: + from azure.ai.ml.entities import BatchEndpoint + + return BatchEndpoint(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/batch/batch_endpoint_defaults.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/batch/batch_endpoint_defaults.py new file mode 100644 index 00000000..49699bb0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/batch/batch_endpoint_defaults.py @@ -0,0 +1,28 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging +from typing import Any + +from marshmallow import fields, post_load + +from azure.ai.ml._restclient.v2023_10_01.models import BatchEndpointDefaults +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + +module_logger = logging.getLogger(__name__) + + +class BatchEndpointsDefaultsSchema(metaclass=PatchedSchemaMeta): + deployment_name = fields.Str( + metadata={ + "description": """Name of the deployment that will be default for the endpoint. + This deployment will end up getting 100% traffic when the endpoint scoring URL is invoked.""" + } + ) + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: + return BatchEndpointDefaults(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/endpoint.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/endpoint.py new file mode 100644 index 00000000..1ff43338 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/endpoint.py @@ -0,0 +1,41 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import logging + +from marshmallow import fields, validate + +from azure.ai.ml._restclient.v2022_02_01_preview.models import EndpointAuthMode +from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum +from azure.ai.ml._schema.core.schema import PathAwareSchema +from azure.ai.ml._schema.identity import IdentitySchema +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.constants._endpoint import EndpointConfigurations + +module_logger = logging.getLogger(__name__) + + +class EndpointSchema(PathAwareSchema): + id = fields.Str() + name = fields.Str(required=True, validate=validate.Regexp(EndpointConfigurations.NAME_REGEX_PATTERN)) + description = fields.Str(metadata={"description": "Description of the inference endpoint."}) + tags = fields.Dict() + provisioning_state = fields.Str(metadata={"description": "Provisioning state for the endpoint."}) + properties = fields.Dict() + auth_mode = StringTransformedEnum( + allowed_values=[ + EndpointAuthMode.AML_TOKEN, + EndpointAuthMode.KEY, + EndpointAuthMode.AAD_TOKEN, + ], + casing_transform=camel_to_snake, + metadata={ + "description": """authentication method: no auth, key based or azure ml token based. + aad_token is only valid for batch endpoint.""" + }, + ) + scoring_uri = fields.Str(metadata={"description": "The endpoint uri that can be used for scoring"}) + location = fields.Str() + openapi_uri = fields.Str(metadata={"description": "Endpoint Open API URI."}) + identity = NestedField(IdentitySchema) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/online/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/online/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/online/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/online/online_endpoint.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/online/online_endpoint.py new file mode 100644 index 00000000..84b34636 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/online/online_endpoint.py @@ -0,0 +1,66 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging +from typing import Any + +from marshmallow import ValidationError, fields, post_load, validates + +from azure.ai.ml._schema._endpoint.endpoint import EndpointSchema +from azure.ai.ml._schema.core.fields import ArmStr, StringTransformedEnum +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, AzureMLResourceType, PublicNetworkAccess + +module_logger = logging.getLogger(__name__) + + +class OnlineEndpointSchema(EndpointSchema): + traffic = fields.Dict( + keys=fields.Str(), + values=fields.Int(), + metadata={ + "description": """a dict with key as deployment name and value as traffic percentage. + The values need to sum to 100 """ + }, + ) + kind = fields.Str(dump_only=True) + + mirror_traffic = fields.Dict( + keys=fields.Str(), + values=fields.Int(), + metadata={ + "description": """a dict with key as deployment name and value as traffic percentage. + Only one key will be accepted and value needs to be less than or equal to 50%""" + }, + ) + + @validates("traffic") + def validate_traffic(self, data, **kwargs): + if sum(data.values()) > 100: + raise ValidationError("Traffic rule percentages must sum to less than or equal to 100%") + + +class KubernetesOnlineEndpointSchema(OnlineEndpointSchema): + provisioning_state = fields.Str(metadata={"description": "status of the deployment provisioning operation"}) + compute = ArmStr(azureml_type=AzureMLResourceType.COMPUTE) + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: + from azure.ai.ml.entities import KubernetesOnlineEndpoint + + return KubernetesOnlineEndpoint(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data) + + +class ManagedOnlineEndpointSchema(OnlineEndpointSchema): + provisioning_state = fields.Str() + public_network_access = StringTransformedEnum( + allowed_values=[PublicNetworkAccess.ENABLED, PublicNetworkAccess.DISABLED] + ) + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: + from azure.ai.ml.entities import ManagedOnlineEndpoint + + return ManagedOnlineEndpoint(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/__init__.py new file mode 100644 index 00000000..69c1cdbd --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/__init__.py @@ -0,0 +1,25 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) + +from .delay_metadata_schema import DelayMetadataSchema +from .feature_schema import FeatureSchema +from .feature_set_schema import FeatureSetSchema +from .featureset_spec_metadata_schema import FeaturesetSpecMetadataSchema +from .feature_set_specification_schema import FeatureSetSpecificationSchema +from .materialization_settings_schema import MaterializationSettingsSchema +from .source_metadata_schema import SourceMetadataSchema +from .timestamp_column_metadata_schema import TimestampColumnMetadataSchema + +__all__ = [ + "DelayMetadataSchema", + "FeatureSchema", + "FeatureSetSchema", + "FeaturesetSpecMetadataSchema", + "FeatureSetSpecificationSchema", + "MaterializationSettingsSchema", + "SourceMetadataSchema", + "TimestampColumnMetadataSchema", +] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/delay_metadata_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/delay_metadata_schema.py new file mode 100644 index 00000000..5ad78a7a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/delay_metadata_schema.py @@ -0,0 +1,21 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + + +class DelayMetadataSchema(metaclass=PatchedSchemaMeta): + days = fields.Int(required=False) + hours = fields.Int(required=False) + minutes = fields.Int(required=False) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._feature_set.delay_metadata import DelayMetadata + + return DelayMetadata(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_schema.py new file mode 100644 index 00000000..6d248270 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_schema.py @@ -0,0 +1,29 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + + +class FeatureSchema(metaclass=PatchedSchemaMeta): + name = fields.Str( + required=True, + allow_none=False, + ) + data_type = fields.Str( + required=True, + allow_none=False, + data_key="type", + ) + description = fields.Str(required=False) + tags = fields.Dict(keys=fields.Str(), values=fields.Str(), required=False) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._feature_set.feature import Feature + + return Feature(description=data.pop("description", None), **data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_set_backfill_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_set_backfill_schema.py new file mode 100644 index 00000000..0ee5af8e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_set_backfill_schema.py @@ -0,0 +1,22 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields + +from azure.ai.ml._schema._feature_set.feature_window_schema import FeatureWindowSchema +from azure.ai.ml._schema._feature_set.materialization_settings_schema import MaterializationComputeResourceSchema +from azure.ai.ml._schema.core.fields import NestedField +from azure.ai.ml._schema.core.schema import YamlFileSchema + + +class FeatureSetBackfillSchema(YamlFileSchema): + name = fields.Str(required=True) + version = fields.Str(required=True) + feature_window = NestedField(FeatureWindowSchema) + description = fields.Str() + tags = fields.Dict() + resource = NestedField(MaterializationComputeResourceSchema) + spark_configuration = fields.Dict() + data_status = fields.List(fields.Str()) + job_id = fields.Str() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_set_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_set_schema.py new file mode 100644 index 00000000..08722402 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_set_schema.py @@ -0,0 +1,27 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields, post_dump, validate + +from azure.ai.ml._schema import NestedField +from azure.ai.ml._schema.core.schema import YamlFileSchema + +from .feature_set_specification_schema import FeatureSetSpecificationSchema +from .materialization_settings_schema import MaterializationSettingsSchema + + +class FeatureSetSchema(YamlFileSchema): + name = fields.Str(required=True, allow_none=False) + version = fields.Str(required=True, allow_none=False) + latest_version = fields.Str(dump_only=True) + specification = NestedField(FeatureSetSpecificationSchema, required=True, allow_none=False) + entities = fields.List(fields.Str, required=True, allow_none=False) + stage = fields.Str(validate=validate.OneOf(["Development", "Production", "Archived"]), dump_default="Development") + description = fields.Str() + tags = fields.Dict(keys=fields.Str(), values=fields.Str()) + materialization_settings = NestedField(MaterializationSettingsSchema) + + @post_dump + def remove_empty_values(self, data, **kwargs): # pylint: disable=unused-argument + return {key: value for key, value in data.items() if value} diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_set_specification_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_set_specification_schema.py new file mode 100644 index 00000000..64b399fb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_set_specification_schema.py @@ -0,0 +1,19 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + + +class FeatureSetSpecificationSchema(metaclass=PatchedSchemaMeta): + path = fields.Str(required=True, allow_none=False) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._feature_set.feature_set_specification import FeatureSetSpecification + + return FeatureSetSpecification(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_transformation_code_metadata_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_transformation_code_metadata_schema.py new file mode 100644 index 00000000..8b173865 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_transformation_code_metadata_schema.py @@ -0,0 +1,22 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + + +class FeatureTransformationCodeMetadataSchema(metaclass=PatchedSchemaMeta): + path = fields.Str(required=False) + transformer_class = fields.Str(required=False) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._feature_set.feature_transformation_code_metadata import ( + FeatureTransformationCodeMetadata, + ) + + return FeatureTransformationCodeMetadata(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_window_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_window_schema.py new file mode 100644 index 00000000..d114c731 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_window_schema.py @@ -0,0 +1,11 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields +from azure.ai.ml._schema.core.schema import YamlFileSchema + + +class FeatureWindowSchema(YamlFileSchema): + feature_window_end = fields.Str() + feature_window_start = fields.Str() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/featureset_spec_metadata_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/featureset_spec_metadata_schema.py new file mode 100644 index 00000000..251ccd6e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/featureset_spec_metadata_schema.py @@ -0,0 +1,33 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from typing import Dict + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.fields import NestedField +from azure.ai.ml._schema.core.schema import YamlFileSchema +from azure.ai.ml._schema._feature_store_entity.data_column_schema import DataColumnSchema + +from .source_metadata_schema import SourceMetadataSchema +from .delay_metadata_schema import DelayMetadataSchema +from .feature_schema import FeatureSchema +from .feature_transformation_code_metadata_schema import FeatureTransformationCodeMetadataSchema + + +class FeaturesetSpecMetadataSchema(YamlFileSchema): + source = fields.Nested(SourceMetadataSchema, required=True) + feature_transformation_code = fields.Nested(FeatureTransformationCodeMetadataSchema, required=False) + features = fields.List(NestedField(FeatureSchema), required=True, allow_none=False) + index_columns = fields.List(NestedField(DataColumnSchema), required=False) + source_lookback = fields.Nested(DelayMetadataSchema, required=False) + temporal_join_lookback = fields.Nested(DelayMetadataSchema, required=False) + + @post_load + def make(self, data: Dict, **kwargs): + from azure.ai.ml.entities._feature_set.featureset_spec_metadata import FeaturesetSpecMetadata + + return FeaturesetSpecMetadata(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/featureset_spec_properties_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/featureset_spec_properties_schema.py new file mode 100644 index 00000000..e3a56542 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/featureset_spec_properties_schema.py @@ -0,0 +1,55 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields + +from azure.ai.ml._schema.core.fields import NestedField +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta, YamlFileSchema + +from .source_process_code_metadata_schema import SourceProcessCodeSchema +from .timestamp_column_metadata_schema import TimestampColumnMetadataSchema + + +# pylint: disable-next=name-too-long +class FeatureTransformationCodePropertiesSchema(metaclass=PatchedSchemaMeta): + path = fields.Str(data_key="Path") + transformer_class = fields.Str(data_key="TransformerClass") + + +class DelayMetadataPropertiesSchema(metaclass=PatchedSchemaMeta): + days = fields.Int(data_key="Days") + hours = fields.Int(data_key="Hours") + minutes = fields.Int(data_key="Minutes") + + +class FeaturePropertiesSchema(metaclass=PatchedSchemaMeta): + name = fields.Str(data_key="FeatureName") + data_type = fields.Str(data_key="DataType") + description = fields.Str(data_key="Description") + tags = fields.Dict(keys=fields.Str(), values=fields.Str(), data_key="Tags") + + +class ColumnPropertiesSchema(metaclass=PatchedSchemaMeta): + name = fields.Str(data_key="ColumnName") + type = fields.Str(data_key="DataType") + + +class SourcePropertiesSchema(metaclass=PatchedSchemaMeta): + type = fields.Str(required=True) + path = fields.Str(required=False) + timestamp_column = fields.Nested(TimestampColumnMetadataSchema, data_key="timestampColumn") + source_delay = fields.Nested(DelayMetadataPropertiesSchema, data_key="sourceDelay") + source_process_code = fields.Nested(SourceProcessCodeSchema) + dict = fields.Dict(keys=fields.Str(), values=fields.Str(), data_key="kwargs") + + +class FeaturesetSpecPropertiesSchema(YamlFileSchema): + source = fields.Nested(SourcePropertiesSchema, data_key="source") + feature_transformation_code = fields.Nested( + FeatureTransformationCodePropertiesSchema, data_key="featureTransformationCode" + ) + features = fields.List(NestedField(FeaturePropertiesSchema), data_key="features") + index_columns = fields.List(NestedField(ColumnPropertiesSchema), data_key="indexColumns") + source_lookback = fields.Nested(DelayMetadataPropertiesSchema, data_key="sourceLookback") + temporal_join_lookback = fields.Nested(DelayMetadataPropertiesSchema, data_key="temporalJoinLookback") diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/materialization_settings_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/materialization_settings_schema.py new file mode 100644 index 00000000..8cf68b67 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/materialization_settings_schema.py @@ -0,0 +1,37 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load + +from azure.ai.ml._schema import NestedField +from azure.ai.ml._schema._notification.notification_schema import NotificationSchema +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml._schema.schedule.trigger import RecurrenceTriggerSchema + + +class MaterializationComputeResourceSchema(metaclass=PatchedSchemaMeta): + instance_type = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._feature_set.materialization_compute_resource import MaterializationComputeResource + + return MaterializationComputeResource(instance_type=data.pop("instance_type"), **data) + + +class MaterializationSettingsSchema(metaclass=PatchedSchemaMeta): + schedule = NestedField(RecurrenceTriggerSchema) + notification = NestedField(NotificationSchema) + resource = NestedField(MaterializationComputeResourceSchema) + spark_configuration = fields.Dict() + offline_enabled = fields.Boolean() + online_enabled = fields.Boolean() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._feature_set.materialization_settings import MaterializationSettings + + return MaterializationSettings(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/source_metadata_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/source_metadata_schema.py new file mode 100644 index 00000000..345c9084 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/source_metadata_schema.py @@ -0,0 +1,30 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from typing import Dict + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + +from .delay_metadata_schema import DelayMetadataSchema +from .source_process_code_metadata_schema import SourceProcessCodeSchema +from .timestamp_column_metadata_schema import TimestampColumnMetadataSchema + + +class SourceMetadataSchema(metaclass=PatchedSchemaMeta): + type = fields.Str(required=True) + path = fields.Str(required=False) + timestamp_column = fields.Nested(TimestampColumnMetadataSchema, required=False) + source_delay = fields.Nested(DelayMetadataSchema, required=False) + source_process_code = fields.Nested(SourceProcessCodeSchema, load_only=True, required=False) + dict = fields.Dict(keys=fields.Str(), values=fields.Str(), data_key="kwargs", load_only=True, required=False) + + @post_load + def make(self, data: Dict, **kwargs): + from azure.ai.ml.entities._feature_set.source_metadata import SourceMetadata + + return SourceMetadata(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/source_process_code_metadata_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/source_process_code_metadata_schema.py new file mode 100644 index 00000000..b8b93739 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/source_process_code_metadata_schema.py @@ -0,0 +1,20 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + + +class SourceProcessCodeSchema(metaclass=PatchedSchemaMeta): + path = fields.Str(required=True, allow_none=False) + process_class = fields.Str(required=True, allow_none=False) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._feature_set.source_process_code_metadata import SourceProcessCodeMetadata + + return SourceProcessCodeMetadata(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/timestamp_column_metadata_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/timestamp_column_metadata_schema.py new file mode 100644 index 00000000..6d7982be --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/timestamp_column_metadata_schema.py @@ -0,0 +1,20 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + + +class TimestampColumnMetadataSchema(metaclass=PatchedSchemaMeta): + name = fields.Str(required=True) + format = fields.Str(required=False) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._feature_set.timestamp_column_metadata import TimestampColumnMetadata + + return TimestampColumnMetadata(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store/__init__.py new file mode 100644 index 00000000..5e7d7822 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store/__init__.py @@ -0,0 +1,15 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) + +from .compute_runtime_schema import ComputeRuntimeSchema +from .feature_store_schema import FeatureStoreSchema +from .materialization_store_schema import MaterializationStoreSchema + +__all__ = [ + "ComputeRuntimeSchema", + "FeatureStoreSchema", + "MaterializationStoreSchema", +] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store/compute_runtime_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store/compute_runtime_schema.py new file mode 100644 index 00000000..48db586f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store/compute_runtime_schema.py @@ -0,0 +1,19 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + + +class ComputeRuntimeSchema(metaclass=PatchedSchemaMeta): + spark_runtime_version = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._workspace.compute_runtime import ComputeRuntime + + return ComputeRuntime(spark_runtime_version=data.pop("spark_runtime_version")) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store/feature_store_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store/feature_store_schema.py new file mode 100644 index 00000000..78fb0642 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store/feature_store_schema.py @@ -0,0 +1,43 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields, EXCLUDE + +from azure.ai.ml._schema._utils.utils import validate_arm_str +from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum +from azure.ai.ml._schema.core.schema import PathAwareSchema +from azure.ai.ml._schema.workspace.customer_managed_key import CustomerManagedKeySchema +from azure.ai.ml._schema.workspace.identity import IdentitySchema, UserAssignedIdentitySchema +from azure.ai.ml._utils.utils import snake_to_pascal +from azure.ai.ml.constants._common import PublicNetworkAccess +from azure.ai.ml._schema.workspace.networking import ManagedNetworkSchema +from .compute_runtime_schema import ComputeRuntimeSchema +from .materialization_store_schema import MaterializationStoreSchema + + +class FeatureStoreSchema(PathAwareSchema): + name = fields.Str(required=True) + compute_runtime = NestedField(ComputeRuntimeSchema) + offline_store = NestedField(MaterializationStoreSchema) + online_store = NestedField(MaterializationStoreSchema) + materialization_identity = NestedField(UserAssignedIdentitySchema) + description = fields.Str() + tags = fields.Dict(keys=fields.Str(), values=fields.Str()) + display_name = fields.Str() + location = fields.Str() + resource_group = fields.Str() + hbi_workspace = fields.Bool() + storage_account = fields.Str(validate=validate_arm_str) + container_registry = fields.Str(validate=validate_arm_str) + key_vault = fields.Str(validate=validate_arm_str) + application_insights = fields.Str(validate=validate_arm_str) + customer_managed_key = NestedField(CustomerManagedKeySchema) + image_build_compute = fields.Str() + public_network_access = StringTransformedEnum( + allowed_values=[PublicNetworkAccess.DISABLED, PublicNetworkAccess.ENABLED], + casing_transform=snake_to_pascal, + ) + identity = NestedField(IdentitySchema) + primary_user_assigned_identity = fields.Str() + managed_network = NestedField(ManagedNetworkSchema, unknown=EXCLUDE) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store/materialization_store_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store/materialization_store_schema.py new file mode 100644 index 00000000..091cd4eb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store/materialization_store_schema.py @@ -0,0 +1,23 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + + +class MaterializationStoreSchema(metaclass=PatchedSchemaMeta): + type = fields.Str(required=True, allow_none=False) + target = fields.Str(required=True, allow_none=False) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._feature_store.materialization_store import MaterializationStore + + return MaterializationStore( + type=data.pop("type"), + target=data.pop("target"), + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store_entity/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store_entity/__init__.py new file mode 100644 index 00000000..8fec3153 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store_entity/__init__.py @@ -0,0 +1,13 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) + +from .data_column_schema import DataColumnSchema +from .feature_store_entity_schema import FeatureStoreEntitySchema + +__all__ = [ + "DataColumnSchema", + "FeatureStoreEntitySchema", +] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store_entity/data_column_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store_entity/data_column_schema.py new file mode 100644 index 00000000..9fffc055 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store_entity/data_column_schema.py @@ -0,0 +1,26 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + + +class DataColumnSchema(metaclass=PatchedSchemaMeta): + name = fields.Str( + required=True, + allow_none=False, + ) + type = fields.Str( + required=True, + allow_none=False, + ) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._feature_store_entity.data_column import DataColumn + + return DataColumn(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store_entity/feature_store_entity_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store_entity/feature_store_entity_schema.py new file mode 100644 index 00000000..51505430 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store_entity/feature_store_entity_schema.py @@ -0,0 +1,26 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +from marshmallow import fields, post_dump, validate + +from azure.ai.ml._schema import NestedField +from azure.ai.ml._schema.core.schema import YamlFileSchema + +from .data_column_schema import DataColumnSchema + + +class FeatureStoreEntitySchema(YamlFileSchema): + name = fields.Str(required=True, allow_none=False) + version = fields.Str(required=True, allow_none=False) + latest_version = fields.Str(dump_only=True) + index_columns = fields.List(NestedField(DataColumnSchema), required=True, allow_none=False) + stage = fields.Str(validate=validate.OneOf(["Development", "Production", "Archived"]), dump_default="Development") + description = fields.Str() + tags = fields.Dict(keys=fields.Str(), values=fields.Str()) + properties = fields.Dict(keys=fields.Str(), values=fields.Str()) + + @post_dump + def remove_empty_values(self, data, **kwargs): # pylint: disable=unused-argument + return {key: value for key, value in data.items() if value} diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/__init__.py new file mode 100644 index 00000000..e47aa230 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/__init__.py @@ -0,0 +1,19 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) + +from .azure_openai_finetuning import AzureOpenAIFineTuningSchema +from .azure_openai_hyperparameters import AzureOpenAIHyperparametersSchema +from .custom_model_finetuning import CustomModelFineTuningSchema +from .finetuning_job import FineTuningJobSchema +from .finetuning_vertical import FineTuningVerticalSchema + +__all__ = [ + "AzureOpenAIFineTuningSchema", + "AzureOpenAIHyperparametersSchema", + "CustomModelFineTuningSchema", + "FineTuningJobSchema", + "FineTuningVerticalSchema", +] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/azure_openai_finetuning.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/azure_openai_finetuning.py new file mode 100644 index 00000000..f6d2a58d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/azure_openai_finetuning.py @@ -0,0 +1,54 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from typing import Any, Dict +from marshmallow import post_load + + +from azure.ai.ml._schema.core.fields import StringTransformedEnum +from azure.ai.ml._restclient.v2024_01_01_preview.models import ModelProvider +from azure.ai.ml._schema._finetuning.azure_openai_hyperparameters import AzureOpenAIHyperparametersSchema +from azure.ai.ml._schema._finetuning.finetuning_vertical import FineTuningVerticalSchema +from azure.ai.ml.entities._job.finetuning.azure_openai_hyperparameters import AzureOpenAIHyperparameters +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml._schema.core.fields import NestedField +from azure.ai.ml.constants._job.finetuning import FineTuningConstants +from azure.ai.ml._utils._experimental import experimental + + +@experimental +class AzureOpenAIFineTuningSchema(FineTuningVerticalSchema): + # This is meant to match the yaml definition NOT the models defined in _restclient + + model_provider = StringTransformedEnum( + required=True, allowed_values=ModelProvider.AZURE_OPEN_AI, casing_transform=camel_to_snake + ) + hyperparameters = NestedField(AzureOpenAIHyperparametersSchema(), data_key=FineTuningConstants.HyperParameters) + + @post_load + def post_load_processing(self, data: Dict, **kwargs) -> Dict[str, Any]: + """Post load processing for the schema. + + :param data: Dictionary of parsed values from the yaml. + :type data: typing.Dict + + :return Dictionary of parsed values from the yaml. + :rtype Dict[str, Any] + """ + data.pop("model_provider") + hyperaparameters = data.pop("hyperparameters", None) + + if hyperaparameters and not isinstance(hyperaparameters, AzureOpenAIHyperparameters): + hyperaparameters_dict = {} + for key, value in hyperaparameters.items(): + hyperaparameters_dict[key] = value + azure_openai_hyperparameters = AzureOpenAIHyperparameters( + batch_size=hyperaparameters_dict.get("batch_size", None), + learning_rate_multiplier=hyperaparameters_dict.get("learning_rate_multiplier", None), + n_epochs=hyperaparameters_dict.get("n_epochs", None), + ) + data["hyperparameters"] = azure_openai_hyperparameters + return data diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/azure_openai_hyperparameters.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/azure_openai_hyperparameters.py new file mode 100644 index 00000000..f421188d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/azure_openai_hyperparameters.py @@ -0,0 +1,18 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml._utils._experimental import experimental + + +@experimental +class AzureOpenAIHyperparametersSchema(metaclass=PatchedSchemaMeta): + n_epochs = fields.Int() + learning_rate_multiplier = fields.Float() + batch_size = fields.Int() + # TODO: Should be dict<string,string>, check schema for the same. + # For now not exposing as we dont have REST layer representation exposed. + # Need to check with the team. + # additional_parameters = fields.Dict() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/constants.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/constants.py new file mode 100644 index 00000000..3e14dca4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/constants.py @@ -0,0 +1,17 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +class SnakeCaseFineTuningTaskTypes: + CHAT_COMPLETION = "chat_completion" + TEXT_COMPLETION = "text_completion" + TEXT_CLASSIFICATION = "text_classification" + QUESTION_ANSWERING = "question_answering" + TEXT_SUMMARIZATION = "text_summarization" + TOKEN_CLASSIFICATION = "token_classification" + TEXT_TRANSLATION = "text_translation" + IMAGE_CLASSIFICATION = "image_classification" + IMAGE_INSTANCE_SEGMENTATION = "image_instance_segmentation" + IMAGE_OBJECT_DETECTION = "image_object_detection" + VIDEO_MULTI_OBJECT_TRACKING = "video_multi_object_tracking" diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/custom_model_finetuning.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/custom_model_finetuning.py new file mode 100644 index 00000000..9d5b22a7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/custom_model_finetuning.py @@ -0,0 +1,35 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from typing import Any, Dict +from marshmallow import fields, post_load + +from azure.ai.ml._restclient.v2024_01_01_preview.models import ModelProvider +from azure.ai.ml._schema._finetuning.finetuning_vertical import FineTuningVerticalSchema +from azure.ai.ml._schema.core.fields import StringTransformedEnum +from azure.ai.ml._utils._experimental import experimental + + +@experimental +class CustomModelFineTuningSchema(FineTuningVerticalSchema): + # This is meant to match the yaml definition NOT the models defined in _restclient + + model_provider = StringTransformedEnum(required=True, allowed_values=ModelProvider.CUSTOM) + hyperparameters = fields.Dict(keys=fields.Str(), values=fields.Str(allow_none=True)) + + @post_load + def post_load_processing(self, data: Dict, **kwargs) -> Dict[str, Any]: + """Post-load processing for the schema. + + :param data: Dictionary of parsed values from the yaml. + :type data: typing.Dict + + :return Dictionary of parsed values from the yaml. + :rtype Dict[str, Any] + """ + + data.pop("model_provider") + return data diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/finetuning_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/finetuning_job.py new file mode 100644 index 00000000..e1b2270e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/finetuning_job.py @@ -0,0 +1,21 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from azure.ai.ml._schema.job import BaseJobSchema +from azure.ai.ml._schema.job.input_output_fields_provider import OutputsField +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml._schema.core.fields import ( + NestedField, +) +from ..queue_settings import QueueSettingsSchema +from ..job_resources import JobResourcesSchema + +# This is meant to match the yaml definition NOT the models defined in _restclient + + +@experimental +class FineTuningJobSchema(BaseJobSchema): + outputs = OutputsField() + queue_settings = NestedField(QueueSettingsSchema) + resources = NestedField(JobResourcesSchema) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/finetuning_vertical.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/finetuning_vertical.py new file mode 100644 index 00000000..10ac51ff --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/finetuning_vertical.py @@ -0,0 +1,73 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields + +from azure.ai.ml._schema._finetuning.finetuning_job import FineTuningJobSchema +from azure.ai.ml._schema._finetuning.constants import SnakeCaseFineTuningTaskTypes +from azure.ai.ml._schema.core.fields import ( + ArmVersionedStr, + LocalPathField, + NestedField, + StringTransformedEnum, + UnionField, +) +from azure.ai.ml.constants import JobType +from azure.ai.ml._utils.utils import snake_to_camel +from azure.ai.ml._schema.job.input_output_entry import DataInputSchema, ModelInputSchema +from azure.ai.ml.constants._job.finetuning import FineTuningConstants +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.constants._common import AzureMLResourceType + + +# This is meant to match the yaml definition NOT the models defined in _restclient + + +@experimental +class FineTuningVerticalSchema(FineTuningJobSchema): + type = StringTransformedEnum(required=True, allowed_values=JobType.FINE_TUNING) + model = NestedField(ModelInputSchema, required=True) + training_data = UnionField( + [ + NestedField(DataInputSchema), + ArmVersionedStr(azureml_type=AzureMLResourceType.DATA), + fields.Str(metadata={"pattern": r"^(http(s)?):.*"}), + fields.Str(metadata={"pattern": r"^(wasb(s)?):.*"}), + LocalPathField(pattern=r"^file:.*"), + LocalPathField( + pattern=r"^(?!(azureml|http(s)?|wasb(s)?|file):).*", + ), + ] + ) + validation_data = UnionField( + [ + NestedField(DataInputSchema), + ArmVersionedStr(azureml_type=AzureMLResourceType.DATA), + fields.Str(metadata={"pattern": r"^(http(s)?):.*"}), + fields.Str(metadata={"pattern": r"^(wasb(s)?):.*"}), + LocalPathField(pattern=r"^file:.*"), + LocalPathField( + pattern=r"^(?!(azureml|http(s)?|wasb(s)?|file):).*", + ), + ] + ) + + task = StringTransformedEnum( + allowed_values=[ + SnakeCaseFineTuningTaskTypes.CHAT_COMPLETION, + SnakeCaseFineTuningTaskTypes.TEXT_COMPLETION, + SnakeCaseFineTuningTaskTypes.TEXT_CLASSIFICATION, + SnakeCaseFineTuningTaskTypes.QUESTION_ANSWERING, + SnakeCaseFineTuningTaskTypes.TEXT_SUMMARIZATION, + SnakeCaseFineTuningTaskTypes.TOKEN_CLASSIFICATION, + SnakeCaseFineTuningTaskTypes.TEXT_TRANSLATION, + SnakeCaseFineTuningTaskTypes.IMAGE_CLASSIFICATION, + SnakeCaseFineTuningTaskTypes.IMAGE_INSTANCE_SEGMENTATION, + SnakeCaseFineTuningTaskTypes.IMAGE_OBJECT_DETECTION, + SnakeCaseFineTuningTaskTypes.VIDEO_MULTI_OBJECT_TRACKING, + ], + casing_transform=snake_to_camel, + data_key=FineTuningConstants.TaskType, + required=True, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_notification/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_notification/__init__.py new file mode 100644 index 00000000..b95c2d6d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_notification/__init__.py @@ -0,0 +1,11 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) + +from .notification_schema import NotificationSchema + +__all__ = [ + "NotificationSchema", +] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_notification/notification_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_notification/notification_schema.py new file mode 100644 index 00000000..21245bc9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_notification/notification_schema.py @@ -0,0 +1,24 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, validate, post_load + +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + + +class NotificationSchema(metaclass=PatchedSchemaMeta): + email_on = fields.List( + fields.Str(validate=validate.OneOf(["JobCompleted", "JobFailed", "JobCancelled"])), + required=True, + allow_none=False, + ) + emails = fields.List(fields.Str, required=True, allow_none=False) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._notification.notification import Notification + + return Notification(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/__init__.py new file mode 100644 index 00000000..1d08c92a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/__init__.py @@ -0,0 +1,9 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) + +from .sweep_job import SweepJobSchema + +__all__ = ["SweepJobSchema"] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/_constants.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/_constants.py new file mode 100644 index 00000000..644c3046 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/_constants.py @@ -0,0 +1,6 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +BASE_ERROR_MESSAGE = "Search space type not one of: " diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/parameterized_sweep.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/parameterized_sweep.py new file mode 100644 index 00000000..e48c9637 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/parameterized_sweep.py @@ -0,0 +1,30 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from azure.ai.ml._schema.core.fields import ExperimentalField, NestedField, PathAwareSchema +from azure.ai.ml._schema.job_resource_configuration import JobResourceConfigurationSchema + +from ..job.job_limits import SweepJobLimitsSchema +from ..queue_settings import QueueSettingsSchema +from .sweep_fields_provider import EarlyTerminationField, SamplingAlgorithmField, SearchSpaceField +from .sweep_objective import SweepObjectiveSchema + + +class ParameterizedSweepSchema(PathAwareSchema): + """Shared schema for standalone and pipeline sweep job.""" + + sampling_algorithm = SamplingAlgorithmField() + search_space = SearchSpaceField() + objective = NestedField( + SweepObjectiveSchema, + required=True, + metadata={"description": "The name and optimization goal of the primary metric."}, + ) + early_termination = EarlyTerminationField() + limits = NestedField( + SweepJobLimitsSchema, + required=True, + ) + queue_settings = ExperimentalField(NestedField(QueueSettingsSchema)) + resources = NestedField(JobResourceConfigurationSchema) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/__init__.py new file mode 100644 index 00000000..d206a9b6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/__init__.py @@ -0,0 +1,21 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) + +from .choice import ChoiceSchema +from .normal import IntegerQNormalSchema, NormalSchema, QNormalSchema +from .randint import RandintSchema +from .uniform import IntegerQUniformSchema, QUniformSchema, UniformSchema + +__all__ = [ + "ChoiceSchema", + "NormalSchema", + "QNormalSchema", + "RandintSchema", + "UniformSchema", + "QUniformSchema", + "IntegerQUniformSchema", + "IntegerQNormalSchema", +] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/choice.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/choice.py new file mode 100644 index 00000000..7e6b5a76 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/choice.py @@ -0,0 +1,63 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import ValidationError, fields, post_load, pre_dump + +from azure.ai.ml._schema._sweep.search_space.normal import NormalSchema, QNormalSchema +from azure.ai.ml._schema._sweep.search_space.randint import RandintSchema +from azure.ai.ml._schema._sweep.search_space.uniform import QUniformSchema, UniformSchema +from azure.ai.ml._schema.core.fields import ( + DumpableIntegerField, + DumpableStringField, + NestedField, + StringTransformedEnum, + UnionField, +) +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml.constants._job.sweep import SearchSpace + + +class ChoiceSchema(metaclass=PatchedSchemaMeta): + values = fields.List( + UnionField( + [ + DumpableIntegerField(strict=True), + DumpableStringField(), + fields.Float(), + fields.Dict( + keys=fields.Str(), + values=UnionField( + [ + NestedField("ChoiceSchema"), + NestedField(NormalSchema()), + NestedField(QNormalSchema()), + NestedField(RandintSchema()), + NestedField(UniformSchema()), + NestedField(QUniformSchema()), + DumpableIntegerField(strict=True), + fields.Float(), + fields.Str(), + ] + ), + ), + ] + ) + ) + type = StringTransformedEnum(required=True, allowed_values=SearchSpace.CHOICE) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.sweep import Choice + + return Choice(**data) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.sweep import Choice + + if not isinstance(data, Choice): + raise ValidationError("Cannot dump non-Choice object into ChoiceSchema") + return data diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/normal.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/normal.py new file mode 100644 index 00000000..b29f175e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/normal.py @@ -0,0 +1,60 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import ValidationError, fields, post_load +from marshmallow.decorators import pre_dump + +from azure.ai.ml._schema.core.fields import DumpableIntegerField, StringTransformedEnum, UnionField +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml.constants._common import TYPE +from azure.ai.ml.constants._job.sweep import SearchSpace + + +class NormalSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum(required=True, allowed_values=SearchSpace.NORMAL_LOGNORMAL) + mu = UnionField([DumpableIntegerField(strict=True), fields.Float()], required=True) + sigma = UnionField([DumpableIntegerField(strict=True), fields.Float()], required=True) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.sweep import LogNormal, Normal + + return Normal(**data) if data[TYPE] == SearchSpace.NORMAL else LogNormal(**data) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.sweep import Normal + + if not isinstance(data, Normal): + raise ValidationError("Cannot dump non-Normal object into NormalSchema") + return data + + +class QNormalSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum(required=True, allowed_values=SearchSpace.QNORMAL_QLOGNORMAL) + mu = UnionField([DumpableIntegerField(strict=True), fields.Float()], required=True) + sigma = UnionField([DumpableIntegerField(strict=True), fields.Float()], required=True) + q = UnionField([DumpableIntegerField(strict=True), fields.Float()], required=True) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.sweep import QLogNormal, QNormal + + return QNormal(**data) if data[TYPE] == SearchSpace.QNORMAL else QLogNormal(**data) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.sweep import QLogNormal, QNormal + + if not isinstance(data, (QNormal, QLogNormal)): + raise ValidationError("Cannot dump non-QNormal or non-QLogNormal object into QNormalSchema") + return data + + +class IntegerQNormalSchema(QNormalSchema): + mu = DumpableIntegerField(strict=True, required=True) + sigma = DumpableIntegerField(strict=True, required=True) + q = DumpableIntegerField(strict=True, required=True) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/randint.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/randint.py new file mode 100644 index 00000000..8df0d4f5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/randint.py @@ -0,0 +1,30 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import ValidationError, fields, post_load, pre_dump + +from azure.ai.ml._schema.core.fields import StringTransformedEnum +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml.constants._job.sweep import SearchSpace + + +class RandintSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum(required=True, allowed_values=SearchSpace.RANDINT) + upper = fields.Integer(required=True) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.sweep import Randint + + return Randint(**data) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.sweep import Randint + + if not isinstance(data, Randint): + raise ValidationError("Cannot dump non-Randint object into RandintSchema") + return data diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/uniform.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/uniform.py new file mode 100644 index 00000000..2eb1d98f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/uniform.py @@ -0,0 +1,62 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import ValidationError, fields, post_load, pre_dump + +from azure.ai.ml._schema._sweep._constants import BASE_ERROR_MESSAGE +from azure.ai.ml._schema.core.fields import DumpableIntegerField, StringTransformedEnum, UnionField +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml.constants._common import TYPE +from azure.ai.ml.constants._job.sweep import SearchSpace + + +class UniformSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum(required=True, allowed_values=SearchSpace.UNIFORM_LOGUNIFORM) + min_value = UnionField([DumpableIntegerField(strict=True), fields.Float()], required=True) + max_value = UnionField([DumpableIntegerField(strict=True), fields.Float()], required=True) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.sweep import LogUniform, Uniform + + if not isinstance(data, (Uniform, LogUniform)): + raise ValidationError("Cannot dump non-Uniform or non-LogUniform object into UniformSchema") + if data.type.lower() not in SearchSpace.UNIFORM_LOGUNIFORM: + raise ValidationError(BASE_ERROR_MESSAGE + str(SearchSpace.UNIFORM_LOGUNIFORM)) + return data + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.sweep import LogUniform, Uniform + + return Uniform(**data) if data[TYPE] == SearchSpace.UNIFORM else LogUniform(**data) + + +class QUniformSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum(required=True, allowed_values=SearchSpace.QUNIFORM_QLOGUNIFORM) + min_value = UnionField([DumpableIntegerField(strict=True), fields.Float()], required=True) + max_value = UnionField([DumpableIntegerField(strict=True), fields.Float()], required=True) + q = UnionField([DumpableIntegerField(strict=True), fields.Float()], required=True) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.sweep import QLogUniform, QUniform + + return QUniform(**data) if data[TYPE] == SearchSpace.QUNIFORM else QLogUniform(**data) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.sweep import QLogUniform, QUniform + + if not isinstance(data, (QUniform, QLogUniform)): + raise ValidationError("Cannot dump non-QUniform or non-QLogUniform object into UniformSchema") + return data + + +class IntegerQUniformSchema(QUniformSchema): + min_value = DumpableIntegerField(strict=True, required=True) + max_value = DumpableIntegerField(strict=True, required=True) + q = DumpableIntegerField(strict=True, required=True) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_fields_provider.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_fields_provider.py new file mode 100644 index 00000000..e96d4fa2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_fields_provider.py @@ -0,0 +1,77 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields + +from azure.ai.ml._restclient.v2022_02_01_preview.models import SamplingAlgorithmType +from azure.ai.ml._schema._sweep.search_space import ( + ChoiceSchema, + NormalSchema, + QNormalSchema, + QUniformSchema, + RandintSchema, + UniformSchema, +) +from azure.ai.ml._schema._sweep.sweep_sampling_algorithm import ( + BayesianSamplingAlgorithmSchema, + GridSamplingAlgorithmSchema, + RandomSamplingAlgorithmSchema, +) +from azure.ai.ml._schema._sweep.sweep_termination import ( + BanditPolicySchema, + MedianStoppingPolicySchema, + TruncationSelectionPolicySchema, +) +from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum, UnionField + + +def SamplingAlgorithmField(): + return UnionField( + [ + SamplingAlgorithmTypeField(), + NestedField(RandomSamplingAlgorithmSchema()), + NestedField(GridSamplingAlgorithmSchema()), + NestedField(BayesianSamplingAlgorithmSchema()), + ] + ) + + +def SamplingAlgorithmTypeField(): + return StringTransformedEnum( + required=True, + allowed_values=[ + SamplingAlgorithmType.BAYESIAN, + SamplingAlgorithmType.GRID, + SamplingAlgorithmType.RANDOM, + ], + metadata={"description": "The sampling algorithm to use for the hyperparameter sweep."}, + ) + + +def SearchSpaceField(): + return fields.Dict( + keys=fields.Str(), + values=UnionField( + [ + NestedField(ChoiceSchema()), + NestedField(UniformSchema()), + NestedField(QUniformSchema()), + NestedField(NormalSchema()), + NestedField(QNormalSchema()), + NestedField(RandintSchema()), + ] + ), + metadata={"description": "The parameters to sweep over the trial."}, + ) + + +def EarlyTerminationField(): + return UnionField( + [ + NestedField(BanditPolicySchema()), + NestedField(MedianStoppingPolicySchema()), + NestedField(TruncationSelectionPolicySchema()), + ], + metadata={"description": "The early termination policy to be applied to the Sweep runs."}, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_job.py new file mode 100644 index 00000000..f835ed0a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_job.py @@ -0,0 +1,18 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from azure.ai.ml._schema._sweep.parameterized_sweep import ParameterizedSweepSchema +from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum +from azure.ai.ml._schema.job import BaseJobSchema, ParameterizedCommandSchema +from azure.ai.ml._schema.job.input_output_fields_provider import InputsField, OutputsField +from azure.ai.ml.constants import JobType + +# This is meant to match the yaml definition NOT the models defined in _restclient + + +class SweepJobSchema(BaseJobSchema, ParameterizedSweepSchema): + type = StringTransformedEnum(required=True, allowed_values=JobType.SWEEP) + trial = NestedField(ParameterizedCommandSchema, required=True) + inputs = InputsField() + outputs = OutputsField() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_objective.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_objective.py new file mode 100644 index 00000000..fdc24fdf --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_objective.py @@ -0,0 +1,31 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging + +from marshmallow import fields, post_load + +from azure.ai.ml._restclient.v2022_10_01.models import Goal +from azure.ai.ml._schema.core.fields import StringTransformedEnum +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml._utils.utils import camel_to_snake + +module_logger = logging.getLogger(__name__) + + +class SweepObjectiveSchema(metaclass=PatchedSchemaMeta): + goal = StringTransformedEnum( + required=True, + allowed_values=[Goal.MINIMIZE, Goal.MAXIMIZE], + casing_transform=camel_to_snake, + ) + primary_metric = fields.Str(required=True) + + @post_load + def make(self, data, **kwargs) -> "Objective": + from azure.ai.ml.entities._job.sweep.objective import Objective + + return Objective(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_sampling_algorithm.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_sampling_algorithm.py new file mode 100644 index 00000000..2b8137b4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_sampling_algorithm.py @@ -0,0 +1,103 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging + +from marshmallow import ValidationError, fields, post_load, pre_dump + +from azure.ai.ml._restclient.v2023_02_01_preview.models import RandomSamplingAlgorithmRule, SamplingAlgorithmType +from azure.ai.ml._schema.core.fields import StringTransformedEnum, UnionField +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml._utils.utils import camel_to_snake + +module_logger = logging.getLogger(__name__) + + +class RandomSamplingAlgorithmSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum( + required=True, + allowed_values=SamplingAlgorithmType.RANDOM, + casing_transform=camel_to_snake, + ) + + seed = fields.Int() + + logbase = UnionField( + [ + fields.Number(), + fields.Str(), + ], + data_key="logbase", + ) + + rule = StringTransformedEnum( + allowed_values=[ + RandomSamplingAlgorithmRule.RANDOM, + RandomSamplingAlgorithmRule.SOBOL, + ], + casing_transform=camel_to_snake, + ) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.sweep import RandomSamplingAlgorithm + + data.pop("type") + return RandomSamplingAlgorithm(**data) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.sweep import RandomSamplingAlgorithm + + if not isinstance(data, RandomSamplingAlgorithm): + raise ValidationError("Cannot dump non-RandomSamplingAlgorithm object into RandomSamplingAlgorithm") + return data + + +class GridSamplingAlgorithmSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum( + required=True, + allowed_values=SamplingAlgorithmType.GRID, + casing_transform=camel_to_snake, + ) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.sweep import GridSamplingAlgorithm + + data.pop("type") + return GridSamplingAlgorithm(**data) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.sweep import GridSamplingAlgorithm + + if not isinstance(data, GridSamplingAlgorithm): + raise ValidationError("Cannot dump non-GridSamplingAlgorithm object into GridSamplingAlgorithm") + return data + + +class BayesianSamplingAlgorithmSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum( + required=True, + allowed_values=SamplingAlgorithmType.BAYESIAN, + casing_transform=camel_to_snake, + ) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.sweep import BayesianSamplingAlgorithm + + data.pop("type") + return BayesianSamplingAlgorithm(**data) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.sweep import BayesianSamplingAlgorithm + + if not isinstance(data, BayesianSamplingAlgorithm): + raise ValidationError("Cannot dump non-BayesianSamplingAlgorithm object into BayesianSamplingAlgorithm") + return data diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_termination.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_termination.py new file mode 100644 index 00000000..08fa9145 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_termination.py @@ -0,0 +1,95 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging + +from marshmallow import ValidationError, fields, post_load, pre_dump + +from azure.ai.ml._restclient.v2022_02_01_preview.models import EarlyTerminationPolicyType +from azure.ai.ml._schema.core.fields import StringTransformedEnum +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml._utils.utils import camel_to_snake + +module_logger = logging.getLogger(__name__) + + +class EarlyTerminationPolicySchema(metaclass=PatchedSchemaMeta): + evaluation_interval = fields.Int(allow_none=True) + delay_evaluation = fields.Int(allow_none=True) + + +class BanditPolicySchema(EarlyTerminationPolicySchema): + type = StringTransformedEnum( + required=True, + allowed_values=EarlyTerminationPolicyType.BANDIT, + casing_transform=camel_to_snake, + ) + slack_factor = fields.Float(allow_none=True) + slack_amount = fields.Float(allow_none=True) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.sweep import BanditPolicy + + data.pop("type", None) + return BanditPolicy(**data) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.sweep import BanditPolicy + + if not isinstance(data, BanditPolicy): + raise ValidationError("Cannot dump non-BanditPolicy object into BanditPolicySchema") + return data + + +class MedianStoppingPolicySchema(EarlyTerminationPolicySchema): + type = StringTransformedEnum( + required=True, + allowed_values=EarlyTerminationPolicyType.MEDIAN_STOPPING, + casing_transform=camel_to_snake, + ) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.sweep import MedianStoppingPolicy + + data.pop("type", None) + return MedianStoppingPolicy(**data) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.sweep import MedianStoppingPolicy + + if not isinstance(data, MedianStoppingPolicy): + raise ValidationError("Cannot dump non-MedicanStoppingPolicy object into MedianStoppingPolicySchema") + return data + + +class TruncationSelectionPolicySchema(EarlyTerminationPolicySchema): + type = StringTransformedEnum( + required=True, + allowed_values=EarlyTerminationPolicyType.TRUNCATION_SELECTION, + casing_transform=camel_to_snake, + ) + truncation_percentage = fields.Int(required=True) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.sweep import TruncationSelectionPolicy + + data.pop("type", None) + return TruncationSelectionPolicy(**data) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.sweep import TruncationSelectionPolicy + + if not isinstance(data, TruncationSelectionPolicy): + raise ValidationError( + "Cannot dump non-TruncationSelectionPolicy object into TruncationSelectionPolicySchema" + ) + return data diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/__init__.py new file mode 100644 index 00000000..29a4fcd3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/data_binding_expression.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/data_binding_expression.py new file mode 100644 index 00000000..611c80a2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/data_binding_expression.py @@ -0,0 +1,88 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from typing import Union + +from marshmallow import Schema, fields + +from azure.ai.ml._schema.core.fields import DataBindingStr, ExperimentalField, NestedField, UnionField +from azure.ai.ml._schema.core.schema import PathAwareSchema + +DATA_BINDING_SUPPORTED_KEY = "_data_binding_supported" + + +def _is_literal(field): + return not isinstance(field, (NestedField, fields.List, fields.Dict, UnionField)) + + +def _add_data_binding_to_field(field, attrs_to_skip, schema_stack): + if hasattr(field, DATA_BINDING_SUPPORTED_KEY) and getattr(field, DATA_BINDING_SUPPORTED_KEY): + return field + data_binding_field = DataBindingStr() + if isinstance(field, UnionField): + for field_obj in field.union_fields: + if not _is_literal(field_obj): + _add_data_binding_to_field(field_obj, attrs_to_skip, schema_stack=schema_stack) + field.insert_union_field(data_binding_field) + elif isinstance(field, fields.Dict): + # handle dict, dict value can be None + if field.value_field is not None: + field.value_field = _add_data_binding_to_field(field.value_field, attrs_to_skip, schema_stack=schema_stack) + elif isinstance(field, fields.List): + # handle list + field.inner = _add_data_binding_to_field(field.inner, attrs_to_skip, schema_stack=schema_stack) + elif isinstance(field, ExperimentalField): + field = ExperimentalField( + _add_data_binding_to_field(field.experimental_field, attrs_to_skip, schema_stack=schema_stack), + data_key=field.data_key, + attribute=field.attribute, + dump_only=field.dump_only, + required=field.required, + allow_none=field.allow_none, + ) + elif isinstance(field, NestedField): + # handle nested field + support_data_binding_expression_for_fields(field.schema, attrs_to_skip, schema_stack=schema_stack) + else: + # change basic fields to union + field = UnionField( + [data_binding_field, field], + data_key=field.data_key, + attribute=field.attribute, + dump_only=field.dump_only, + required=field.required, + allow_none=field.allow_none, + ) + + setattr(field, DATA_BINDING_SUPPORTED_KEY, True) + return field + + +# pylint: disable-next=docstring-missing-param +def support_data_binding_expression_for_fields( # pylint: disable=name-too-long + schema: Union[PathAwareSchema, Schema], attrs_to_skip=None, schema_stack=None +): + """Update fields inside schema to support data binding string. + + Only first layer of recursive schema is supported now. + """ + if hasattr(schema, DATA_BINDING_SUPPORTED_KEY) and getattr(schema, DATA_BINDING_SUPPORTED_KEY): + return + + setattr(schema, DATA_BINDING_SUPPORTED_KEY, True) + + if attrs_to_skip is None: + attrs_to_skip = [] + if schema_stack is None: + schema_stack = [] + schema_type_name = type(schema).__name__ + if schema_type_name in schema_stack: + return + schema_stack.append(schema_type_name) + for attr, field_obj in schema.load_fields.items(): + if attr not in attrs_to_skip: + schema.load_fields[attr] = _add_data_binding_to_field(field_obj, attrs_to_skip, schema_stack=schema_stack) + for attr, field_obj in schema.dump_fields.items(): + if attr not in attrs_to_skip: + schema.dump_fields[attr] = _add_data_binding_to_field(field_obj, attrs_to_skip, schema_stack=schema_stack) + schema_stack.pop() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/utils.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/utils.py new file mode 100644 index 00000000..c1ee3568 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/utils.py @@ -0,0 +1,94 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import copy +import logging +import re +from collections import OrderedDict +from typing import Any, Dict, Optional, Union + +from marshmallow.exceptions import ValidationError + +module_logger = logging.getLogger(__name__) + + +class ArmId(str): + def __new__(cls, content): + validate_arm_str(content) + return str.__new__(cls, content) + + +def validate_arm_str(arm_str: Union[ArmId, str]) -> bool: + """Validate whether the given string is in fact in the format of an ARM ID. + + :param arm_str: The string to validate. + :type arm_str: Either a string (in case of incorrect formatting) or ArmID (in case of correct formatting). + :returns: True if the string is correctly formatted, False otherwise. + :rtype: bool + """ + reg_str = ( + r"/subscriptions/[0-9a-f]{8}-([0-9a-f]{4}-){3}[0-9a-f]{12}?/resourcegroups/.*/providers/[a-z.a-z]*/[a-z]*/.*" + ) + lowered = arm_str.lower() + match = re.match(reg_str, lowered) + if match and match.group() == lowered: + return True + raise ValidationError(f"ARM string {arm_str} is not formatted correctly.") + + +def get_subnet_str(vnet_name: str, subnet: str, sub_id: Optional[str] = None, rg: Optional[str] = None) -> str: + if vnet_name and not subnet: + raise ValidationError("Subnet is required when vnet name is specified.") + try: + validate_arm_str(subnet) + return subnet + except ValidationError: + return ( + f"/subscriptions/{sub_id}/resourceGroups/{rg}/" + f"providers/Microsoft.Network/virtualNetworks/{vnet_name}/subnets/{subnet}" + ) + + +def replace_key_in_odict(odict: OrderedDict, old_key: Any, new_key: Any): + if not odict or old_key not in odict: + return odict + return OrderedDict([(new_key, v) if k == old_key else (k, v) for k, v in odict.items()]) + + +# This is temporary until deployments(batch/K8S) support registry references +def exit_if_registry_assets(data: Dict, caller: str) -> None: + startswith = "azureml://registries/" + if ( + "environment" in data + and data["environment"] + and isinstance(data["environment"], str) + and data["environment"].startswith(startswith) + ): + raise ValidationError(f"Registry reference for environments is not supported for {caller}") + if "model" in data and data["model"] and isinstance(data["model"], str) and data["model"].startswith(startswith): + raise ValidationError(f"Registry reference for models is not supported for {caller}") + if ( + "code_configuration" in data + and data["code_configuration"].code + and isinstance(data["code_configuration"].code, str) + and data["code_configuration"].code.startswith(startswith) + ): + raise ValidationError(f"Registry reference for code_configuration.code is not supported for {caller}") + + +def _resolve_group_inputs_for_component(component, **kwargs): # pylint: disable=unused-argument + # Try resolve object's inputs & outputs and return a resolved new object + from azure.ai.ml.entities._inputs_outputs import GroupInput + + result = copy.copy(component) + + flatten_inputs = {} + for key, val in result.inputs.items(): + if isinstance(val, GroupInput): + flatten_inputs.update(val.flatten(group_parameter_name=key)) + continue + flatten_inputs[key] = val + + # Flatten group inputs + result._inputs = flatten_inputs # pylint: disable=protected-access + return result diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/__init__.py new file mode 100644 index 00000000..29a4fcd3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/artifact.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/artifact.py new file mode 100644 index 00000000..fc107a78 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/artifact.py @@ -0,0 +1,24 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging + +from marshmallow import fields, post_load + +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY + +from .asset import AssetSchema + +module_logger = logging.getLogger(__name__) + + +class ArtifactSchema(AssetSchema): + datastore = fields.Str(metadata={"description": "Name of the datastore to upload to."}, required=False) + + @post_load + def make(self, data, **kwargs): + data[BASE_PATH_CONTEXT_KEY] = self.context[BASE_PATH_CONTEXT_KEY] + return data diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/asset.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/asset.py new file mode 100644 index 00000000..09edb115 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/asset.py @@ -0,0 +1,42 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging + +from marshmallow import ValidationError, fields +from marshmallow.decorators import pre_load + +from azure.ai.ml._schema.core.auto_delete_setting import AutoDeleteSettingSchema +from azure.ai.ml._schema.core.fields import NestedField, VersionField, ExperimentalField +from azure.ai.ml._schema.job.creation_context import CreationContextSchema + +from ..core.resource import ResourceSchema + +module_logger = logging.getLogger(__name__) + + +class AssetSchema(ResourceSchema): + version = VersionField() + creation_context = NestedField(CreationContextSchema, dump_only=True) + latest_version = fields.Str(dump_only=True) + auto_delete_setting = ExperimentalField(NestedField(AutoDeleteSettingSchema)) + + +class AnonymousAssetSchema(AssetSchema): + version = VersionField(dump_only=True) + name = fields.Str(dump_only=True) + + @pre_load + def warn_if_named(self, data, **kwargs): + if isinstance(data, str): + raise ValidationError("Anonymous assets must be defined inline") + name = data.pop("name", None) + data.pop("version", None) + if name is not None: + module_logger.warning( + "Warning: the provided asset name '%s' will not be used for anonymous registration.", name + ) + return data diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/code_asset.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/code_asset.py new file mode 100644 index 00000000..0610caff --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/code_asset.py @@ -0,0 +1,47 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging + +from marshmallow import ValidationError, fields, post_load, pre_dump + +from azure.ai.ml._schema.core.fields import ArmStr +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, AzureMLResourceType + +from .artifact import ArtifactSchema +from .asset import AnonymousAssetSchema + +module_logger = logging.getLogger(__name__) + + +class CodeAssetSchema(ArtifactSchema): + id = ArmStr(azureml_type=AzureMLResourceType.CODE, dump_only=True) + path = fields.Str( + metadata={ + "description": "A local path or a Blob URI pointing to a file or directory where code asset is located." + } + ) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._assets import Code + + return Code(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data) + + +class AnonymousCodeAssetSchema(CodeAssetSchema, AnonymousAssetSchema): + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._assets import Code + + return Code(is_anonymous=True, base_path=self.context[BASE_PATH_CONTEXT_KEY], **data) + + @pre_dump + def validate(self, data, **kwargs): + # AnonymousCodeAssetSchema does not support None or arm string(fall back to ArmVersionedStr) + if data is None or not hasattr(data, "get"): + raise ValidationError("Code cannot be None") + return data diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/data.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/data.py new file mode 100644 index 00000000..e14afd9b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/data.py @@ -0,0 +1,25 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields, validate + +from azure.ai.ml.constants._common import AssetTypes + +from .artifact import ArtifactSchema +from .asset import AnonymousAssetSchema + + +class DataSchema(ArtifactSchema): + path = fields.Str(metadata={"description": "URI pointing to a file or folder."}, required=True) + properties = fields.Dict(dump_only=True) + type = fields.Str( + metadata={"description": "the type of data. Valid values are uri_file, uri_folder, or mltable."}, + validate=validate.OneOf([AssetTypes.URI_FILE, AssetTypes.URI_FOLDER, AssetTypes.MLTABLE]), + dump_default=AssetTypes.URI_FOLDER, + error_messages={"validator_failed": "value must be uri_file, uri_folder, or mltable."}, + ) + + +class AnonymousDataSchema(DataSchema, AnonymousAssetSchema): + pass diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/environment.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/environment.py new file mode 100644 index 00000000..3ca5333f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/environment.py @@ -0,0 +1,160 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging + +from marshmallow import ValidationError, fields, post_load, pre_dump, pre_load + +from azure.ai.ml._restclient.v2022_05_01.models import ( + InferenceContainerProperties, + OperatingSystemType, + Route, +) +from azure.ai.ml._schema.core.fields import ExperimentalField, NestedField, UnionField, LocalPathField +from azure.ai.ml._schema.core.intellectual_property import IntellectualPropertySchema + +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml.constants._common import ( + ANONYMOUS_ENV_NAME, + BASE_PATH_CONTEXT_KEY, + CREATE_ENVIRONMENT_ERROR_MESSAGE, + AzureMLResourceType, + YAMLRefDocLinks, +) + +from ..core.fields import ArmStr, RegistryStr, StringTransformedEnum, VersionField +from .asset import AnonymousAssetSchema, AssetSchema + +module_logger = logging.getLogger(__name__) + + +class BuildContextSchema(metaclass=PatchedSchemaMeta): + dockerfile_path = fields.Str() + path = UnionField( + [ + LocalPathField(), + # build context also support http url + fields.URL(), + ] + ) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._assets.environment import BuildContext + + return BuildContext(**data) + + +class RouteSchema(metaclass=PatchedSchemaMeta): + port = fields.Int(required=True) + path = fields.Str(required=True) + + @post_load + def make(self, data, **kwargs): + return Route(**data) + + +class InferenceConfigSchema(metaclass=PatchedSchemaMeta): + liveness_route = NestedField(RouteSchema, required=True) + scoring_route = NestedField(RouteSchema, required=True) + readiness_route = NestedField(RouteSchema, required=True) + + @post_load + def make(self, data, **kwargs): + return InferenceContainerProperties(**data) + + +class _BaseEnvironmentSchema(AssetSchema): + id = UnionField( + [ + RegistryStr(dump_only=True), + ArmStr(azureml_type=AzureMLResourceType.ENVIRONMENT, dump_only=True), + ] + ) + build = NestedField( + BuildContextSchema, + metadata={"description": "Docker build context to create the environment. Mutually exclusive with image"}, + ) + image = fields.Str() + conda_file = UnionField([fields.Raw(), fields.Str()]) + inference_config = NestedField(InferenceConfigSchema) + os_type = StringTransformedEnum( + allowed_values=[OperatingSystemType.Linux, OperatingSystemType.Windows], + required=False, + ) + datastore = fields.Str( + metadata={ + "description": "Name of the datastore to upload to.", + "arm_type": AzureMLResourceType.DATASTORE, + }, + required=False, + ) + intellectual_property = ExperimentalField(NestedField(IntellectualPropertySchema), dump_only=True) + + @pre_load + def pre_load(self, data, **kwargs): + if isinstance(data, str): + raise ValidationError("Environment schema data cannot be a string") + # validates that "channels" and "dependencies" are not included in the data creation. + # These properties should only be on environment conda files not in the environment creation file + if "channels" in data or "dependencies" in data: + environmentMessage = CREATE_ENVIRONMENT_ERROR_MESSAGE.format(YAMLRefDocLinks.ENVIRONMENT) + raise ValidationError(environmentMessage) + return data + + @pre_dump + def validate(self, data, **kwargs): + from azure.ai.ml.entities._assets import Environment + + if isinstance(data, Environment): + if data._intellectual_property: # pylint: disable=protected-access + ipp_field = data._intellectual_property # pylint: disable=protected-access + if ipp_field: + setattr(data, "intellectual_property", ipp_field) + return data + if data is None or not hasattr(data, "get"): + raise ValidationError("Environment cannot be None") + return data + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._assets import Environment + + try: + obj = Environment(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data) + except FileNotFoundError as e: + # Environment.__init__() will raise FileNotFoundError if build.path is not found when trying to calculate + # the hash for anonymous. Raise ValidationError instead to collect all errors in schema validation. + raise ValidationError("Environment file not found: {}".format(e)) from e + return obj + + +class EnvironmentSchema(_BaseEnvironmentSchema): + name = fields.Str(required=True) + version = VersionField() + + +class AnonymousEnvironmentSchema(_BaseEnvironmentSchema, AnonymousAssetSchema): + @pre_load + # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype + def trim_dump_only(self, data, **kwargs): + """trim_dump_only in PathAwareSchema removes all properties which are dump only. + + By the time we reach this schema name and version properties are removed so no warning is shown. This method + overrides trim_dump_only in PathAwareSchema to check for name and version and raise warning if present. And then + calls the it + """ + if isinstance(data, str) or data is None: + return data + name = data.pop("name", None) + data.pop("version", None) + # CliV2AnonymousEnvironment is a default name for anonymous environment + if name is not None and name != ANONYMOUS_ENV_NAME: + module_logger.warning( + "Warning: the provided asset name '%s' will not be used for anonymous registration", + name, + ) + return super(AnonymousEnvironmentSchema, self).trim_dump_only(data, **kwargs) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/federated_learning_silo.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/federated_learning_silo.py new file mode 100644 index 00000000..80c4ba7e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/federated_learning_silo.py @@ -0,0 +1,24 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# # TODO determine where this file should live. +from marshmallow import fields + +from azure.ai.ml._schema.core.resource import YamlFileSchema +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml._schema.job.input_output_fields_provider import InputsField + + +# Inherits from YamlFileSchema instead of something for specific because +# this does not represent a server-side resource. +@experimental +class FederatedLearningSiloSchema(YamlFileSchema): + """The YAML definition of a silo for describing a federated learning data target. + Unlike most SDK/CLI schemas, this schema does not represent an AML resource; + it is merely used to simplify the loading and validation of silos which are used + to create FL pipeline nodes. + """ + + compute = fields.Str() + datastore = fields.Str() + inputs = InputsField() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/index.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/index.py new file mode 100644 index 00000000..4a97c0ab --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/index.py @@ -0,0 +1,30 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.fields import ArmStr +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, AzureMLResourceType + +from .artifact import ArtifactSchema + + +class IndexAssetSchema(ArtifactSchema): + name = fields.Str(required=True, allow_none=False) + id = ArmStr(azureml_type=AzureMLResourceType.INDEX, dump_only=True) + stage = fields.Str(default="Development") + path = fields.Str( + required=True, + metadata={ + "description": "A local path or a Blob URI pointing to a file or directory where index files are located." + }, + ) + properties = fields.Dict(keys=fields.Str(), values=fields.Str()) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._assets import Index + + return Index(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/model.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/model.py new file mode 100644 index 00000000..60c17f63 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/model.py @@ -0,0 +1,65 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging + +from marshmallow import fields, post_load, pre_dump + +from azure.ai.ml._schema.core.fields import ExperimentalField, NestedField +from azure.ai.ml._schema.core.intellectual_property import IntellectualPropertySchema +from azure.ai.ml._schema.core.schema import PathAwareSchema +from azure.ai.ml._schema.job import CreationContextSchema +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, AssetTypes, AzureMLResourceType + +from ..core.fields import ArmVersionedStr, StringTransformedEnum, VersionField + +module_logger = logging.getLogger(__name__) + + +class ModelSchema(PathAwareSchema): + name = fields.Str(required=True) + id = ArmVersionedStr(azureml_type=AzureMLResourceType.MODEL, dump_only=True) + type = StringTransformedEnum( + allowed_values=[ + AssetTypes.CUSTOM_MODEL, + AssetTypes.MLFLOW_MODEL, + AssetTypes.TRITON_MODEL, + ], + metadata={"description": "The storage format for this entity. Used for NCD."}, + ) + path = fields.Str() + version = VersionField() + description = fields.Str() + properties = fields.Dict() + tags = fields.Dict() + stage = fields.Str() + utc_time_created = fields.DateTime(format="iso", dump_only=True) + flavors = fields.Dict() + creation_context = NestedField(CreationContextSchema, dump_only=True) + job_name = fields.Str(dump_only=True) + latest_version = fields.Str(dump_only=True) + datastore = fields.Str(metadata={"description": "Name of the datastore to upload to."}, required=False) + intellectual_property = ExperimentalField(NestedField(IntellectualPropertySchema, required=False), dump_only=True) + system_metadata = fields.Dict() + + @pre_dump + def validate(self, data, **kwargs): + if data._intellectual_property: # pylint: disable=protected-access + ipp_field = data._intellectual_property # pylint: disable=protected-access + if ipp_field: + setattr(data, "intellectual_property", ipp_field) + return data + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._assets import Model + + return Model(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data) + + +class AnonymousModelSchema(ModelSchema): + name = fields.Str() + version = VersionField() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/__init__.py new file mode 100644 index 00000000..29a4fcd3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/base_environment_source.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/base_environment_source.py new file mode 100644 index 00000000..09e0a56c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/base_environment_source.py @@ -0,0 +1,23 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging +from marshmallow import fields, post_load +from azure.ai.ml._schema.core.schema import PathAwareSchema + + +module_logger = logging.getLogger(__name__) + + +class BaseEnvironmentSourceSchema(PathAwareSchema): + type = fields.Str() + resource_id = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import BaseEnvironment + + return BaseEnvironment(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/inference_server.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/inference_server.py new file mode 100644 index 00000000..c6e38331 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/inference_server.py @@ -0,0 +1,51 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument,no-else-return + +import logging + +from marshmallow import post_load +from azure.ai.ml._schema._deployment.code_configuration_schema import CodeConfigurationSchema +from azure.ai.ml._schema.core.fields import StringTransformedEnum, NestedField +from azure.ai.ml._schema.core.schema import PathAwareSchema +from azure.ai.ml.constants._common import InferenceServerType +from .online_inference_configuration import OnlineInferenceConfigurationSchema + + +module_logger = logging.getLogger(__name__) + + +class InferenceServerSchema(PathAwareSchema): + type = StringTransformedEnum( + allowed_values=[ + InferenceServerType.AZUREML_ONLINE, + InferenceServerType.AZUREML_BATCH, + InferenceServerType.CUSTOM, + InferenceServerType.TRITON, + ], + required=True, + ) + code_configuration = NestedField(CodeConfigurationSchema) # required for batch and online + inference_configuration = NestedField(OnlineInferenceConfigurationSchema) # required for custom and Triton + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import ( + AzureMLOnlineInferencingServer, + AzureMLBatchInferencingServer, + CustomInferencingServer, + TritonInferencingServer, + ) + + if data["type"] == InferenceServerType.AZUREML_ONLINE: + return AzureMLOnlineInferencingServer(**data) + elif data["type"] == InferenceServerType.AZUREML_BATCH: + return AzureMLBatchInferencingServer(**data) + elif data["type"] == InferenceServerType.CUSTOM: + return CustomInferencingServer(**data) + elif data["type"] == InferenceServerType.TRITON: + return TritonInferencingServer(**data) + else: + return None diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/model_configuration.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/model_configuration.py new file mode 100644 index 00000000..0e5a54a5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/model_configuration.py @@ -0,0 +1,30 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging + +from marshmallow import fields, post_load +from azure.ai.ml._schema.core.schema import PathAwareSchema +from azure.ai.ml._schema.core.fields import StringTransformedEnum + + +module_logger = logging.getLogger(__name__) + + +class ModelConfigurationSchema(PathAwareSchema): + mode = StringTransformedEnum( + allowed_values=[ + "copy", + "download", + ] + ) + mount_path = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import ModelConfiguration + + return ModelConfiguration(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/model_package.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/model_package.py new file mode 100644 index 00000000..142c85c8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/model_package.py @@ -0,0 +1,41 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.schema import PathAwareSchema +from azure.ai.ml._schema.core.fields import UnionField, NestedField, StringTransformedEnum +from .inference_server import InferenceServerSchema +from .model_configuration import ModelConfigurationSchema +from .model_package_input import ModelPackageInputSchema +from .base_environment_source import BaseEnvironmentSourceSchema + +module_logger = logging.getLogger(__name__) + + +class ModelPackageSchema(PathAwareSchema): + target_environment = UnionField( + union_fields=[ + fields.Dict(keys=StringTransformedEnum(allowed_values=["name"]), values=fields.Str()), + fields.Str(required=True), + ] + ) + base_environment_source = NestedField(BaseEnvironmentSourceSchema) + inferencing_server = NestedField(InferenceServerSchema) + model_configuration = NestedField(ModelConfigurationSchema) + inputs = fields.List(NestedField(ModelPackageInputSchema)) + tags = fields.Dict() + environment_variables = fields.Dict( + metadata={"description": "Environment variables configuration for the model package."} + ) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import ModelPackage + + return ModelPackage(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/model_package_input.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/model_package_input.py new file mode 100644 index 00000000..a1a1dd8b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/model_package_input.py @@ -0,0 +1,81 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging + +from marshmallow import fields, post_load +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml._schema.core.schema import PathAwareSchema +from azure.ai.ml._schema.core.fields import StringTransformedEnum, UnionField, NestedField + +module_logger = logging.getLogger(__name__) + + +class PathBaseSchema(PathAwareSchema): + input_path_type = StringTransformedEnum( + allowed_values=[ + "path_id", + "url", + "path_version", + ], + casing_transform=camel_to_snake, + ) + + +class PackageInputPathIdSchema(PathBaseSchema): + resource_id = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._assets._artifacts._package.model_package import PackageInputPathId + + return PackageInputPathId(**data) + + +class PackageInputPathUrlSchema(PathBaseSchema): + url = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._assets._artifacts._package.model_package import PackageInputPathUrl + + return PackageInputPathUrl(**data) + + +class PackageInputPathSchema(PathBaseSchema): + resource_name = fields.Str() + resource_version = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._assets._artifacts._package.model_package import PackageInputPathVersion + + return PackageInputPathVersion(**data) + + +class ModelPackageInputSchema(PathAwareSchema): + type = StringTransformedEnum(allowed_values=["uri_file", "uri_folder"], casing_transform=camel_to_snake) + mode = StringTransformedEnum( + allowed_values=[ + "read_only_mount", + "download", + ], + casing_transform=camel_to_snake, + ) + path = UnionField( + [ + NestedField(PackageInputPathIdSchema), + NestedField(PackageInputPathUrlSchema), + NestedField(PackageInputPathSchema), + ] + ) + mount_path = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._assets._artifacts._package.model_package import ModelPackageInput + + return ModelPackageInput(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/online_inference_configuration.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/online_inference_configuration.py new file mode 100644 index 00000000..b5c313ed --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/online_inference_configuration.py @@ -0,0 +1,30 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging +from marshmallow import fields, post_load +from azure.ai.ml._schema.core.fields import NestedField +from azure.ai.ml._schema.core.schema import PathAwareSchema +from .route import RouteSchema + + +module_logger = logging.getLogger(__name__) + + +class OnlineInferenceConfigurationSchema(PathAwareSchema): + liveness_route = NestedField(RouteSchema) + readiness_route = NestedField(RouteSchema) + scoring_route = NestedField(RouteSchema) + entry_script = fields.Str() + configuration = fields.Dict() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._assets._artifacts._package.inferencing_server import ( + OnlineInferenceConfiguration, + ) + + return OnlineInferenceConfiguration(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/route.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/route.py new file mode 100644 index 00000000..86f37e06 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/route.py @@ -0,0 +1,22 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument,bad-mcs-method-argument + +import logging +from marshmallow import fields, post_load +from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta + +module_logger = logging.getLogger(__name__) + + +class RouteSchema(PatchedSchemaMeta): + port = fields.Str() + path = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._assets._artifacts._package.inferencing_server import Route + + return Route(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/workspace_asset_reference.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/workspace_asset_reference.py new file mode 100644 index 00000000..83d6d793 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/workspace_asset_reference.py @@ -0,0 +1,27 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging + +from marshmallow import fields, post_load + +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY + +from .asset import AssetSchema + +module_logger = logging.getLogger(__name__) + + +class WorkspaceAssetReferenceSchema(AssetSchema): + destination_name = fields.Str() + destination_version = fields.Str() + source_asset_id = fields.Str(required=True) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._assets.workspace_asset_reference import WorkspaceAssetReference + + return WorkspaceAssetReference(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/__init__.py new file mode 100644 index 00000000..36befc7c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/__init__.py @@ -0,0 +1,30 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) + +from .automl_job import AutoMLJobSchema +from .automl_vertical import AutoMLVerticalSchema +from .featurization_settings import FeaturizationSettingsSchema, TableFeaturizationSettingsSchema +from .forecasting_settings import ForecastingSettingsSchema +from .table_vertical.classification import AutoMLClassificationSchema +from .table_vertical.forecasting import AutoMLForecastingSchema +from .table_vertical.regression import AutoMLRegressionSchema +from .table_vertical.table_vertical import AutoMLTableVerticalSchema +from .table_vertical.table_vertical_limit_settings import AutoMLTableLimitsSchema +from .training_settings import TrainingSettingsSchema + +__all__ = [ + "AutoMLJobSchema", + "AutoMLVerticalSchema", + "FeaturizationSettingsSchema", + "TableFeaturizationSettingsSchema", + "ForecastingSettingsSchema", + "AutoMLClassificationSchema", + "AutoMLForecastingSchema", + "AutoMLRegressionSchema", + "AutoMLTableVerticalSchema", + "AutoMLTableLimitsSchema", + "TrainingSettingsSchema", +] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/automl_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/automl_job.py new file mode 100644 index 00000000..ebec82c7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/automl_job.py @@ -0,0 +1,21 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields + +from azure.ai.ml._schema.core.fields import ExperimentalField, NestedField, StringTransformedEnum +from azure.ai.ml._schema.job import BaseJobSchema +from azure.ai.ml._schema.job.input_output_fields_provider import OutputsField +from azure.ai.ml._schema.job_resource_configuration import JobResourceConfigurationSchema +from azure.ai.ml._schema.queue_settings import QueueSettingsSchema +from azure.ai.ml.constants import JobType + + +class AutoMLJobSchema(BaseJobSchema): + type = StringTransformedEnum(required=True, allowed_values=JobType.AUTOML) + environment_id = fields.Str() + environment_variables = fields.Dict(keys=fields.Str(), values=fields.Str()) + outputs = OutputsField() + resources = NestedField(JobResourceConfigurationSchema()) + queue_settings = ExperimentalField(NestedField(QueueSettingsSchema)) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/automl_vertical.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/automl_vertical.py new file mode 100644 index 00000000..2cf3bb83 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/automl_vertical.py @@ -0,0 +1,18 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from azure.ai.ml._restclient.v2023_04_01_preview.models import LogVerbosity +from azure.ai.ml._schema.automl.automl_job import AutoMLJobSchema +from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum, UnionField +from azure.ai.ml._schema.job.input_output_entry import MLTableInputSchema +from azure.ai.ml._utils.utils import camel_to_snake + + +class AutoMLVerticalSchema(AutoMLJobSchema): + log_verbosity = StringTransformedEnum( + allowed_values=[o.value for o in LogVerbosity], + casing_transform=camel_to_snake, + load_default=LogVerbosity.INFO, + ) + training_data = UnionField([NestedField(MLTableInputSchema)]) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/featurization_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/featurization_settings.py new file mode 100644 index 00000000..19998e45 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/featurization_settings.py @@ -0,0 +1,74 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields as flds +from marshmallow import post_load + +from azure.ai.ml._restclient.v2023_04_01_preview.models import BlockedTransformers +from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum, UnionField +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.constants._job.automl import AutoMLConstants, AutoMLTransformerParameterKeys + + +class ColumnTransformerSchema(metaclass=PatchedSchemaMeta): + fields = flds.List(flds.Str()) + parameters = flds.Dict( + keys=flds.Str(), + values=UnionField([flds.Float(), flds.Str()], allow_none=True, load_default=None), + ) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.automl import ColumnTransformer + + return ColumnTransformer(**data) + + +class FeaturizationSettingsSchema(metaclass=PatchedSchemaMeta): + dataset_language = flds.Str() + + +class NlpFeaturizationSettingsSchema(FeaturizationSettingsSchema): + dataset_language = flds.Str() + + @post_load + def make(self, data, **kwargs) -> "NlpFeaturizationSettings": + from azure.ai.ml.automl import NlpFeaturizationSettings + + return NlpFeaturizationSettings(**data) + + +class TableFeaturizationSettingsSchema(FeaturizationSettingsSchema): + mode = StringTransformedEnum( + allowed_values=[ + AutoMLConstants.AUTO, + AutoMLConstants.OFF, + AutoMLConstants.CUSTOM, + ], + load_default=AutoMLConstants.AUTO, + ) + blocked_transformers = flds.List( + StringTransformedEnum( + allowed_values=[o.value for o in BlockedTransformers], + casing_transform=camel_to_snake, + ) + ) + column_name_and_types = flds.Dict(keys=flds.Str(), values=flds.Str()) + transformer_params = flds.Dict( + keys=StringTransformedEnum( + allowed_values=[o.value for o in AutoMLTransformerParameterKeys], + casing_transform=camel_to_snake, + ), + values=flds.List(NestedField(ColumnTransformerSchema())), + ) + enable_dnn_featurization = flds.Bool() + + @post_load + def make(self, data, **kwargs) -> "TabularFeaturizationSettings": + from azure.ai.ml.automl import TabularFeaturizationSettings + + return TabularFeaturizationSettings(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/forecasting_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/forecasting_settings.py new file mode 100644 index 00000000..56033e14 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/forecasting_settings.py @@ -0,0 +1,66 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load + +from azure.ai.ml._restclient.v2023_04_01_preview.models import FeatureLags as FeatureLagsMode +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + ForecastHorizonMode, + SeasonalityMode, + ShortSeriesHandlingConfiguration, + TargetAggregationFunction, + TargetLagsMode, + TargetRollingWindowSizeMode, +) +from azure.ai.ml._restclient.v2023_04_01_preview.models import UseStl as STLMode +from azure.ai.ml._schema.core.fields import StringTransformedEnum, UnionField +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + + +class ForecastingSettingsSchema(metaclass=PatchedSchemaMeta): + country_or_region_for_holidays = fields.Str() + cv_step_size = fields.Int() + forecast_horizon = UnionField( + [ + StringTransformedEnum(allowed_values=[ForecastHorizonMode.AUTO]), + fields.Int(), + ] + ) + target_lags = UnionField( + [ + StringTransformedEnum(allowed_values=[TargetLagsMode.AUTO]), + fields.Int(), + fields.List(fields.Int()), + ] + ) + target_rolling_window_size = UnionField( + [ + StringTransformedEnum(allowed_values=[TargetRollingWindowSizeMode.AUTO]), + fields.Int(), + ] + ) + time_column_name = fields.Str() + time_series_id_column_names = UnionField([fields.Str(), fields.List(fields.Str())]) + frequency = fields.Str() + feature_lags = StringTransformedEnum(allowed_values=[FeatureLagsMode.NONE, FeatureLagsMode.AUTO]) + seasonality = UnionField( + [ + StringTransformedEnum(allowed_values=[SeasonalityMode.AUTO]), + fields.Int(), + ] + ) + short_series_handling_config = StringTransformedEnum( + allowed_values=[o.value for o in ShortSeriesHandlingConfiguration] + ) + use_stl = StringTransformedEnum(allowed_values=[STLMode.NONE, STLMode.SEASON, STLMode.SEASON_TREND]) + target_aggregate_function = StringTransformedEnum(allowed_values=[o.value for o in TargetAggregationFunction]) + features_unknown_at_forecast_time = UnionField([fields.Str(), fields.List(fields.Str())]) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._job.automl.tabular.forecasting_settings import ForecastingSettings + + return ForecastingSettings(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/__init__.py new file mode 100644 index 00000000..29a4fcd3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_classification.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_classification.py new file mode 100644 index 00000000..c539f037 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_classification.py @@ -0,0 +1,66 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from typing import Any, Dict + +from marshmallow import fields, post_load + +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + ClassificationMultilabelPrimaryMetrics, + ClassificationPrimaryMetrics, + TaskType, +) +from azure.ai.ml._schema.automl.image_vertical.image_model_distribution_settings import ( + ImageModelDistributionSettingsClassificationSchema, +) +from azure.ai.ml._schema.automl.image_vertical.image_model_settings import ImageModelSettingsClassificationSchema +from azure.ai.ml._schema.automl.image_vertical.image_vertical import ImageVerticalSchema +from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.constants._job.automl import AutoMLConstants + + +class ImageClassificationBaseSchema(ImageVerticalSchema): + training_parameters = NestedField(ImageModelSettingsClassificationSchema()) + search_space = fields.List(NestedField(ImageModelDistributionSettingsClassificationSchema())) + + +class ImageClassificationSchema(ImageClassificationBaseSchema): + task_type = StringTransformedEnum( + allowed_values=TaskType.IMAGE_CLASSIFICATION, + casing_transform=camel_to_snake, + data_key=AutoMLConstants.TASK_TYPE_YAML, + required=True, + ) + primary_metric = StringTransformedEnum( + allowed_values=[o.value for o in ClassificationPrimaryMetrics], + casing_transform=camel_to_snake, + load_default=camel_to_snake(ClassificationPrimaryMetrics.Accuracy), + ) + + @post_load + def make(self, data, **kwargs) -> Dict[str, Any]: + data.pop("task_type") + return data + + +class ImageClassificationMultilabelSchema(ImageClassificationBaseSchema): + task_type = StringTransformedEnum( + allowed_values=TaskType.IMAGE_CLASSIFICATION_MULTILABEL, + casing_transform=camel_to_snake, + data_key=AutoMLConstants.TASK_TYPE_YAML, + required=True, + ) + primary_metric = StringTransformedEnum( + allowed_values=[o.value for o in ClassificationMultilabelPrimaryMetrics], + casing_transform=camel_to_snake, + load_default=camel_to_snake(ClassificationMultilabelPrimaryMetrics.IOU), + ) + + @post_load + def make(self, data, **kwargs) -> Dict[str, Any]: + data.pop("task_type") + return data diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_limit_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_limit_settings.py new file mode 100644 index 00000000..3f5c73e8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_limit_settings.py @@ -0,0 +1,21 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + + +class ImageLimitsSchema(metaclass=PatchedSchemaMeta): + max_concurrent_trials = fields.Int() + max_trials = fields.Int() + timeout_minutes = fields.Int() # type duration + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.automl import ImageLimitSettings + + return ImageLimitSettings(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_model_distribution_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_model_distribution_settings.py new file mode 100644 index 00000000..9f784038 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_model_distribution_settings.py @@ -0,0 +1,216 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_dump, post_load, pre_load + +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + LearningRateScheduler, + ModelSize, + StochasticOptimizer, + ValidationMetricType, +) +from azure.ai.ml._schema._sweep.search_space import ( + ChoiceSchema, + IntegerQNormalSchema, + IntegerQUniformSchema, + NormalSchema, + QNormalSchema, + QUniformSchema, + RandintSchema, + UniformSchema, +) +from azure.ai.ml._schema.core.fields import ( + DumpableIntegerField, + DumpableStringField, + NestedField, + StringTransformedEnum, + UnionField, +) +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml._utils.utils import camel_to_snake + + +def choice_schema_of_type(cls, **kwargs): + class CustomChoiceSchema(ChoiceSchema): + values = fields.List(cls(**kwargs)) + + return CustomChoiceSchema() + + +def choice_and_single_value_schema_of_type(cls, **kwargs): + # Reshuffling the order of fields for allowing choice of booleans. + # The reason is, while dumping [Bool, Choice[Bool]] is parsing even dict as True. + # Since all unionFields are parsed sequentially, to avoid this, we are giving the "type" field at the end. + return UnionField([NestedField(choice_schema_of_type(cls, **kwargs)), cls(**kwargs)]) + + +FLOAT_SEARCH_SPACE_DISTRIBUTION_FIELD = UnionField( + [ + fields.Float(), + DumpableIntegerField(strict=True), + NestedField(choice_schema_of_type(DumpableIntegerField, strict=True)), + NestedField(choice_schema_of_type(fields.Float)), + NestedField(UniformSchema()), + NestedField(QUniformSchema()), + NestedField(NormalSchema()), + NestedField(QNormalSchema()), + NestedField(RandintSchema()), + ] +) + +INT_SEARCH_SPACE_DISTRIBUTION_FIELD = UnionField( + [ + DumpableIntegerField(strict=True), + NestedField(choice_schema_of_type(DumpableIntegerField, strict=True)), + NestedField(RandintSchema()), + NestedField(IntegerQUniformSchema()), + NestedField(IntegerQNormalSchema()), + ] +) + +STRING_SEARCH_SPACE_DISTRIBUTION_FIELD = choice_and_single_value_schema_of_type(DumpableStringField) +BOOL_SEARCH_SPACE_DISTRIBUTION_FIELD = choice_and_single_value_schema_of_type(fields.Bool) + +model_size_enum_args = {"allowed_values": [o.value for o in ModelSize], "casing_transform": camel_to_snake} +learning_rate_scheduler_enum_args = { + "allowed_values": [o.value for o in LearningRateScheduler], + "casing_transform": camel_to_snake, +} +optimizer_enum_args = {"allowed_values": [o.value for o in StochasticOptimizer], "casing_transform": camel_to_snake} +validation_metric_enum_args = { + "allowed_values": [o.value for o in ValidationMetricType], + "casing_transform": camel_to_snake, +} + + +MODEL_SIZE_DISTRIBUTION_FIELD = choice_and_single_value_schema_of_type(StringTransformedEnum, **model_size_enum_args) +LEARNING_RATE_SCHEDULER_DISTRIBUTION_FIELD = choice_and_single_value_schema_of_type( + StringTransformedEnum, **learning_rate_scheduler_enum_args +) +OPTIMIZER_DISTRIBUTION_FIELD = choice_and_single_value_schema_of_type(StringTransformedEnum, **optimizer_enum_args) +VALIDATION_METRIC_DISTRIBUTION_FIELD = choice_and_single_value_schema_of_type( + StringTransformedEnum, **validation_metric_enum_args +) + + +class ImageModelDistributionSettingsSchema(metaclass=PatchedSchemaMeta): + ams_gradient = BOOL_SEARCH_SPACE_DISTRIBUTION_FIELD + augmentations = STRING_SEARCH_SPACE_DISTRIBUTION_FIELD + beta1 = FLOAT_SEARCH_SPACE_DISTRIBUTION_FIELD + beta2 = FLOAT_SEARCH_SPACE_DISTRIBUTION_FIELD + distributed = BOOL_SEARCH_SPACE_DISTRIBUTION_FIELD + early_stopping = BOOL_SEARCH_SPACE_DISTRIBUTION_FIELD + early_stopping_delay = INT_SEARCH_SPACE_DISTRIBUTION_FIELD + early_stopping_patience = INT_SEARCH_SPACE_DISTRIBUTION_FIELD + evaluation_frequency = INT_SEARCH_SPACE_DISTRIBUTION_FIELD + enable_onnx_normalization = BOOL_SEARCH_SPACE_DISTRIBUTION_FIELD + gradient_accumulation_step = INT_SEARCH_SPACE_DISTRIBUTION_FIELD + layers_to_freeze = INT_SEARCH_SPACE_DISTRIBUTION_FIELD + learning_rate = FLOAT_SEARCH_SPACE_DISTRIBUTION_FIELD + learning_rate_scheduler = LEARNING_RATE_SCHEDULER_DISTRIBUTION_FIELD + momentum = FLOAT_SEARCH_SPACE_DISTRIBUTION_FIELD + nesterov = BOOL_SEARCH_SPACE_DISTRIBUTION_FIELD + number_of_epochs = INT_SEARCH_SPACE_DISTRIBUTION_FIELD + number_of_workers = INT_SEARCH_SPACE_DISTRIBUTION_FIELD + optimizer = OPTIMIZER_DISTRIBUTION_FIELD + random_seed = INT_SEARCH_SPACE_DISTRIBUTION_FIELD + step_lr_gamma = FLOAT_SEARCH_SPACE_DISTRIBUTION_FIELD + step_lr_step_size = INT_SEARCH_SPACE_DISTRIBUTION_FIELD + training_batch_size = INT_SEARCH_SPACE_DISTRIBUTION_FIELD + validation_batch_size = INT_SEARCH_SPACE_DISTRIBUTION_FIELD + warmup_cosine_lr_cycles = FLOAT_SEARCH_SPACE_DISTRIBUTION_FIELD + warmup_cosine_lr_warmup_epochs = INT_SEARCH_SPACE_DISTRIBUTION_FIELD + weight_decay = FLOAT_SEARCH_SPACE_DISTRIBUTION_FIELD + + +# pylint: disable-next=name-too-long +class ImageModelDistributionSettingsClassificationSchema(ImageModelDistributionSettingsSchema): + model_name = STRING_SEARCH_SPACE_DISTRIBUTION_FIELD + training_crop_size = INT_SEARCH_SPACE_DISTRIBUTION_FIELD + validation_crop_size = INT_SEARCH_SPACE_DISTRIBUTION_FIELD + validation_resize_size = INT_SEARCH_SPACE_DISTRIBUTION_FIELD + weighted_loss = INT_SEARCH_SPACE_DISTRIBUTION_FIELD + + @post_dump + def conversion(self, data, **kwargs): + if self.context.get("inside_pipeline", False): # pylint: disable=no-member + # AutoML job inside pipeline does load(dump) instead of calling to_rest_object + # explicitly for creating the autoRest Object from sdk job. + # Hence for pipeline job, we explicitly convert Sweep Distribution dict to str after dump in this method. + # For standalone automl job, same conversion happens in image_classification_job._to_rest_object() + from azure.ai.ml.entities._job.automl.search_space_utils import _convert_sweep_dist_dict_to_str_dict + + data = _convert_sweep_dist_dict_to_str_dict(data) + return data + + @pre_load + def before_make(self, data, **kwargs): + if self.context.get("inside_pipeline", False): # pylint: disable=no-member + from azure.ai.ml.entities._job.automl.search_space_utils import _convert_sweep_dist_str_to_dict + + # Converting Sweep Distribution str to Sweep Distribution dict for complying with search_space schema. + data = _convert_sweep_dist_str_to_dict(data) + return data + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.automl import ImageClassificationSearchSpace + + return ImageClassificationSearchSpace(**data) + + +# pylint: disable-next=name-too-long +class ImageModelDistributionSettingsDetectionCommonSchema(ImageModelDistributionSettingsSchema): + box_detections_per_image = INT_SEARCH_SPACE_DISTRIBUTION_FIELD + box_score_threshold = FLOAT_SEARCH_SPACE_DISTRIBUTION_FIELD + image_size = INT_SEARCH_SPACE_DISTRIBUTION_FIELD + max_size = INT_SEARCH_SPACE_DISTRIBUTION_FIELD + min_size = INT_SEARCH_SPACE_DISTRIBUTION_FIELD + model_size = MODEL_SIZE_DISTRIBUTION_FIELD + multi_scale = BOOL_SEARCH_SPACE_DISTRIBUTION_FIELD + nms_iou_threshold = FLOAT_SEARCH_SPACE_DISTRIBUTION_FIELD + tile_grid_size = STRING_SEARCH_SPACE_DISTRIBUTION_FIELD + tile_overlap_ratio = FLOAT_SEARCH_SPACE_DISTRIBUTION_FIELD + tile_predictions_nms_threshold = FLOAT_SEARCH_SPACE_DISTRIBUTION_FIELD + validation_iou_threshold = FLOAT_SEARCH_SPACE_DISTRIBUTION_FIELD + validation_metric_type = VALIDATION_METRIC_DISTRIBUTION_FIELD + + @post_dump + def conversion(self, data, **kwargs): + if self.context.get("inside_pipeline", False): # pylint: disable=no-member + # AutoML job inside pipeline does load(dump) instead of calling to_rest_object + # explicitly for creating the autoRest Object from sdk job object. + # Hence for pipeline job, we explicitly convert Sweep Distribution dict to str after dump in this method. + # For standalone automl job, same conversion happens in image_object_detection_job._to_rest_object() + from azure.ai.ml.entities._job.automl.search_space_utils import _convert_sweep_dist_dict_to_str_dict + + data = _convert_sweep_dist_dict_to_str_dict(data) + return data + + @pre_load + def before_make(self, data, **kwargs): + if self.context.get("inside_pipeline", False): # pylint: disable=no-member + from azure.ai.ml.entities._job.automl.search_space_utils import _convert_sweep_dist_str_to_dict + + # Converting Sweep Distribution str to Sweep Distribution dict for complying with search_space schema. + data = _convert_sweep_dist_str_to_dict(data) + return data + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.automl import ImageObjectDetectionSearchSpace + + return ImageObjectDetectionSearchSpace(**data) + + +# pylint: disable-next=name-too-long +class ImageModelDistributionSettingsObjectDetectionSchema(ImageModelDistributionSettingsDetectionCommonSchema): + model_name = STRING_SEARCH_SPACE_DISTRIBUTION_FIELD + + +# pylint: disable-next=name-too-long +class ImageModelDistributionSettingsInstanceSegmentationSchema(ImageModelDistributionSettingsObjectDetectionSchema): + model_name = STRING_SEARCH_SPACE_DISTRIBUTION_FIELD diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_model_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_model_settings.py new file mode 100644 index 00000000..7c88e628 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_model_settings.py @@ -0,0 +1,96 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load + +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + LearningRateScheduler, + ModelSize, + StochasticOptimizer, + ValidationMetricType, +) +from azure.ai.ml._schema.core.fields import StringTransformedEnum +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml._utils.utils import camel_to_snake + + +class ImageModelSettingsSchema(metaclass=PatchedSchemaMeta): + ams_gradient = fields.Bool() + advanced_settings = fields.Str() + beta1 = fields.Float() + beta2 = fields.Float() + checkpoint_frequency = fields.Int() + checkpoint_run_id = fields.Str() + distributed = fields.Bool() + early_stopping = fields.Bool() + early_stopping_delay = fields.Int() + early_stopping_patience = fields.Int() + evaluation_frequency = fields.Int() + enable_onnx_normalization = fields.Bool() + gradient_accumulation_step = fields.Int() + layers_to_freeze = fields.Int() + learning_rate = fields.Float() + learning_rate_scheduler = StringTransformedEnum( + allowed_values=[o.value for o in LearningRateScheduler], + casing_transform=camel_to_snake, + ) + model_name = fields.Str() + momentum = fields.Float() + nesterov = fields.Bool() + number_of_epochs = fields.Int() + number_of_workers = fields.Int() + optimizer = StringTransformedEnum( + allowed_values=[o.value for o in StochasticOptimizer], + casing_transform=camel_to_snake, + ) + random_seed = fields.Int() + step_lr_gamma = fields.Float() + step_lr_step_size = fields.Int() + training_batch_size = fields.Int() + validation_batch_size = fields.Int() + warmup_cosine_lr_cycles = fields.Float() + warmup_cosine_lr_warmup_epochs = fields.Int() + weight_decay = fields.Float() + + +class ImageModelSettingsClassificationSchema(ImageModelSettingsSchema): + training_crop_size = fields.Int() + validation_crop_size = fields.Int() + validation_resize_size = fields.Int() + weighted_loss = fields.Int() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._job.automl.image.image_model_settings import ImageModelSettingsClassification + + return ImageModelSettingsClassification(**data) + + +class ImageModelSettingsObjectDetectionSchema(ImageModelSettingsSchema): + box_detections_per_image = fields.Int() + box_score_threshold = fields.Float() + image_size = fields.Int() + max_size = fields.Int() + min_size = fields.Int() + model_size = StringTransformedEnum(allowed_values=[o.value for o in ModelSize], casing_transform=camel_to_snake) + multi_scale = fields.Bool() + nms_iou_threshold = fields.Float() + tile_grid_size = fields.Str() + tile_overlap_ratio = fields.Float() + tile_predictions_nms_threshold = fields.Float() + validation_iou_threshold = fields.Float() + validation_metric_type = StringTransformedEnum( + allowed_values=[o.value for o in ValidationMetricType], + casing_transform=camel_to_snake, + ) + log_training_metrics = fields.Str() + log_validation_loss = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._job.automl.image.image_model_settings import ImageModelSettingsObjectDetection + + return ImageModelSettingsObjectDetection(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_object_detection.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_object_detection.py new file mode 100644 index 00000000..cb753882 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_object_detection.py @@ -0,0 +1,66 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from typing import Any, Dict + +from marshmallow import fields, post_load + +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + InstanceSegmentationPrimaryMetrics, + ObjectDetectionPrimaryMetrics, + TaskType, +) +from azure.ai.ml._schema.automl.image_vertical.image_model_distribution_settings import ( + ImageModelDistributionSettingsInstanceSegmentationSchema, + ImageModelDistributionSettingsObjectDetectionSchema, +) +from azure.ai.ml._schema.automl.image_vertical.image_model_settings import ImageModelSettingsObjectDetectionSchema +from azure.ai.ml._schema.automl.image_vertical.image_vertical import ImageVerticalSchema +from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.constants._job.automl import AutoMLConstants + + +class ImageObjectDetectionSchema(ImageVerticalSchema): + task_type = StringTransformedEnum( + allowed_values=TaskType.IMAGE_OBJECT_DETECTION, + casing_transform=camel_to_snake, + data_key=AutoMLConstants.TASK_TYPE_YAML, + required=True, + ) + primary_metric = StringTransformedEnum( + allowed_values=ObjectDetectionPrimaryMetrics.MEAN_AVERAGE_PRECISION, + casing_transform=camel_to_snake, + load_default=camel_to_snake(ObjectDetectionPrimaryMetrics.MEAN_AVERAGE_PRECISION), + ) + training_parameters = NestedField(ImageModelSettingsObjectDetectionSchema()) + search_space = fields.List(NestedField(ImageModelDistributionSettingsObjectDetectionSchema())) + + @post_load + def make(self, data, **kwargs) -> Dict[str, Any]: + data.pop("task_type") + return data + + +class ImageInstanceSegmentationSchema(ImageVerticalSchema): + task_type = StringTransformedEnum( + allowed_values=TaskType.IMAGE_INSTANCE_SEGMENTATION, + casing_transform=camel_to_snake, + data_key=AutoMLConstants.TASK_TYPE_YAML, + required=True, + ) + primary_metric = StringTransformedEnum( + allowed_values=[InstanceSegmentationPrimaryMetrics.MEAN_AVERAGE_PRECISION], + casing_transform=camel_to_snake, + load_default=camel_to_snake(InstanceSegmentationPrimaryMetrics.MEAN_AVERAGE_PRECISION), + ) + training_parameters = NestedField(ImageModelSettingsObjectDetectionSchema()) + search_space = fields.List(NestedField(ImageModelDistributionSettingsInstanceSegmentationSchema())) + + @post_load + def make(self, data, **kwargs) -> Dict[str, Any]: + data.pop("task_type") + return data diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_sweep_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_sweep_settings.py new file mode 100644 index 00000000..66dfd7ae --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_sweep_settings.py @@ -0,0 +1,27 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument,protected-access + +from marshmallow import post_load, pre_dump + +from azure.ai.ml._schema._sweep.sweep_fields_provider import EarlyTerminationField, SamplingAlgorithmField +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + + +class ImageSweepSettingsSchema(metaclass=PatchedSchemaMeta): + sampling_algorithm = SamplingAlgorithmField() + early_termination = EarlyTerminationField() + + @pre_dump + def conversion(self, data, **kwargs): + rest_obj = data._to_rest_object() + rest_obj.early_termination = data.early_termination + return rest_obj + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.automl import ImageSweepSettings + + return ImageSweepSettings(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_vertical.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_vertical.py new file mode 100644 index 00000000..fdfaa79f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_vertical.py @@ -0,0 +1,19 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from azure.ai.ml._schema.automl.automl_vertical import AutoMLVerticalSchema +from azure.ai.ml._schema.automl.image_vertical.image_limit_settings import ImageLimitsSchema +from azure.ai.ml._schema.automl.image_vertical.image_sweep_settings import ImageSweepSettingsSchema +from azure.ai.ml._schema.core.fields import NestedField, UnionField, fields +from azure.ai.ml._schema.job.input_output_entry import MLTableInputSchema + + +class ImageVerticalSchema(AutoMLVerticalSchema): + limits = NestedField(ImageLimitsSchema()) + sweep = NestedField(ImageSweepSettingsSchema()) + target_column_name = fields.Str(required=True) + test_data = UnionField([NestedField(MLTableInputSchema)]) + test_data_size = fields.Float() + validation_data = UnionField([NestedField(MLTableInputSchema)]) + validation_data_size = fields.Float() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/__init__.py new file mode 100644 index 00000000..29a4fcd3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/nlp_fixed_parameters.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/nlp_fixed_parameters.py new file mode 100644 index 00000000..2a5cb336 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/nlp_fixed_parameters.py @@ -0,0 +1,33 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load + +from azure.ai.ml._restclient.v2023_04_01_preview.models import NlpLearningRateScheduler +from azure.ai.ml._schema.core.fields import StringTransformedEnum +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml._utils.utils import camel_to_snake + + +class NlpFixedParametersSchema(metaclass=PatchedSchemaMeta): + gradient_accumulation_steps = fields.Int() + learning_rate = fields.Float() + learning_rate_scheduler = StringTransformedEnum( + allowed_values=[obj.value for obj in NlpLearningRateScheduler], + casing_transform=camel_to_snake, + ) + model_name = fields.Str() + number_of_epochs = fields.Int() + training_batch_size = fields.Int() + validation_batch_size = fields.Int() + warmup_ratio = fields.Float() + weight_decay = fields.Float() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.automl import NlpFixedParameters + + return NlpFixedParameters(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/nlp_parameter_subspace.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/nlp_parameter_subspace.py new file mode 100644 index 00000000..de963478 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/nlp_parameter_subspace.py @@ -0,0 +1,106 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_dump, post_load, pre_load + +from azure.ai.ml._restclient.v2023_04_01_preview.models import NlpLearningRateScheduler +from azure.ai.ml._schema._sweep.search_space import ( + ChoiceSchema, + NormalSchema, + QNormalSchema, + QUniformSchema, + RandintSchema, + UniformSchema, +) +from azure.ai.ml._schema.core.fields import ( + DumpableIntegerField, + DumpableStringField, + NestedField, + StringTransformedEnum, + UnionField, +) +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml._utils.utils import camel_to_snake + + +def choice_schema_of_type(cls, **kwargs): + class CustomChoiceSchema(ChoiceSchema): + values = fields.List(cls(**kwargs)) + + return CustomChoiceSchema() + + +def choice_and_single_value_schema_of_type(cls, **kwargs): + return UnionField([cls(**kwargs), NestedField(choice_schema_of_type(cls, **kwargs))]) + + +FLOAT_SEARCH_SPACE_DISTRIBUTION_FIELD = UnionField( + [ + fields.Float(), + DumpableIntegerField(strict=True), + NestedField(choice_schema_of_type(DumpableIntegerField, strict=True)), + NestedField(choice_schema_of_type(fields.Float)), + NestedField(UniformSchema()), + NestedField(QUniformSchema()), + NestedField(NormalSchema()), + NestedField(QNormalSchema()), + NestedField(RandintSchema()), + ] +) + +INT_SEARCH_SPACE_DISTRIBUTION_FIELD = UnionField( + [ + DumpableIntegerField(strict=True), + NestedField(choice_schema_of_type(DumpableIntegerField, strict=True)), + NestedField(RandintSchema()), + ] +) + +STRING_SEARCH_SPACE_DISTRIBUTION_FIELD = choice_and_single_value_schema_of_type(DumpableStringField) +BOOL_SEARCH_SPACE_DISTRIBUTION_FIELD = choice_and_single_value_schema_of_type(fields.Bool) + + +class NlpParameterSubspaceSchema(metaclass=PatchedSchemaMeta): + gradient_accumulation_steps = INT_SEARCH_SPACE_DISTRIBUTION_FIELD + learning_rate = FLOAT_SEARCH_SPACE_DISTRIBUTION_FIELD + learning_rate_scheduler = choice_and_single_value_schema_of_type( + StringTransformedEnum, + allowed_values=[obj.value for obj in NlpLearningRateScheduler], + casing_transform=camel_to_snake, + ) + model_name = STRING_SEARCH_SPACE_DISTRIBUTION_FIELD + number_of_epochs = INT_SEARCH_SPACE_DISTRIBUTION_FIELD + training_batch_size = INT_SEARCH_SPACE_DISTRIBUTION_FIELD + validation_batch_size = INT_SEARCH_SPACE_DISTRIBUTION_FIELD + warmup_ratio = FLOAT_SEARCH_SPACE_DISTRIBUTION_FIELD + weight_decay = FLOAT_SEARCH_SPACE_DISTRIBUTION_FIELD + + @post_dump + def conversion(self, data, **kwargs): + if self.context.get("inside_pipeline", False): # pylint: disable=no-member + # AutoML job inside pipeline does load(dump) instead of calling to_rest_object + # explicitly for creating the autoRest Object from sdk job. + # Hence for pipeline job, we explicitly convert Sweep Distribution dict to str after dump in this method. + # For standalone automl job, same conversion happens in text_classification_job._to_rest_object() + from azure.ai.ml.entities._job.automl.search_space_utils import _convert_sweep_dist_dict_to_str_dict + + data = _convert_sweep_dist_dict_to_str_dict(data) + return data + + @pre_load + def before_make(self, data, **kwargs): + if self.context.get("inside_pipeline", False): # pylint: disable=no-member + from azure.ai.ml.entities._job.automl.search_space_utils import _convert_sweep_dist_str_to_dict + + # Converting Sweep Distribution str to Sweep Distribution dict for complying with search_space schema. + data = _convert_sweep_dist_str_to_dict(data) + return data + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.automl import NlpSearchSpace + + return NlpSearchSpace(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/nlp_sweep_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/nlp_sweep_settings.py new file mode 100644 index 00000000..ab9b5ec3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/nlp_sweep_settings.py @@ -0,0 +1,27 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument,protected-access + +from marshmallow import post_load, pre_dump + +from azure.ai.ml._schema._sweep.sweep_fields_provider import EarlyTerminationField, SamplingAlgorithmField +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + + +class NlpSweepSettingsSchema(metaclass=PatchedSchemaMeta): + sampling_algorithm = SamplingAlgorithmField() + early_termination = EarlyTerminationField() + + @pre_dump + def conversion(self, data, **kwargs): + rest_obj = data._to_rest_object() + rest_obj.early_termination = data.early_termination + return rest_obj + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.automl import NlpSweepSettings + + return NlpSweepSettings(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/nlp_vertical.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/nlp_vertical.py new file mode 100644 index 00000000..f701ce95 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/nlp_vertical.py @@ -0,0 +1,24 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields + +from azure.ai.ml._schema.automl.automl_vertical import AutoMLVerticalSchema +from azure.ai.ml._schema.automl.featurization_settings import NlpFeaturizationSettingsSchema +from azure.ai.ml._schema.automl.nlp_vertical.nlp_fixed_parameters import NlpFixedParametersSchema +from azure.ai.ml._schema.automl.nlp_vertical.nlp_parameter_subspace import NlpParameterSubspaceSchema +from azure.ai.ml._schema.automl.nlp_vertical.nlp_sweep_settings import NlpSweepSettingsSchema +from azure.ai.ml._schema.automl.nlp_vertical.nlp_vertical_limit_settings import NlpLimitsSchema +from azure.ai.ml._schema.core.fields import NestedField, UnionField +from azure.ai.ml._schema.job.input_output_entry import MLTableInputSchema +from azure.ai.ml.constants._job.automl import AutoMLConstants + + +class NlpVerticalSchema(AutoMLVerticalSchema): + limits = NestedField(NlpLimitsSchema()) + sweep = NestedField(NlpSweepSettingsSchema()) + training_parameters = NestedField(NlpFixedParametersSchema()) + search_space = fields.List(NestedField(NlpParameterSubspaceSchema())) + featurization = NestedField(NlpFeaturizationSettingsSchema(), data_key=AutoMLConstants.FEATURIZATION_YAML) + validation_data = UnionField([NestedField(MLTableInputSchema)]) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/nlp_vertical_limit_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/nlp_vertical_limit_settings.py new file mode 100644 index 00000000..fe054f38 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/nlp_vertical_limit_settings.py @@ -0,0 +1,23 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + + +class NlpLimitsSchema(metaclass=PatchedSchemaMeta): + max_concurrent_trials = fields.Int() + max_trials = fields.Int() + max_nodes = fields.Int() + timeout_minutes = fields.Int() # type duration + trial_timeout_minutes = fields.Int() # type duration + + @post_load + def make(self, data, **kwargs) -> "NlpLimitSettings": + from azure.ai.ml.automl import NlpLimitSettings + + return NlpLimitSettings(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/text_classification.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/text_classification.py new file mode 100644 index 00000000..14e0b7d6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/text_classification.py @@ -0,0 +1,36 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from typing import Any, Dict + +from marshmallow import post_load + +from azure.ai.ml._restclient.v2023_04_01_preview.models import ClassificationPrimaryMetrics, TaskType +from azure.ai.ml._schema.automl.nlp_vertical.nlp_vertical import NlpVerticalSchema +from azure.ai.ml._schema.core.fields import StringTransformedEnum, fields +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.constants._job.automl import AutoMLConstants + + +class TextClassificationSchema(NlpVerticalSchema): + task_type = StringTransformedEnum( + allowed_values=TaskType.TEXT_CLASSIFICATION, + casing_transform=camel_to_snake, + data_key=AutoMLConstants.TASK_TYPE_YAML, + required=True, + ) + primary_metric = StringTransformedEnum( + allowed_values=[o.value for o in ClassificationPrimaryMetrics], + casing_transform=camel_to_snake, + load_default=camel_to_snake(ClassificationPrimaryMetrics.ACCURACY), + ) + # added here as for text_ner target_column_name is optional + target_column_name = fields.Str(required=True) + + @post_load + def make(self, data, **kwargs) -> Dict[str, Any]: + data.pop("task_type") + return data diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/text_classification_multilabel.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/text_classification_multilabel.py new file mode 100644 index 00000000..56cd5bc1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/text_classification_multilabel.py @@ -0,0 +1,36 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from typing import Any, Dict + +from marshmallow import post_load + +from azure.ai.ml._restclient.v2023_04_01_preview.models import ClassificationMultilabelPrimaryMetrics, TaskType +from azure.ai.ml._schema.automl.nlp_vertical.nlp_vertical import NlpVerticalSchema +from azure.ai.ml._schema.core.fields import StringTransformedEnum, fields +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.constants._job.automl import AutoMLConstants + + +class TextClassificationMultilabelSchema(NlpVerticalSchema): + task_type = StringTransformedEnum( + allowed_values=TaskType.TEXT_CLASSIFICATION_MULTILABEL, + casing_transform=camel_to_snake, + data_key=AutoMLConstants.TASK_TYPE_YAML, + required=True, + ) + primary_metric = StringTransformedEnum( + allowed_values=ClassificationMultilabelPrimaryMetrics.ACCURACY, + casing_transform=camel_to_snake, + load_default=camel_to_snake(ClassificationMultilabelPrimaryMetrics.ACCURACY), + ) + # added here as for text_ner target_column_name is optional + target_column_name = fields.Str(required=True) + + @post_load + def make(self, data, **kwargs) -> Dict[str, Any]: + data.pop("task_type") + return data diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/text_ner.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/text_ner.py new file mode 100644 index 00000000..3609b1d0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/text_ner.py @@ -0,0 +1,35 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from typing import Any, Dict + +from marshmallow import post_load + +from azure.ai.ml._restclient.v2023_04_01_preview.models import ClassificationPrimaryMetrics, TaskType +from azure.ai.ml._schema.automl.nlp_vertical.nlp_vertical import NlpVerticalSchema +from azure.ai.ml._schema.core.fields import StringTransformedEnum, fields +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.constants._job.automl import AutoMLConstants + + +class TextNerSchema(NlpVerticalSchema): + task_type = StringTransformedEnum( + allowed_values=TaskType.TEXT_NER, + casing_transform=camel_to_snake, + data_key=AutoMLConstants.TASK_TYPE_YAML, + required=True, + ) + primary_metric = StringTransformedEnum( + allowed_values=ClassificationPrimaryMetrics.ACCURACY, + casing_transform=camel_to_snake, + load_default=camel_to_snake(ClassificationPrimaryMetrics.ACCURACY), + ) + target_column_name = fields.Str() + + @post_load + def make(self, data, **kwargs) -> Dict[str, Any]: + data.pop("task_type") + return data diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/__init__.py new file mode 100644 index 00000000..29a4fcd3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/classification.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/classification.py new file mode 100644 index 00000000..f9ce7b8b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/classification.py @@ -0,0 +1,37 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from typing import Any, Dict + +from marshmallow import fields, post_load + +from azure.ai.ml._restclient.v2023_04_01_preview.models import ClassificationPrimaryMetrics, TaskType +from azure.ai.ml._schema.automl.table_vertical.table_vertical import AutoMLTableVerticalSchema +from azure.ai.ml._schema.automl.training_settings import ClassificationTrainingSettingsSchema +from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.constants._job.automl import AutoMLConstants + + +class AutoMLClassificationSchema(AutoMLTableVerticalSchema): + task_type = StringTransformedEnum( + allowed_values=TaskType.CLASSIFICATION, + casing_transform=camel_to_snake, + data_key=AutoMLConstants.TASK_TYPE_YAML, + required=True, + ) + primary_metric = StringTransformedEnum( + allowed_values=[o.value for o in ClassificationPrimaryMetrics], + casing_transform=camel_to_snake, + load_default=camel_to_snake(ClassificationPrimaryMetrics.AUC_WEIGHTED), + ) + positive_label = fields.Str() + training = NestedField(ClassificationTrainingSettingsSchema(), data_key=AutoMLConstants.TRAINING_YAML) + + @post_load + def make(self, data, **kwargs) -> Dict[str, Any]: + data.pop("task_type") + return data diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/forecasting.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/forecasting.py new file mode 100644 index 00000000..7f302c97 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/forecasting.py @@ -0,0 +1,38 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from typing import Any, Dict + +from marshmallow import post_load + +from azure.ai.ml._restclient.v2023_04_01_preview.models import ForecastingPrimaryMetrics, TaskType +from azure.ai.ml._schema.automl.forecasting_settings import ForecastingSettingsSchema +from azure.ai.ml._schema.automl.table_vertical.table_vertical import AutoMLTableVerticalSchema +from azure.ai.ml._schema.automl.training_settings import ForecastingTrainingSettingsSchema +from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.constants._job.automl import AutoMLConstants + + +class AutoMLForecastingSchema(AutoMLTableVerticalSchema): + task_type = StringTransformedEnum( + allowed_values=TaskType.FORECASTING, + casing_transform=camel_to_snake, + data_key=AutoMLConstants.TASK_TYPE_YAML, + required=True, + ) + primary_metric = StringTransformedEnum( + allowed_values=[o.value for o in ForecastingPrimaryMetrics], + casing_transform=camel_to_snake, + load_default=camel_to_snake(ForecastingPrimaryMetrics.NORMALIZED_ROOT_MEAN_SQUARED_ERROR), + ) + training = NestedField(ForecastingTrainingSettingsSchema(), data_key=AutoMLConstants.TRAINING_YAML) + forecasting_settings = NestedField(ForecastingSettingsSchema(), data_key=AutoMLConstants.FORECASTING_YAML) + + @post_load + def make(self, data, **kwargs) -> Dict[str, Any]: + data.pop("task_type") + return data diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/regression.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/regression.py new file mode 100644 index 00000000..fc1e3900 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/regression.py @@ -0,0 +1,36 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from typing import Any, Dict + +from marshmallow import post_load + +from azure.ai.ml._restclient.v2023_04_01_preview.models import RegressionPrimaryMetrics, TaskType +from azure.ai.ml._schema.automl.table_vertical.table_vertical import AutoMLTableVerticalSchema +from azure.ai.ml._schema.automl.training_settings import RegressionTrainingSettingsSchema +from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.constants._job.automl import AutoMLConstants + + +class AutoMLRegressionSchema(AutoMLTableVerticalSchema): + task_type = StringTransformedEnum( + allowed_values=TaskType.REGRESSION, + casing_transform=camel_to_snake, + data_key=AutoMLConstants.TASK_TYPE_YAML, + required=True, + ) + primary_metric = StringTransformedEnum( + allowed_values=[o.value for o in RegressionPrimaryMetrics], + casing_transform=camel_to_snake, + load_default=camel_to_snake(RegressionPrimaryMetrics.NORMALIZED_ROOT_MEAN_SQUARED_ERROR), + ) + training = NestedField(RegressionTrainingSettingsSchema(), data_key=AutoMLConstants.TRAINING_YAML) + + @post_load + def make(self, data, **kwargs) -> Dict[str, Any]: + data.pop("task_type") + return data diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/table_vertical.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/table_vertical.py new file mode 100644 index 00000000..e98d7066 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/table_vertical.py @@ -0,0 +1,29 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from azure.ai.ml._restclient.v2023_04_01_preview.models import NCrossValidationsMode +from azure.ai.ml._schema.automl.automl_vertical import AutoMLVerticalSchema +from azure.ai.ml._schema.automl.featurization_settings import TableFeaturizationSettingsSchema +from azure.ai.ml._schema.automl.table_vertical.table_vertical_limit_settings import AutoMLTableLimitsSchema +from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum, UnionField, fields +from azure.ai.ml._schema.job.input_output_entry import MLTableInputSchema +from azure.ai.ml.constants._job.automl import AutoMLConstants + + +class AutoMLTableVerticalSchema(AutoMLVerticalSchema): + limits = NestedField(AutoMLTableLimitsSchema(), data_key=AutoMLConstants.LIMITS_YAML) + featurization = NestedField(TableFeaturizationSettingsSchema(), data_key=AutoMLConstants.FEATURIZATION_YAML) + target_column_name = fields.Str(required=True) + validation_data = UnionField([NestedField(MLTableInputSchema)]) + validation_data_size = fields.Float() + cv_split_column_names = fields.List(fields.Str()) + n_cross_validations = UnionField( + [ + StringTransformedEnum(allowed_values=[NCrossValidationsMode.AUTO]), + fields.Int(), + ], + ) + weight_column_name = fields.Str() + test_data = UnionField([NestedField(MLTableInputSchema)]) + test_data_size = fields.Float() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/table_vertical_limit_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/table_vertical_limit_settings.py new file mode 100644 index 00000000..122774a6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/table_vertical_limit_settings.py @@ -0,0 +1,28 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load + +from azure.ai.ml._schema import ExperimentalField +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml.constants._job.automl import AutoMLConstants + + +class AutoMLTableLimitsSchema(metaclass=PatchedSchemaMeta): + enable_early_termination = fields.Bool() + exit_score = fields.Float() + max_concurrent_trials = fields.Int() + max_cores_per_trial = fields.Int() + max_nodes = ExperimentalField(fields.Int()) + max_trials = fields.Int(data_key=AutoMLConstants.MAX_TRIALS_YAML) + timeout_minutes = fields.Int() # type duration + trial_timeout_minutes = fields.Int() # type duration + + @post_load + def make(self, data, **kwargs) -> "TabularLimitSettings": + from azure.ai.ml.automl import TabularLimitSettings + + return TabularLimitSettings(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/training_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/training_settings.py new file mode 100644 index 00000000..57a76892 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/training_settings.py @@ -0,0 +1,122 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load + +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + ClassificationModels, + ForecastingModels, + RegressionModels, + StackMetaLearnerType, +) +from azure.ai.ml.constants import TabularTrainingMode +from azure.ai.ml._schema import ExperimentalField +from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.constants._job.automl import AutoMLConstants +from azure.ai.ml.entities._job.automl.training_settings import ( + ClassificationTrainingSettings, + ForecastingTrainingSettings, + RegressionTrainingSettings, +) + + +class StackEnsembleSettingsSchema(metaclass=PatchedSchemaMeta): + stack_meta_learner_kwargs = fields.Dict() + stack_meta_learner_train_percentage = fields.Float() + stack_meta_learner_type = StringTransformedEnum( + allowed_values=[o.value for o in StackMetaLearnerType], + casing_transform=camel_to_snake, + ) + + @post_load + def make(self, data, **kwargs): + # Converting it here, as there is no corresponding entity class + stack_meta_learner_type = data.pop("stack_meta_learner_type") + stack_meta_learner_type = StackMetaLearnerType[stack_meta_learner_type.upper()] + from azure.ai.ml.entities._job.automl.stack_ensemble_settings import StackEnsembleSettings + + return StackEnsembleSettings(stack_meta_learner_type=stack_meta_learner_type, **data) + + +class TrainingSettingsSchema(metaclass=PatchedSchemaMeta): + enable_dnn_training = fields.Bool() + enable_model_explainability = fields.Bool() + enable_onnx_compatible_models = fields.Bool() + enable_stack_ensemble = fields.Bool() + enable_vote_ensemble = fields.Bool() + ensemble_model_download_timeout = fields.Int(data_key=AutoMLConstants.ENSEMBLE_MODEL_DOWNLOAD_TIMEOUT_YAML) + stack_ensemble_settings = NestedField(StackEnsembleSettingsSchema()) + training_mode = ExperimentalField( + StringTransformedEnum( + allowed_values=[o.value for o in TabularTrainingMode], + casing_transform=camel_to_snake, + ) + ) + + +class ClassificationTrainingSettingsSchema(TrainingSettingsSchema): + allowed_training_algorithms = fields.List( + StringTransformedEnum( + allowed_values=[o.value for o in ClassificationModels], + casing_transform=camel_to_snake, + ), + data_key=AutoMLConstants.ALLOWED_ALGORITHMS_YAML, + ) + blocked_training_algorithms = fields.List( + StringTransformedEnum( + allowed_values=[o.value for o in ClassificationModels], + casing_transform=camel_to_snake, + ), + data_key=AutoMLConstants.BLOCKED_ALGORITHMS_YAML, + ) + + @post_load + def make(self, data, **kwargs) -> "ClassificationTrainingSettings": + return ClassificationTrainingSettings(**data) + + +class ForecastingTrainingSettingsSchema(TrainingSettingsSchema): + allowed_training_algorithms = fields.List( + StringTransformedEnum( + allowed_values=[o.value for o in ForecastingModels], + casing_transform=camel_to_snake, + ), + data_key=AutoMLConstants.ALLOWED_ALGORITHMS_YAML, + ) + blocked_training_algorithms = fields.List( + StringTransformedEnum( + allowed_values=[o.value for o in ForecastingModels], + casing_transform=camel_to_snake, + ), + data_key=AutoMLConstants.BLOCKED_ALGORITHMS_YAML, + ) + + @post_load + def make(self, data, **kwargs) -> "ForecastingTrainingSettings": + return ForecastingTrainingSettings(**data) + + +class RegressionTrainingSettingsSchema(TrainingSettingsSchema): + allowed_training_algorithms = fields.List( + StringTransformedEnum( + allowed_values=[o.value for o in RegressionModels], + casing_transform=camel_to_snake, + ), + data_key=AutoMLConstants.ALLOWED_ALGORITHMS_YAML, + ) + blocked_training_algorithms = fields.List( + StringTransformedEnum( + allowed_values=[o.value for o in RegressionModels], + casing_transform=camel_to_snake, + ), + data_key=AutoMLConstants.BLOCKED_ALGORITHMS_YAML, + ) + + @post_load + def make(self, data, **kwargs) -> "RegressionTrainingSettings": + return RegressionTrainingSettings(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/__init__.py new file mode 100644 index 00000000..1b92f18e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/__init__.py @@ -0,0 +1,48 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore + +from .command_component import AnonymousCommandComponentSchema, CommandComponentSchema, ComponentFileRefField +from .component import ComponentSchema, ComponentYamlRefField +from .data_transfer_component import ( + AnonymousDataTransferCopyComponentSchema, + AnonymousDataTransferExportComponentSchema, + AnonymousDataTransferImportComponentSchema, + DataTransferCopyComponentFileRefField, + DataTransferCopyComponentSchema, + DataTransferExportComponentFileRefField, + DataTransferExportComponentSchema, + DataTransferImportComponentFileRefField, + DataTransferImportComponentSchema, +) +from .import_component import AnonymousImportComponentSchema, ImportComponentFileRefField, ImportComponentSchema +from .parallel_component import AnonymousParallelComponentSchema, ParallelComponentFileRefField, ParallelComponentSchema +from .spark_component import AnonymousSparkComponentSchema, SparkComponentFileRefField, SparkComponentSchema + +__all__ = [ + "ComponentSchema", + "CommandComponentSchema", + "AnonymousCommandComponentSchema", + "ComponentFileRefField", + "ParallelComponentSchema", + "AnonymousParallelComponentSchema", + "ParallelComponentFileRefField", + "ImportComponentSchema", + "AnonymousImportComponentSchema", + "ImportComponentFileRefField", + "AnonymousSparkComponentSchema", + "SparkComponentFileRefField", + "SparkComponentSchema", + "AnonymousDataTransferCopyComponentSchema", + "DataTransferCopyComponentFileRefField", + "DataTransferCopyComponentSchema", + "AnonymousDataTransferImportComponentSchema", + "DataTransferImportComponentFileRefField", + "DataTransferImportComponentSchema", + "AnonymousDataTransferExportComponentSchema", + "DataTransferExportComponentFileRefField", + "DataTransferExportComponentSchema", + "ComponentYamlRefField", +] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/automl_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/automl_component.py new file mode 100644 index 00000000..aef98cca --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/automl_component.py @@ -0,0 +1,23 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from azure.ai.ml._restclient.v2022_10_01_preview.models import TaskType +from azure.ai.ml._schema.component.component import ComponentSchema +from azure.ai.ml._schema.core.fields import StringTransformedEnum +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.constants import JobType + + +class AutoMLComponentSchema(ComponentSchema): + """AutoMl component schema. + + Only has type & task property with basic component properties. No inputs & outputs are allowed. + """ + + type = StringTransformedEnum(required=True, allowed_values=JobType.AUTOML) + task = StringTransformedEnum( + # TODO: verify if this works + allowed_values=[t for t in TaskType], # pylint: disable=unnecessary-comprehension + casing_transform=camel_to_snake, + required=True, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/command_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/command_component.py new file mode 100644 index 00000000..9d688ee0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/command_component.py @@ -0,0 +1,137 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument,protected-access +from copy import deepcopy + +import yaml +from marshmallow import INCLUDE, fields, post_dump, post_load + +from azure.ai.ml._schema.assets.asset import AnonymousAssetSchema +from azure.ai.ml._schema.component.component import ComponentSchema +from azure.ai.ml._schema.component.input_output import ( + OutputPortSchema, + PrimitiveOutputSchema, +) +from azure.ai.ml._schema.component.resource import ComponentResourceSchema +from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta +from azure.ai.ml._schema.core.fields import ( + ExperimentalField, + FileRefField, + NestedField, + StringTransformedEnum, + UnionField, +) +from azure.ai.ml._schema.job.distribution import ( + MPIDistributionSchema, + PyTorchDistributionSchema, + TensorFlowDistributionSchema, + RayDistributionSchema, +) +from azure.ai.ml._schema.job.parameterized_command import ParameterizedCommandSchema +from azure.ai.ml._utils.utils import is_private_preview_enabled +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, AzureDevopsArtifactsType +from azure.ai.ml.constants._component import ComponentSource, NodeType + + +class AzureDevopsArtifactsSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum(allowed_values=[AzureDevopsArtifactsType.ARTIFACT]) + feed = fields.Str() + name = fields.Str() + version = fields.Str() + scope = fields.Str() + organization = fields.Str() + project = fields.Str() + + +class CommandComponentSchema(ComponentSchema, ParameterizedCommandSchema): + class Meta: + exclude = ["environment_variables"] # component doesn't have environment variables + + type = StringTransformedEnum(allowed_values=[NodeType.COMMAND]) + resources = NestedField(ComponentResourceSchema, unknown=INCLUDE) + distribution = UnionField( + [ + NestedField(MPIDistributionSchema, unknown=INCLUDE), + NestedField(TensorFlowDistributionSchema, unknown=INCLUDE), + NestedField(PyTorchDistributionSchema, unknown=INCLUDE), + ExperimentalField(NestedField(RayDistributionSchema, unknown=INCLUDE)), + ], + metadata={"description": "Provides the configuration for a distributed run."}, + ) + # primitive output is only supported for command component & pipeline component + outputs = fields.Dict( + keys=fields.Str(), + values=UnionField( + [ + NestedField(OutputPortSchema), + NestedField(PrimitiveOutputSchema, unknown=INCLUDE), + ] + ), + ) + properties = fields.Dict(keys=fields.Str(), values=fields.Raw()) + + # Note: AzureDevopsArtifactsSchema only available when private preview flag opened before init of command component + # schema class. + if is_private_preview_enabled(): + additional_includes = fields.List(UnionField([fields.Str(), NestedField(AzureDevopsArtifactsSchema)])) + else: + additional_includes = fields.List(fields.Str()) + + @post_dump + def remove_unnecessary_fields(self, component_schema_dict, **kwargs): + # remove empty properties to keep the component spec unchanged + if not component_schema_dict.get("properties"): + component_schema_dict.pop("properties", None) + if ( + component_schema_dict.get("additional_includes") is not None + and len(component_schema_dict["additional_includes"]) == 0 + ): + component_schema_dict.pop("additional_includes") + return component_schema_dict + + +class RestCommandComponentSchema(CommandComponentSchema): + """When component load from rest, won't validate on name since there might be existing component with invalid + name.""" + + name = fields.Str(required=True) + + +class AnonymousCommandComponentSchema(AnonymousAssetSchema, CommandComponentSchema): + """Anonymous command component schema. + + Note inheritance follows order: AnonymousAssetSchema, CommandComponentSchema because we need name and version to be + dump_only(marshmallow collects fields follows method resolution order). + """ + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import CommandComponent + + # Inline component will have source=YAML.JOB + # As we only regard full separate component file as YAML.COMPONENT + return CommandComponent( + base_path=self.context[BASE_PATH_CONTEXT_KEY], + _source=ComponentSource.YAML_JOB, + **data, + ) + + +class ComponentFileRefField(FileRefField): + def _deserialize(self, value, attr, data, **kwargs): + # Get component info from component yaml file. + data = super()._deserialize(value, attr, data, **kwargs) + component_dict = yaml.safe_load(data) + source_path = self.context[BASE_PATH_CONTEXT_KEY] / value + + # Update base_path to parent path of component file. + component_schema_context = deepcopy(self.context) + component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent + component = AnonymousCommandComponentSchema(context=component_schema_context).load( + component_dict, unknown=INCLUDE + ) + component._source_path = source_path + component._source = ComponentSource.YAML_COMPONENT + return component diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/component.py new file mode 100644 index 00000000..5772a607 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/component.py @@ -0,0 +1,143 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from pathlib import Path + +from marshmallow import ValidationError, fields, post_dump, pre_dump, pre_load +from marshmallow.fields import Field + +from azure.ai.ml._schema.component.input_output import InputPortSchema, OutputPortSchema, ParameterSchema +from azure.ai.ml._schema.core.fields import ( + ArmVersionedStr, + ExperimentalField, + NestedField, + PythonFuncNameStr, + UnionField, +) +from azure.ai.ml._schema.core.intellectual_property import IntellectualPropertySchema +from azure.ai.ml._utils.utils import is_private_preview_enabled, load_yaml +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, AzureMLResourceType + +from .._utils.utils import _resolve_group_inputs_for_component +from ..assets.asset import AssetSchema +from ..core.fields import RegistryStr + + +class ComponentNameStr(PythonFuncNameStr): + def _get_field_name(self): + return "Component" + + +class ComponentYamlRefField(Field): + """Allows you to nest a :class:`Schema <marshmallow.Schema>` + inside a yaml ref field. + """ + + def _jsonschema_type_mapping(self): + schema = {"type": "string"} + if self.name is not None: + schema["title"] = self.name + if self.dump_only: + schema["readonly"] = True + return schema + + def _deserialize(self, value, attr, data, **kwargs): + if not isinstance(value, str): + raise ValidationError(f"Nested yaml ref field expected a string but got {type(value)}.") + + base_path = Path(self.context[BASE_PATH_CONTEXT_KEY]) + + source_path = Path(value) + # raise if the string is not a valid path, like "azureml:xxx" + try: + source_path.resolve() + except OSError as ex: + raise ValidationError(f"Nested file ref field expected a local path but got {value}.") from ex + + if not source_path.is_absolute(): + source_path = base_path / source_path + + if not source_path.is_file(): + raise ValidationError( + f"Nested yaml ref field expected a local path but can't find {value} based on {base_path.as_posix()}." + ) + + loaded_value = load_yaml(source_path) + + # local import to avoid circular import + from azure.ai.ml.entities import Component + + component = Component._load(data=loaded_value, yaml_path=source_path) # pylint: disable=protected-access + return component + + def _serialize(self, value, attr, obj, **kwargs): + raise ValidationError("Serialize on RefField is not supported.") + + +class ComponentSchema(AssetSchema): + schema = fields.Str(data_key="$schema", attribute="_schema") + name = ComponentNameStr(required=True) + id = UnionField( + [ + RegistryStr(dump_only=True), + ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, dump_only=True), + ] + ) + display_name = fields.Str() + description = fields.Str() + tags = fields.Dict(keys=fields.Str(), values=fields.Str()) + is_deterministic = fields.Bool() + inputs = fields.Dict( + keys=fields.Str(), + values=UnionField( + [ + NestedField(ParameterSchema), + NestedField(InputPortSchema), + ] + ), + ) + outputs = fields.Dict( + keys=fields.Str(), + values=NestedField(OutputPortSchema), + ) + # hide in private preview + if is_private_preview_enabled(): + intellectual_property = ExperimentalField(NestedField(IntellectualPropertySchema)) + + def __init__(self, *args, **kwargs): + # Remove schema_ignored to enable serialize and deserialize schema. + self._declared_fields.pop("schema_ignored", None) + super().__init__(*args, **kwargs) + + @pre_load + def convert_version_to_str(self, data, **kwargs): # pylint: disable=unused-argument + if isinstance(data, dict) and data.get("version", None): + data["version"] = str(data["version"]) + return data + + @pre_dump + def add_private_fields_to_dump(self, data, **kwargs): # pylint: disable=unused-argument + # The ipp field is set on the component object as "_intellectual_property". + # We need to set it as "intellectual_property" before dumping so that Marshmallow + # can pick up the field correctly on dump and show it back to the user. + ipp_field = data._intellectual_property # pylint: disable=protected-access + if ipp_field: + setattr(data, "intellectual_property", ipp_field) + return data + + @post_dump + def convert_input_value_to_str(self, data, **kwargs): # pylint:disable=unused-argument + if isinstance(data, dict) and data.get("inputs", None): + input_dict = data["inputs"] + for input_value in input_dict.values(): + input_type = input_value.get("type", None) + if isinstance(input_type, str) and input_type.lower() == "float": + # Convert number to string to avoid precision issue + for key in ["default", "min", "max"]: + if input_value.get(key, None) is not None: + input_value[key] = str(input_value[key]) + return data + + @pre_dump + def flatten_group_inputs(self, data, **kwargs): # pylint: disable=unused-argument + return _resolve_group_inputs_for_component(data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/data_transfer_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/data_transfer_component.py new file mode 100644 index 00000000..70035d57 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/data_transfer_component.py @@ -0,0 +1,257 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from copy import deepcopy + +import yaml +from marshmallow import INCLUDE, fields, post_load, validates, ValidationError + +from azure.ai.ml._schema.assets.asset import AnonymousAssetSchema +from azure.ai.ml._schema.component.component import ComponentSchema +from azure.ai.ml._schema.component.input_output import InputPortSchema +from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta +from azure.ai.ml._schema.core.fields import FileRefField, StringTransformedEnum, NestedField +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, AssetTypes +from azure.ai.ml.constants._component import ( + ComponentSource, + NodeType, + DataTransferTaskType, + DataCopyMode, + ExternalDataType, +) + + +class DataTransferComponentSchemaMixin(ComponentSchema): + type = StringTransformedEnum(allowed_values=[NodeType.DATA_TRANSFER]) + + +class DataTransferCopyComponentSchema(DataTransferComponentSchemaMixin): + task = StringTransformedEnum(allowed_values=[DataTransferTaskType.COPY_DATA], required=True) + data_copy_mode = StringTransformedEnum( + allowed_values=[DataCopyMode.MERGE_WITH_OVERWRITE, DataCopyMode.FAIL_IF_CONFLICT] + ) + inputs = fields.Dict( + keys=fields.Str(), + values=NestedField(InputPortSchema), + ) + + @validates("outputs") + def outputs_key(self, value): + outputs_count = len(value) + if outputs_count != 1: + msg = "Only support single output in {}, but there're {} outputs." + raise ValidationError( + message=msg.format(DataTransferTaskType.COPY_DATA, outputs_count), field_name="outputs" + ) + + +class SinkSourceSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum( + allowed_values=[ExternalDataType.FILE_SYSTEM, ExternalDataType.DATABASE], required=True + ) + + +class SourceInputsSchema(metaclass=PatchedSchemaMeta): + """ + For export task in DataTransfer, inputs type only support uri_file for database and uri_folder for filesystem. + """ + + type = StringTransformedEnum(allowed_values=[AssetTypes.URI_FOLDER, AssetTypes.URI_FILE], required=True) + + +class SinkOutputsSchema(metaclass=PatchedSchemaMeta): + """ + For import task in DataTransfer, outputs type only support mltable for database and uri_folder for filesystem; + """ + + type = StringTransformedEnum(allowed_values=[AssetTypes.MLTABLE, AssetTypes.URI_FOLDER], required=True) + + +class DataTransferImportComponentSchema(DataTransferComponentSchemaMixin): + task = StringTransformedEnum(allowed_values=[DataTransferTaskType.IMPORT_DATA], required=True) + source = NestedField(SinkSourceSchema, required=True) + outputs = fields.Dict( + keys=fields.Str(), + values=NestedField(SinkOutputsSchema), + ) + + @validates("inputs") + def inputs_key(self, value): + raise ValidationError(f"inputs field is not a valid filed in task type " f"{DataTransferTaskType.IMPORT_DATA}.") + + @validates("outputs") + def outputs_key(self, value): + if len(value) != 1 or value and list(value.keys())[0] != "sink": + raise ValidationError( + f"outputs field only support one output called sink in task type " + f"{DataTransferTaskType.IMPORT_DATA}." + ) + + +class DataTransferExportComponentSchema(DataTransferComponentSchemaMixin): + task = StringTransformedEnum(allowed_values=[DataTransferTaskType.EXPORT_DATA], required=True) + inputs = fields.Dict( + keys=fields.Str(), + values=NestedField(SourceInputsSchema), + ) + sink = NestedField(SinkSourceSchema(), required=True) + + @validates("inputs") + def inputs_key(self, value): + if len(value) != 1 or value and list(value.keys())[0] != "source": + raise ValidationError( + f"inputs field only support one input called source in task type " + f"{DataTransferTaskType.EXPORT_DATA}." + ) + + @validates("outputs") + def outputs_key(self, value): + raise ValidationError( + f"outputs field is not a valid filed in task type " f"{DataTransferTaskType.EXPORT_DATA}." + ) + + +class RestDataTransferCopyComponentSchema(DataTransferCopyComponentSchema): + """When component load from rest, won't validate on name since there might + be existing component with invalid name.""" + + name = fields.Str(required=True) + + +class RestDataTransferImportComponentSchema(DataTransferImportComponentSchema): + """When component load from rest, won't validate on name since there might + be existing component with invalid name.""" + + name = fields.Str(required=True) + + +class RestDataTransferExportComponentSchema(DataTransferExportComponentSchema): + """When component load from rest, won't validate on name since there might + be existing component with invalid name.""" + + name = fields.Str(required=True) + + +class AnonymousDataTransferCopyComponentSchema(AnonymousAssetSchema, DataTransferCopyComponentSchema): + """Anonymous data transfer copy component schema. + + Note inheritance follows order: AnonymousAssetSchema, + AnonymousDataTransferCopyComponentSchema because we need name and version to be + dump_only(marshmallow collects fields follows method resolution + order). + """ + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._component.datatransfer_component import DataTransferCopyComponent + + # Inline component will have source=YAML.JOB + # As we only regard full separate component file as YAML.COMPONENT + return DataTransferCopyComponent( + base_path=self.context[BASE_PATH_CONTEXT_KEY], + _source=kwargs.pop("_source", ComponentSource.YAML_JOB), + **data, + ) + + +# pylint: disable-next=name-too-long +class AnonymousDataTransferImportComponentSchema(AnonymousAssetSchema, DataTransferImportComponentSchema): + """Anonymous data transfer import component schema. + + Note inheritance follows order: AnonymousAssetSchema, + DataTransferImportComponentSchema because we need name and version to be + dump_only(marshmallow collects fields follows method resolution + order). + """ + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._component.datatransfer_component import DataTransferImportComponent + + # Inline component will have source=YAML.JOB + # As we only regard full separate component file as YAML.COMPONENT + return DataTransferImportComponent( + base_path=self.context[BASE_PATH_CONTEXT_KEY], + _source=kwargs.pop("_source", ComponentSource.YAML_JOB), + **data, + ) + + +# pylint: disable-next=name-too-long +class AnonymousDataTransferExportComponentSchema(AnonymousAssetSchema, DataTransferExportComponentSchema): + """Anonymous data transfer export component schema. + + Note inheritance follows order: AnonymousAssetSchema, + DataTransferExportComponentSchema because we need name and version to be + dump_only(marshmallow collects fields follows method resolution + order). + """ + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._component.datatransfer_component import DataTransferExportComponent + + # Inline component will have source=YAML.JOB + # As we only regard full separate component file as YAML.COMPONENT + return DataTransferExportComponent( + base_path=self.context[BASE_PATH_CONTEXT_KEY], + _source=kwargs.pop("_source", ComponentSource.YAML_JOB), + **data, + ) + + +class DataTransferCopyComponentFileRefField(FileRefField): + def _deserialize(self, value, attr, data, **kwargs): + # Get component info from component yaml file. + data = super()._deserialize(value, attr, data, **kwargs) + component_dict = yaml.safe_load(data) + source_path = self.context[BASE_PATH_CONTEXT_KEY] / value + + # Update base_path to parent path of component file. + component_schema_context = deepcopy(self.context) + component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent + component = AnonymousDataTransferCopyComponentSchema(context=component_schema_context).load( + component_dict, unknown=INCLUDE + ) + component._source_path = source_path + component._source = ComponentSource.YAML_COMPONENT + return component + + +class DataTransferImportComponentFileRefField(FileRefField): + def _deserialize(self, value, attr, data, **kwargs): + # Get component info from component yaml file. + data = super()._deserialize(value, attr, data, **kwargs) + component_dict = yaml.safe_load(data) + source_path = self.context[BASE_PATH_CONTEXT_KEY] / value + + # Update base_path to parent path of component file. + component_schema_context = deepcopy(self.context) + component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent + component = AnonymousDataTransferImportComponentSchema(context=component_schema_context).load( + component_dict, unknown=INCLUDE + ) + component._source_path = source_path + component._source = ComponentSource.YAML_COMPONENT + return component + + +class DataTransferExportComponentFileRefField(FileRefField): + def _deserialize(self, value, attr, data, **kwargs): + # Get component info from component yaml file. + data = super()._deserialize(value, attr, data, **kwargs) + component_dict = yaml.safe_load(data) + source_path = self.context[BASE_PATH_CONTEXT_KEY] / value + + # Update base_path to parent path of component file. + component_schema_context = deepcopy(self.context) + component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent + component = AnonymousDataTransferExportComponentSchema(context=component_schema_context).load( + component_dict, unknown=INCLUDE + ) + component._source_path = source_path + component._source = ComponentSource.YAML_COMPONENT + return component diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/flow.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/flow.py new file mode 100644 index 00000000..848220d3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/flow.py @@ -0,0 +1,107 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields + +from azure.ai.ml._schema import YamlFileSchema +from azure.ai.ml._schema.component import ComponentSchema +from azure.ai.ml._schema.component.component import ComponentNameStr +from azure.ai.ml._schema.core.fields import ( + ArmVersionedStr, + EnvironmentField, + LocalPathField, + NestedField, + StringTransformedEnum, + UnionField, +) +from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta +from azure.ai.ml.constants._common import AzureMLResourceType +from azure.ai.ml.constants._component import NodeType + + +class _ComponentMetadataSchema(metaclass=PatchedSchemaMeta): + """Schema to recognize metadata of a flow as a component.""" + + name = ComponentNameStr() + version = fields.Str() + display_name = fields.Str() + description = fields.Str() + tags = fields.Dict(keys=fields.Str(), values=fields.Str()) + + +class _FlowAttributesSchema(metaclass=PatchedSchemaMeta): + """Schema to recognize attributes of a flow.""" + + variant = fields.Str() + column_mappings = fields.Dict( + fields.Str(), + fields.Str(), + ) + connections = fields.Dict( + keys=fields.Str(), + values=fields.Dict( + keys=fields.Str(), + values=fields.Str(), + ), + ) + environment_variables = fields.Dict( + fields.Str(), + fields.Str(), + ) + + +class _FLowComponentOverridesSchema(metaclass=PatchedSchemaMeta): + environment = EnvironmentField() + is_deterministic = fields.Bool() + + +class _FlowComponentOverridableSchema(metaclass=PatchedSchemaMeta): + # the field name must be the same as azure.ai.ml.constants._common.PROMPTFLOW_AZUREML_OVERRIDE_KEY + azureml = NestedField(_FLowComponentOverridesSchema) + + +class FlowSchema(YamlFileSchema, _ComponentMetadataSchema, _FlowComponentOverridableSchema): + """Schema for flow.dag.yaml file.""" + + environment_variables = fields.Dict( + fields.Str(), + fields.Str(), + ) + additional_includes = fields.List(LocalPathField()) + + +class RunSchema(YamlFileSchema, _ComponentMetadataSchema, _FlowAttributesSchema, _FlowComponentOverridableSchema): + """Schema for run.yaml file.""" + + flow = LocalPathField(required=True) + + +class FlowComponentSchema(ComponentSchema, _FlowAttributesSchema, _FLowComponentOverridesSchema): + """FlowSchema and FlowRunSchema are used to load flow while FlowComponentSchema is used to dump flow.""" + + class Meta: + """Override this to exclude inputs & outputs as component doesn't have them.""" + + exclude = ["inputs", "outputs"] # component doesn't have inputs & outputs + + # TODO: name should be required? + name = ComponentNameStr() + + type = StringTransformedEnum(allowed_values=[NodeType.FLOW_PARALLEL], required=True) + + # name, version, tags, display_name and is_deterministic are inherited from ComponentSchema + properties = fields.Dict( + fields.Str(), + fields.Str(), + ) + + # this is different from regular CodeField + code = UnionField( + [ + LocalPathField(), + ArmVersionedStr(azureml_type=AzureMLResourceType.CODE), + ], + metadata={"description": "A local path or http:, https:, azureml: url pointing to a remote location."}, + ) + additional_includes = fields.List(LocalPathField(), load_only=True) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/import_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/import_component.py new file mode 100644 index 00000000..b0ec14ea --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/import_component.py @@ -0,0 +1,74 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint: disable=protected-access +from copy import deepcopy + +import yaml +from marshmallow import INCLUDE, fields, post_load, validate + +from azure.ai.ml._schema.assets.asset import AnonymousAssetSchema +from azure.ai.ml._schema.component.component import ComponentSchema +from azure.ai.ml._schema.component.input_output import OutputPortSchema, ParameterSchema +from azure.ai.ml._schema.core.fields import FileRefField, NestedField, StringTransformedEnum +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY +from azure.ai.ml.constants._component import ComponentSource, NodeType + + +class ImportComponentSchema(ComponentSchema): + class Meta: + exclude = ["inputs", "outputs"] # inputs or outputs property not applicable to import job + + type = StringTransformedEnum(allowed_values=[NodeType.IMPORT]) + source = fields.Dict( + keys=fields.Str(validate=validate.OneOf(["type", "connection", "query", "path"])), + values=NestedField(ParameterSchema), + required=True, + ) + + output = NestedField(OutputPortSchema, required=True) + + +class RestCommandComponentSchema(ImportComponentSchema): + """When component load from rest, won't validate on name since there might be existing component with invalid + name.""" + + name = fields.Str(required=True) + + +class AnonymousImportComponentSchema(AnonymousAssetSchema, ImportComponentSchema): + """Anonymous command component schema. + + Note inheritance follows order: AnonymousAssetSchema, CommandComponentSchema because we need name and version to be + dump_only(marshmallow collects fields follows method resolution order). + """ + + @post_load + def make(self, data, **kwargs): # pylint: disable=unused-argument + from azure.ai.ml.entities._component.import_component import ImportComponent + + # Inline component will have source=YAML.JOB + # As we only regard full separate component file as YAML.COMPONENT + return ImportComponent( + base_path=self.context[BASE_PATH_CONTEXT_KEY], + _source=ComponentSource.YAML_JOB, + **data, + ) + + +class ImportComponentFileRefField(FileRefField): + def _deserialize(self, value, attr, data, **kwargs): + # Get component info from component yaml file. + data = super()._deserialize(value, attr, data, **kwargs) + component_dict = yaml.safe_load(data) + source_path = self.context[BASE_PATH_CONTEXT_KEY] / value + + # Update base_path to parent path of component file. + component_schema_context = deepcopy(self.context) + component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent + component = AnonymousImportComponentSchema(context=component_schema_context).load( + component_dict, unknown=INCLUDE + ) + component._source_path = source_path + component._source = ComponentSource.YAML_COMPONENT + return component diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/input_output.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/input_output.py new file mode 100644 index 00000000..9fef9489 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/input_output.py @@ -0,0 +1,126 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import INCLUDE, fields, pre_dump + +from azure.ai.ml._schema.core.fields import DumpableEnumField, ExperimentalField, NestedField, UnionField +from azure.ai.ml._schema.core.intellectual_property import ProtectionLevelSchema +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml._utils.utils import is_private_preview_enabled +from azure.ai.ml.constants._common import AssetTypes, InputOutputModes, LegacyAssetTypes +from azure.ai.ml.constants._component import ComponentParameterTypes + +# Here we use an adhoc way to collect all class constant attributes by checking if it's upper letter +# because making those constants enum will fail in string serialization in marshmallow +asset_type_obj = AssetTypes() +SUPPORTED_PORT_TYPES = [LegacyAssetTypes.PATH] + [ + getattr(asset_type_obj, k) for k in dir(asset_type_obj) if k.isupper() +] +param_obj = ComponentParameterTypes() +SUPPORTED_PARAM_TYPES = [getattr(param_obj, k) for k in dir(param_obj) if k.isupper()] + +input_output_type_obj = InputOutputModes() +# Link mode is only supported in component level currently +SUPPORTED_INPUT_OUTPUT_MODES = [ + getattr(input_output_type_obj, k) for k in dir(input_output_type_obj) if k.isupper() +] + ["link"] + + +class InputPortSchema(metaclass=PatchedSchemaMeta): + type = DumpableEnumField( + allowed_values=SUPPORTED_PORT_TYPES, + required=True, + ) + description = fields.Str() + optional = fields.Bool() + default = fields.Str() + mode = DumpableEnumField( + allowed_values=SUPPORTED_INPUT_OUTPUT_MODES, + ) + # hide in private preview + if is_private_preview_enabled(): + # only protection_level is allowed for inputs + intellectual_property = ExperimentalField(NestedField(ProtectionLevelSchema)) + + @pre_dump + def add_private_fields_to_dump(self, data, **kwargs): # pylint: disable=unused-argument + # The ipp field is set on the output object as "_intellectual_property". + # We need to set it as "intellectual_property" before dumping so that Marshmallow + # can pick up the field correctly on dump and show it back to the user. + if hasattr(data, "_intellectual_property"): + ipp_field = data._intellectual_property # pylint: disable=protected-access + if ipp_field: + setattr(data, "intellectual_property", ipp_field) + return data + + +class OutputPortSchema(metaclass=PatchedSchemaMeta): + type = DumpableEnumField( + allowed_values=SUPPORTED_PORT_TYPES, + required=True, + ) + description = fields.Str() + mode = DumpableEnumField( + allowed_values=SUPPORTED_INPUT_OUTPUT_MODES, + ) + # hide in private preview + if is_private_preview_enabled(): + # only protection_level is allowed for outputs + intellectual_property = ExperimentalField(NestedField(ProtectionLevelSchema)) + + @pre_dump + def add_private_fields_to_dump(self, data, **kwargs): # pylint: disable=unused-argument + # The ipp field is set on the output object as "_intellectual_property". + # We need to set it as "intellectual_property" before dumping so that Marshmallow + # can pick up the field correctly on dump and show it back to the user. + if hasattr(data, "_intellectual_property"): + ipp_field = data._intellectual_property # pylint: disable=protected-access + if ipp_field: + setattr(data, "intellectual_property", ipp_field) + return data + + +class PrimitiveOutputSchema(OutputPortSchema): + # Note: according to marshmallow doc on Handling Unknown Fields: + # https://marshmallow.readthedocs.io/en/stable/quickstart.html#handling-unknown-fields + # specify unknown at instantiation time will not take effect; + # still add here just for explicitly declare this behavior: + # primitive type output used in environment that private preview flag is not enabled. + class Meta: + unknown = INCLUDE + + type = DumpableEnumField( + allowed_values=SUPPORTED_PARAM_TYPES, + required=True, + ) + # hide early_available in spec + if is_private_preview_enabled(): + early_available = fields.Bool() + + # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype + def _serialize(self, obj, *, many: bool = False): + """Override to add private preview hidden fields + + :keyword many: Whether obj is a collection of objects. + :paramtype many: bool + """ + from azure.ai.ml.entities._job.pipeline._attr_dict import has_attr_safe + + ret = super()._serialize(obj, many=many) # pylint: disable=no-member + if has_attr_safe(obj, "early_available") and obj.early_available is not None and "early_available" not in ret: + ret["early_available"] = obj.early_available + return ret + + +class ParameterSchema(metaclass=PatchedSchemaMeta): + type = DumpableEnumField( + allowed_values=SUPPORTED_PARAM_TYPES, + required=True, + ) + optional = fields.Bool() + default = UnionField([fields.Str(), fields.Number(), fields.Bool()]) + description = fields.Str() + max = UnionField([fields.Str(), fields.Number()]) + min = UnionField([fields.Str(), fields.Number()]) + enum = fields.List(fields.Str()) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_component.py new file mode 100644 index 00000000..70f286a9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_component.py @@ -0,0 +1,108 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from copy import deepcopy + +import yaml +from marshmallow import INCLUDE, fields, post_load + +from azure.ai.ml._schema.assets.asset import AnonymousAssetSchema +from azure.ai.ml._schema.component.component import ComponentSchema +from azure.ai.ml._schema.component.parallel_task import ComponentParallelTaskSchema +from azure.ai.ml._schema.component.resource import ComponentResourceSchema +from azure.ai.ml._schema.component.retry_settings import RetrySettingsSchema +from azure.ai.ml._schema.core.fields import DumpableEnumField, FileRefField, NestedField, StringTransformedEnum +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, LoggingLevel +from azure.ai.ml.constants._component import ComponentSource, NodeType + + +class ParallelComponentSchema(ComponentSchema): + type = StringTransformedEnum(allowed_values=[NodeType.PARALLEL], required=True) + resources = NestedField(ComponentResourceSchema, unknown=INCLUDE) + logging_level = DumpableEnumField( + allowed_values=[LoggingLevel.DEBUG, LoggingLevel.INFO, LoggingLevel.WARN], + dump_default=LoggingLevel.INFO, + metadata={ + "description": "A string of the logging level name, which is defined in 'logging'. \ + Possible values are 'WARNING', 'INFO', and 'DEBUG'." + }, + ) + task = NestedField(ComponentParallelTaskSchema, unknown=INCLUDE) + mini_batch_size = fields.Str( + metadata={"description": "The The batch size of current job."}, + ) + partition_keys = fields.List( + fields.Str(), metadata={"description": "The keys used to partition input data into mini-batches"} + ) + + input_data = fields.Str() + retry_settings = NestedField(RetrySettingsSchema, unknown=INCLUDE) + max_concurrency_per_instance = fields.Integer( + dump_default=1, + metadata={"description": "The max parallellism that each compute instance has."}, + ) + error_threshold = fields.Integer( + dump_default=-1, + metadata={ + "description": "The number of item processing failures should be ignored. \ + If the error_threshold is reached, the job terminates. \ + For a list of files as inputs, one item means one file reference. \ + This setting doesn't apply to command parallelization." + }, + ) + mini_batch_error_threshold = fields.Integer( + dump_default=-1, + metadata={ + "description": "The number of mini batch processing failures should be ignored. \ + If the mini_batch_error_threshold is reached, the job terminates. \ + For a list of files as inputs, one item means one file reference. \ + This setting can be used by either command or python function parallelization. \ + Only one error_threshold setting can be used in one job." + }, + ) + + +class RestParallelComponentSchema(ParallelComponentSchema): + """When component load from rest, won't validate on name since there might be existing component with invalid + name.""" + + name = fields.Str(required=True) + + +class AnonymousParallelComponentSchema(AnonymousAssetSchema, ParallelComponentSchema): + """Anonymous parallel component schema. + + Note inheritance follows order: AnonymousAssetSchema, ParallelComponentSchema because we need name and version to be + dump_only(marshmallow collects fields follows method resolution order). + """ + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._component.parallel_component import ParallelComponent + + return ParallelComponent( + base_path=self.context[BASE_PATH_CONTEXT_KEY], + _source=kwargs.pop("_source", ComponentSource.YAML_JOB), + **data, + ) + + +class ParallelComponentFileRefField(FileRefField): + def _deserialize(self, value, attr, data, **kwargs): + # Get component info from component yaml file. + data = super()._deserialize(value, attr, data, **kwargs) + component_dict = yaml.safe_load(data) + source_path = self.context[BASE_PATH_CONTEXT_KEY] / value + + # Update base_path to parent path of component file. + component_schema_context = deepcopy(self.context) + component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent + component = AnonymousParallelComponentSchema(context=component_schema_context).load( + component_dict, unknown=INCLUDE + ) + component._source_path = source_path + component._source = ComponentSource.YAML_COMPONENT + return component diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_task.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_task.py new file mode 100644 index 00000000..390a6683 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_task.py @@ -0,0 +1,23 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +from marshmallow import fields + +from azure.ai.ml._schema.core.fields import CodeField, EnvironmentField, StringTransformedEnum +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml.constants import ParallelTaskType + + +class ComponentParallelTaskSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum( + allowed_values=[ParallelTaskType.RUN_FUNCTION, ParallelTaskType.MODEL, ParallelTaskType.FUNCTION], + required=True, + ) + code = CodeField() + entry_script = fields.Str() + program_arguments = fields.Str() + model = fields.Str() + append_row_to = fields.Str() + environment = EnvironmentField(required=True) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/resource.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/resource.py new file mode 100644 index 00000000..592d740c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/resource.py @@ -0,0 +1,22 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import INCLUDE, post_dump, post_load + +from azure.ai.ml._schema.job_resource_configuration import JobResourceConfigurationSchema + + +class ComponentResourceSchema(JobResourceConfigurationSchema): + class Meta: + unknown = INCLUDE + + @post_load + def make(self, data, **kwargs): + return data + + @post_dump(pass_original=True) + def dump_override(self, data, original, **kwargs): + return original diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/retry_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/retry_settings.py new file mode 100644 index 00000000..bac2c54d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/retry_settings.py @@ -0,0 +1,13 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields + +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml._schema.core.fields import DataBindingStr, UnionField + + +class RetrySettingsSchema(metaclass=PatchedSchemaMeta): + timeout = UnionField([fields.Int(), DataBindingStr]) + max_retries = UnionField([fields.Int(), DataBindingStr]) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/spark_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/spark_component.py new file mode 100644 index 00000000..445481ec --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/spark_component.py @@ -0,0 +1,79 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument,protected-access + +from copy import deepcopy + +import yaml +from marshmallow import INCLUDE, fields, post_dump, post_load + +from azure.ai.ml._schema.assets.asset import AnonymousAssetSchema +from azure.ai.ml._schema.component.component import ComponentSchema +from azure.ai.ml._schema.core.fields import FileRefField, StringTransformedEnum +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY +from azure.ai.ml.constants._component import ComponentSource, NodeType + +from ..job.parameterized_spark import ParameterizedSparkSchema + + +class SparkComponentSchema(ComponentSchema, ParameterizedSparkSchema): + type = StringTransformedEnum(allowed_values=[NodeType.SPARK]) + additional_includes = fields.List(fields.Str()) + + @post_dump + def remove_unnecessary_fields(self, component_schema_dict, **kwargs): + if ( + component_schema_dict.get("additional_includes") is not None + and len(component_schema_dict["additional_includes"]) == 0 + ): + component_schema_dict.pop("additional_includes") + return component_schema_dict + + +class RestSparkComponentSchema(SparkComponentSchema): + """When component load from rest, won't validate on name since there might + be existing component with invalid name.""" + + name = fields.Str(required=True) + + +class AnonymousSparkComponentSchema(AnonymousAssetSchema, SparkComponentSchema): + """Anonymous spark component schema. + + Note inheritance follows order: AnonymousAssetSchema, + SparkComponentSchema because we need name and version to be + dump_only(marshmallow collects fields follows method resolution + order). + """ + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._component.spark_component import SparkComponent + + # Inline component will have source=YAML.JOB + # As we only regard full separate component file as YAML.COMPONENT + return SparkComponent( + base_path=self.context[BASE_PATH_CONTEXT_KEY], + _source=kwargs.pop("_source", ComponentSource.YAML_JOB), + **data, + ) + + +class SparkComponentFileRefField(FileRefField): + def _deserialize(self, value, attr, data, **kwargs): + # Get component info from component yaml file. + data = super()._deserialize(value, attr, data, **kwargs) + component_dict = yaml.safe_load(data) + source_path = self.context[BASE_PATH_CONTEXT_KEY] / value + + # Update base_path to parent path of component file. + component_schema_context = deepcopy(self.context) + component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent + component = AnonymousSparkComponentSchema(context=component_schema_context).load( + component_dict, unknown=INCLUDE + ) + component._source_path = source_path + component._source = ComponentSource.YAML_COMPONENT + return component diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/__init__.py new file mode 100644 index 00000000..29a4fcd3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/aml_compute.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/aml_compute.py new file mode 100644 index 00000000..304b0eae --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/aml_compute.py @@ -0,0 +1,47 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields +from marshmallow.decorators import post_load + +from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta +from azure.ai.ml.constants._compute import ComputeTier, ComputeType, ComputeSizeTier + +from ..core.fields import NestedField, StringTransformedEnum, UnionField +from .compute import ComputeSchema, IdentitySchema, NetworkSettingsSchema + + +class AmlComputeSshSettingsSchema(metaclass=PatchedSchemaMeta): + admin_username = fields.Str() + admin_password = fields.Str() + ssh_key_value = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import AmlComputeSshSettings + + return AmlComputeSshSettings(**data) + + +class AmlComputeSchema(ComputeSchema): + type = StringTransformedEnum(allowed_values=[ComputeType.AMLCOMPUTE], required=True) + size = UnionField( + union_fields=[ + fields.Str(metadata={"arm_type": ComputeSizeTier.AML_COMPUTE_DEDICATED, "tier": ComputeTier.DEDICATED}), + fields.Str(metadata={"arm_type": ComputeSizeTier.AML_COMPUTE_LOWPRIORITY, "tier": ComputeTier.LOWPRIORITY}), + ], + ) + tier = StringTransformedEnum(allowed_values=[ComputeTier.LOWPRIORITY, ComputeTier.DEDICATED]) + min_instances = fields.Int() + max_instances = fields.Int() + idle_time_before_scale_down = fields.Int() + ssh_public_access_enabled = fields.Bool() + ssh_settings = NestedField(AmlComputeSshSettingsSchema) + network_settings = NestedField(NetworkSettingsSchema) + identity = NestedField(IdentitySchema) + enable_node_public_ip = fields.Bool( + metadata={"description": "Enable or disable node public IP address provisioning."} + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/aml_compute_node_info.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/aml_compute_node_info.py new file mode 100644 index 00000000..983f76f6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/aml_compute_node_info.py @@ -0,0 +1,15 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from marshmallow import fields + +from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta + + +class AmlComputeNodeInfoSchema(metaclass=PatchedSchemaMeta): + node_id = fields.Str() + private_ip_address = fields.Str() + public_ip_address = fields.Str() + port = fields.Str() + node_state = fields.Str() + current_job_name = fields.Str() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/attached_compute.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/attached_compute.py new file mode 100644 index 00000000..2ac4ce9e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/attached_compute.py @@ -0,0 +1,12 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from marshmallow import fields + +from .compute import ComputeSchema + + +class AttachedComputeSchema(ComputeSchema): + resource_id = fields.Str(required=True) + ssh_port = fields.Int() + compute_location = fields.Str() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/compute.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/compute.py new file mode 100644 index 00000000..4488b53d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/compute.py @@ -0,0 +1,85 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument +from marshmallow import fields +from marshmallow.decorators import post_load + +from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml._vendor.azure_resources.models._resource_management_client_enums import ResourceIdentityType +from azure.ai.ml.entities._credentials import ManagedIdentityConfiguration + +from ..core.schema import PathAwareSchema + + +class ComputeSchema(PathAwareSchema): + name = fields.Str(required=True) + id = fields.Str(dump_only=True) + type = fields.Str() + location = fields.Str() + description = fields.Str() + provisioning_errors = fields.Str(dump_only=True) + created_on = fields.Str(dump_only=True) + provisioning_state = fields.Str(dump_only=True) + resource_id = fields.Str() + tags = fields.Dict(keys=fields.Str(), values=fields.Str()) + + +class NetworkSettingsSchema(PathAwareSchema): + vnet_name = fields.Str() + subnet = fields.Str() + public_ip_address = fields.Str(dump_only=True) + private_ip_address = fields.Str(dump_only=True) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import NetworkSettings + + return NetworkSettings(**data) + + +class UserAssignedIdentitySchema(PathAwareSchema): + resource_id = fields.Str() + principal_id = fields.Str(dump_only=True) + client_id = fields.Str(dump_only=True) + tenant_id = fields.Str(dump_only=True) + + @post_load + def make(self, data, **kwargs): + return ManagedIdentityConfiguration(**data) + + +class IdentitySchema(PathAwareSchema): + type = StringTransformedEnum( + allowed_values=[ + ResourceIdentityType.SYSTEM_ASSIGNED, + ResourceIdentityType.USER_ASSIGNED, + ResourceIdentityType.NONE, + ResourceIdentityType.SYSTEM_ASSIGNED_USER_ASSIGNED, + ], + casing_transform=camel_to_snake, + metadata={"description": "resource identity type."}, + ) + user_assigned_identities = fields.List(NestedField(UserAssignedIdentitySchema)) + principal_id = fields.Str(dump_only=True) + tenant_id = fields.Str(dump_only=True) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import IdentityConfiguration + + user_assigned_identities_list = [] + user_assigned_identities = data.pop("user_assigned_identities", None) + if user_assigned_identities: + for identity in user_assigned_identities: + user_assigned_identities_list.append( + ManagedIdentityConfiguration( + resource_id=identity.get("resource_id", None), + client_id=identity.get("client_id", None), + object_id=identity.get("object_id", None), + ) + ) + data["user_assigned_identities"] = user_assigned_identities_list + return IdentityConfiguration(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/compute_instance.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/compute_instance.py new file mode 100644 index 00000000..c72e06bb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/compute_instance.py @@ -0,0 +1,83 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields +from marshmallow.decorators import post_load + +# pylint: disable=unused-argument +from azure.ai.ml._schema import PathAwareSchema +from azure.ai.ml.constants._compute import ComputeType, ComputeSizeTier + +from ..core.fields import ExperimentalField, NestedField, StringTransformedEnum +from .compute import ComputeSchema, IdentitySchema, NetworkSettingsSchema +from .schedule import ComputeSchedulesSchema +from .setup_scripts import SetupScriptsSchema +from .custom_applications import CustomApplicationsSchema + + +class ComputeInstanceSshSettingsSchema(PathAwareSchema): + admin_username = fields.Str(dump_only=True) + ssh_port = fields.Str(dump_only=True) + ssh_key_value = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import ComputeInstanceSshSettings + + return ComputeInstanceSshSettings(**data) + + +class CreateOnBehalfOfSchema(PathAwareSchema): + user_tenant_id = fields.Str() + user_object_id = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import AssignedUserConfiguration + + return AssignedUserConfiguration(**data) + + +class OsImageMetadataSchema(PathAwareSchema): + is_latest_os_image_version = fields.Bool(dump_only=True) + current_image_version = fields.Str(dump_only=True) + latest_image_version = fields.Str(dump_only=True) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import ImageMetadata + + return ImageMetadata(**data) + + +class ComputeInstanceSchema(ComputeSchema): + type = StringTransformedEnum(allowed_values=[ComputeType.COMPUTEINSTANCE], required=True) + size = fields.Str(metadata={"arm_type": ComputeSizeTier.COMPUTE_INSTANCE}) + network_settings = NestedField(NetworkSettingsSchema) + create_on_behalf_of = NestedField(CreateOnBehalfOfSchema) + ssh_settings = NestedField(ComputeInstanceSshSettingsSchema) + ssh_public_access_enabled = fields.Bool(dump_default=None) + state = fields.Str(dump_only=True) + last_operation = fields.Dict(keys=fields.Str(), values=fields.Str(), dump_only=True) + services = fields.List(fields.Dict(keys=fields.Str(), values=fields.Str()), dump_only=True) + schedules = NestedField(ComputeSchedulesSchema) + identity = ExperimentalField(NestedField(IdentitySchema)) + idle_time_before_shutdown = fields.Str() + idle_time_before_shutdown_minutes = fields.Int() + custom_applications = fields.List(NestedField(CustomApplicationsSchema)) + setup_scripts = NestedField(SetupScriptsSchema) + os_image_metadata = NestedField(OsImageMetadataSchema, dump_only=True) + enable_node_public_ip = fields.Bool( + metadata={"description": "Enable or disable node public IP address provisioning."} + ) + enable_sso = fields.Bool(metadata={"description": "Enable or disable single sign-on for the compute instance."}) + enable_root_access = fields.Bool( + metadata={"description": "Enable or disable root access for the compute instance."} + ) + release_quota_on_stop = fields.Bool( + metadata={"description": "Release quota on stop for the compute instance. Defaults to False."} + ) + enable_os_patching = fields.Bool( + metadata={"description": "Enable or disable OS patching for the compute instance. Defaults to False."} + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/custom_applications.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/custom_applications.py new file mode 100644 index 00000000..66fa587c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/custom_applications.py @@ -0,0 +1,60 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument +from marshmallow import fields +from marshmallow.decorators import post_load + +from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum +from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta +from azure.ai.ml.constants._compute import CustomApplicationDefaults + + +class ImageSettingsSchema(metaclass=PatchedSchemaMeta): + reference = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._compute._custom_applications import ImageSettings + + return ImageSettings(**data) + + +class EndpointsSettingsSchema(metaclass=PatchedSchemaMeta): + target = fields.Int() + published = fields.Int() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._compute._custom_applications import EndpointsSettings + + return EndpointsSettings(**data) + + +class VolumeSettingsSchema(metaclass=PatchedSchemaMeta): + source = fields.Str() + target = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._compute._custom_applications import VolumeSettings + + return VolumeSettings(**data) + + +class CustomApplicationsSchema(metaclass=PatchedSchemaMeta): + name = fields.Str(required=True) + type = StringTransformedEnum(allowed_values=[CustomApplicationDefaults.DOCKER]) + image = NestedField(ImageSettingsSchema) + endpoints = fields.List(NestedField(EndpointsSettingsSchema)) + environment_variables = fields.Dict() + bind_mounts = fields.List(NestedField(VolumeSettingsSchema)) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._compute._custom_applications import ( + CustomApplications, + ) + + return CustomApplications(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/kubernetes_compute.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/kubernetes_compute.py new file mode 100644 index 00000000..a84102ca --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/kubernetes_compute.py @@ -0,0 +1,16 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from marshmallow import fields + +from azure.ai.ml.constants._compute import ComputeType + +from ..core.fields import NestedField, StringTransformedEnum +from .compute import ComputeSchema, IdentitySchema + + +class KubernetesComputeSchema(ComputeSchema): + type = StringTransformedEnum(allowed_values=[ComputeType.KUBERNETES], required=True) + namespace = fields.Str(required=True, dump_default="default") + properties = fields.Dict() + identity = NestedField(IdentitySchema) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/schedule.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/schedule.py new file mode 100644 index 00000000..49f41edf --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/schedule.py @@ -0,0 +1,118 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields +from marshmallow.decorators import post_load + +from azure.ai.ml._restclient.v2022_10_01_preview.models import ComputePowerAction, RecurrenceFrequency +from azure.ai.ml._restclient.v2022_10_01_preview.models import ScheduleStatus as ScheduleState +from azure.ai.ml._restclient.v2022_10_01_preview.models import TriggerType, WeekDay +from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum, UnionField +from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta + + +class BaseTriggerSchema(metaclass=PatchedSchemaMeta): + start_time = fields.Str() + time_zone = fields.Str() + + +class CronTriggerSchema(BaseTriggerSchema): + type = StringTransformedEnum(required=True, allowed_values=TriggerType.CRON) + expression = fields.Str(required=True) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import CronTrigger + + data.pop("type") + return CronTrigger(**data) + + +class RecurrenceScheduleSchema(metaclass=PatchedSchemaMeta): + week_days = fields.List( + StringTransformedEnum( + allowed_values=[ + WeekDay.SUNDAY, + WeekDay.MONDAY, + WeekDay.TUESDAY, + WeekDay.WEDNESDAY, + WeekDay.THURSDAY, + WeekDay.FRIDAY, + WeekDay.SATURDAY, + ], + ) + ) + hours = fields.List(fields.Int()) + minutes = fields.List(fields.Int()) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import RecurrencePattern + + return RecurrencePattern(**data) + + +class RecurrenceTriggerSchema(BaseTriggerSchema): + type = StringTransformedEnum(required=True, allowed_values=TriggerType.RECURRENCE) + frequency = StringTransformedEnum( + required=True, + allowed_values=[ + RecurrenceFrequency.MINUTE, + RecurrenceFrequency.HOUR, + RecurrenceFrequency.DAY, + RecurrenceFrequency.WEEK, + RecurrenceFrequency.MONTH, + ], + ) + interval = fields.Int() + schedule = NestedField(RecurrenceScheduleSchema) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import RecurrenceTrigger + + data.pop("type") + return RecurrenceTrigger(**data) + + +class ComputeStartStopScheduleSchema(metaclass=PatchedSchemaMeta): + trigger = UnionField( + [ + NestedField(CronTriggerSchema()), + NestedField(RecurrenceTriggerSchema()), + ], + ) + action = StringTransformedEnum( + required=True, + allowed_values=[ + ComputePowerAction.START, + ComputePowerAction.STOP, + ], + ) + state = StringTransformedEnum( + allowed_values=[ + ScheduleState.ENABLED, + ScheduleState.DISABLED, + ], + ) + schedule_id = fields.Str(dump_only=True) + provisioning_state = fields.Str(dump_only=True) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import ComputeStartStopSchedule + + return ComputeStartStopSchedule(**data) + + +class ComputeSchedulesSchema(metaclass=PatchedSchemaMeta): + compute_start_stop = fields.List(NestedField(ComputeStartStopScheduleSchema)) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import ComputeSchedules + + return ComputeSchedules(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/setup_scripts.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/setup_scripts.py new file mode 100644 index 00000000..da3f3c14 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/setup_scripts.py @@ -0,0 +1,33 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument +from marshmallow import fields +from marshmallow.decorators import post_load + +from azure.ai.ml._schema.core.fields import NestedField +from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta + + +class ScriptReferenceSchema(metaclass=PatchedSchemaMeta): + path = fields.Str() + command = fields.Str() + timeout_minutes = fields.Int() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._compute._setup_scripts import ScriptReference + + return ScriptReference(**data) + + +class SetupScriptsSchema(metaclass=PatchedSchemaMeta): + creation_script = NestedField(ScriptReferenceSchema()) + startup_script = NestedField(ScriptReferenceSchema()) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._compute._setup_scripts import SetupScripts + + return SetupScripts(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/synapsespark_compute.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/synapsespark_compute.py new file mode 100644 index 00000000..11760186 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/synapsespark_compute.py @@ -0,0 +1,49 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields +from marshmallow.decorators import post_load + +from azure.ai.ml.constants._compute import ComputeType + +from ..core.fields import NestedField, StringTransformedEnum +from ..core.schema import PathAwareSchema +from .compute import ComputeSchema, IdentitySchema + + +class AutoScaleSettingsSchema(PathAwareSchema): + min_node_count = fields.Int(dump_only=True) + max_node_count = fields.Int(dump_only=True) + auto_scale_enabled = fields.Bool(dump_only=True) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import AutoScaleSettings + + return AutoScaleSettings(**data) + + +class AutoPauseSettingsSchema(PathAwareSchema): + delay_in_minutes = fields.Int(dump_only=True) + auto_pause_enabled = fields.Bool(dump_only=True) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import AutoPauseSettings + + return AutoPauseSettings(**data) + + +class SynapseSparkComputeSchema(ComputeSchema): + type = StringTransformedEnum(allowed_values=[ComputeType.SYNAPSESPARK], required=True) + resource_id = fields.Str(required=True) + identity = NestedField(IdentitySchema) + node_family = fields.Str(dump_only=True) + node_size = fields.Str(dump_only=True) + node_count = fields.Int(dump_only=True) + spark_version = fields.Str(dump_only=True) + scale_settings = NestedField(AutoScaleSettingsSchema) + auto_pause_settings = NestedField(AutoPauseSettingsSchema) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/usage.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/usage.py new file mode 100644 index 00000000..4860946b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/usage.py @@ -0,0 +1,42 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields +from marshmallow.decorators import post_load + +from azure.ai.ml._restclient.v2022_10_01_preview.models import UsageUnit +from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum, UnionField +from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta +from azure.ai.ml._utils.utils import camel_to_snake + + +class UsageNameSchema(metaclass=PatchedSchemaMeta): + value = fields.Str() + localized_value = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import UsageName + + return UsageName(**data) + + +class UsageSchema(metaclass=PatchedSchemaMeta): + id = fields.Str() + aml_workspace_location = fields.Str() + type = fields.Str() + unit = UnionField( + [ + fields.Str(), + StringTransformedEnum( + allowed_values=UsageUnit.COUNT, + casing_transform=camel_to_snake, + ), + ] + ) + current_value = fields.Int() + limit = fields.Int() + name = NestedField(UsageNameSchema) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/virtual_machine_compute.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/virtual_machine_compute.py new file mode 100644 index 00000000..deb92d3c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/virtual_machine_compute.py @@ -0,0 +1,34 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields +from marshmallow.decorators import post_load + +from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta +from azure.ai.ml.constants._compute import ComputeType + +from ..core.fields import NestedField, StringTransformedEnum +from .compute import ComputeSchema + + +class VirtualMachineSshSettingsSchema(metaclass=PatchedSchemaMeta): + admin_username = fields.Str() + admin_password = fields.Str() + ssh_port = fields.Int() + ssh_private_key_file = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import VirtualMachineSshSettings + + return VirtualMachineSshSettings(**data) + + +class VirtualMachineComputeSchema(ComputeSchema): + type = StringTransformedEnum(allowed_values=[ComputeType.VIRTUALMACHINE], required=True) + resource_id = fields.Str(required=True) + compute_location = fields.Str(dump_only=True) + ssh_settings = NestedField(VirtualMachineSshSettingsSchema) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/vm_size.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/vm_size.py new file mode 100644 index 00000000..79ee8ea7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/vm_size.py @@ -0,0 +1,19 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from marshmallow import fields + +from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta + + +class VmSizeSchema(metaclass=PatchedSchemaMeta): + name = fields.Str() + family = fields.Str() + v_cp_us = fields.Int() + gpus = fields.Int() + os_vhd_size_mb = fields.Int() + max_resource_volume_mb = fields.Int() + memory_gb = fields.Float() + low_priority_capable = fields.Bool() + premium_io = fields.Bool() + supported_compute_types = fields.Str() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/__init__.py new file mode 100644 index 00000000..29a4fcd3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/auto_delete_setting.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/auto_delete_setting.py new file mode 100644 index 00000000..ca2bd2e1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/auto_delete_setting.py @@ -0,0 +1,38 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load +from azure.ai.ml._schema.core.fields import StringTransformedEnum +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.constants._common import AutoDeleteCondition +from azure.ai.ml.entities._assets.auto_delete_setting import AutoDeleteSetting + + +@experimental +class BaseAutoDeleteSettingSchema(metaclass=PatchedSchemaMeta): + @post_load + def make(self, data, **kwargs) -> "AutoDeleteSetting": + return AutoDeleteSetting(**data) + + +@experimental +class AutoDeleteConditionSchema(BaseAutoDeleteSettingSchema): + condition = StringTransformedEnum( + allowed_values=[condition.name for condition in AutoDeleteCondition], + casing_transform=camel_to_snake, + ) + + +@experimental +class ValueSchema(BaseAutoDeleteSettingSchema): + value = fields.Str() + + +@experimental +class AutoDeleteSettingSchema(AutoDeleteConditionSchema, ValueSchema): + pass diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/fields.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/fields.py new file mode 100644 index 00000000..fd7956b8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/fields.py @@ -0,0 +1,1029 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access,too-many-lines + +import copy +import logging +import os +import re +import traceback +import typing +from abc import abstractmethod +from pathlib import Path +from typing import List, Optional, Union + +from marshmallow import RAISE, fields +from marshmallow.exceptions import ValidationError +from marshmallow.fields import Field, Nested +from marshmallow.utils import FieldInstanceResolutionError, from_iso_datetime, resolve_field_instance + +from ..._utils._arm_id_utils import AMLVersionedArmId, is_ARM_id_for_resource, parse_name_label, parse_name_version +from ..._utils._experimental import _is_warning_cached +from ..._utils.utils import is_data_binding_expression, is_valid_node_name, load_file, load_yaml +from ...constants._common import ( + ARM_ID_PREFIX, + AZUREML_RESOURCE_PROVIDER, + BASE_PATH_CONTEXT_KEY, + CONDA_FILE, + DOCKER_FILE_NAME, + EXPERIMENTAL_FIELD_MESSAGE, + EXPERIMENTAL_LINK_MESSAGE, + FILE_PREFIX, + INTERNAL_REGISTRY_URI_FORMAT, + LOCAL_COMPUTE_TARGET, + LOCAL_PATH, + REGISTRY_URI_FORMAT, + RESOURCE_ID_FORMAT, + AzureMLResourceType, + DefaultOpenEncoding, +) +from ...entities._job.pipeline._attr_dict import try_get_non_arbitrary_attr +from ...exceptions import MlException, ValidationException +from ..core.schema import PathAwareSchema + +module_logger = logging.getLogger(__name__) +T = typing.TypeVar("T") + + +class StringTransformedEnum(Field): + def __init__(self, **kwargs): + # pop marshmallow unknown args to avoid warnings + self.allowed_values = kwargs.pop("allowed_values", None) + self.casing_transform = kwargs.pop("casing_transform", lambda x: x.lower()) + self.pass_original = kwargs.pop("pass_original", False) + super().__init__(**kwargs) + if isinstance(self.allowed_values, str): + self.allowed_values = [self.allowed_values] + self.allowed_values = [self.casing_transform(x) for x in self.allowed_values] + + def _jsonschema_type_mapping(self): + schema = {"type": "string", "enum": self.allowed_values} + if self.name is not None: + schema["title"] = self.name + if self.dump_only: + schema["readonly"] = True + return schema + + def _serialize(self, value, attr, obj, **kwargs): + if not value: + return None + if isinstance(value, str) and self.casing_transform(value) in self.allowed_values: + return value if self.pass_original else self.casing_transform(value) + raise ValidationError(f"Value {value!r} passed is not in set {self.allowed_values}") + + def _deserialize(self, value, attr, data, **kwargs): + if isinstance(value, str) and self.casing_transform(value) in self.allowed_values: + return value if self.pass_original else self.casing_transform(value) + raise ValidationError(f"Value {value!r} passed is not in set {self.allowed_values}") + + +class DumpableEnumField(StringTransformedEnum): + def __init__(self, **kwargs): + """Enum field that will raise exception when dumping.""" + kwargs.pop("casing_transform", None) + super(DumpableEnumField, self).__init__(casing_transform=lambda x: x, **kwargs) + + +class LocalPathField(fields.Str): + """A field that validates that the input is a local path. + + Can only be used as fields of PathAwareSchema. + """ + + default_error_messages = { + "invalid_path": "The filename, directory name, or volume label syntax is incorrect.", + "path_not_exist": "Can't find {allow_type} in resolved absolute path: {path}.", + } + + def __init__(self, allow_dir=True, allow_file=True, **kwargs): + self._allow_dir = allow_dir + self._allow_file = allow_file + self._pattern = kwargs.get("pattern", None) + super().__init__() + + def _jsonschema_type_mapping(self): + schema = {"type": "string", "arm_type": LOCAL_PATH} + if self.name is not None: + schema["title"] = self.name + if self.dump_only: + schema["readonly"] = True + if self._pattern: + schema["pattern"] = self._pattern + return schema + + # pylint: disable-next=docstring-missing-param + def _resolve_path(self, value: Union[str, os.PathLike]) -> Path: + """Resolve path to absolute path based on base_path in context. + + Will resolve the path if it's already an absolute path. + + :return: The resolved path + :rtype: Path + """ + try: + result = Path(value) + base_path = Path(self.context[BASE_PATH_CONTEXT_KEY]) + if not result.is_absolute(): + result = base_path / result + + # for non-path string like "azureml:/xxx", OSError can be raised in either + # resolve() or is_dir() or is_file() + result = result.resolve() + if (self._allow_dir and result.is_dir()) or (self._allow_file and result.is_file()): + return result + except OSError as e: + raise self.make_error("invalid_path") from e + raise self.make_error("path_not_exist", path=result.as_posix(), allow_type=self.allowed_path_type) + + @property + def allowed_path_type(self) -> str: + if self._allow_dir and self._allow_file: + return "directory or file" + if self._allow_dir: + return "directory" + return "file" + + def _validate(self, value): + # inherited validations like required, allow_none, etc. + super(LocalPathField, self)._validate(value) + + if value is None: + return + self._resolve_path(value) + + def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[str]: + # do not block serializing None even if required or not allow_none. + if value is None: + return None + # always dump path as absolute path in string as base_path will be dropped after serialization + return super(LocalPathField, self)._serialize(self._resolve_path(value).as_posix(), attr, obj, **kwargs) + + +class SerializeValidatedUrl(fields.Url): + """This field will validate if value is an url during serialization, so that only valid urls can be serialized as + this schema. + + Use this schema instead of fields.Url when unioned with ArmStr or its subclasses like ArmVersionedStr, so that the + field can be serialized correctly after deserialization. azureml:xxx => xxx => azureml:xxx e.g. The field will still + always be serializable as any string can be serialized as an ArmStr. + """ + + def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[str]: + if value is None: + return None + self._validate(value) + return super(SerializeValidatedUrl, self)._serialize(value, attr, obj, **kwargs) + + +class DataBindingStr(fields.Str): + """A string represents a binding to some data in pipeline job, e.g.: parent.jobs.inputs.input1, + parent.jobs.node1.outputs.output1.""" + + def _jsonschema_type_mapping(self): + schema = {"type": "string", "pattern": r"\$\{\{\s*(\S*)\s*\}\}"} + if self.name is not None: + schema["title"] = self.name + if self.dump_only: + schema["readonly"] = True + return schema + + def _serialize(self, value, attr, obj, **kwargs): + # None value handling logic is inside _serialize but outside _validate/_deserialize + if value is None: + return None + + from azure.ai.ml.entities._job.pipeline._io import InputOutputBase + + if isinstance(value, InputOutputBase): + value = str(value) + + self._validate(value) + return super(DataBindingStr, self)._serialize(value, attr, obj, **kwargs) + + def _validate(self, value): + if is_data_binding_expression(value, is_singular=False): + return super(DataBindingStr, self)._validate(value) + raise ValidationError(f"Value passed is not a data binding string: {value}") + + +class NodeBindingStr(DataBindingStr): + """A string represents a binding to some node in pipeline job, e.g.: parent.jobs.node1.""" + + def _serialize(self, value, attr, obj, **kwargs): + # None value handling logic is inside _serialize but outside _validate/_deserialize + if value is None: + return None + + from azure.ai.ml.entities._builders import BaseNode + + if isinstance(value, BaseNode): + value = f"${{{{parent.jobs.{value.name}}}}}" + + self._validate(value) + return super(NodeBindingStr, self)._serialize(value, attr, obj, **kwargs) + + def _validate(self, value): + if is_data_binding_expression(value, is_singular=True): + return super(NodeBindingStr, self)._validate(value) + raise ValidationError(f"Value passed is not a node binding string: {value}") + + +class DateTimeStr(fields.Str): + """A string represents a datetime in ISO8601 format.""" + + def _jsonschema_type_mapping(self): + schema = {"type": "string"} + if self.name is not None: + schema["title"] = self.name + if self.dump_only: + schema["readonly"] = True + return schema + + def _serialize(self, value, attr, obj, **kwargs): + if value is None: + return None + self._validate(value) + return super(DateTimeStr, self)._serialize(value, attr, obj, **kwargs) + + def _validate(self, value): + try: + from_iso_datetime(value) + except Exception as e: + raise ValidationError(f"Not a valid ISO8601-formatted datetime string: {value}") from e + + +class ArmStr(Field): + """A string represents an ARM ID for some AzureML resource.""" + + def __init__(self, **kwargs): + self.azureml_type = kwargs.pop("azureml_type", None) + self.pattern = kwargs.pop("pattern", r"^azureml:.+") + super().__init__(**kwargs) + + def _jsonschema_type_mapping(self): + schema = { + "type": "string", + "pattern": self.pattern, + "arm_type": self.azureml_type, + } + if self.name is not None: + schema["title"] = self.name + if self.dump_only: + schema["readonly"] = True + return schema + + def _serialize(self, value, attr, obj, **kwargs): + if isinstance(value, str): + serialized_value = value if value.startswith(ARM_ID_PREFIX) else f"{ARM_ID_PREFIX}{value}" + return serialized_value + if value is None and not self.required: + return None + raise ValidationError(f"Non-string passed to ArmStr for {attr}") + + def _deserialize(self, value, attr, data, **kwargs): + if isinstance(value, str) and value.startswith(ARM_ID_PREFIX): + name = value[len(ARM_ID_PREFIX) :] + return name + formatted_resource_id = RESOURCE_ID_FORMAT.format( + "<subscription_id>", + "<resource_group>", + AZUREML_RESOURCE_PROVIDER, + "<workspace_name>/", + ) + if self.azureml_type is not None: + azureml_type_suffix = self.azureml_type + else: + azureml_type_suffix = "<asset_type>" + "/<resource_name>/<version-if applicable>)" + raise ValidationError( + f"In order to specify an existing {self.azureml_type if self.azureml_type is not None else 'asset'}, " + "please provide either of the following prefixed with 'azureml:':\n" + "1. The full ARM ID for the resource, e.g." + f"azureml:{formatted_resource_id + azureml_type_suffix}\n" + "2. The short-hand name of the resource registered in the workspace, " + "eg: azureml:<short-hand-name>:<version-if applicable>. " + "For example, version 1 of the environment registered as " + "'my-env' in the workspace can be referenced as 'azureml:my-env:1'" + ) + + +class ArmVersionedStr(ArmStr): + """A string represents an ARM ID for some AzureML resource with version.""" + + def __init__(self, **kwargs): + self.allow_default_version = kwargs.pop("allow_default_version", False) + super().__init__(**kwargs) + + def _deserialize(self, value, attr, data, **kwargs): + arm_id = super()._deserialize(value, attr, data, **kwargs) + try: + AMLVersionedArmId(arm_id) + return arm_id + except ValidationException: + pass + + if is_ARM_id_for_resource(name=arm_id, resource_type=self.azureml_type): + msg = "id for {} is invalid" + raise ValidationError(message=msg.format(attr)) + + try: + name, label = parse_name_label(arm_id) + except ValidationException as e: + # Schema will try to deserialize the value with all possible Schema & catch ValidationError + # So raise ValidationError instead of ValidationException + raise ValidationError(e.message) from e + + version = None + if not label: + name, version = parse_name_version(arm_id) + + if not (label or version): + if self.allow_default_version: + return name + raise ValidationError(f"Either version or label is not provided for {attr} or the id is not valid.") + + if version: + return f"{name}:{version}" + return f"{name}@{label}" + + +class FileRefField(Field): + """A string represents a file reference in pipeline job, e.g.: file:./my_file.txt, file:../my_file.txt,""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def _jsonschema_type_mapping(self): + schema = {"type": "string"} + if self.name is not None: + schema["title"] = self.name + if self.dump_only: + schema["readonly"] = True + return schema + + def _deserialize(self, value, attr, data, **kwargs): + if isinstance(value, str) and not value.startswith(FILE_PREFIX): + base_path = Path(self.context[BASE_PATH_CONTEXT_KEY]) + path = Path(value) + if not path.is_absolute(): + path = base_path / path + path.resolve() + data = load_file(path) + return data + raise ValidationError(f"Not supporting non file for {attr}") + + def _serialize(self, value: typing.Any, attr: str, obj: typing.Any, **kwargs): + raise ValidationError("Serialize on FileRefField is not supported.") + + +class RefField(Field): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def _jsonschema_type_mapping(self): + schema = {"type": "string"} + if self.name is not None: + schema["title"] = self.name + if self.dump_only: + schema["readonly"] = True + return schema + + def _deserialize(self, value, attr, data, **kwargs): + if isinstance(value, str) and ( + value.startswith(FILE_PREFIX) + or (os.path.isdir(value) or os.path.isfile(value)) + or value == DOCKER_FILE_NAME + ): # "Dockerfile" w/o file: prefix doesn't register as a path + if value.startswith(FILE_PREFIX): + value = value[len(FILE_PREFIX) :] + base_path = Path(self.context[BASE_PATH_CONTEXT_KEY]) + + path = Path(value) + if not path.is_absolute(): + path = base_path / path + path.resolve() + if attr == CONDA_FILE: # conda files should be loaded as dictionaries + data = load_yaml(path) + else: + data = load_file(path) + return data + raise ValidationError(f"Not supporting non file for {attr}") + + def _serialize(self, value: typing.Any, attr: str, obj: typing.Any, **kwargs): + raise ValidationError("Serialize on RefField is not supported.") + + +class NestedField(Nested): + """anticipates the default coming in next marshmallow version, unknown=True.""" + + def __init__(self, *args, **kwargs): + if kwargs.get("unknown") is None: + kwargs["unknown"] = RAISE + super().__init__(*args, **kwargs) + + +# Note: Currently contains a bug where the order in which fields are inputted can potentially cause a bug +# Example, the first line below works, but the second one fails upon calling load_from_dict +# with the error " AttributeError: 'list' object has no attribute 'get'" +# inputs = UnionField([fields.List(NestedField(DataSchema)), NestedField(DataSchema)]) +# inputs = UnionField([NestedField(DataSchema), fields.List(NestedField(DataSchema))]) +class UnionField(fields.Field): + """A field that can be one of multiple types.""" + + def __init__(self, union_fields: List[fields.Field], is_strict=False, **kwargs): + super().__init__(**kwargs) + try: + # add the validation and make sure union_fields must be subclasses or instances of + # marshmallow.base.FieldABC + self._union_fields = [resolve_field_instance(cls_or_instance) for cls_or_instance in union_fields] + # TODO: make serialization/de-serialization work in the same way as json schema when is_strict is True + self.is_strict = is_strict # S\When True, combine fields with oneOf instead of anyOf at schema generation + except FieldInstanceResolutionError as error: + raise ValueError( + 'Elements of "union_fields" must be subclasses or instances of marshmallow.base.FieldABC.' + ) from error + + @property + def union_fields(self): + return iter(self._union_fields) + + def insert_union_field(self, field): + self._union_fields.insert(0, field) + + # This sets the parent for the schema and also handles nesting. + def _bind_to_schema(self, field_name, schema): + super()._bind_to_schema(field_name, schema) + self._union_fields = self._create_bind_fields(self._union_fields, field_name) + + def _create_bind_fields(self, _fields, field_name): + new_union_fields = [] + for field in _fields: + field = copy.deepcopy(field) + field._bind_to_schema(field_name, self) + new_union_fields.append(field) + return new_union_fields + + def _serialize(self, value, attr, obj, **kwargs): + if value is None: + return None + errors = [] + for field in self._union_fields: + try: + return field._serialize(value, attr, obj, **kwargs) + + except ValidationError as e: + errors.extend(e.messages) + except (TypeError, ValueError, AttributeError, ValidationException) as e: + errors.extend([str(e)]) + raise ValidationError(message=errors, field_name=attr) + + def _deserialize(self, value, attr, data, **kwargs): + errors = [] + for schema in self._union_fields: + try: + return schema.deserialize(value, attr, data, **kwargs) + except ValidationError as e: + errors.append(e.normalized_messages()) + except ValidationException as e: + # ValidationException is explicitly raised in project code so usually easy to locate with error message + errors.append([str(e)]) + except (FileNotFoundError, TypeError) as e: + # FileNotFoundError and TypeError can be raised in system code, so we need to add more information + # TODO: consider if it's possible to handle those errors in their directly relative + # code instead of in UnionField + trace = traceback.format_exc().splitlines() + if len(trace) >= 3: + errors.append([f"{trace[-1]} from {trace[-3]} {trace[-2]}"]) + else: + errors.append([f"{e.__class__.__name__}: {e}"]) + finally: + # Revert base path to original path when job schema fail to deserialize job. For example, when load + # parallel job with component file reference starting with FILE prefix, maybe first CommandSchema will + # load component yaml according to AnonymousCommandComponentSchema, and YamlFileSchema will update base + # path. When CommandSchema fail to load, then Parallelschema will load component yaml according to + # AnonymousParallelComponentSchema, but base path now is incorrect, and will raise path not found error + # when load component yaml file. + if ( + hasattr(schema, "name") + and schema.name == "jobs" + and hasattr(schema, "schema") + and isinstance(schema.schema, PathAwareSchema) + ): + # use old base path to recover original base path + schema.schema.context[BASE_PATH_CONTEXT_KEY] = schema.schema.old_base_path + # recover base path of parent schema + schema.context[BASE_PATH_CONTEXT_KEY] = schema.schema.context[BASE_PATH_CONTEXT_KEY] + raise ValidationError(errors, field_name=attr) + + +class TypeSensitiveUnionField(UnionField): + """Union field which will try to simplify error messages based on type field in failed + serialization/deserialization. + + If value doesn't have type, will skip error messages from fields with type field If value has type & its type + doesn't match any allowed types, raise "Value {} not in set {}" If value has type & its type matches at least 1 + allowed value, it will raise the first matched error. + """ + + def __init__( + self, + type_sensitive_fields_dict: typing.Dict[str, List[fields.Field]], + *, + plain_union_fields: Optional[List[fields.Field]] = None, + allow_load_from_file: bool = True, + type_field_name="type", + **kwargs, + ): + """param type_sensitive_fields_dict: a dict of type name to list of + type sensitive fields param plain_union_fields: list of fields that + will be used if value doesn't have type field type plain_union_fields: + List[fields.Field] param allow_load_from_file: whether to allow load + from file, default to True type allow_load_from_file: bool param + type_field_name: field name of type field, default value is "type" type + type_field_name: str.""" + self._type_sensitive_fields_dict = {} + self._allow_load_from_yaml = allow_load_from_file + + union_fields = plain_union_fields or [] + for type_name, type_sensitive_fields in type_sensitive_fields_dict.items(): + union_fields.extend(type_sensitive_fields) + self._type_sensitive_fields_dict[type_name] = [ + resolve_field_instance(cls_or_instance) for cls_or_instance in type_sensitive_fields + ] + + super(TypeSensitiveUnionField, self).__init__(union_fields, **kwargs) + self._type_field_name = type_field_name + + def _bind_to_schema(self, field_name, schema): + super()._bind_to_schema(field_name, schema) + for ( + type_name, + type_sensitive_fields, + ) in self._type_sensitive_fields_dict.items(): + self._type_sensitive_fields_dict[type_name] = self._create_bind_fields(type_sensitive_fields, field_name) + + @property + def type_field_name(self) -> str: + return self._type_field_name + + @property + def allowed_types(self) -> List[str]: + return list(self._type_sensitive_fields_dict.keys()) + + # pylint: disable-next=docstring-missing-param + def insert_type_sensitive_field(self, type_name, field): + """Insert a new type sensitive field for a specific type.""" + if type_name not in self._type_sensitive_fields_dict: + self._type_sensitive_fields_dict[type_name] = [] + self._type_sensitive_fields_dict[type_name].insert(0, field) + self.insert_union_field(field) + + # pylint: disable-next=docstring-missing-param + def _simplified_error_base_on_type(self, e, value, attr) -> Exception: + """Returns a simplified error based on value type + + :return: Returns + * e if value doesn't havetype + * ValidationError("Value {} not in set {}") if value type not in allowed types + * First Matched Error message if value has type and type matches atleast one field + :rtype: Exception + """ + value_type = try_get_non_arbitrary_attr(value, self.type_field_name) + if value_type is None: + # if value has no type field, raise original error + return e + if value_type not in self.allowed_types: + # if value has type field but its value doesn't match any allowed value, raise ValidationError directly + return ValidationError( + message={self.type_field_name: f"Value {value_type!r} passed is not in set {self.allowed_types}"}, + field_name=attr, + ) + filtered_messages = [] + # if value has type field and its value match at least 1 allowed value, raise first matched + for error in e.messages: + # for non-nested schema, their error message will be {"_schema": ["xxx"]} + if len(error) == 1 and "_schema" in error: + continue + # for nested schema, type field won't be within error only if type field value is matched + # then return first matched error message + if self.type_field_name in error: + continue + filtered_messages.append(error) + + if len(filtered_messages) == 0: + # shouldn't happen + return e + # TODO: consider if we should keep all filtered messages + return ValidationError(message=filtered_messages[0], field_name=attr) + + def _serialize(self, value, attr, obj, **kwargs): + union_fields = self._union_fields[:] + value_type = try_get_non_arbitrary_attr(value, self.type_field_name) + if value_type is not None and value_type in self.allowed_types: + target_fields = self._type_sensitive_fields_dict[value_type] + if len(target_fields) == 1: + return target_fields[0]._serialize(value, attr, obj, **kwargs) + self._union_fields = target_fields + + try: + return super(TypeSensitiveUnionField, self)._serialize(value, attr, obj, **kwargs) + except ValidationError as e: + raise self._simplified_error_base_on_type(e, value, attr) + finally: + self._union_fields = union_fields + + def _try_load_from_yaml(self, value): + target_path = value + if target_path.startswith(FILE_PREFIX): + target_path = target_path[len(FILE_PREFIX) :] + try: + import yaml + + base_path = Path(self.context[BASE_PATH_CONTEXT_KEY]) + target_path = Path(target_path) + if not target_path.is_absolute(): + target_path = base_path / target_path + target_path.resolve() + if target_path.is_file(): + self.context[BASE_PATH_CONTEXT_KEY] = target_path.parent + with target_path.open(encoding=DefaultOpenEncoding.READ) as f: + return yaml.safe_load(f) + except Exception: # pylint: disable=W0718 + pass + return value + + def _deserialize(self, value, attr, data, **kwargs): + try: + return super(TypeSensitiveUnionField, self)._deserialize(value, attr, data, **kwargs) + except ValidationError as e: + if isinstance(value, str) and self._allow_load_from_yaml: + value = self._try_load_from_yaml(value) + raise self._simplified_error_base_on_type(e, value, attr) + + +def ComputeField(**kwargs) -> Field: + """ + :return: The compute field + :rtype: Field + """ + return UnionField( + [ + StringTransformedEnum(allowed_values=[LOCAL_COMPUTE_TARGET]), + ArmStr(azureml_type=AzureMLResourceType.COMPUTE), + # Case for virtual clusters + fields.Str(), + ], + metadata={"description": "The compute resource."}, + **kwargs, + ) + + +def CodeField(**kwargs) -> Field: + """ + :return: The code field + :rtype: Field + """ + return UnionField( + [ + LocalPathField(), + SerializeValidatedUrl(), + GitStr(), + RegistryStr(azureml_type=AzureMLResourceType.CODE), + InternalRegistryStr(azureml_type=AzureMLResourceType.CODE), + # put arm versioned string at last order as it can deserialize any string into "azureml:<origin>" + ArmVersionedStr(azureml_type=AzureMLResourceType.CODE), + ], + metadata={"description": "A local path or http:, https:, azureml: url pointing to a remote location."}, + **kwargs, + ) + + +def EnvironmentField(*, extra_fields: List[Field] = None, **kwargs): + """Function to return a union field for environment. + + :keyword extra_fields: Extra fields to be added to the union field + :paramtype extra_fields: List[Field] + :return: The environment field + :rtype: Field + """ + extra_fields = extra_fields or [] + # local import to avoid circular dependency + from azure.ai.ml._schema.assets.environment import AnonymousEnvironmentSchema + + return UnionField( + [ + NestedField(AnonymousEnvironmentSchema), + RegistryStr(azureml_type=AzureMLResourceType.ENVIRONMENT), + ArmVersionedStr(azureml_type=AzureMLResourceType.ENVIRONMENT, allow_default_version=True), + ] + + extra_fields, + **kwargs, + ) + + +def DistributionField(**kwargs): + """Function to return a union field for distribution. + + :return: The distribution field + :rtype: Field + """ + from azure.ai.ml._schema.job.distribution import ( + MPIDistributionSchema, + PyTorchDistributionSchema, + RayDistributionSchema, + TensorFlowDistributionSchema, + ) + + return UnionField( + [ + NestedField(PyTorchDistributionSchema, **kwargs), + NestedField(TensorFlowDistributionSchema, **kwargs), + NestedField(MPIDistributionSchema, **kwargs), + ExperimentalField(NestedField(RayDistributionSchema, **kwargs)), + ] + ) + + +def PrimitiveValueField(**kwargs): + """Function to return a union field for primitive value. + + :return: The primitive value field + :rtype: Field + """ + return UnionField( + [ + # Note: order matters here - to make sure value parsed correctly. + # By default when strict is false, marshmallow downcasts float to int. + # Setting it to true will throw a validation error when loading a float to int. + # https://github.com/marshmallow-code/marshmallow/pull/755 + # Use DumpableIntegerField to make sure there will be validation error when + # loading/dumping a float to int. + # note that this field can serialize bool instance but cannot deserialize bool instance. + DumpableIntegerField(strict=True), + # Use DumpableFloatField with strict of True to avoid '1'(str) serialized to 1.0(float) + DumpableFloatField(strict=True), + # put string schema after Int and Float to make sure they won't dump to string + fields.Str(), + # fields.Bool comes last since it'll parse anything non-falsy to True + fields.Bool(), + ], + **kwargs, + ) + + +class VersionField(Field): + """A string represents a version, e.g.: 1, 1.0, 1.0.0. + Will always convert to string to ensure that "1.0" won't be converted to 1. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def _jsonschema_type_mapping(self): + schema = {"anyOf": [{"type": "string"}, {"type": "integer"}]} + if self.name is not None: + schema["title"] = self.name + if self.dump_only: + schema["readonly"] = True + return schema + + def _deserialize(self, value, attr, data, **kwargs) -> str: + if isinstance(value, str): + return value + if isinstance(value, (int, float)): + return str(value) + msg = f"Type {type(value)} is not supported for version." + raise MlException(message=msg, no_personal_data_message=msg) + + +class NumberVersionField(VersionField): + """A string represents a version, e.g.: 1, 1.0, 1.0.0. + Will always convert to string to ensure that "1.0" won't be converted to 1. + """ + + default_error_messages = { + "max_version": "Version {input} is greater than or equal to upper bound {bound}.", + "min_version": "Version {input} is smaller than lower bound {bound}.", + "invalid": "Number version must be integers concatenated by '.', like 1.0.1.", + } + + def __init__(self, *args, upper_bound: Optional[str] = None, lower_bound: Optional[str] = None, **kwargs) -> None: + self._upper = None if upper_bound is None else self._version_to_tuple(upper_bound) + self._lower = None if lower_bound is None else self._version_to_tuple(lower_bound) + super().__init__(*args, **kwargs) + + def _version_to_tuple(self, value: str): + try: + return tuple(int(v) for v in str(value).split(".")) + except ValueError as e: + raise self.make_error("invalid") from e + + def _validate(self, value): + super()._validate(value) + value_tuple = self._version_to_tuple(value) + if self._upper is not None and value_tuple >= self._upper: + raise self.make_error("max_version", input=value, bound=self._upper) + if self._lower is not None and value_tuple < self._lower: + raise self.make_error("min_version", input=value, bound=self._lower) + + +class DumpableIntegerField(fields.Integer): + """A int field that cannot serialize other type of values to int if self.strict.""" + + def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[typing.Union[str, T]]: + if self.strict and not isinstance(value, int): + # this implementation can serialize bool to bool + raise self.make_error("invalid", input=value) + return super()._serialize(value, attr, obj, **kwargs) + + +class DumpableFloatField(fields.Float): + """A float field that cannot serialize other type of values to float if self.strict.""" + + def __init__( + self, + *, + strict: bool = False, + allow_nan: bool = False, + as_string: bool = False, + **kwargs, + ): + self.strict = strict + super().__init__(allow_nan=allow_nan, as_string=as_string, **kwargs) + + def _validated(self, value): + if self.strict and not isinstance(value, float): + raise self.make_error("invalid", input=value) + return super()._validated(value) + + def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[typing.Union[str, T]]: + return super()._serialize(self._validated(value), attr, obj, **kwargs) + + +class DumpableStringField(fields.String): + """A string field that cannot serialize other type of values to string if self.strict.""" + + def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[typing.Union[str, T]]: + if not isinstance(value, str): + raise ValidationError("Given value is not a string") + return super()._serialize(value, attr, obj, **kwargs) + + +class ExperimentalField(fields.Field): + def __init__(self, experimental_field: fields.Field, **kwargs): + super().__init__(**kwargs) + try: + self._experimental_field = resolve_field_instance(experimental_field) + self.required = experimental_field.required + except FieldInstanceResolutionError as error: + raise ValueError( + '"experimental_field" must be subclasses or instances of marshmallow.base.FieldABC.' + ) from error + + @property + def experimental_field(self): + return self._experimental_field + + # This sets the parent for the schema and also handles nesting. + def _bind_to_schema(self, field_name, schema): + super()._bind_to_schema(field_name, schema) + self._experimental_field._bind_to_schema(field_name, schema) + + def _serialize(self, value, attr, obj, **kwargs): + if value is None: + return None + return self._experimental_field._serialize(value, attr, obj, **kwargs) + + def _deserialize(self, value, attr, data, **kwargs): + if value is not None: + message = "Field '{0}': {1} {2}".format(attr, EXPERIMENTAL_FIELD_MESSAGE, EXPERIMENTAL_LINK_MESSAGE) + if not _is_warning_cached(message): + module_logger.warning(message) + + return self._experimental_field._deserialize(value, attr, data, **kwargs) + + +class RegistryStr(Field): + """A string represents a registry ID for some AzureML resource.""" + + def __init__(self, **kwargs): + self.azureml_type = kwargs.pop("azureml_type", None) + super().__init__(**kwargs) + + def _jsonschema_type_mapping(self): + schema = { + "type": "string", + "pattern": "^azureml://registries/.*", + "arm_type": self.azureml_type, + } + if self.name is not None: + schema["title"] = self.name + if self.dump_only: + schema["readonly"] = True + return schema + + def _serialize(self, value, attr, obj, **kwargs): + if isinstance(value, str) and value.startswith(REGISTRY_URI_FORMAT): + return f"{value}" + if value is None and not self.required: + return None + raise ValidationError(f"Non-string passed to RegistryStr for {attr}") + + def _deserialize(self, value, attr, data, **kwargs): + if isinstance(value, str) and value.startswith(REGISTRY_URI_FORMAT): + return value + raise ValidationError( + f"In order to specify an existing {self.azureml_type}, " + "please provide the correct registry path prefixed with 'azureml://':\n" + ) + + +class InternalRegistryStr(RegistryStr): + """A string represents a registry ID for some internal AzureML resource.""" + + def _jsonschema_type_mapping(self): + schema = super()._jsonschema_type_mapping() + schema["pattern"] = "^azureml://feeds/.*" + return schema + + def _deserialize(self, value, attr, data, **kwargs): + if isinstance(value, str) and value.startswith(INTERNAL_REGISTRY_URI_FORMAT): + value = value.replace(INTERNAL_REGISTRY_URI_FORMAT, REGISTRY_URI_FORMAT, 1) + return super()._deserialize(value, attr, data, **kwargs) + + +class PythonFuncNameStr(fields.Str): + """A string represents a python function name.""" + + @abstractmethod + def _get_field_name(self) -> str: + """Returns field name, used for error message.""" + + # pylint: disable-next=docstring-missing-param + def _deserialize(self, value, attr, data, **kwargs) -> str: + """Validate component name. + + :return: The component name + :rtype: str + """ + name = super()._deserialize(value, attr, data, **kwargs) + pattern = r"^[a-z][a-z\d_]*$" + if not re.match(pattern, name): + raise ValidationError( + f"{self._get_field_name()} name should only contain " + "lower letter, number, underscore and start with a lower letter. " + f"Currently got {name}." + ) + return name + + +class PipelineNodeNameStr(fields.Str): + """A string represents a pipeline node name.""" + + @abstractmethod + def _get_field_name(self) -> str: + """Returns field name, used for error message.""" + + # pylint: disable-next=docstring-missing-param + def _deserialize(self, value, attr, data, **kwargs) -> str: + """Validate component name. + + :return: The component name + :rtype: str + """ + name = super()._deserialize(value, attr, data, **kwargs) + if not is_valid_node_name(name): + raise ValidationError( + f"{self._get_field_name()} name should be a valid python identifier" + "(lower letters, numbers, underscore and start with a letter or underscore). " + "Currently got {name}." + ) + return name + + +class GitStr(fields.Str): + """A string represents a git path.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def _jsonschema_type_mapping(self): + schema = {"type": "string", "pattern": "^git+"} + if self.name is not None: + schema["title"] = self.name + if self.dump_only: + schema["readonly"] = True + return schema + + def _serialize(self, value, attr, obj, **kwargs): + if isinstance(value, str) and value.startswith("git+"): + return f"{value}" + if value is None and not self.required: + return None + raise ValidationError(f"Non-string passed to GitStr for {attr}") + + def _deserialize(self, value, attr, data, **kwargs): + if isinstance(value, str) and value.startswith("git+"): + return value + raise ValidationError("In order to specify a git path, please provide the correct path prefixed with 'git+\n") diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/intellectual_property.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/intellectual_property.py new file mode 100644 index 00000000..2ae47130 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/intellectual_property.py @@ -0,0 +1,38 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load +from azure.ai.ml._schema.core.fields import StringTransformedEnum +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.constants._assets import IPProtectionLevel +from azure.ai.ml.entities._assets.intellectual_property import IntellectualProperty + + +@experimental +class BaseIntellectualPropertySchema(metaclass=PatchedSchemaMeta): + @post_load + def make(self, data, **kwargs) -> "IntellectualProperty": + return IntellectualProperty(**data) + + +@experimental +class ProtectionLevelSchema(BaseIntellectualPropertySchema): + protection_level = StringTransformedEnum( + allowed_values=[level.name for level in IPProtectionLevel], + casing_transform=camel_to_snake, + ) + + +@experimental +class PublisherSchema(BaseIntellectualPropertySchema): + publisher = fields.Str() + + +@experimental +class IntellectualPropertySchema(ProtectionLevelSchema, PublisherSchema): + pass diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/resource.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/resource.py new file mode 100644 index 00000000..dbbc6f63 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/resource.py @@ -0,0 +1,51 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument,protected-access + +import logging + +from marshmallow import fields, post_dump, post_load, pre_dump + +from ...constants._common import BASE_PATH_CONTEXT_KEY +from .schema import YamlFileSchema + +module_logger = logging.getLogger(__name__) + + +class ResourceSchema(YamlFileSchema): + name = fields.Str(attribute="name") + id = fields.Str(attribute="id") + description = fields.Str(attribute="description") + tags = fields.Dict(keys=fields.Str, attribute="tags") + + @post_load(pass_original=True) + def pass_source_path(self, data, original, **kwargs): + path = self._resolve_path(original, base_path=self._previous_base_path) + if path is not None: + from ...entities import Resource + + if isinstance(data, dict): + # data will be used in Resource.__init__ + data["source_path"] = path.as_posix() + elif isinstance(data, Resource): + # some resource will make dict into object in their post_load + # not sure if it's a better way to unify them + data._source_path = path + return data + + @pre_dump + def update_base_path_pre_dump(self, data, **kwargs): + # inherit from parent if base_path is not set + if data.base_path: + self._previous_base_path = self.context[BASE_PATH_CONTEXT_KEY] + self.context[BASE_PATH_CONTEXT_KEY] = data.base_path + return data + + @post_dump + def reset_base_path_post_dump(self, data, **kwargs): + if self._previous_base_path is not None: + # pop state + self.context[BASE_PATH_CONTEXT_KEY] = self._previous_base_path + return data diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/schema.py new file mode 100644 index 00000000..062575bc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/schema.py @@ -0,0 +1,123 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import copy +import logging +from pathlib import Path +from typing import Optional + +from marshmallow import fields, post_load, pre_load +from pydash import objects + +from azure.ai.ml._schema.core.schema_meta import PatchedBaseSchema, PatchedSchemaMeta +from azure.ai.ml._utils.utils import load_yaml +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, FILE_PREFIX, PARAMS_OVERRIDE_KEY +from azure.ai.ml.exceptions import MlException + +module_logger = logging.getLogger(__name__) + + +class PathAwareSchema(PatchedBaseSchema, metaclass=PatchedSchemaMeta): + schema_ignored = fields.Str(data_key="$schema", dump_only=True) + + def __init__(self, *args, **kwargs): + # this will make context of all PathAwareSchema child class point to one object + self.context = kwargs.get("context", None) + if self.context is None or self.context.get(BASE_PATH_CONTEXT_KEY, None) is None: + msg = "Base path for reading files is required when building PathAwareSchema" + raise MlException(message=msg, no_personal_data_message=msg) + # set old base path, note it's an Path object and point to the same object with + # self.context.get(BASE_PATH_CONTEXT_KEY) + self.old_base_path = self.context.get(BASE_PATH_CONTEXT_KEY) + super().__init__(*args, **kwargs) + + @pre_load + def add_param_overrides(self, data, **kwargs): + # Removing params override from context so that overriding is done once on the yaml + # child schema should not override the params. + params_override = self.context.pop(PARAMS_OVERRIDE_KEY, None) + if params_override is not None: + for override in params_override: + for param, val in override.items(): + # Check that none of the intermediary levels are string references (azureml/file) + param_tokens = param.split(".") + test_layer = data + for layer in param_tokens: + if test_layer is None: + continue + if isinstance(test_layer, str): + msg = f"Cannot use '--set' on properties defined by reference strings: --set {param}" + raise MlException( + message=msg, + no_personal_data_message=msg, + ) + test_layer = test_layer.get(layer, None) + objects.set_(data, param, val) + return data + + @pre_load + # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype + def trim_dump_only(self, data, **kwargs): + """Marshmallow raises if dump_only fields are present in the schema. This is not desirable for our use case, + where read-only properties can be present in the yaml, and should simply be ignored, while we should raise in. + + the case an unknown field is present - to prevent typos. + """ + if isinstance(data, str) or data is None: + return data + for key, value in self.fields.items(): + if value.dump_only: + schema_key = value.data_key or key + if data.get(schema_key, None) is not None: + data.pop(schema_key) + return data + + +class YamlFileSchema(PathAwareSchema): + """Base class that allows derived classes to be built from paths to separate yaml files in place of inline yaml + definitions. + + This will be transparent to any parent schema containing a nested schema of the derived class, it will not need a + union type for the schema, a YamlFile string will be resolved by the pre_load method into a dictionary. On loading + the child yaml, update the base path to use for loading sub-child files. + """ + + def __init__(self, *args, **kwargs): + self._previous_base_path = None + super().__init__(*args, **kwargs) + + @classmethod + def _resolve_path(cls, data, base_path) -> Optional[Path]: + if isinstance(data, str) and data.startswith(FILE_PREFIX): + # Use directly if absolute path + path = Path(data[len(FILE_PREFIX) :]) + if not path.is_absolute(): + path = Path(base_path) / path + path.resolve() + return path + return None + + @pre_load + def load_from_file(self, data, **kwargs): + path = self._resolve_path(data, Path(self.context[BASE_PATH_CONTEXT_KEY])) + if path is not None: + self._previous_base_path = Path(self.context[BASE_PATH_CONTEXT_KEY]) + # Push update + # deepcopy self.context[BASE_PATH_CONTEXT_KEY] to update old base path + self.old_base_path = copy.deepcopy(self.context[BASE_PATH_CONTEXT_KEY]) + self.context[BASE_PATH_CONTEXT_KEY] = path.parent + + data = load_yaml(path) + return data + return data + + # Schemas are read depth-first, so push/pop to update current path + @post_load + def reset_base_path_post_load(self, data, **kwargs): + if self._previous_base_path is not None: + # pop state + self.context[BASE_PATH_CONTEXT_KEY] = self._previous_base_path + return data diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/schema_meta.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/schema_meta.py new file mode 100644 index 00000000..d352137c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/schema_meta.py @@ -0,0 +1,53 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging +from collections import OrderedDict + +from marshmallow import RAISE +from marshmallow.decorators import post_dump +from marshmallow.schema import Schema, SchemaMeta + +module_logger = logging.getLogger(__name__) + + +class PatchedMeta: + ordered = True + unknown = RAISE + + +class PatchedBaseSchema(Schema): + class Meta: + unknown = RAISE + ordered = True + + @post_dump + # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype + def remove_none(self, data, **kwargs): + """Prevents from dumping attributes that are None, thus making the dump more compact.""" + return OrderedDict((key, value) for key, value in data.items() if value is not None) + + +class PatchedSchemaMeta(SchemaMeta): + """Currently there is an open issue in marshmallow, that the "unknown" property is not inherited. + + We use a metaclass to inject a Meta class into all our Schema classes. + """ + + def __new__(mcs, name, bases, dct): + meta = dct.get("Meta") + if meta is None: + dct["Meta"] = PatchedMeta + else: + if not hasattr(meta, "unknown"): + dct["Meta"].unknown = RAISE + if not hasattr(meta, "ordered"): + dct["Meta"].ordered = True + + if PatchedBaseSchema not in bases: + bases = bases + (PatchedBaseSchema,) + klass = super().__new__(mcs, name, bases, dct) + return klass diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/identity.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/identity.py new file mode 100644 index 00000000..24cc357c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/identity.py @@ -0,0 +1,63 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import ValidationError, fields, post_load, pre_dump, validates + +from azure.ai.ml._schema.core.fields import StringTransformedEnum +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml._vendor.azure_resources.models._resource_management_client_enums import ResourceIdentityType +from azure.ai.ml.entities._credentials import IdentityConfiguration, ManagedIdentityConfiguration + + +class IdentitySchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum( + allowed_values=[ + ResourceIdentityType.SYSTEM_ASSIGNED, + ResourceIdentityType.USER_ASSIGNED, + ResourceIdentityType.NONE, + # ResourceIdentityType.SYSTEM_ASSIGNED_USER_ASSIGNED, # This is for post PuPr + ], + casing_transform=camel_to_snake, + metadata={"description": "resource identity type."}, + ) + principal_id = fields.Str() + tenant_id = fields.Str() + user_assigned_identities = fields.List(fields.Dict(keys=fields.Str(), values=fields.Str())) + + @validates("user_assigned_identities") + def validate_user_assigned_identities(self, data, **kwargs): + if len(data) > 1: + raise ValidationError(f"Only 1 user assigned identity is currently supported, {len(data)} found") + + @post_load + def make(self, data, **kwargs): + user_assigned_identities_list = [] + user_assigned_identities = data.pop("user_assigned_identities", None) + if user_assigned_identities: + for identity in user_assigned_identities: + user_assigned_identities_list.append( + ManagedIdentityConfiguration( + resource_id=identity.get("resource_id", None), + client_id=identity.get("client_id", None), + object_id=identity.get("object_id", None), + ) + ) + data["user_assigned_identities"] = user_assigned_identities_list + return IdentityConfiguration(**data) + + @pre_dump + def predump(self, data, **kwargs): + if data.user_assigned_identities: + ids = [] + for _id in data.user_assigned_identities: + item = {} + item["resource_id"] = _id.resource_id + item["principal_id"] = _id.principal_id + item["client_id"] = _id.client_id + ids.append(item) + data.user_assigned_identities = ids + return data diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/__init__.py new file mode 100644 index 00000000..11687396 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/__init__.py @@ -0,0 +1,28 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore + +from azure.ai.ml._schema.job.creation_context import CreationContextSchema + +from .base_job import BaseJobSchema +from .command_job import CommandJobSchema +from .import_job import ImportJobSchema +from .parallel_job import ParallelJobSchema +from .parameterized_command import ParameterizedCommandSchema +from .parameterized_parallel import ParameterizedParallelSchema +from .parameterized_spark import ParameterizedSparkSchema +from .spark_job import SparkJobSchema + +__all__ = [ + "BaseJobSchema", + "ParameterizedCommandSchema", + "ParameterizedParallelSchema", + "CommandJobSchema", + "ImportJobSchema", + "SparkJobSchema", + "ParallelJobSchema", + "CreationContextSchema", + "ParameterizedSparkSchema", +] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/base_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/base_job.py new file mode 100644 index 00000000..852d3921 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/base_job.py @@ -0,0 +1,69 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import logging + +from marshmallow import fields + +from azure.ai.ml._schema.core.fields import ArmStr, ComputeField, NestedField, UnionField +from azure.ai.ml._schema.core.resource import ResourceSchema +from azure.ai.ml._schema.job.identity import AMLTokenIdentitySchema, ManagedIdentitySchema, UserIdentitySchema +from azure.ai.ml.constants._common import AzureMLResourceType + +from .creation_context import CreationContextSchema +from .services import ( + JobServiceSchema, + SshJobServiceSchema, + VsCodeJobServiceSchema, + TensorBoardJobServiceSchema, + JupyterLabJobServiceSchema, +) + +module_logger = logging.getLogger(__name__) + + +class BaseJobSchema(ResourceSchema): + creation_context = NestedField(CreationContextSchema, dump_only=True) + services = fields.Dict( + keys=fields.Str(), + values=UnionField( + [ + NestedField(SshJobServiceSchema), + NestedField(TensorBoardJobServiceSchema), + NestedField(VsCodeJobServiceSchema), + NestedField(JupyterLabJobServiceSchema), + # JobServiceSchema should be the last in the list. + # To support types not set by users like Custom, Tracking, Studio. + NestedField(JobServiceSchema), + ], + is_strict=True, + ), + ) + name = fields.Str() + id = ArmStr(azureml_type=AzureMLResourceType.JOB, dump_only=True, required=False) + display_name = fields.Str(required=False) + tags = fields.Dict(keys=fields.Str(), values=fields.Str(allow_none=True)) + status = fields.Str(dump_only=True) + experiment_name = fields.Str() + properties = fields.Dict(keys=fields.Str(), values=fields.Str(allow_none=True)) + description = fields.Str() + log_files = fields.Dict( + keys=fields.Str(), + values=fields.Str(), + dump_only=True, + metadata={ + "description": ( + "The list of log files associated with this run. This section is only populated " + "by the service and will be ignored if contained in a yaml sent to the service " + "(e.g. via `az ml job create` ...)" + ) + }, + ) + compute = ComputeField(required=False) + identity = UnionField( + [ + NestedField(ManagedIdentitySchema), + NestedField(AMLTokenIdentitySchema), + NestedField(UserIdentitySchema), + ] + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/command_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/command_job.py new file mode 100644 index 00000000..9cce7de7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/command_job.py @@ -0,0 +1,23 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields + +from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum +from azure.ai.ml._schema.job.input_output_fields_provider import InputsField, OutputsField +from azure.ai.ml.constants import JobType + +from .base_job import BaseJobSchema +from .job_limits import CommandJobLimitsSchema +from .parameterized_command import ParameterizedCommandSchema + + +class CommandJobSchema(ParameterizedCommandSchema, BaseJobSchema): + type = StringTransformedEnum(allowed_values=JobType.COMMAND) + # do not promote it as CommandComponent has no field named 'limits' + limits = NestedField(CommandJobLimitsSchema) + parameters = fields.Dict(dump_only=True) + inputs = InputsField() + outputs = OutputsField() + parent_job_name = fields.Str() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/creation_context.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/creation_context.py new file mode 100644 index 00000000..79956e1c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/creation_context.py @@ -0,0 +1,16 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields + +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + + +class CreationContextSchema(metaclass=PatchedSchemaMeta): + created_at = fields.DateTime() + created_by = fields.Str() + created_by_type = fields.Str() + last_modified_at = fields.DateTime() + last_modified_by = fields.Str() + last_modified_by_type = fields.Str() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/data_transfer_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/data_transfer_job.py new file mode 100644 index 00000000..6ea54df6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/data_transfer_job.py @@ -0,0 +1,60 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import validates, ValidationError, fields +from azure.ai.ml._schema.core.fields import NestedField +from azure.ai.ml._schema.job.input_output_fields_provider import InputsField, OutputsField +from azure.ai.ml._schema.job.input_output_entry import DatabaseSchema, FileSystemSchema, OutputSchema +from azure.ai.ml.constants import JobType +from azure.ai.ml.constants._component import DataTransferTaskType, DataCopyMode + +from ..core.fields import ComputeField, StringTransformedEnum, UnionField +from .base_job import BaseJobSchema + + +class DataTransferCopyJobSchema(BaseJobSchema): + type = StringTransformedEnum(required=True, allowed_values=JobType.DATA_TRANSFER) + task = StringTransformedEnum(allowed_values=[DataTransferTaskType.COPY_DATA], required=True) + data_copy_mode = StringTransformedEnum( + allowed_values=[DataCopyMode.MERGE_WITH_OVERWRITE, DataCopyMode.FAIL_IF_CONFLICT] + ) + compute = ComputeField() + inputs = InputsField() + outputs = OutputsField() + + +class DataTransferImportJobSchema(BaseJobSchema): + type = StringTransformedEnum(required=True, allowed_values=JobType.DATA_TRANSFER) + task = StringTransformedEnum(allowed_values=[DataTransferTaskType.IMPORT_DATA], required=True) + compute = ComputeField() + outputs = fields.Dict( + keys=fields.Str(), + values=NestedField(nested=OutputSchema, allow_none=False), + metadata={"description": "Outputs of a data transfer job."}, + ) + source = UnionField([NestedField(DatabaseSchema), NestedField(FileSystemSchema)], required=True, allow_none=False) + + @validates("outputs") + def outputs_key(self, value): + if len(value) != 1 or list(value.keys())[0] != "sink": + raise ValidationError( + f"outputs field only support one output called sink in task type " + f"{DataTransferTaskType.IMPORT_DATA}." + ) + + +class DataTransferExportJobSchema(BaseJobSchema): + type = StringTransformedEnum(required=True, allowed_values=JobType.DATA_TRANSFER) + task = StringTransformedEnum(allowed_values=[DataTransferTaskType.EXPORT_DATA], required=True) + compute = ComputeField() + inputs = InputsField(allow_none=False) + sink = UnionField([NestedField(DatabaseSchema), NestedField(FileSystemSchema)], required=True, allow_none=False) + + @validates("inputs") + def inputs_key(self, value): + if len(value) != 1 or list(value.keys())[0] != "source": + raise ValidationError( + f"inputs field only support one input called source in task type " + f"{DataTransferTaskType.EXPORT_DATA}." + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/distribution.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/distribution.py new file mode 100644 index 00000000..475792a3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/distribution.py @@ -0,0 +1,104 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging + +from marshmallow import ValidationError, fields, post_load, pre_dump + +from azure.ai.ml._schema.core.fields import StringTransformedEnum +from azure.ai.ml.constants import DistributionType +from azure.ai.ml._utils._experimental import experimental + +from ..core.schema import PatchedSchemaMeta + +module_logger = logging.getLogger(__name__) + + +class MPIDistributionSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum(required=True, allowed_values=DistributionType.MPI) + process_count_per_instance = fields.Int() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml import MpiDistribution + + data.pop("type", None) + return MpiDistribution(**data) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml import MpiDistribution + + if not isinstance(data, MpiDistribution): + raise ValidationError("Cannot dump non-MpiDistribution object into MpiDistributionSchema") + return data + + +class TensorFlowDistributionSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum(required=True, allowed_values=DistributionType.TENSORFLOW) + parameter_server_count = fields.Int() + worker_count = fields.Int() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml import TensorFlowDistribution + + data.pop("type", None) + return TensorFlowDistribution(**data) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml import TensorFlowDistribution + + if not isinstance(data, TensorFlowDistribution): + raise ValidationError("Cannot dump non-TensorFlowDistribution object into TensorFlowDistributionSchema") + return data + + +class PyTorchDistributionSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum(required=True, allowed_values=DistributionType.PYTORCH) + process_count_per_instance = fields.Int() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml import PyTorchDistribution + + data.pop("type", None) + return PyTorchDistribution(**data) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml import PyTorchDistribution + + if not isinstance(data, PyTorchDistribution): + raise ValidationError("Cannot dump non-PyTorchDistribution object into PyTorchDistributionSchema") + return data + + +@experimental +class RayDistributionSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum(required=True, allowed_values=DistributionType.RAY) + port = fields.Int() + address = fields.Str() + include_dashboard = fields.Bool() + dashboard_port = fields.Int() + head_node_additional_args = fields.Str() + worker_node_additional_args = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml import RayDistribution + + data.pop("type", None) + return RayDistribution(**data) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml import RayDistribution + + if not isinstance(data, RayDistribution): + raise ValidationError("Cannot dump non-RayDistribution object into RayDistributionSchema") + return data diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/identity.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/identity.py new file mode 100644 index 00000000..2f2be676 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/identity.py @@ -0,0 +1,67 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging + +from marshmallow import fields, post_load + +from azure.ai.ml._restclient.v2023_04_01_preview.models import ( + ConnectionAuthType, + IdentityConfigurationType, +) +from azure.ai.ml._schema.core.fields import StringTransformedEnum +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.entities._credentials import ( + AmlTokenConfiguration, + ManagedIdentityConfiguration, + UserIdentityConfiguration, +) + +from ..core.schema import PatchedSchemaMeta + +module_logger = logging.getLogger(__name__) + + +class ManagedIdentitySchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum( + required=True, + allowed_values=[IdentityConfigurationType.MANAGED, ConnectionAuthType.MANAGED_IDENTITY], + casing_transform=camel_to_snake, + ) + client_id = fields.Str() + object_id = fields.Str() + msi_resource_id = fields.Str() + + @post_load + def make(self, data, **kwargs): + data.pop("type") + return ManagedIdentityConfiguration(**data) + + +class AMLTokenIdentitySchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum( + required=True, + allowed_values=IdentityConfigurationType.AML_TOKEN, + casing_transform=camel_to_snake, + ) + + @post_load + def make(self, data, **kwargs): + data.pop("type") + return AmlTokenConfiguration(**data) + + +class UserIdentitySchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum( + required=True, + allowed_values=IdentityConfigurationType.USER_IDENTITY, + casing_transform=camel_to_snake, + ) + + @post_load + def make(self, data, **kwargs): + data.pop("type") + return UserIdentityConfiguration(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/import_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/import_job.py new file mode 100644 index 00000000..8f7c3908 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/import_job.py @@ -0,0 +1,54 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum, UnionField +from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta +from azure.ai.ml._schema.job.input_output_entry import OutputSchema +from azure.ai.ml.constants import ImportSourceType, JobType + +from .base_job import BaseJobSchema + + +class DatabaseImportSourceSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum( + allowed_values=[ + ImportSourceType.AZURESQLDB, + ImportSourceType.AZURESYNAPSEANALYTICS, + ImportSourceType.SNOWFLAKE, + ], + required=True, + ) + connection = fields.Str(required=True) + query = fields.Str(required=True) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._job.import_job import DatabaseImportSource + + return DatabaseImportSource(**data) + + +class FileImportSourceSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum(allowed_values=[ImportSourceType.S3], required=True) + connection = fields.Str(required=True) + path = fields.Str(required=True) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._job.import_job import FileImportSource + + return FileImportSource(**data) + + +class ImportJobSchema(BaseJobSchema): + class Meta: + exclude = ["compute"] # compute property not applicable to import job + + type = StringTransformedEnum(allowed_values=JobType.IMPORT) + source = UnionField([NestedField(DatabaseImportSourceSchema), NestedField(FileImportSourceSchema)], required=True) + output = NestedField(OutputSchema, required=True) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/input_output_entry.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/input_output_entry.py new file mode 100644 index 00000000..1300ab07 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/input_output_entry.py @@ -0,0 +1,256 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging + +from marshmallow import ValidationError, fields, post_load, pre_dump + +from azure.ai.ml._schema.core.fields import ( + ArmVersionedStr, + StringTransformedEnum, + UnionField, + LocalPathField, + NestedField, + VersionField, +) + +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta, PathAwareSchema +from azure.ai.ml.constants._common import ( + AssetTypes, + AzureMLResourceType, + InputOutputModes, +) +from azure.ai.ml.constants._component import ExternalDataType + +module_logger = logging.getLogger(__name__) + + +class InputSchema(metaclass=PatchedSchemaMeta): + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._inputs_outputs import Input + + return Input(**data) + + @pre_dump + def check_dict(self, data, **kwargs): + from azure.ai.ml.entities._inputs_outputs import Input + + if isinstance(data, Input): + return data + raise ValidationError("InputSchema needs type Input to dump") + + +def generate_path_property(azureml_type): + return UnionField( + [ + ArmVersionedStr(azureml_type=azureml_type), + fields.Str(metadata={"pattern": r"^(http(s)?):.*"}), + fields.Str(metadata={"pattern": r"^(wasb(s)?):.*"}), + LocalPathField(pattern=r"^file:.*"), + LocalPathField( + pattern=r"^(?!(azureml|http(s)?|wasb(s)?|file):).*", + ), + ], + is_strict=True, + ) + + +def generate_path_on_compute_property(azureml_type): + return UnionField( + [ + LocalPathField(pattern=r"^file:.*"), + ], + is_strict=True, + ) + + +def generate_datastore_property(): + metadata = { + "description": "Name of the datastore to upload local paths to.", + "arm_type": AzureMLResourceType.DATASTORE, + } + return fields.Str(metadata=metadata, required=False) + + +class ModelInputSchema(InputSchema): + mode = StringTransformedEnum( + allowed_values=[ + InputOutputModes.DOWNLOAD, + InputOutputModes.RO_MOUNT, + InputOutputModes.DIRECT, + ], + required=False, + ) + type = StringTransformedEnum( + allowed_values=[ + AssetTypes.CUSTOM_MODEL, + AssetTypes.MLFLOW_MODEL, + AssetTypes.TRITON_MODEL, + ] + ) + path = generate_path_property(azureml_type=AzureMLResourceType.MODEL) + datastore = generate_datastore_property() + + +class DataInputSchema(InputSchema): + mode = StringTransformedEnum( + allowed_values=[ + InputOutputModes.DOWNLOAD, + InputOutputModes.RO_MOUNT, + InputOutputModes.DIRECT, + ], + required=False, + ) + type = StringTransformedEnum( + allowed_values=[ + AssetTypes.URI_FILE, + AssetTypes.URI_FOLDER, + ] + ) + path = generate_path_property(azureml_type=AzureMLResourceType.DATA) + path_on_compute = generate_path_on_compute_property(azureml_type=AzureMLResourceType.DATA) + datastore = generate_datastore_property() + + +class MLTableInputSchema(InputSchema): + mode = StringTransformedEnum( + allowed_values=[ + InputOutputModes.DOWNLOAD, + InputOutputModes.RO_MOUNT, + InputOutputModes.EVAL_MOUNT, + InputOutputModes.EVAL_DOWNLOAD, + InputOutputModes.DIRECT, + ], + required=False, + ) + type = StringTransformedEnum(allowed_values=[AssetTypes.MLTABLE]) + path = generate_path_property(azureml_type=AzureMLResourceType.DATA) + path_on_compute = generate_path_on_compute_property(azureml_type=AzureMLResourceType.DATA) + datastore = generate_datastore_property() + + +class InputLiteralValueSchema(metaclass=PatchedSchemaMeta): + value = UnionField([fields.Str(), fields.Bool(), fields.Int(), fields.Float()]) + + @post_load + def make(self, data, **kwargs): + return data["value"] + + @pre_dump + def check_dict(self, data, **kwargs): + if hasattr(data, "value"): + return data + raise ValidationError("InputLiteralValue must have a field value") + + +class OutputSchema(PathAwareSchema): + name = fields.Str() + version = VersionField() + mode = StringTransformedEnum( + allowed_values=[ + InputOutputModes.MOUNT, + InputOutputModes.UPLOAD, + InputOutputModes.RW_MOUNT, + InputOutputModes.DIRECT, + ], + required=False, + ) + type = StringTransformedEnum( + allowed_values=[ + AssetTypes.URI_FILE, + AssetTypes.URI_FOLDER, + AssetTypes.CUSTOM_MODEL, + AssetTypes.MLFLOW_MODEL, + AssetTypes.MLTABLE, + AssetTypes.TRITON_MODEL, + ] + ) + path = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._inputs_outputs import Output + + return Output(**data) + + @pre_dump + def check_dict(self, data, **kwargs): + from azure.ai.ml.entities._inputs_outputs import Output + + if isinstance(data, Output): + return data + # Assists with union schema + raise ValidationError("OutputSchema needs type Output to dump") + + +class StoredProcedureParamsSchema(metaclass=PatchedSchemaMeta): + name = fields.Str() + value = fields.Str() + type = fields.Str() + + @pre_dump + def check_dict(self, data, **kwargs): + for key in self.dump_fields.keys(): # pylint: disable=no-member + if data.get(key, None) is None: + msg = "StoredProcedureParams must have a {!r} value." + raise ValidationError(msg.format(key)) + return data + + +class DatabaseSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum(allowed_values=[ExternalDataType.DATABASE], required=True) + table_name = fields.Str() + query = fields.Str( + metadata={"description": "The sql query command."}, + ) + stored_procedure = fields.Str() + stored_procedure_params = fields.List(NestedField(StoredProcedureParamsSchema)) + + connection = fields.Str(required=True) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.data_transfer import Database + + data.pop("type", None) + return Database(**data) + + @pre_dump + def check_dict(self, data, **kwargs): + from azure.ai.ml.data_transfer import Database + + if isinstance(data, Database): + return data + msg = "DatabaseSchema needs type Database to dump, but got {!r}." + raise ValidationError(msg.format(type(data))) + + +class FileSystemSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum( + allowed_values=[ + ExternalDataType.FILE_SYSTEM, + ], + ) + path = fields.Str() + + connection = fields.Str(required=True) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.data_transfer import FileSystem + + data.pop("type", None) + return FileSystem(**data) + + @pre_dump + def check_dict(self, data, **kwargs): + from azure.ai.ml.data_transfer import FileSystem + + if isinstance(data, FileSystem): + return data + msg = "FileSystemSchema needs type FileSystem to dump, but got {!r}." + raise ValidationError(msg.format(type(data))) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/input_output_fields_provider.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/input_output_fields_provider.py new file mode 100644 index 00000000..7fb2e8e0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/input_output_fields_provider.py @@ -0,0 +1,50 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields + +from azure.ai.ml._schema._utils.data_binding_expression import support_data_binding_expression_for_fields +from azure.ai.ml._schema.core.fields import NestedField, PrimitiveValueField, UnionField +from azure.ai.ml._schema.job.input_output_entry import ( + DataInputSchema, + InputLiteralValueSchema, + MLTableInputSchema, + ModelInputSchema, + OutputSchema, +) + + +def InputsField(*, support_databinding: bool = False, **kwargs): + value_fields = [ + NestedField(DataInputSchema), + NestedField(ModelInputSchema), + NestedField(MLTableInputSchema), + NestedField(InputLiteralValueSchema), + PrimitiveValueField(is_strict=False), + # This ordering of types for the values keyword is intentional. The ordering of types + # determines what order schema values are matched and cast in. Changing the current ordering can + # result in values being mis-cast such as 1.0 translating into True. + ] + + # As is_strict is set to True, 1 and only 1 value field must be matched. + # root level data-binding expression has already been covered by PrimitiveValueField; + # If support_databinding is True, we should only add data-binding expression support for nested fields. + if support_databinding: + for field_obj in value_fields: + if isinstance(field_obj, NestedField): + support_data_binding_expression_for_fields(field_obj.schema) + + return fields.Dict( + keys=fields.Str(), + values=UnionField(value_fields, metadata={"description": "Inputs to a job."}, is_strict=True, **kwargs), + ) + + +def OutputsField(**kwargs): + return fields.Dict( + keys=fields.Str(), + values=NestedField(nested=OutputSchema, allow_none=True), + metadata={"description": "Outputs of a job."}, + **kwargs + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/input_port.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/input_port.py new file mode 100644 index 00000000..f37b2a16 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/input_port.py @@ -0,0 +1,29 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging + +from marshmallow import fields, post_load, validate + +from azure.ai.ml.entities import InputPort + +from ..core.schema import PatchedSchemaMeta + +module_logger = logging.getLogger(__name__) + + +class InputPortSchema(metaclass=PatchedSchemaMeta): + type_string = fields.Str( + data_key="type", + validate=validate.OneOf(["path", "number", "null"]), + dump_default="null", + ) + default = fields.Str() + optional = fields.Bool() + + @post_load + def make(self, data, **kwargs): + return InputPort(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/job_limits.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/job_limits.py new file mode 100644 index 00000000..850e9b3d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/job_limits.py @@ -0,0 +1,45 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load, validate + +from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta + + +class CommandJobLimitsSchema(metaclass=PatchedSchemaMeta): + timeout = fields.Int() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import CommandJobLimits + + return CommandJobLimits(**data) + + +class SweepJobLimitsSchema(metaclass=PatchedSchemaMeta): + max_concurrent_trials = fields.Int(metadata={"description": "Sweep Job max concurrent trials."}) + max_total_trials = fields.Int( + metadata={"description": "Sweep Job max total trials."}, + required=True, + ) + timeout = fields.Int( + metadata={"description": "The max run duration in Seconds, after which the job will be cancelled."} + ) + trial_timeout = fields.Int(metadata={"description": "Sweep Job Trial timeout value."}) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.sweep import SweepJobLimits + + return SweepJobLimits(**data) + + +class DoWhileLimitsSchema(metaclass=PatchedSchemaMeta): + max_iteration_count = fields.Int( + metadata={"description": "The max iteration for do_while loop."}, + validate=validate.Range(min=1, max=1000), + required=True, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/job_output.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/job_output.py new file mode 100644 index 00000000..80679119 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/job_output.py @@ -0,0 +1,18 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import logging + +from marshmallow import fields + +from azure.ai.ml._schema.core.fields import ArmStr +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml.constants._common import AzureMLResourceType + +module_logger = logging.getLogger(__name__) + + +class JobOutputSchema(metaclass=PatchedSchemaMeta): + datastore_id = ArmStr(azureml_type=AzureMLResourceType.DATASTORE) + path = fields.Str() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/parallel_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/parallel_job.py new file mode 100644 index 00000000..c539e407 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/parallel_job.py @@ -0,0 +1,15 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from azure.ai.ml._schema.core.fields import StringTransformedEnum +from azure.ai.ml._schema.job.input_output_fields_provider import InputsField, OutputsField +from azure.ai.ml.constants import JobType + +from .base_job import BaseJobSchema +from .parameterized_parallel import ParameterizedParallelSchema + + +class ParallelJobSchema(ParameterizedParallelSchema, BaseJobSchema): + type = StringTransformedEnum(allowed_values=JobType.PARALLEL) + inputs = InputsField() + outputs = OutputsField() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/parameterized_command.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/parameterized_command.py new file mode 100644 index 00000000..1c011bc9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/parameterized_command.py @@ -0,0 +1,41 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields + +from azure.ai.ml._schema.core.fields import ( + CodeField, + DistributionField, + EnvironmentField, + ExperimentalField, + NestedField, +) +from azure.ai.ml._schema.core.schema import PathAwareSchema +from azure.ai.ml._schema.job.input_output_entry import InputLiteralValueSchema +from azure.ai.ml._schema.job_resource_configuration import JobResourceConfigurationSchema +from azure.ai.ml._schema.queue_settings import QueueSettingsSchema + +from ..core.fields import UnionField + + +class ParameterizedCommandSchema(PathAwareSchema): + command = fields.Str( + metadata={ + # pylint: disable=line-too-long + "description": "The command run and the parameters passed. This string may contain place holders of inputs in {}. " + }, + required=True, + ) + code = CodeField() + environment = EnvironmentField(required=True) + environment_variables = UnionField( + [ + fields.Dict(keys=fields.Str(), values=fields.Str()), + # Used for binding environment variables + NestedField(InputLiteralValueSchema), + ] + ) + resources = NestedField(JobResourceConfigurationSchema) + distribution = DistributionField() + queue_settings = ExperimentalField(NestedField(QueueSettingsSchema)) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/parameterized_parallel.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/parameterized_parallel.py new file mode 100644 index 00000000..bb5cd063 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/parameterized_parallel.py @@ -0,0 +1,72 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import INCLUDE, fields + +from azure.ai.ml._schema.component.parallel_task import ComponentParallelTaskSchema +from azure.ai.ml._schema.component.retry_settings import RetrySettingsSchema +from azure.ai.ml._schema.core.fields import DumpableEnumField, NestedField +from azure.ai.ml._schema.core.schema import PathAwareSchema +from azure.ai.ml._schema.job.input_output_entry import InputLiteralValueSchema +from azure.ai.ml._schema.job_resource_configuration import JobResourceConfigurationSchema +from azure.ai.ml.constants._common import LoggingLevel + +from ..core.fields import UnionField + + +class ParameterizedParallelSchema(PathAwareSchema): + logging_level = DumpableEnumField( + allowed_values=[LoggingLevel.DEBUG, LoggingLevel.INFO, LoggingLevel.WARN], + dump_default=LoggingLevel.INFO, + metadata={ + "description": ( + "A string of the logging level name, which is defined in 'logging'. " + "Possible values are 'WARNING', 'INFO', and 'DEBUG'." + ) + }, + ) + task = NestedField(ComponentParallelTaskSchema, unknown=INCLUDE) + mini_batch_size = fields.Str( + metadata={"description": "The batch size of current job."}, + ) + partition_keys = fields.List( + fields.Str(), metadata={"description": "The keys used to partition input data into mini-batches"} + ) + input_data = fields.Str() + resources = NestedField(JobResourceConfigurationSchema) + retry_settings = NestedField(RetrySettingsSchema, unknown=INCLUDE) + max_concurrency_per_instance = fields.Integer( + dump_default=1, + metadata={"description": "The max parallellism that each compute instance has."}, + ) + error_threshold = fields.Integer( + dump_default=-1, + metadata={ + "description": ( + "The number of item processing failures should be ignored. " + "If the error_threshold is reached, the job terminates. " + "For a list of files as inputs, one item means one file reference. " + "This setting doesn't apply to command parallelization." + ) + }, + ) + mini_batch_error_threshold = fields.Integer( + dump_default=-1, + metadata={ + "description": ( + "The number of mini batch processing failures should be ignored. " + "If the mini_batch_error_threshold is reached, the job terminates. " + "For a list of files as inputs, one item means one file reference. " + "This setting can be used by either command or python function parallelization. " + "Only one error_threshold setting can be used in one job." + ) + }, + ) + environment_variables = UnionField( + [ + fields.Dict(keys=fields.Str(), values=fields.Str()), + # Used for binding environment variables + NestedField(InputLiteralValueSchema), + ] + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/parameterized_spark.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/parameterized_spark.py new file mode 100644 index 00000000..49e9560a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/parameterized_spark.py @@ -0,0 +1,151 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint: disable=unused-argument + +import re +from typing import Any, Dict, List + +from marshmallow import ValidationError, fields, post_dump, post_load, pre_dump, pre_load, validates + +from azure.ai.ml._schema.core.fields import CodeField, EnvironmentField, NestedField +from azure.ai.ml._schema.core.schema import PathAwareSchema +from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta + +from ..core.fields import UnionField + +re_memory_pattern = re.compile("^\\d+[kKmMgGtTpP]$") + + +class SparkEntryFileSchema(metaclass=PatchedSchemaMeta): + file = fields.Str(required=True) + # add spark_job_entry_type and make it dump only to align with model definition, + # this will make us get expected value when call spark._from_rest_object() + spark_job_entry_type = fields.Str(dump_only=True) + + @pre_dump + def to_dict(self, data, **kwargs): + return {"file": data.entry} + + +class SparkEntryClassSchema(metaclass=PatchedSchemaMeta): + class_name = fields.Str(required=True) + # add spark_job_entry_type and make it dump only to align with model definition, + # this will make us get expected value when call spark._from_rest_object() + spark_job_entry_type = fields.Str(dump_only=True) + + @pre_dump + def to_dict(self, data, **kwargs): + return {"class_name": data.entry} + + +CONF_KEY_MAP = { + "driver_cores": "spark.driver.cores", + "driver_memory": "spark.driver.memory", + "executor_cores": "spark.executor.cores", + "executor_memory": "spark.executor.memory", + "executor_instances": "spark.executor.instances", + "dynamic_allocation_enabled": "spark.dynamicAllocation.enabled", + "dynamic_allocation_min_executors": "spark.dynamicAllocation.minExecutors", + "dynamic_allocation_max_executors": "spark.dynamicAllocation.maxExecutors", +} + + +def no_duplicates(name: str, value: List): + if len(value) != len(set(value)): + raise ValidationError(f"{name} must not contain duplicate entries.") + + +class ParameterizedSparkSchema(PathAwareSchema): + code = CodeField() + entry = UnionField( + [NestedField(SparkEntryFileSchema), NestedField(SparkEntryClassSchema)], + required=True, + metadata={"description": "Entry."}, + ) + py_files = fields.List(fields.Str(required=True)) + jars = fields.List(fields.Str(required=True)) + files = fields.List(fields.Str(required=True)) + archives = fields.List(fields.Str(required=True)) + conf = fields.Dict(keys=fields.Str(), values=fields.Raw()) + properties = fields.Dict(keys=fields.Str(), values=fields.Raw()) + environment = EnvironmentField(allow_none=True) + args = fields.Str(metadata={"description": "Command Line arguments."}) + + @validates("py_files") + def no_duplicate_py_files(self, value): + no_duplicates("py_files", value) + + @validates("jars") + def no_duplicate_jars(self, value): + no_duplicates("jars", value) + + @validates("files") + def no_duplicate_files(self, value): + no_duplicates("files", value) + + @validates("archives") + def no_duplicate_archives(self, value): + no_duplicates("archives", value) + + @pre_load + # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype + def map_conf_field_names(self, data, **kwargs): + """Map the field names in the conf dictionary. + This function must be called after YamlFileSchema.load_from_file. + Given marshmallow executes the pre_load functions in the alphabetical order (marshmallow\\schema.py:L166, + functions will be checked in alphabetical order when inject to cls._hooks), we must make sure the function + name is alphabetically after "load_from_file". + """ + # TODO: it's dangerous to depend on an alphabetical order, we'd better move related logic out of Schema. + conf = data["conf"] if "conf" in data else None + if conf is not None: + for field_key, dict_key in CONF_KEY_MAP.items(): + value = conf.get(dict_key, None) + if dict_key in conf and value is not None: + del conf[dict_key] + conf[field_key] = value + data["conf"] = conf + return data + + @post_dump(pass_original=True) + def serialize_field_names(self, data: Dict[str, Any], original_data: Dict[str, Any], **kwargs): + conf = data["conf"] if "conf" in data else {} + if original_data.conf is not None and conf is not None: + for field_name, value in original_data.conf.items(): + if field_name not in conf: + if isinstance(value, str) and value.isdigit(): + value = int(value) + conf[field_name] = value + if conf is not None: + for field_name, dict_name in CONF_KEY_MAP.items(): + val = conf.get(field_name, None) + if field_name in conf and val is not None: + if isinstance(val, str) and val.isdigit(): + val = int(val) + del conf[field_name] + conf[dict_name] = val + data["conf"] = conf + return data + + @post_load + def demote_conf_fields(self, data, **kwargs): + conf = data["conf"] if "conf" in data else None + if conf is not None: + for field_name, _ in CONF_KEY_MAP.items(): + value = conf.get(field_name, None) + if field_name in conf and value is not None: + del conf[field_name] + data[field_name] = value + return data + + @pre_dump + def promote_conf_fields(self, data: object, **kwargs): + # copy fields from root object into the 'conf' + conf = data.conf or {} + for field_name, _ in CONF_KEY_MAP.items(): + value = data.__getattribute__(field_name) + if value is not None: + conf[field_name] = value + data.__setattr__("conf", conf) + return data diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/services.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/services.py new file mode 100644 index 00000000..f6fed8c2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/services.py @@ -0,0 +1,100 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import logging + +from marshmallow import fields, post_load + +from azure.ai.ml.entities._job.job_service import ( + JobService, + SshJobService, + JupyterLabJobService, + VsCodeJobService, + TensorBoardJobService, +) +from azure.ai.ml.constants._job.job import JobServiceTypeNames +from azure.ai.ml._schema.core.fields import StringTransformedEnum, UnionField + +from ..core.schema import PathAwareSchema + +module_logger = logging.getLogger(__name__) + + +class JobServiceBaseSchema(PathAwareSchema): + port = fields.Int() + endpoint = fields.Str(dump_only=True) + status = fields.Str(dump_only=True) + nodes = fields.Str() + error_message = fields.Str(dump_only=True) + properties = fields.Dict() + + +class JobServiceSchema(JobServiceBaseSchema): + """This is to support tansformation of job services passed as dict type and internal job services like Custom, + Tracking, Studio set by the system.""" + + type = UnionField( + [ + StringTransformedEnum( + allowed_values=JobServiceTypeNames.NAMES_ALLOWED_FOR_PUBLIC, + pass_original=True, + ), + fields.Str(), + ] + ) + + @post_load + def make(self, data, **kwargs): # pylint: disable=unused-argument + data.pop("type", None) + return JobService(**data) + + +class TensorBoardJobServiceSchema(JobServiceBaseSchema): + type = StringTransformedEnum( + allowed_values=JobServiceTypeNames.EntityNames.TENSOR_BOARD, + pass_original=True, + ) + log_dir = fields.Str() + + @post_load + def make(self, data, **kwargs): # pylint: disable=unused-argument + data.pop("type", None) + return TensorBoardJobService(**data) + + +class SshJobServiceSchema(JobServiceBaseSchema): + type = StringTransformedEnum( + allowed_values=JobServiceTypeNames.EntityNames.SSH, + pass_original=True, + ) + ssh_public_keys = fields.Str() + + @post_load + def make(self, data, **kwargs): # pylint: disable=unused-argument + data.pop("type", None) + return SshJobService(**data) + + +class VsCodeJobServiceSchema(JobServiceBaseSchema): + type = StringTransformedEnum( + allowed_values=JobServiceTypeNames.EntityNames.VS_CODE, + pass_original=True, + ) + + @post_load + def make(self, data, **kwargs): # pylint: disable=unused-argument + data.pop("type", None) + return VsCodeJobService(**data) + + +class JupyterLabJobServiceSchema(JobServiceBaseSchema): + type = StringTransformedEnum( + allowed_values=JobServiceTypeNames.EntityNames.JUPYTER_LAB, + pass_original=True, + ) + + @post_load + def make(self, data, **kwargs): # pylint: disable=unused-argument + data.pop("type", None) + return JupyterLabJobService(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/spark_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/spark_job.py new file mode 100644 index 00000000..f9363175 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/spark_job.py @@ -0,0 +1,28 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from azure.ai.ml._schema.core.fields import NestedField +from azure.ai.ml._schema.job.identity import AMLTokenIdentitySchema, ManagedIdentitySchema, UserIdentitySchema +from azure.ai.ml._schema.job.input_output_fields_provider import InputsField, OutputsField +from azure.ai.ml._schema.spark_resource_configuration import SparkResourceConfigurationSchema +from azure.ai.ml.constants import JobType + +from ..core.fields import ComputeField, StringTransformedEnum, UnionField +from .base_job import BaseJobSchema +from .parameterized_spark import ParameterizedSparkSchema + + +class SparkJobSchema(ParameterizedSparkSchema, BaseJobSchema): + type = StringTransformedEnum(required=True, allowed_values=JobType.SPARK) + compute = ComputeField() + inputs = InputsField() + outputs = OutputsField() + resources = NestedField(SparkResourceConfigurationSchema) + identity = UnionField( + [ + NestedField(ManagedIdentitySchema), + NestedField(AMLTokenIdentitySchema), + NestedField(UserIdentitySchema), + ] + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job_resource_configuration.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job_resource_configuration.py new file mode 100644 index 00000000..859eef31 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job_resource_configuration.py @@ -0,0 +1,38 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.fields import UnionField + +from .resource_configuration import ResourceConfigurationSchema + + +class JobResourceConfigurationSchema(ResourceConfigurationSchema): + locations = fields.List(fields.Str()) + shm_size = fields.Str( + metadata={ + "description": ( + "The size of the docker container's shared memory block. " + "This should be in the format of `<number><unit>` where number as " + "to be greater than 0 and the unit can be one of " + "`b` (bytes), `k` (kilobytes), `m` (megabytes), or `g` (gigabytes)." + ) + } + ) + max_instance_count = fields.Int( + metadata={"description": "The maximum number of instances to make available to this job."} + ) + docker_args = UnionField( + [ + fields.Str(metadata={"description": "arguments to pass to the Docker run command."}), + fields.List(fields.Str()), + ] + ) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import JobResourceConfiguration + + return JobResourceConfiguration(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job_resources.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job_resources.py new file mode 100644 index 00000000..49e6eaa0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job_resources.py @@ -0,0 +1,21 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta + + +class JobResourcesSchema(metaclass=PatchedSchemaMeta): + instance_types = fields.List( + fields.Str(), metadata={"description": "The instance type to make available to this job."} + ) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import JobResources + + return JobResources(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/__init__.py new file mode 100644 index 00000000..29a4fcd3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/alert_notification.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/alert_notification.py new file mode 100644 index 00000000..bd7fd69c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/alert_notification.py @@ -0,0 +1,19 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + + +class AlertNotificationSchema(metaclass=PatchedSchemaMeta): + emails = fields.List(fields.Str) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.alert_notification import AlertNotification + + return AlertNotification(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/compute.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/compute.py new file mode 100644 index 00000000..483b4ac5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/compute.py @@ -0,0 +1,23 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + + +class ComputeConfigurationSchema(metaclass=PatchedSchemaMeta): + compute_type = fields.Str(allowed_values=["ServerlessSpark"]) + + +class ServerlessSparkComputeSchema(ComputeConfigurationSchema): + runtime_version = fields.Str() + instance_type = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.compute import ServerlessSparkCompute + + return ServerlessSparkCompute(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/input_data.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/input_data.py new file mode 100644 index 00000000..d5a6a4f9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/input_data.py @@ -0,0 +1,52 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load + +from azure.ai.ml.constants._monitoring import MonitorDatasetContext +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum, UnionField +from azure.ai.ml._schema.job.input_output_entry import MLTableInputSchema, DataInputSchema + + +class MonitorInputDataSchema(metaclass=PatchedSchemaMeta): + input_data = UnionField(union_fields=[NestedField(DataInputSchema), NestedField(MLTableInputSchema)]) + data_context = StringTransformedEnum(allowed_values=[o.value for o in MonitorDatasetContext]) + target_columns = fields.Dict() + job_type = fields.Str() + uri = fields.Str() + + +class FixedInputDataSchema(MonitorInputDataSchema): + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.input_data import FixedInputData + + return FixedInputData(**data) + + +class TrailingInputDataSchema(MonitorInputDataSchema): + window_size = fields.Str() + window_offset = fields.Str() + pre_processing_component_id = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.input_data import TrailingInputData + + return TrailingInputData(**data) + + +class StaticInputDataSchema(MonitorInputDataSchema): + pre_processing_component_id = fields.Str() + window_start = fields.String() + window_end = fields.String() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.input_data import StaticInputData + + return StaticInputData(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/monitor_definition.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/monitor_definition.py new file mode 100644 index 00000000..3fe52c9d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/monitor_definition.py @@ -0,0 +1,53 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load + +from azure.ai.ml.constants._monitoring import AZMONITORING +from azure.ai.ml._schema.monitoring.target import MonitoringTargetSchema +from azure.ai.ml._schema.monitoring.compute import ServerlessSparkComputeSchema +from azure.ai.ml._schema.monitoring.signals import ( + DataDriftSignalSchema, + DataQualitySignalSchema, + PredictionDriftSignalSchema, + FeatureAttributionDriftSignalSchema, + CustomMonitoringSignalSchema, + GenerationSafetyQualitySchema, + ModelPerformanceSignalSchema, + GenerationTokenStatisticsSchema, +) +from azure.ai.ml._schema.monitoring.alert_notification import AlertNotificationSchema +from azure.ai.ml._schema.core.fields import NestedField, UnionField, StringTransformedEnum +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + + +class MonitorDefinitionSchema(metaclass=PatchedSchemaMeta): + compute = NestedField(ServerlessSparkComputeSchema) + monitoring_target = NestedField(MonitoringTargetSchema) + monitoring_signals = fields.Dict( + keys=fields.Str(), + values=UnionField( + union_fields=[ + NestedField(DataDriftSignalSchema), + NestedField(DataQualitySignalSchema), + NestedField(PredictionDriftSignalSchema), + NestedField(FeatureAttributionDriftSignalSchema), + NestedField(CustomMonitoringSignalSchema), + NestedField(GenerationSafetyQualitySchema), + NestedField(ModelPerformanceSignalSchema), + NestedField(GenerationTokenStatisticsSchema), + ] + ), + ) + alert_notification = UnionField( + union_fields=[StringTransformedEnum(allowed_values=AZMONITORING), NestedField(AlertNotificationSchema)] + ) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.definition import MonitorDefinition + + return MonitorDefinition(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/schedule.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/schedule.py new file mode 100644 index 00000000..a2034d33 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/schedule.py @@ -0,0 +1,11 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from azure.ai.ml._schema.core.fields import NestedField +from azure.ai.ml._schema.monitoring.monitor_definition import MonitorDefinitionSchema +from azure.ai.ml._schema.schedule.schedule import ScheduleSchema + + +class MonitorScheduleSchema(ScheduleSchema): + create_monitor = NestedField(MonitorDefinitionSchema) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/signals.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/signals.py new file mode 100644 index 00000000..4f55393b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/signals.py @@ -0,0 +1,348 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load, pre_dump, ValidationError + +from azure.ai.ml._schema.job.input_output_entry import DataInputSchema, MLTableInputSchema +from azure.ai.ml.constants._common import AzureMLResourceType +from azure.ai.ml.constants._monitoring import ( + MonitorSignalType, + ALL_FEATURES, + MonitorModelType, + MonitorDatasetContext, + FADColumnNames, +) +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml._schema.core.fields import ArmVersionedStr, NestedField, UnionField, StringTransformedEnum +from azure.ai.ml._schema.monitoring.thresholds import ( + DataDriftMetricThresholdSchema, + DataQualityMetricThresholdSchema, + PredictionDriftMetricThresholdSchema, + FeatureAttributionDriftMetricThresholdSchema, + ModelPerformanceMetricThresholdSchema, + CustomMonitoringMetricThresholdSchema, + GenerationSafetyQualityMetricThresholdSchema, + GenerationTokenStatisticsMonitorMetricThresholdSchema, +) + + +class DataSegmentSchema(metaclass=PatchedSchemaMeta): + feature_name = fields.Str() + feature_values = fields.List(fields.Str) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import DataSegment + + return DataSegment(**data) + + +class MonitorFeatureFilterSchema(metaclass=PatchedSchemaMeta): + top_n_feature_importance = fields.Int() + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import MonitorFeatureFilter + + if not isinstance(data, MonitorFeatureFilter): + raise ValidationError("Cannot dump non-MonitorFeatureFilter object into MonitorFeatureFilter") + return data + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import MonitorFeatureFilter + + return MonitorFeatureFilter(**data) + + +class BaselineDataRangeSchema(metaclass=PatchedSchemaMeta): + window_start = fields.Str() + window_end = fields.Str() + lookback_window_size = fields.Str() + lookback_window_offset = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import BaselineDataRange + + return BaselineDataRange(**data) + + +class ProductionDataSchema(metaclass=PatchedSchemaMeta): + input_data = UnionField(union_fields=[NestedField(DataInputSchema), NestedField(MLTableInputSchema)]) + data_context = StringTransformedEnum(allowed_values=[o.value for o in MonitorDatasetContext]) + pre_processing_component = fields.Str() + data_window = NestedField(BaselineDataRangeSchema) + data_column_names = fields.Dict(keys=fields.Str(), values=fields.Str()) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import ProductionData + + return ProductionData(**data) + + +class ReferenceDataSchema(metaclass=PatchedSchemaMeta): + input_data = UnionField(union_fields=[NestedField(DataInputSchema), NestedField(MLTableInputSchema)]) + data_context = StringTransformedEnum(allowed_values=[o.value for o in MonitorDatasetContext]) + pre_processing_component = fields.Str() + target_column_name = fields.Str() + data_window = NestedField(BaselineDataRangeSchema) + data_column_names = fields.Dict(keys=fields.Str(), values=fields.Str()) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import ReferenceData + + return ReferenceData(**data) + + +class MonitoringSignalSchema(metaclass=PatchedSchemaMeta): + production_data = NestedField(ProductionDataSchema) + reference_data = NestedField(ReferenceDataSchema) + properties = fields.Dict() + alert_enabled = fields.Bool() + + +class DataSignalSchema(MonitoringSignalSchema): + features = UnionField( + union_fields=[ + NestedField(MonitorFeatureFilterSchema), + StringTransformedEnum(allowed_values=ALL_FEATURES), + fields.List(fields.Str), + ] + ) + feature_type_override = fields.Dict() + + +class DataDriftSignalSchema(DataSignalSchema): + type = StringTransformedEnum(allowed_values=MonitorSignalType.DATA_DRIFT, required=True) + metric_thresholds = NestedField(DataDriftMetricThresholdSchema) + data_segment = NestedField(DataSegmentSchema) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import DataDriftSignal + + if not isinstance(data, DataDriftSignal): + raise ValidationError("Cannot dump non-DataDriftSignal object into DataDriftSignal") + return data + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import DataDriftSignal + + data.pop("type", None) + return DataDriftSignal(**data) + + +class DataQualitySignalSchema(DataSignalSchema): + type = StringTransformedEnum(allowed_values=MonitorSignalType.DATA_QUALITY, required=True) + metric_thresholds = NestedField(DataQualityMetricThresholdSchema) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import DataQualitySignal + + if not isinstance(data, DataQualitySignal): + raise ValidationError("Cannot dump non-DataQualitySignal object into DataQualitySignal") + return data + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import DataQualitySignal + + data.pop("type", None) + return DataQualitySignal(**data) + + +class PredictionDriftSignalSchema(MonitoringSignalSchema): + type = StringTransformedEnum(allowed_values=MonitorSignalType.PREDICTION_DRIFT, required=True) + metric_thresholds = NestedField(PredictionDriftMetricThresholdSchema) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import PredictionDriftSignal + + if not isinstance(data, PredictionDriftSignal): + raise ValidationError("Cannot dump non-PredictionDriftSignal object into PredictionDriftSignal") + return data + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import PredictionDriftSignal + + data.pop("type", None) + return PredictionDriftSignal(**data) + + +class ModelSignalSchema(MonitoringSignalSchema): + model_type = StringTransformedEnum(allowed_values=[MonitorModelType.CLASSIFICATION, MonitorModelType.REGRESSION]) + + +class FADProductionDataSchema(metaclass=PatchedSchemaMeta): + input_data = UnionField(union_fields=[NestedField(DataInputSchema), NestedField(MLTableInputSchema)]) + data_context = StringTransformedEnum(allowed_values=[o.value for o in MonitorDatasetContext]) + data_column_names = fields.Dict( + keys=StringTransformedEnum(allowed_values=[o.value for o in FADColumnNames]), values=fields.Str() + ) + pre_processing_component = fields.Str() + data_window = NestedField(BaselineDataRangeSchema) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import FADProductionData + + return FADProductionData(**data) + + +class FeatureAttributionDriftSignalSchema(metaclass=PatchedSchemaMeta): + production_data = fields.List(NestedField(FADProductionDataSchema)) + reference_data = NestedField(ReferenceDataSchema) + alert_enabled = fields.Bool() + type = StringTransformedEnum(allowed_values=MonitorSignalType.FEATURE_ATTRIBUTION_DRIFT, required=True) + metric_thresholds = NestedField(FeatureAttributionDriftMetricThresholdSchema) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import FeatureAttributionDriftSignal + + if not isinstance(data, FeatureAttributionDriftSignal): + raise ValidationError( + "Cannot dump non-FeatureAttributionDriftSignal object into FeatureAttributionDriftSignal" + ) + return data + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import FeatureAttributionDriftSignal + + data.pop("type", None) + return FeatureAttributionDriftSignal(**data) + + +class ModelPerformanceSignalSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum(allowed_values=MonitorSignalType.MODEL_PERFORMANCE, required=True) + production_data = NestedField(ProductionDataSchema) + reference_data = NestedField(ReferenceDataSchema) + data_segment = NestedField(DataSegmentSchema) + alert_enabled = fields.Bool() + metric_thresholds = NestedField(ModelPerformanceMetricThresholdSchema) + properties = fields.Dict() + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import ModelPerformanceSignal + + if not isinstance(data, ModelPerformanceSignal): + raise ValidationError("Cannot dump non-ModelPerformanceSignal object into ModelPerformanceSignal") + return data + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import ModelPerformanceSignal + + data.pop("type", None) + return ModelPerformanceSignal(**data) + + +class ConnectionSchema(metaclass=PatchedSchemaMeta): + environment_variables = fields.Dict(keys=fields.Str(), values=fields.Str()) + secret_config = fields.Dict(keys=fields.Str(), values=fields.Str()) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import Connection + + return Connection(**data) + + +class CustomMonitoringSignalSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum(allowed_values=MonitorSignalType.CUSTOM, required=True) + connection = NestedField(ConnectionSchema) + component_id = ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT) + metric_thresholds = fields.List(NestedField(CustomMonitoringMetricThresholdSchema)) + input_data = fields.Dict(keys=fields.Str(), values=NestedField(ReferenceDataSchema)) + alert_enabled = fields.Bool() + inputs = fields.Dict( + keys=fields.Str, values=UnionField(union_fields=[NestedField(DataInputSchema), NestedField(MLTableInputSchema)]) + ) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import CustomMonitoringSignal + + if not isinstance(data, CustomMonitoringSignal): + raise ValidationError("Cannot dump non-CustomMonitoringSignal object into CustomMonitoringSignal") + return data + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import CustomMonitoringSignal + + data.pop("type", None) + return CustomMonitoringSignal(**data) + + +class LlmDataSchema(metaclass=PatchedSchemaMeta): + input_data = UnionField(union_fields=[NestedField(DataInputSchema), NestedField(MLTableInputSchema)]) + data_column_names = fields.Dict() + data_window = NestedField(BaselineDataRangeSchema) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import LlmData + + return LlmData(**data) + + +class GenerationSafetyQualitySchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum(allowed_values=MonitorSignalType.GENERATION_SAFETY_QUALITY, required=True) + production_data = fields.List(NestedField(LlmDataSchema)) + connection_id = fields.Str() + metric_thresholds = NestedField(GenerationSafetyQualityMetricThresholdSchema) + alert_enabled = fields.Bool() + properties = fields.Dict() + sampling_rate = fields.Float() + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import GenerationSafetyQualitySignal + + if not isinstance(data, GenerationSafetyQualitySignal): + raise ValidationError("Cannot dump non-GenerationSafetyQuality object into GenerationSafetyQuality") + return data + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import GenerationSafetyQualitySignal + + data.pop("type", None) + return GenerationSafetyQualitySignal(**data) + + +class GenerationTokenStatisticsSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum(allowed_values=MonitorSignalType.GENERATION_TOKEN_STATISTICS, required=True) + production_data = NestedField(LlmDataSchema) + metric_thresholds = NestedField(GenerationTokenStatisticsMonitorMetricThresholdSchema) + alert_enabled = fields.Bool() + properties = fields.Dict() + sampling_rate = fields.Float() + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import GenerationTokenStatisticsSignal + + if not isinstance(data, GenerationTokenStatisticsSignal): + raise ValidationError("Cannot dump non-GenerationSafetyQuality object into GenerationSafetyQuality") + return data + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import GenerationTokenStatisticsSignal + + data.pop("type", None) + return GenerationTokenStatisticsSignal(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/target.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/target.py new file mode 100644 index 00000000..6d3032ca --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/target.py @@ -0,0 +1,25 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load + + +from azure.ai.ml.constants._common import AzureMLResourceType +from azure.ai.ml.constants._monitoring import MonitorTargetTasks +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml._schema.core.fields import ArmVersionedStr, StringTransformedEnum + + +class MonitoringTargetSchema(metaclass=PatchedSchemaMeta): + model_id = ArmVersionedStr(azureml_type=AzureMLResourceType.MODEL) + ml_task = StringTransformedEnum(allowed_values=[o.value for o in MonitorTargetTasks]) + endpoint_deployment_id = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.target import MonitoringTarget + + return MonitoringTarget(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/thresholds.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/thresholds.py new file mode 100644 index 00000000..b7970fca --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/thresholds.py @@ -0,0 +1,211 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument, name-too-long + +from marshmallow import fields, post_load + +from azure.ai.ml.constants._monitoring import MonitorFeatureType +from azure.ai.ml._schema.core.fields import StringTransformedEnum, NestedField +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + + +class MetricThresholdSchema(metaclass=PatchedSchemaMeta): + threshold = fields.Number() + + +class NumericalDriftMetricsSchema(metaclass=PatchedSchemaMeta): + jensen_shannon_distance = fields.Number() + normalized_wasserstein_distance = fields.Number() + population_stability_index = fields.Number() + two_sample_kolmogorov_smirnov_test = fields.Number() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.thresholds import NumericalDriftMetrics + + return NumericalDriftMetrics(**data) + + +class CategoricalDriftMetricsSchema(metaclass=PatchedSchemaMeta): + jensen_shannon_distance = fields.Number() + population_stability_index = fields.Number() + pearsons_chi_squared_test = fields.Number() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.thresholds import CategoricalDriftMetrics + + return CategoricalDriftMetrics(**data) + + +class DataDriftMetricThresholdSchema(MetricThresholdSchema): + data_type = StringTransformedEnum(allowed_values=[MonitorFeatureType.NUMERICAL, MonitorFeatureType.CATEGORICAL]) + + numerical = NestedField(NumericalDriftMetricsSchema) + categorical = NestedField(CategoricalDriftMetricsSchema) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.thresholds import DataDriftMetricThreshold + + return DataDriftMetricThreshold(**data) + + +class DataQualityMetricsNumericalSchema(metaclass=PatchedSchemaMeta): + null_value_rate = fields.Number() + data_type_error_rate = fields.Number() + out_of_bounds_rate = fields.Number() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.thresholds import DataQualityMetricsNumerical + + return DataQualityMetricsNumerical(**data) + + +class DataQualityMetricsCategoricalSchema(metaclass=PatchedSchemaMeta): + null_value_rate = fields.Number() + data_type_error_rate = fields.Number() + out_of_bounds_rate = fields.Number() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.thresholds import DataQualityMetricsCategorical + + return DataQualityMetricsCategorical(**data) + + +class DataQualityMetricThresholdSchema(MetricThresholdSchema): + data_type = StringTransformedEnum(allowed_values=[MonitorFeatureType.NUMERICAL, MonitorFeatureType.CATEGORICAL]) + numerical = NestedField(DataQualityMetricsNumericalSchema) + categorical = NestedField(DataQualityMetricsCategoricalSchema) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.thresholds import DataQualityMetricThreshold + + return DataQualityMetricThreshold(**data) + + +class PredictionDriftMetricThresholdSchema(MetricThresholdSchema): + data_type = StringTransformedEnum(allowed_values=[MonitorFeatureType.NUMERICAL, MonitorFeatureType.CATEGORICAL]) + numerical = NestedField(NumericalDriftMetricsSchema) + categorical = NestedField(CategoricalDriftMetricsSchema) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.thresholds import PredictionDriftMetricThreshold + + return PredictionDriftMetricThreshold(**data) + + +# pylint: disable-next=name-too-long +class FeatureAttributionDriftMetricThresholdSchema(MetricThresholdSchema): + normalized_discounted_cumulative_gain = fields.Number() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.thresholds import FeatureAttributionDriftMetricThreshold + + return FeatureAttributionDriftMetricThreshold(**data) + + +class ModelPerformanceClassificationThresholdsSchema(metaclass=PatchedSchemaMeta): + accuracy = fields.Number() + precision = fields.Number() + recall = fields.Number() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.thresholds import ModelPerformanceClassificationThresholds + + return ModelPerformanceClassificationThresholds(**data) + + +class ModelPerformanceRegressionThresholdsSchema(metaclass=PatchedSchemaMeta): + mae = fields.Number() + mse = fields.Number() + rmse = fields.Number() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.thresholds import ModelPerformanceRegressionThresholds + + return ModelPerformanceRegressionThresholds(**data) + + +class ModelPerformanceMetricThresholdSchema(MetricThresholdSchema): + classification = NestedField(ModelPerformanceClassificationThresholdsSchema) + regression = NestedField(ModelPerformanceRegressionThresholdsSchema) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.thresholds import ModelPerformanceMetricThreshold + + return ModelPerformanceMetricThreshold(**data) + + +class CustomMonitoringMetricThresholdSchema(MetricThresholdSchema): + metric_name = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.thresholds import CustomMonitoringMetricThreshold + + return CustomMonitoringMetricThreshold(**data) + + +class GenerationSafetyQualityMetricThresholdSchema(metaclass=PatchedSchemaMeta): # pylint: disable=name-too-long + groundedness = fields.Dict( + keys=StringTransformedEnum( + allowed_values=["aggregated_groundedness_pass_rate", "acceptable_groundedness_score_per_instance"] + ), + values=fields.Number(), + ) + relevance = fields.Dict( + keys=StringTransformedEnum( + allowed_values=["aggregated_relevance_pass_rate", "acceptable_relevance_score_per_instance"] + ), + values=fields.Number(), + ) + coherence = fields.Dict( + keys=StringTransformedEnum( + allowed_values=["aggregated_coherence_pass_rate", "acceptable_coherence_score_per_instance"] + ), + values=fields.Number(), + ) + fluency = fields.Dict( + keys=StringTransformedEnum( + allowed_values=["aggregated_fluency_pass_rate", "acceptable_fluency_score_per_instance"] + ), + values=fields.Number(), + ) + similarity = fields.Dict( + keys=StringTransformedEnum( + allowed_values=["aggregated_similarity_pass_rate", "acceptable_similarity_score_per_instance"] + ), + values=fields.Number(), + ) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.thresholds import GenerationSafetyQualityMonitoringMetricThreshold + + return GenerationSafetyQualityMonitoringMetricThreshold(**data) + + +class GenerationTokenStatisticsMonitorMetricThresholdSchema( + metaclass=PatchedSchemaMeta +): # pylint: disable=name-too-long + totaltoken = fields.Dict( + keys=StringTransformedEnum(allowed_values=["total_token_count", "total_token_count_per_group"]), + values=fields.Number(), + ) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.thresholds import GenerationTokenStatisticsMonitorMetricThreshold + + return GenerationTokenStatisticsMonitorMetricThreshold(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/__init__.py new file mode 100644 index 00000000..a19931cd --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/__init__.py @@ -0,0 +1,17 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint: disable=unused-import +__path__ = __import__("pkgutil").extend_path(__path__, __name__) + +from .component_job import ( + CommandSchema, + ImportSchema, + ParallelSchema, + SparkSchema, + DataTransferCopySchema, + DataTransferImportSchema, + DataTransferExportSchema, +) +from .pipeline_job import PipelineJobSchema +from .settings import PipelineJobSettingsSchema diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/automl_node.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/automl_node.py new file mode 100644 index 00000000..4b815db7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/automl_node.py @@ -0,0 +1,148 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument,protected-access +from typing import List + +from marshmallow import fields, post_dump, post_load, pre_dump + +from azure.ai.ml._schema._utils.data_binding_expression import support_data_binding_expression_for_fields +from azure.ai.ml._schema.automl import AutoMLClassificationSchema, AutoMLForecastingSchema, AutoMLRegressionSchema +from azure.ai.ml._schema.automl.image_vertical.image_classification import ( + ImageClassificationMultilabelSchema, + ImageClassificationSchema, +) +from azure.ai.ml._schema.automl.image_vertical.image_object_detection import ( + ImageInstanceSegmentationSchema, + ImageObjectDetectionSchema, +) +from azure.ai.ml._schema.automl.nlp_vertical.text_classification import TextClassificationSchema +from azure.ai.ml._schema.automl.nlp_vertical.text_classification_multilabel import TextClassificationMultilabelSchema +from azure.ai.ml._schema.automl.nlp_vertical.text_ner import TextNerSchema +from azure.ai.ml._schema.core.fields import ComputeField, NestedField, UnionField +from azure.ai.ml._schema.core.schema import PathAwareSchema +from azure.ai.ml._schema.job.input_output_entry import MLTableInputSchema, OutputSchema +from azure.ai.ml._schema.pipeline.pipeline_job_io import OutputBindingStr + + +class AutoMLNodeMixin(PathAwareSchema): + """Inherit this mixin to change automl job schemas to automl node schema. + + eg: Compute is required for automl job but not required for automl node in pipeline. + Note: Inherit this before BaseJobSchema to make sure optional takes affect. + """ + + def __init__(self, **kwargs): + super(AutoMLNodeMixin, self).__init__(**kwargs) + # update field objects and add data binding support, won't bind task & type as data binding + support_data_binding_expression_for_fields(self, attrs_to_skip=["task_type", "type"]) + + compute = ComputeField(required=False) + outputs = fields.Dict( + keys=fields.Str(), + values=UnionField([NestedField(OutputSchema), OutputBindingStr], allow_none=True), + ) + + @pre_dump + def resolve_outputs(self, job: "AutoMLJob", **kwargs): + # Try resolve object's inputs & outputs and return a resolved new object + import copy + + result = copy.copy(job) + result._outputs = job._build_outputs() + return result + + @post_dump(pass_original=True) + # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype + def resolve_nested_data(self, job_dict: dict, job: "AutoMLJob", **kwargs): + """Resolve nested data into flatten format.""" + from azure.ai.ml.entities._job.automl.automl_job import AutoMLJob + + if not isinstance(job, AutoMLJob): + return job_dict + # change output to rest output dicts + job_dict["outputs"] = job._to_rest_outputs() + return job_dict + + @post_load + def make(self, data, **kwargs): + data["task"] = data.pop("task_type") + return data + + +class AutoMLClassificationNodeSchema(AutoMLNodeMixin, AutoMLClassificationSchema): + training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + test_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + + +class AutoMLRegressionNodeSchema(AutoMLNodeMixin, AutoMLRegressionSchema): + training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + test_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + + +class AutoMLForecastingNodeSchema(AutoMLNodeMixin, AutoMLForecastingSchema): + training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + test_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + + +class AutoMLTextClassificationNode(AutoMLNodeMixin, TextClassificationSchema): + training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + + +class AutoMLTextClassificationMultilabelNode(AutoMLNodeMixin, TextClassificationMultilabelSchema): + training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + + +class AutoMLTextNerNode(AutoMLNodeMixin, TextNerSchema): + training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + + +class ImageClassificationMulticlassNodeSchema(AutoMLNodeMixin, ImageClassificationSchema): + training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + + +class ImageClassificationMultilabelNodeSchema(AutoMLNodeMixin, ImageClassificationMultilabelSchema): + training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + + +class ImageObjectDetectionNodeSchema(AutoMLNodeMixin, ImageObjectDetectionSchema): + training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + + +class ImageInstanceSegmentationNodeSchema(AutoMLNodeMixin, ImageInstanceSegmentationSchema): + training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + + +def AutoMLNodeSchema(**kwargs) -> List[fields.Field]: + """Get the list of all nested schema for all AutoML nodes. + + :return: The list of fields + :rtype: List[fields.Field] + """ + return [ + # region: automl node schemas + NestedField(AutoMLClassificationNodeSchema, **kwargs), + NestedField(AutoMLRegressionNodeSchema, **kwargs), + NestedField(AutoMLForecastingNodeSchema, **kwargs), + # Vision + NestedField(ImageClassificationMulticlassNodeSchema, **kwargs), + NestedField(ImageClassificationMultilabelNodeSchema, **kwargs), + NestedField(ImageObjectDetectionNodeSchema, **kwargs), + NestedField(ImageInstanceSegmentationNodeSchema, **kwargs), + # NLP + NestedField(AutoMLTextClassificationNode, **kwargs), + NestedField(AutoMLTextClassificationMultilabelNode, **kwargs), + NestedField(AutoMLTextNerNode, **kwargs), + # endregion + ] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/component_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/component_job.py new file mode 100644 index 00000000..8f179479 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/component_job.py @@ -0,0 +1,554 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import logging + +from marshmallow import INCLUDE, ValidationError, fields, post_dump, post_load, pre_dump, validates + +from ..._schema.component import ( + AnonymousCommandComponentSchema, + AnonymousDataTransferCopyComponentSchema, + AnonymousImportComponentSchema, + AnonymousParallelComponentSchema, + AnonymousSparkComponentSchema, + ComponentFileRefField, + ComponentYamlRefField, + DataTransferCopyComponentFileRefField, + ImportComponentFileRefField, + ParallelComponentFileRefField, + SparkComponentFileRefField, +) +from ..._utils.utils import is_data_binding_expression +from ...constants._common import AzureMLResourceType +from ...constants._component import DataTransferTaskType, NodeType +from ...entities._inputs_outputs import Input +from ...entities._job.pipeline._attr_dict import _AttrDict +from ...exceptions import ValidationException +from .._sweep.parameterized_sweep import ParameterizedSweepSchema +from .._utils.data_binding_expression import support_data_binding_expression_for_fields +from ..component.flow import FlowComponentSchema +from ..core.fields import ( + ArmVersionedStr, + ComputeField, + EnvironmentField, + NestedField, + RegistryStr, + StringTransformedEnum, + TypeSensitiveUnionField, + UnionField, +) +from ..core.schema import PathAwareSchema +from ..job import ParameterizedCommandSchema, ParameterizedParallelSchema, ParameterizedSparkSchema +from ..job.identity import AMLTokenIdentitySchema, ManagedIdentitySchema, UserIdentitySchema +from ..job.input_output_entry import DatabaseSchema, FileSystemSchema, OutputSchema +from ..job.input_output_fields_provider import InputsField +from ..job.job_limits import CommandJobLimitsSchema +from ..job.parameterized_spark import SparkEntryClassSchema, SparkEntryFileSchema +from ..job.services import ( + JobServiceSchema, + JupyterLabJobServiceSchema, + SshJobServiceSchema, + TensorBoardJobServiceSchema, + VsCodeJobServiceSchema, +) +from ..pipeline.pipeline_job_io import OutputBindingStr +from ..spark_resource_configuration import SparkResourceConfigurationForNodeSchema + +module_logger = logging.getLogger(__name__) + + +# do inherit PathAwareSchema to support relative path & default partial load (allow None value if not specified) +class BaseNodeSchema(PathAwareSchema): + """Base schema for all node schemas.""" + + unknown = INCLUDE + + inputs = InputsField(support_databinding=True) + outputs = fields.Dict( + keys=fields.Str(), + values=UnionField([OutputBindingStr, NestedField(OutputSchema)], allow_none=True), + ) + properties = fields.Dict(keys=fields.Str(), values=fields.Str(allow_none=True)) + comment = fields.Str() + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # data binding expression is not supported inside component field, while validation error + # message will be very long when component is an object as error message will include + # str(component), so just add component to skip list. The same to trial in Sweep. + support_data_binding_expression_for_fields(self, ["type", "component", "trial", "inputs"]) + + @post_dump(pass_original=True) + # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype + def add_user_setting_attr_dict(self, data, original_data, **kwargs): # pylint: disable=unused-argument + """Support serializing unknown fields for pipeline node.""" + if isinstance(original_data, _AttrDict): + user_setting_attr_dict = original_data._get_attrs() + # TODO: dump _AttrDict values to serializable data like dict instead of original object + # skip fields that are already serialized + for key, value in user_setting_attr_dict.items(): + if key not in data: + data[key] = value + return data + + # an alternative would be set schema property to be load_only, but sub-schemas like CommandSchema usually also + # inherit from other schema classes which also have schema property. Set post dump here would be more efficient. + @post_dump() + def remove_meaningless_key_for_node( + self, + data, + **kwargs, # pylint: disable=unused-argument + ): + data.pop("$schema", None) + return data + + +def _delete_type_for_binding(io): + for key in io: + if isinstance(io[key], Input) and io[key].path and is_data_binding_expression(io[key].path): + io[key].type = None + + +def _resolve_inputs(result, original_job): + result._inputs = original_job._build_inputs() + # delete type for literal binding input + _delete_type_for_binding(result._inputs) + + +def _resolve_outputs(result, original_job): + result._outputs = original_job._build_outputs() + # delete type for literal binding output + _delete_type_for_binding(result._outputs) + + +def _resolve_inputs_outputs(job): + # Try resolve object's inputs & outputs and return a resolved new object + import copy + + result = copy.copy(job) + _resolve_inputs(result, job) + _resolve_outputs(result, job) + + return result + + +class CommandSchema(BaseNodeSchema, ParameterizedCommandSchema): + """Schema for Command.""" + + # pylint: disable=unused-argument + component = TypeSensitiveUnionField( + { + NodeType.COMMAND: [ + # inline component or component file reference starting with FILE prefix + NestedField(AnonymousCommandComponentSchema, unknown=INCLUDE), + # component file reference + ComponentFileRefField(), + ], + }, + plain_union_fields=[ + # for registry type assets + RegistryStr(), + # existing component + ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True), + ], + required=True, + ) + # code is directly linked to component.code, so no need to validate or dump it + code = fields.Str(allow_none=True, load_only=True) + type = StringTransformedEnum(allowed_values=[NodeType.COMMAND]) + compute = ComputeField() + # do not promote it as CommandComponent has no field named 'limits' + limits = NestedField(CommandJobLimitsSchema) + # Change required fields to optional + command = fields.Str( + metadata={ + "description": "The command run and the parameters passed. \ + This string may contain place holders of inputs in {}. " + }, + load_only=True, + ) + environment = EnvironmentField() + services = fields.Dict( + keys=fields.Str(), + values=UnionField( + [ + NestedField(SshJobServiceSchema), + NestedField(JupyterLabJobServiceSchema), + NestedField(TensorBoardJobServiceSchema), + NestedField(VsCodeJobServiceSchema), + # JobServiceSchema should be the last in the list. + # To support types not set by users like Custom, Tracking, Studio. + NestedField(JobServiceSchema), + ], + is_strict=True, + ), + ) + identity = UnionField( + [ + NestedField(ManagedIdentitySchema), + NestedField(AMLTokenIdentitySchema), + NestedField(UserIdentitySchema), + ] + ) + + @post_load + def make(self, data, **kwargs) -> "Command": + from azure.ai.ml.entities._builders import parse_inputs_outputs + from azure.ai.ml.entities._builders.command_func import command + + # parse inputs/outputs + data = parse_inputs_outputs(data) + try: + command_node = command(**data) + except ValidationException as e: + # It may raise ValidationError during initialization, command._validate_io e.g. raise ValidationError + # instead in marshmallow function, so it won't break SchemaValidatable._schema_validate + raise ValidationError(e.message) from e + return command_node + + @pre_dump + def resolve_inputs_outputs(self, job, **kwargs): + return _resolve_inputs_outputs(job) + + +class SweepSchema(BaseNodeSchema, ParameterizedSweepSchema): + """Schema for Sweep.""" + + # pylint: disable=unused-argument + type = StringTransformedEnum(allowed_values=[NodeType.SWEEP]) + compute = ComputeField() + trial = TypeSensitiveUnionField( + { + NodeType.SWEEP: [ + # inline component or component file reference starting with FILE prefix + NestedField(AnonymousCommandComponentSchema, unknown=INCLUDE), + # component file reference + ComponentFileRefField(), + ], + }, + plain_union_fields=[ + # existing component + ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True), + ], + required=True, + ) + + @post_load + def make(self, data, **kwargs) -> "Sweep": + from azure.ai.ml.entities._builders import Sweep, parse_inputs_outputs + + # parse inputs/outputs + data = parse_inputs_outputs(data) + return Sweep(**data) + + @pre_dump + def resolve_inputs_outputs(self, job, **kwargs): + return _resolve_inputs_outputs(job) + + +class ParallelSchema(BaseNodeSchema, ParameterizedParallelSchema): + """ + Schema for Parallel. + """ + + # pylint: disable=unused-argument + compute = ComputeField() + component = TypeSensitiveUnionField( + { + NodeType.PARALLEL: [ + # inline component or component file reference starting with FILE prefix + NestedField(AnonymousParallelComponentSchema, unknown=INCLUDE), + # component file reference + ParallelComponentFileRefField(), + ], + NodeType.FLOW_PARALLEL: [ + NestedField(FlowComponentSchema, unknown=INCLUDE, dump_only=True), + ComponentYamlRefField(), + ], + }, + plain_union_fields=[ + # for registry type assets + RegistryStr(), + # existing component + ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True), + ], + required=True, + ) + identity = UnionField( + [ + NestedField(ManagedIdentitySchema), + NestedField(AMLTokenIdentitySchema), + NestedField(UserIdentitySchema), + ] + ) + type = StringTransformedEnum(allowed_values=[NodeType.PARALLEL]) + + @post_load + def make(self, data, **kwargs) -> "Parallel": + from azure.ai.ml.entities._builders import parse_inputs_outputs + from azure.ai.ml.entities._builders.parallel_func import parallel_run_function + + data = parse_inputs_outputs(data) + parallel_node = parallel_run_function(**data) + return parallel_node + + @pre_dump + def resolve_inputs_outputs(self, job, **kwargs): + return _resolve_inputs_outputs(job) + + +class ImportSchema(BaseNodeSchema): + """ + Schema for Import. + """ + + # pylint: disable=unused-argument + component = TypeSensitiveUnionField( + { + NodeType.IMPORT: [ + # inline component or component file reference starting with FILE prefix + NestedField(AnonymousImportComponentSchema, unknown=INCLUDE), + # component file reference + ImportComponentFileRefField(), + ], + }, + plain_union_fields=[ + # for registry type assets + RegistryStr(), + # existing component + ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True), + ], + required=True, + ) + type = StringTransformedEnum(allowed_values=[NodeType.IMPORT]) + + @post_load + def make(self, data, **kwargs) -> "Import": + from azure.ai.ml.entities._builders import parse_inputs_outputs + from azure.ai.ml.entities._builders.import_func import import_job + + # parse inputs/outputs + data = parse_inputs_outputs(data) + import_node = import_job(**data) + return import_node + + @pre_dump + def resolve_inputs_outputs(self, job, **kwargs): + return _resolve_inputs_outputs(job) + + +class SparkSchema(BaseNodeSchema, ParameterizedSparkSchema): + """ + Schema for Spark. + """ + + # pylint: disable=unused-argument + component = TypeSensitiveUnionField( + { + NodeType.SPARK: [ + # inline component or component file reference starting with FILE prefix + NestedField(AnonymousSparkComponentSchema, unknown=INCLUDE), + # component file reference + SparkComponentFileRefField(), + ], + }, + plain_union_fields=[ + # for registry type assets + RegistryStr(), + # existing component + ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True), + ], + required=True, + ) + type = StringTransformedEnum(allowed_values=[NodeType.SPARK]) + compute = ComputeField() + resources = NestedField(SparkResourceConfigurationForNodeSchema) + entry = UnionField( + [NestedField(SparkEntryFileSchema), NestedField(SparkEntryClassSchema)], + metadata={"description": "Entry."}, + ) + py_files = fields.List(fields.Str()) + jars = fields.List(fields.Str()) + files = fields.List(fields.Str()) + archives = fields.List(fields.Str()) + identity = UnionField( + [ + NestedField(ManagedIdentitySchema), + NestedField(AMLTokenIdentitySchema), + NestedField(UserIdentitySchema), + ] + ) + + # code is directly linked to component.code, so no need to validate or dump it + code = fields.Str(allow_none=True, load_only=True) + + @post_load + def make(self, data, **kwargs) -> "Spark": + from azure.ai.ml.entities._builders import parse_inputs_outputs + from azure.ai.ml.entities._builders.spark_func import spark + + # parse inputs/outputs + data = parse_inputs_outputs(data) + try: + spark_node = spark(**data) + except ValidationException as e: + # It may raise ValidationError during initialization, command._validate_io e.g. raise ValidationError + # instead in marshmallow function, so it won't break SchemaValidatable._schema_validate + raise ValidationError(e.message) from e + return spark_node + + @pre_dump + def resolve_inputs_outputs(self, job, **kwargs): + return _resolve_inputs_outputs(job) + + +class DataTransferCopySchema(BaseNodeSchema): + """ + Schema for DataTransferCopy. + """ + + # pylint: disable=unused-argument + component = TypeSensitiveUnionField( + { + NodeType.DATA_TRANSFER: [ + # inline component or component file reference starting with FILE prefix + NestedField(AnonymousDataTransferCopyComponentSchema, unknown=INCLUDE), + # component file reference + DataTransferCopyComponentFileRefField(), + ], + }, + plain_union_fields=[ + # for registry type assets + RegistryStr(), + # existing component + ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True), + ], + required=True, + ) + task = StringTransformedEnum(allowed_values=[DataTransferTaskType.COPY_DATA], required=True) + type = StringTransformedEnum(allowed_values=[NodeType.DATA_TRANSFER], required=True) + compute = ComputeField() + + @post_load + def make(self, data, **kwargs) -> "DataTransferCopy": + from azure.ai.ml.entities._builders import parse_inputs_outputs + from azure.ai.ml.entities._builders.data_transfer_func import copy_data + + # parse inputs/outputs + data = parse_inputs_outputs(data) + try: + data_transfer_node = copy_data(**data) + except ValidationException as e: + # It may raise ValidationError during initialization, data_transfer._validate_io e.g. raise ValidationError + # instead in marshmallow function, so it won't break SchemaValidatable._schema_validate + raise ValidationError(e.message) from e + return data_transfer_node + + @pre_dump + def resolve_inputs_outputs(self, job, **kwargs): + return _resolve_inputs_outputs(job) + + +class DataTransferImportSchema(BaseNodeSchema): + # pylint: disable=unused-argument + component = UnionField( + [ + # for registry type assets + RegistryStr(), + # existing component + ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True), + ], + required=True, + ) + task = StringTransformedEnum(allowed_values=[DataTransferTaskType.IMPORT_DATA], required=True) + type = StringTransformedEnum(allowed_values=[NodeType.DATA_TRANSFER], required=True) + compute = ComputeField() + source = UnionField([NestedField(DatabaseSchema), NestedField(FileSystemSchema)], required=True, allow_none=False) + outputs = fields.Dict( + keys=fields.Str(), values=UnionField([OutputBindingStr, NestedField(OutputSchema)]), allow_none=False + ) + + @validates("inputs") + def inputs_key(self, value): + raise ValidationError(f"inputs field is not a valid filed in task type " f"{DataTransferTaskType.IMPORT_DATA}.") + + @validates("outputs") + def outputs_key(self, value): + if len(value) != 1 or list(value.keys())[0] != "sink": + raise ValidationError( + f"outputs field only support one output called sink in task type " + f"{DataTransferTaskType.IMPORT_DATA}." + ) + + @post_load + def make(self, data, **kwargs) -> "DataTransferImport": + from azure.ai.ml.entities._builders import parse_inputs_outputs + from azure.ai.ml.entities._builders.data_transfer_func import import_data + + # parse inputs/outputs + data = parse_inputs_outputs(data) + try: + data_transfer_node = import_data(**data) + except ValidationException as e: + # It may raise ValidationError during initialization, data_transfer._validate_io e.g. raise ValidationError + # instead in marshmallow function, so it won't break SchemaValidatable._schema_validate + raise ValidationError(e.message) from e + return data_transfer_node + + @pre_dump + def resolve_inputs_outputs(self, job, **kwargs): + return _resolve_inputs_outputs(job) + + +class DataTransferExportSchema(BaseNodeSchema): + # pylint: disable=unused-argument + component = UnionField( + [ + # for registry type assets + RegistryStr(), + # existing component + ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True), + ], + required=True, + ) + task = StringTransformedEnum(allowed_values=[DataTransferTaskType.EXPORT_DATA]) + type = StringTransformedEnum(allowed_values=[NodeType.DATA_TRANSFER]) + compute = ComputeField() + inputs = InputsField(support_databinding=True, allow_none=False) + sink = UnionField([NestedField(DatabaseSchema), NestedField(FileSystemSchema)], required=True, allow_none=False) + + @validates("inputs") + def inputs_key(self, value): + if len(value) != 1 or list(value.keys())[0] != "source": + raise ValidationError( + f"inputs field only support one input called source in task type " + f"{DataTransferTaskType.EXPORT_DATA}." + ) + + @validates("outputs") + def outputs_key(self, value): + raise ValidationError( + f"outputs field is not a valid filed in task type " f"{DataTransferTaskType.EXPORT_DATA}." + ) + + @post_load + def make(self, data, **kwargs) -> "DataTransferExport": + from azure.ai.ml.entities._builders import parse_inputs_outputs + from azure.ai.ml.entities._builders.data_transfer_func import export_data + + # parse inputs/outputs + data = parse_inputs_outputs(data) + try: + data_transfer_node = export_data(**data) + except ValidationException as e: + # It may raise ValidationError during initialization, data_transfer._validate_io e.g. raise ValidationError + # instead in marshmallow function, so it won't break SchemaValidatable._schema_validate + raise ValidationError(e.message) from e + return data_transfer_node + + @pre_dump + def resolve_inputs_outputs(self, job, **kwargs): + return _resolve_inputs_outputs(job) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/condition_node.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/condition_node.py new file mode 100644 index 00000000..a1d2901c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/condition_node.py @@ -0,0 +1,48 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from marshmallow import fields, post_dump, ValidationError + +from azure.ai.ml._schema import StringTransformedEnum +from azure.ai.ml._schema.core.fields import DataBindingStr, NodeBindingStr, UnionField +from azure.ai.ml._schema.pipeline.control_flow_job import ControlFlowSchema +from azure.ai.ml.constants._component import ControlFlowType + + +# ConditionNodeSchema did not inherit from BaseNodeSchema since it doesn't have inputs/outputs like other nodes. +class ConditionNodeSchema(ControlFlowSchema): + type = StringTransformedEnum(allowed_values=[ControlFlowType.IF_ELSE]) + condition = UnionField([DataBindingStr(), fields.Bool()]) + true_block = UnionField([NodeBindingStr(), fields.List(NodeBindingStr())]) + false_block = UnionField([NodeBindingStr(), fields.List(NodeBindingStr())]) + + @post_dump + def simplify_blocks(self, data, **kwargs): # pylint: disable=unused-argument + # simplify true_block and false_block to single node if there is only one node in the list + # this is to make sure the request to backend won't change after we support list true/false blocks + block_keys = ["true_block", "false_block"] + for block in block_keys: + if isinstance(data.get(block), list) and len(data.get(block)) == 1: + data[block] = data.get(block)[0] + + # validate blocks intersection + def _normalize_blocks(key): + blocks = data.get(key, []) + if blocks: + if not isinstance(blocks, list): + blocks = [blocks] + else: + blocks = [] + return blocks + + true_block = _normalize_blocks("true_block") + false_block = _normalize_blocks("false_block") + + if not true_block and not false_block: + raise ValidationError("True block and false block cannot be empty at the same time.") + + intersection = set(true_block).intersection(set(false_block)) + if intersection: + raise ValidationError(f"True block and false block cannot contain same nodes: {intersection}") + + return data diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/control_flow_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/control_flow_job.py new file mode 100644 index 00000000..3d1e3e4a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/control_flow_job.py @@ -0,0 +1,147 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import copy +import json + +from marshmallow import INCLUDE, fields, pre_dump, pre_load + +from azure.ai.ml._schema.core.fields import DataBindingStr, NestedField, StringTransformedEnum, UnionField +from azure.ai.ml._schema.core.schema import PathAwareSchema +from azure.ai.ml.constants._component import ControlFlowType + +from ..job.input_output_entry import OutputSchema +from ..job.input_output_fields_provider import InputsField +from ..job.job_limits import DoWhileLimitsSchema +from .component_job import _resolve_outputs +from .pipeline_job_io import OutputBindingStr + +# pylint: disable=protected-access + + +class ControlFlowSchema(PathAwareSchema): + unknown = INCLUDE + + +class BaseLoopSchema(ControlFlowSchema): + unknown = INCLUDE + body = DataBindingStr() + + @pre_dump + def convert_control_flow_body_to_binding_str(self, data, **kwargs): # pylint: disable= unused-argument + result = copy.copy(data) + # Update body object to data_binding_str + result._body = data._get_body_binding_str() + return result + + +class DoWhileSchema(BaseLoopSchema): + # pylint: disable=unused-argument + type = StringTransformedEnum(allowed_values=[ControlFlowType.DO_WHILE]) + condition = UnionField( + [ + DataBindingStr(), + fields.Str(), + ] + ) + mapping = fields.Dict( + keys=fields.Str(), + values=UnionField( + [ + fields.List(fields.Str()), + fields.Str(), + ] + ), + required=True, + ) + limits = NestedField(DoWhileLimitsSchema, required=True) + + @pre_dump + def resolve_inputs_outputs(self, data, **kwargs): + # Try resolve object's mapping and condition and return a resolved new object + result = copy.copy(data) + mapping = {} + for k, v in result.mapping.items(): + v = v if isinstance(v, list) else [v] + mapping[k] = [item._port_name for item in v] + result._mapping = mapping + + try: + result._condition = result._condition._port_name + except AttributeError: + result._condition = result._condition + + return result + + @pre_dump + def convert_control_flow_body_to_binding_str(self, data, **kwargs): + return super(DoWhileSchema, self).convert_control_flow_body_to_binding_str(data, **kwargs) + + +class ParallelForSchema(BaseLoopSchema): + type = StringTransformedEnum(allowed_values=[ControlFlowType.PARALLEL_FOR]) + items = UnionField( + [ + fields.Dict(keys=fields.Str(), values=InputsField()), + fields.List(InputsField()), + # put str in last to make sure other type items won't become string when dumps. + # TODO: only support binding here + fields.Str(), + ], + required=True, + ) + max_concurrency = fields.Int() + outputs = fields.Dict( + keys=fields.Str(), + values=UnionField([OutputBindingStr, NestedField(OutputSchema)], allow_none=True), + ) + + @pre_load + def load_items(self, data, **kwargs): # pylint: disable= unused-argument + # load items from json to convert the assets in it to rest + try: + items = data["items"] + if isinstance(items, str): + items = json.loads(items) + data["items"] = items + except Exception: # pylint: disable=W0718 + pass + return data + + @pre_dump + def convert_control_flow_body_to_binding_str(self, data, **kwargs): + return super(ParallelForSchema, self).convert_control_flow_body_to_binding_str(data, **kwargs) + + @pre_dump + def resolve_outputs(self, job, **kwargs): # pylint: disable=unused-argument + result = copy.copy(job) + _resolve_outputs(result, job) + return result + + @pre_dump + def serialize_items(self, data, **kwargs): # pylint: disable= unused-argument + # serialize items to json string to avoid being removed by _dump_for_validation + from azure.ai.ml.entities._job.pipeline._io import InputOutputBase + + def _binding_handler(obj): + if isinstance(obj, InputOutputBase): + return str(obj) + return repr(obj) + + result = copy.copy(data) + if isinstance(result.items, (dict, list)): + # use str to serialize input/output builder + result._items = json.dumps(result.items, default=_binding_handler) + return result + + +class FLScatterGatherSchema(ControlFlowSchema): + # TODO determine serialization, or if this is actually needed + + # @pre_dump + def serialize_items(self, data, **kwargs): + pass + + # @pre_dump + def resolve_outputs(self, job, **kwargs): + pass diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_command_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_command_job.py new file mode 100644 index 00000000..c2b96f85 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_command_job.py @@ -0,0 +1,31 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging +from typing import Any + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.fields import ComputeField, EnvironmentField, NestedField, UnionField +from azure.ai.ml._schema.job.command_job import CommandJobSchema +from azure.ai.ml._schema.job.input_output_entry import OutputSchema + +module_logger = logging.getLogger(__name__) + + +class PipelineCommandJobSchema(CommandJobSchema): + compute = ComputeField() + environment = EnvironmentField() + outputs = fields.Dict( + keys=fields.Str(), + values=UnionField([NestedField(OutputSchema), fields.Str()], allow_none=True), + ) + + @post_load + def make(self, data: Any, **kwargs: Any): + from azure.ai.ml.entities import CommandJob + + return CommandJob(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_component.py new file mode 100644 index 00000000..05096e99 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_component.py @@ -0,0 +1,297 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access +from copy import deepcopy + +import yaml +from marshmallow import INCLUDE, fields, post_load, pre_dump + +from azure.ai.ml._schema._utils.utils import _resolve_group_inputs_for_component +from azure.ai.ml._schema.assets.asset import AnonymousAssetSchema +from azure.ai.ml._schema.component.component import ComponentSchema +from azure.ai.ml._schema.component.input_output import OutputPortSchema, PrimitiveOutputSchema +from azure.ai.ml._schema.core.fields import ( + ArmVersionedStr, + FileRefField, + NestedField, + PipelineNodeNameStr, + RegistryStr, + StringTransformedEnum, + TypeSensitiveUnionField, + UnionField, +) +from azure.ai.ml._schema.pipeline.automl_node import AutoMLNodeSchema +from azure.ai.ml._schema.pipeline.component_job import ( + BaseNodeSchema, + CommandSchema, + DataTransferCopySchema, + DataTransferExportSchema, + DataTransferImportSchema, + ImportSchema, + ParallelSchema, + SparkSchema, + SweepSchema, + _resolve_inputs_outputs, +) +from azure.ai.ml._schema.pipeline.condition_node import ConditionNodeSchema +from azure.ai.ml._schema.pipeline.control_flow_job import DoWhileSchema, ParallelForSchema +from azure.ai.ml._schema.pipeline.pipeline_command_job import PipelineCommandJobSchema +from azure.ai.ml._schema.pipeline.pipeline_datatransfer_job import ( + PipelineDataTransferCopyJobSchema, + PipelineDataTransferExportJobSchema, + PipelineDataTransferImportJobSchema, +) +from azure.ai.ml._schema.pipeline.pipeline_import_job import PipelineImportJobSchema +from azure.ai.ml._schema.pipeline.pipeline_parallel_job import PipelineParallelJobSchema +from azure.ai.ml._schema.pipeline.pipeline_spark_job import PipelineSparkJobSchema +from azure.ai.ml._utils.utils import is_private_preview_enabled +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, AzureMLResourceType +from azure.ai.ml.constants._component import ( + CONTROL_FLOW_TYPES, + ComponentSource, + ControlFlowType, + DataTransferTaskType, + NodeType, +) + + +class NodeNameStr(PipelineNodeNameStr): + def _get_field_name(self) -> str: + return "Pipeline node" + + +def PipelineJobsField(): + pipeline_enable_job_type = { + NodeType.COMMAND: [ + NestedField(CommandSchema, unknown=INCLUDE), + NestedField(PipelineCommandJobSchema), + ], + NodeType.IMPORT: [ + NestedField(ImportSchema, unknown=INCLUDE), + NestedField(PipelineImportJobSchema), + ], + NodeType.SWEEP: [NestedField(SweepSchema, unknown=INCLUDE)], + NodeType.PARALLEL: [ + # ParallelSchema support parallel pipeline yml with "component" + NestedField(ParallelSchema, unknown=INCLUDE), + NestedField(PipelineParallelJobSchema, unknown=INCLUDE), + ], + NodeType.PIPELINE: [NestedField("PipelineSchema", unknown=INCLUDE)], + NodeType.AUTOML: AutoMLNodeSchema(unknown=INCLUDE), + NodeType.SPARK: [ + NestedField(SparkSchema, unknown=INCLUDE), + NestedField(PipelineSparkJobSchema), + ], + } + + # Note: the private node types only available when private preview flag opened before init of pipeline job + # schema class. + if is_private_preview_enabled(): + pipeline_enable_job_type[ControlFlowType.DO_WHILE] = [NestedField(DoWhileSchema, unknown=INCLUDE)] + pipeline_enable_job_type[ControlFlowType.IF_ELSE] = [NestedField(ConditionNodeSchema, unknown=INCLUDE)] + pipeline_enable_job_type[ControlFlowType.PARALLEL_FOR] = [NestedField(ParallelForSchema, unknown=INCLUDE)] + + # Todo: Put data_transfer logic to the last to avoid error message conflict, open a item to track: + # https://msdata.visualstudio.com/Vienna/_workitems/edit/2244262/ + pipeline_enable_job_type[NodeType.DATA_TRANSFER] = [ + TypeSensitiveUnionField( + { + DataTransferTaskType.COPY_DATA: [ + NestedField(DataTransferCopySchema, unknown=INCLUDE), + NestedField(PipelineDataTransferCopyJobSchema), + ], + DataTransferTaskType.IMPORT_DATA: [ + NestedField(DataTransferImportSchema, unknown=INCLUDE), + NestedField(PipelineDataTransferImportJobSchema), + ], + DataTransferTaskType.EXPORT_DATA: [ + NestedField(DataTransferExportSchema, unknown=INCLUDE), + NestedField(PipelineDataTransferExportJobSchema), + ], + }, + type_field_name="task", + unknown=INCLUDE, + ) + ] + + pipeline_job_field = fields.Dict( + keys=NodeNameStr(), + values=TypeSensitiveUnionField(pipeline_enable_job_type), + ) + return pipeline_job_field + + +# pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype +def _post_load_pipeline_jobs(context, data: dict) -> dict: + """Silently convert Job in pipeline jobs to node.""" + from azure.ai.ml.entities._builders import parse_inputs_outputs + from azure.ai.ml.entities._builders.condition_node import ConditionNode + from azure.ai.ml.entities._builders.do_while import DoWhile + from azure.ai.ml.entities._builders.parallel_for import ParallelFor + from azure.ai.ml.entities._job.automl.automl_job import AutoMLJob + from azure.ai.ml.entities._job.pipeline._component_translatable import ComponentTranslatableMixin + + # parse inputs/outputs + data = parse_inputs_outputs(data) + # convert JobNode to Component here + jobs = data.get("jobs", {}) + + for key, job_instance in jobs.items(): + if isinstance(job_instance, dict): + # convert AutoML job dict to instance + if job_instance.get("type") == NodeType.AUTOML: + job_instance = AutoMLJob._create_instance_from_schema_dict( + loaded_data=job_instance, + ) + elif job_instance.get("type") in CONTROL_FLOW_TYPES: + # Set source to yaml job for control flow node. + job_instance["_source"] = ComponentSource.YAML_JOB + + job_type = job_instance.get("type") + if job_type == ControlFlowType.IF_ELSE: + # Convert to if-else node. + job_instance = ConditionNode._create_instance_from_schema_dict(loaded_data=job_instance) + elif job_instance.get("type") == ControlFlowType.DO_WHILE: + # Convert to do-while node. + job_instance = DoWhile._create_instance_from_schema_dict( + pipeline_jobs=jobs, loaded_data=job_instance + ) + elif job_instance.get("type") == ControlFlowType.PARALLEL_FOR: + # Convert to do-while node. + job_instance = ParallelFor._create_instance_from_schema_dict( + pipeline_jobs=jobs, loaded_data=job_instance + ) + jobs[key] = job_instance + + for key, job_instance in jobs.items(): + # Translate job to node if translatable and overrides to_node. + if isinstance(job_instance, ComponentTranslatableMixin) and "_to_node" in type(job_instance).__dict__: + # set source as YAML + job_instance = job_instance._to_node( + context=context, + pipeline_job_dict=data, + ) + if job_instance.type == NodeType.DATA_TRANSFER and job_instance.task != DataTransferTaskType.COPY_DATA: + job_instance._source = ComponentSource.BUILTIN + else: + job_instance.component._source = ComponentSource.YAML_JOB + job_instance._source = job_instance.component._source + jobs[key] = job_instance + # update job instance name to key + job_instance.name = key + return data + + +class PipelineComponentSchema(ComponentSchema): + type = StringTransformedEnum(allowed_values=[NodeType.PIPELINE]) + jobs = PipelineJobsField() + + # primitive output is only supported for command component & pipeline component + outputs = fields.Dict( + keys=fields.Str(), + values=UnionField( + [ + NestedField(PrimitiveOutputSchema, unknown=INCLUDE), + NestedField(OutputPortSchema), + ] + ), + ) + + @post_load + def make(self, data, **kwargs): # pylint: disable=unused-argument + return _post_load_pipeline_jobs(self.context, data) + + +class RestPipelineComponentSchema(PipelineComponentSchema): + """When component load from rest, won't validate on name since there might + be existing component with invalid name.""" + + name = fields.Str(required=True) + + +class _AnonymousPipelineComponentSchema(AnonymousAssetSchema, PipelineComponentSchema): + """Anonymous pipeline component schema. + + Note that do not support inline define anonymous pipeline component + directly. Inheritance follows order: AnonymousAssetSchema, + PipelineComponentSchema because we need name and version to be + dump_only(marshmallow collects fields follows method resolution + order). + """ + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._component.pipeline_component import PipelineComponent + + # pipeline jobs post process is required before init of pipeline component: it converts control node dict + # to entity. + # however @post_load invocation order is not guaranteed, so we need to call it explicitly here. + _post_load_pipeline_jobs(self.context, data) + + return PipelineComponent( + base_path=self.context[BASE_PATH_CONTEXT_KEY], + **data, + ) + + +class PipelineComponentFileRefField(FileRefField): + # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype + def _serialize(self, value, attr, obj, **kwargs): + """FileRefField does not support serialize. + + Call AnonymousPipelineComponent schema to serialize. This + function is overwrite because we need Pipeline can be dumped. + """ + # Update base_path to parent path of component file. + component_schema_context = deepcopy(self.context) + value = _resolve_group_inputs_for_component(value) + return _AnonymousPipelineComponentSchema(context=component_schema_context)._serialize(value, **kwargs) + + def _deserialize(self, value, attr, data, **kwargs): + # Get component info from component yaml file. + data = super()._deserialize(value, attr, data, **kwargs) + component_dict = yaml.safe_load(data) + source_path = self.context[BASE_PATH_CONTEXT_KEY] / value + + # Update base_path to parent path of component file. + component_schema_context = deepcopy(self.context) + component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent + component = _AnonymousPipelineComponentSchema(context=component_schema_context).load( + component_dict, unknown=INCLUDE + ) + component._source_path = source_path + component._source = ComponentSource.YAML_COMPONENT + return component + + +# Note: PipelineSchema is defined here instead of component_job.py is to +# resolve circular import and support recursive schema. +class PipelineSchema(BaseNodeSchema): + # pylint: disable=unused-argument + # do not support inline define a pipeline node + component = UnionField( + [ + # for registry type assets + RegistryStr(azureml_type=AzureMLResourceType.COMPONENT), + # existing component + ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True), + # component file reference + PipelineComponentFileRefField(), + ], + required=True, + ) + type = StringTransformedEnum(allowed_values=[NodeType.PIPELINE]) + + @post_load + def make(self, data, **kwargs) -> "Pipeline": + from azure.ai.ml.entities._builders import parse_inputs_outputs + from azure.ai.ml.entities._builders.pipeline import Pipeline + + data = parse_inputs_outputs(data) + return Pipeline(**data) + + @pre_dump + def resolve_inputs_outputs(self, data, **kwargs): + return _resolve_inputs_outputs(data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_datatransfer_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_datatransfer_job.py new file mode 100644 index 00000000..a63e687d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_datatransfer_job.py @@ -0,0 +1,55 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging +from typing import Any + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.fields import NestedField, UnionField +from azure.ai.ml._schema.job.input_output_entry import OutputSchema +from azure.ai.ml._schema.pipeline.pipeline_job_io import OutputBindingStr +from azure.ai.ml._schema.job.data_transfer_job import ( + DataTransferCopyJobSchema, + DataTransferImportJobSchema, + DataTransferExportJobSchema, +) + +module_logger = logging.getLogger(__name__) + + +class PipelineDataTransferCopyJobSchema(DataTransferCopyJobSchema): + outputs = fields.Dict( + keys=fields.Str(), + values=UnionField([NestedField(OutputSchema), OutputBindingStr], allow_none=True), + ) + + @post_load + def make(self, data: Any, **kwargs: Any): + from azure.ai.ml.entities._job.data_transfer.data_transfer_job import DataTransferCopyJob + + return DataTransferCopyJob(**data) + + +class PipelineDataTransferImportJobSchema(DataTransferImportJobSchema): + outputs = fields.Dict( + keys=fields.Str(), + values=UnionField([NestedField(OutputSchema), OutputBindingStr], allow_none=True), + ) + + @post_load + def make(self, data: Any, **kwargs: Any): + from azure.ai.ml.entities._job.data_transfer.data_transfer_job import DataTransferImportJob + + return DataTransferImportJob(**data) + + +class PipelineDataTransferExportJobSchema(DataTransferExportJobSchema): + @post_load + def make(self, data: Any, **kwargs: Any): + from azure.ai.ml.entities._job.data_transfer.data_transfer_job import DataTransferExportJob + + return DataTransferExportJob(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_import_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_import_job.py new file mode 100644 index 00000000..ae338597 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_import_job.py @@ -0,0 +1,25 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging +from typing import Any + +from marshmallow import post_load + +from azure.ai.ml._schema.job.import_job import ImportJobSchema + +module_logger = logging.getLogger(__name__) + + +class PipelineImportJobSchema(ImportJobSchema): + class Meta: + exclude = ["compute"] # compute property not applicable to import job + + @post_load + def make(self, data: Any, **kwargs: Any): + from azure.ai.ml.entities._job.import_job import ImportJob + + return ImportJob(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_job.py new file mode 100644 index 00000000..46daeb92 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_job.py @@ -0,0 +1,76 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging + +from marshmallow import INCLUDE, ValidationError, post_load, pre_dump, pre_load + +from azure.ai.ml._schema.core.fields import ( + ArmVersionedStr, + ComputeField, + NestedField, + RegistryStr, + StringTransformedEnum, + UnionField, +) +from azure.ai.ml._schema.job import BaseJobSchema +from azure.ai.ml._schema.job.input_output_fields_provider import InputsField, OutputsField +from azure.ai.ml._schema.pipeline.component_job import _resolve_inputs_outputs +from azure.ai.ml._schema.pipeline.pipeline_component import ( + PipelineComponentFileRefField, + PipelineJobsField, + _post_load_pipeline_jobs, +) +from azure.ai.ml._schema.pipeline.settings import PipelineJobSettingsSchema +from azure.ai.ml.constants import JobType +from azure.ai.ml.constants._common import AzureMLResourceType + +module_logger = logging.getLogger(__name__) + + +class PipelineJobSchema(BaseJobSchema): + type = StringTransformedEnum(allowed_values=[JobType.PIPELINE]) + compute = ComputeField() + settings = NestedField(PipelineJobSettingsSchema, unknown=INCLUDE) + # Support databinding in inputs as we support macro like ${{name}} + inputs = InputsField(support_databinding=True) + outputs = OutputsField() + jobs = PipelineJobsField() + component = UnionField( + [ + # for registry type assets + RegistryStr(azureml_type=AzureMLResourceType.COMPONENT), + # existing component + ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True), + # component file reference + PipelineComponentFileRefField(), + ], + ) + + @pre_dump() + def backup_jobs_and_remove_component(self, job, **kwargs): + # pylint: disable=protected-access + job_copy = _resolve_inputs_outputs(job) + if not isinstance(job_copy.component, str): + # If component is pipeline component object, + # copy jobs to job and remove component. + if not job_copy._jobs: + job_copy._jobs = job_copy.component.jobs + job_copy.component = None + return job_copy + + @pre_load() + def check_exclusive_fields(self, data: dict, **kwargs) -> dict: + error_msg = "'jobs' and 'component' are mutually exclusive fields in pipeline job." + # When loading from yaml, data["component"] must be a local path (str) + # Otherwise, data["component"] can be a PipelineComponent so data["jobs"] must exist + if isinstance(data.get("component"), str) and data.get("jobs"): + raise ValidationError(error_msg) + return data + + @post_load + def make(self, data: dict, **kwargs) -> dict: + return _post_load_pipeline_jobs(self.context, data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_job_io.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_job_io.py new file mode 100644 index 00000000..3fb6a7b7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_job_io.py @@ -0,0 +1,51 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import logging +import re + +from marshmallow import ValidationError, fields + +from azure.ai.ml.constants._component import ComponentJobConstants + +module_logger = logging.getLogger(__name__) + + +class OutputBindingStr(fields.Field): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def _jsonschema_type_mapping(self): + schema = {"type": "string", "pattern": ComponentJobConstants.OUTPUT_PATTERN} + if self.name is not None: + schema["title"] = self.name + if self.dump_only: + schema["readonly"] = True + return schema + + def _serialize(self, value, attr, obj, **kwargs): + if isinstance(value, str) and re.match(ComponentJobConstants.OUTPUT_PATTERN, value): + return value + # _to_job_output in io.py will return Output + # add this branch to judge whether original value is a simple binding or Output + if ( + isinstance(value.path, str) + and re.match(ComponentJobConstants.OUTPUT_PATTERN, value.path) + and value.mode is None + ): + return value.path + raise ValidationError(f"Invalid output binding string '{value}' passed") + + def _deserialize(self, value, attr, data, **kwargs): + if ( + isinstance(value, dict) + and "path" in value + and "mode" not in value + and "name" not in value + and "version" not in value + ): + value = value["path"] + if isinstance(value, str) and re.match(ComponentJobConstants.OUTPUT_PATTERN, value): + return value + raise ValidationError(f"Invalid output binding string '{value}' passed") diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_parallel_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_parallel_job.py new file mode 100644 index 00000000..3b30fb66 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_parallel_job.py @@ -0,0 +1,40 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging +from typing import Any + +from marshmallow import post_load + +from azure.ai.ml._schema.core.fields import ComputeField, EnvironmentField, StringTransformedEnum +from azure.ai.ml._schema.job import ParameterizedParallelSchema +from azure.ai.ml._schema.pipeline.component_job import BaseNodeSchema + +from ...constants._component import NodeType + +module_logger = logging.getLogger(__name__) + + +# parallel job inherits parallel attributes from ParameterizedParallelSchema and node functionality from BaseNodeSchema +class PipelineParallelJobSchema(BaseNodeSchema, ParameterizedParallelSchema): + """Schema for ParallelJob in PipelineJob/PipelineComponent.""" + + type = StringTransformedEnum(allowed_values=NodeType.PARALLEL) + compute = ComputeField() + environment = EnvironmentField() + + @post_load + def make(self, data: Any, **kwargs: Any): + """Construct a ParallelJob from deserialized data. + + :param data: The deserialized data. + :type data: dict[str, Any] + :return: A ParallelJob. + :rtype: azure.ai.ml.entities._job.parallel.ParallelJob + """ + from azure.ai.ml.entities._job.parallel.parallel_job import ParallelJob + + return ParallelJob(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_spark_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_spark_job.py new file mode 100644 index 00000000..69d58255 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_spark_job.py @@ -0,0 +1,29 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging +from typing import Any + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.fields import NestedField, UnionField +from azure.ai.ml._schema.job.input_output_entry import OutputSchema +from azure.ai.ml._schema.job.spark_job import SparkJobSchema + +module_logger = logging.getLogger(__name__) + + +class PipelineSparkJobSchema(SparkJobSchema): + outputs = fields.Dict( + keys=fields.Str(), + values=UnionField([NestedField(OutputSchema), fields.Str()], allow_none=True), + ) + + @post_load + def make(self, data: Any, **kwargs: Any): + from azure.ai.ml.entities._job.spark_job import SparkJob + + return SparkJob(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/settings.py new file mode 100644 index 00000000..1e5227b0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/settings.py @@ -0,0 +1,42 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import INCLUDE, Schema, fields, post_dump, post_load + +from azure.ai.ml._schema.core.fields import ArmStr, StringTransformedEnum, UnionField +from azure.ai.ml._schema.pipeline.pipeline_component import NodeNameStr +from azure.ai.ml._utils.utils import is_private_preview_enabled +from azure.ai.ml.constants._common import AzureMLResourceType, SERVERLESS_COMPUTE + + +class PipelineJobSettingsSchema(Schema): + class Meta: + unknown = INCLUDE + + default_datastore = ArmStr(azureml_type=AzureMLResourceType.DATASTORE) + default_compute = UnionField( + [ + StringTransformedEnum(allowed_values=[SERVERLESS_COMPUTE]), + ArmStr(azureml_type=AzureMLResourceType.COMPUTE), + ] + ) + continue_on_step_failure = fields.Bool() + force_rerun = fields.Bool() + + # move init/finalize under private preview flag to hide them in spec + if is_private_preview_enabled(): + on_init = NodeNameStr() + on_finalize = NodeNameStr() + + @post_load + def make(self, data, **kwargs) -> "PipelineJobSettings": + from azure.ai.ml.entities import PipelineJobSettings + + return PipelineJobSettings(**data) + + @post_dump + def remove_none(self, data, **kwargs): + return {key: value for key, value in data.items() if value is not None} diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/queue_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/queue_settings.py new file mode 100644 index 00000000..3196a00c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/queue_settings.py @@ -0,0 +1,23 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import post_load +from azure.ai.ml.constants._job.job import JobPriorityValues, JobTierNames +from azure.ai.ml._schema.core.fields import StringTransformedEnum +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + + +class QueueSettingsSchema(metaclass=PatchedSchemaMeta): + job_tier = StringTransformedEnum( + allowed_values=JobTierNames.ALLOWED_NAMES, + ) + priority = StringTransformedEnum( + allowed_values=JobPriorityValues.ALLOWED_VALUES, + ) + + @post_load + def make(self, data, **kwargs): # pylint: disable=unused-argument + from azure.ai.ml.entities import QueueSettings + + return QueueSettings(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/__init__.py new file mode 100644 index 00000000..9c2fe189 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/__init__.py @@ -0,0 +1,9 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore + +from .registry import RegistrySchema + +__all__ = ["RegistrySchema"] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/registry.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/registry.py new file mode 100644 index 00000000..17233195 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/registry.py @@ -0,0 +1,53 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields + +from azure.ai.ml._schema.core.fields import DumpableStringField, NestedField, StringTransformedEnum, UnionField +from azure.ai.ml._schema.core.intellectual_property import PublisherSchema +from azure.ai.ml._schema.core.resource import ResourceSchema +from azure.ai.ml._schema.workspace.identity import IdentitySchema +from azure.ai.ml._utils.utils import snake_to_pascal +from azure.ai.ml.constants._common import PublicNetworkAccess +from azure.ai.ml.constants._registry import AcrAccountSku +from azure.ai.ml.entities._registry.registry_support_classes import SystemCreatedAcrAccount + +from .registry_region_arm_details import RegistryRegionDetailsSchema +from .system_created_acr_account import SystemCreatedAcrAccountSchema +from .util import acr_format_validator + + +# Based on 10-01-preview api +class RegistrySchema(ResourceSchema): + # Inherits name, id, tags, and description fields from ResourceSchema + + # Values from RegistryTrackedResource (Client name: Registry) + location = fields.Str(required=True) + + # Values from Registry (Client name: RegistryProperties) + public_network_access = StringTransformedEnum( + allowed_values=[PublicNetworkAccess.DISABLED, PublicNetworkAccess.ENABLED], + casing_transform=snake_to_pascal, + ) + replication_locations = fields.List(NestedField(RegistryRegionDetailsSchema)) + intellectual_property = NestedField(PublisherSchema) + # This is an acr account which will be applied to every registryRegionArmDetail defined + # in replication_locations. This is different from the internal swagger + # definition, which has a per-region list of acr accounts. + # Per-region acr account configuration is NOT possible through yaml configs for now. + container_registry = UnionField( + [DumpableStringField(validate=acr_format_validator), NestedField(SystemCreatedAcrAccountSchema)], + required=False, + is_strict=True, + load_default=SystemCreatedAcrAccount(acr_account_sku=AcrAccountSku.PREMIUM), + ) + + # Values that can only be set by return values from the system, never + # set by the user. + identity = NestedField(IdentitySchema, dump_only=True) + kind = fields.Str(dump_only=True) + sku = fields.Str(dump_only=True) + managed_resource_group = fields.Str(dump_only=True) + mlflow_registry_uri = fields.Str(dump_only=True) + discovery_url = fields.Str(dump_only=True) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/registry_region_arm_details.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/registry_region_arm_details.py new file mode 100644 index 00000000..c861b94c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/registry_region_arm_details.py @@ -0,0 +1,61 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import ValidationError, fields, post_load, pre_dump + +from azure.ai.ml._schema.core.fields import DumpableStringField, NestedField, UnionField +from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.constants._registry import StorageAccountType +from azure.ai.ml.entities._registry.registry_support_classes import SystemCreatedStorageAccount + +from .system_created_storage_account import SystemCreatedStorageAccountSchema +from .util import storage_account_validator + + +# Differs from the swagger def in that the acr_details can only be supplied as a +# single registry-wide instance, rather than a per-region list. +@experimental +class RegistryRegionDetailsSchema(metaclass=PatchedSchemaMeta): + # Commenting this out for the time being. + # We do not want to surface the acr_config as a per-region configurable + # field. Instead we want to simplify the UX and surface it as a non-list, + # top-level value called 'container_registry'. + # We don't even want to show the per-region acr accounts when displaying a + # registry to the user, so this isn't even left as a dump-only field. + """acr_config = fields.List( + UnionField( + [DumpableStringField(validate=acr_format_validator), NestedField(SystemCreatedAcrAccountSchema)], + dump_only=True, + is_strict=True, + ) + )""" + location = fields.Str() + storage_config = UnionField( + [ + NestedField(SystemCreatedStorageAccountSchema), + fields.List(DumpableStringField(validate=storage_account_validator)), + ], + is_strict=True, + load_default=SystemCreatedStorageAccount( + storage_account_hns=False, storage_account_type=StorageAccountType.STANDARD_LRS + ), + ) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import RegistryRegionDetails + + data.pop("type", None) + return RegistryRegionDetails(**data) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.entities import RegistryRegionDetails + + if not isinstance(data, RegistryRegionDetails): + raise ValidationError("Cannot dump non-RegistryRegionDetails object into RegistryRegionDetailsSchema") + return data diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/system_created_acr_account.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/system_created_acr_account.py new file mode 100644 index 00000000..08b78c2e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/system_created_acr_account.py @@ -0,0 +1,35 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import ValidationError, fields, post_load, pre_dump + +from azure.ai.ml._schema import StringTransformedEnum +from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.constants._registry import AcrAccountSku + + +@experimental +class SystemCreatedAcrAccountSchema(metaclass=PatchedSchemaMeta): + arm_resource_id = fields.Str(dump_only=True) + acr_account_sku = StringTransformedEnum( + allowed_values=[sku.value for sku in AcrAccountSku], casing_transform=lambda x: x.lower() + ) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import SystemCreatedAcrAccount + + data.pop("type", None) + return SystemCreatedAcrAccount(**data) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.entities import SystemCreatedAcrAccount + + if not isinstance(data, SystemCreatedAcrAccount): + raise ValidationError("Cannot dump non-SystemCreatedAcrAccount object into SystemCreatedAcrAccountSchema") + return data diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/system_created_storage_account.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/system_created_storage_account.py new file mode 100644 index 00000000..cdbbcd67 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/system_created_storage_account.py @@ -0,0 +1,40 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import ValidationError, fields, post_load, pre_dump + +from azure.ai.ml._schema import StringTransformedEnum +from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta +from azure.ai.ml.constants._registry import StorageAccountType + + +class SystemCreatedStorageAccountSchema(metaclass=PatchedSchemaMeta): + arm_resource_id = fields.Str(dump_only=True) + storage_account_hns = fields.Bool(load_default=False) + storage_account_type = StringTransformedEnum( + load_default=StorageAccountType.STANDARD_LRS, + allowed_values=[accountType.value for accountType in StorageAccountType], + casing_transform=lambda x: x.lower(), + ) + replication_count = fields.Int(load_default=1, validate=lambda count: count > 0) + replicated_ids = fields.List(fields.Str(), dump_only=True) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import SystemCreatedStorageAccount + + data.pop("type", None) + return SystemCreatedStorageAccount(**data) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.entities import SystemCreatedStorageAccount + + if not isinstance(data, SystemCreatedStorageAccount): + raise ValidationError( + "Cannot dump non-SystemCreatedStorageAccount object into SystemCreatedStorageAccountSchema" + ) + return data diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/util.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/util.py new file mode 100644 index 00000000..19c01e9a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/util.py @@ -0,0 +1,15 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# Simple helper methods to avoid re-using lambda's everywhere + +from azure.ai.ml.constants._registry import ACR_ACCOUNT_FORMAT, STORAGE_ACCOUNT_FORMAT + + +def storage_account_validator(storage_id: str): + return STORAGE_ACCOUNT_FORMAT.match(storage_id) is not None + + +def acr_format_validator(acr_id: str): + return ACR_ACCOUNT_FORMAT.match(acr_id) is not None diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/resource_configuration.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/resource_configuration.py new file mode 100644 index 00000000..fece59a2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/resource_configuration.py @@ -0,0 +1,21 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta + + +class ResourceConfigurationSchema(metaclass=PatchedSchemaMeta): + instance_count = fields.Int() + instance_type = fields.Str(metadata={"description": "The instance type to make available to this job."}) + properties = fields.Dict(keys=fields.Str()) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import ResourceConfiguration + + return ResourceConfiguration(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/create_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/create_job.py new file mode 100644 index 00000000..084f8a5b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/create_job.py @@ -0,0 +1,144 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint: disable=protected-access +import copy +from typing import Optional + +import yaml +from marshmallow import INCLUDE, ValidationError, fields, post_load, pre_load + +from azure.ai.ml._schema import CommandJobSchema +from azure.ai.ml._schema.core.fields import ( + ArmStr, + ComputeField, + EnvironmentField, + FileRefField, + NestedField, + StringTransformedEnum, + UnionField, +) +from azure.ai.ml._schema.job import BaseJobSchema +from azure.ai.ml._schema.job.input_output_fields_provider import InputsField, OutputsField +from azure.ai.ml._schema.pipeline.settings import PipelineJobSettingsSchema +from azure.ai.ml._utils.utils import load_file, merge_dict +from azure.ai.ml.constants import JobType +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, AzureMLResourceType + +_SCHEDULED_JOB_UPDATES_KEY = "scheduled_job_updates" + + +class CreateJobFileRefField(FileRefField): + # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype + def _serialize(self, value, attr, obj, **kwargs): + """FileRefField does not support serialize. + + This function is overwrite because we need job can be dumped inside schedule. + """ + from azure.ai.ml.entities._builders import BaseNode + + if isinstance(value, BaseNode): + # Dump as Job to avoid missing field. + value = value._to_job() + return value._to_dict() + + def _deserialize(self, value, attr, data, **kwargs) -> "Job": + # Get component info from component yaml file. + data = super()._deserialize(value, attr, data, **kwargs) + job_dict = yaml.safe_load(data) + + from azure.ai.ml.entities import Job + + return Job._load( + data=job_dict, + yaml_path=self.context[BASE_PATH_CONTEXT_KEY] / value, + **kwargs, + ) + + +class BaseCreateJobSchema(BaseJobSchema): + compute = ComputeField() + job = UnionField( + [ + ArmStr(azureml_type=AzureMLResourceType.JOB), + CreateJobFileRefField, + ], + required=True, + ) + + # pylint: disable-next=docstring-missing-param + def _get_job_instance_for_remote_job(self, id: Optional[str], data: Optional[dict], **kwargs) -> "Job": + """Get a job instance to store updates for remote job. + + :return: The remote job + :rtype: Job + """ + from azure.ai.ml.entities import Job + + data = {} if data is None else data + if "type" not in data: + raise ValidationError("'type' must be specified when scheduling a remote job with updates.") + # Create a job instance if job is arm id + job_instance = Job._load( + data=data, + **kwargs, + ) + # Set back the id and base path to created job + job_instance._id = id + job_instance._base_path = self.context[BASE_PATH_CONTEXT_KEY] + return job_instance + + @pre_load + def pre_load(self, data, **kwargs): # pylint: disable=unused-argument + if isinstance(data, dict): + # Put the raw replicas into context. + # dict type indicates there are updates to the scheduled job. + copied_data = copy.deepcopy(data) + copied_data.pop("job", None) + self.context[_SCHEDULED_JOB_UPDATES_KEY] = copied_data + return data + + @post_load + def make(self, data: dict, **kwargs) -> "Job": + from azure.ai.ml.entities import Job + + # Get the loaded job + job = data.pop("job") + # Get the raw dict data before load + raw_data = self.context.get(_SCHEDULED_JOB_UPDATES_KEY, {}) + if isinstance(job, Job): + if job._source_path is None: + raise ValidationError("Could not load job for schedule without '_source_path' set.") + # Load local job again with updated values + job_dict = yaml.safe_load(load_file(job._source_path)) + return Job._load( + data=merge_dict(job_dict, raw_data), + yaml_path=job._source_path, + **kwargs, + ) + # Create a job instance for remote job + return self._get_job_instance_for_remote_job(job, raw_data, **kwargs) + + +class PipelineCreateJobSchema(BaseCreateJobSchema): + # Note: Here we do not inherit PipelineJobSchema, as we don't need the post_load, pre_load inside. + type = StringTransformedEnum(allowed_values=[JobType.PIPELINE]) + inputs = InputsField() + outputs = OutputsField() + settings = NestedField(PipelineJobSettingsSchema, unknown=INCLUDE) + + +class CommandCreateJobSchema(BaseCreateJobSchema, CommandJobSchema): + class Meta: + # Refer to https://github.com/Azure/azureml_run_specification/blob/master + # /specs/job-endpoint.md#properties-in-difference-job-types + # code and command can not be set during runtime + exclude = ["code", "command"] + + environment = EnvironmentField() + + +class SparkCreateJobSchema(BaseCreateJobSchema): + type = StringTransformedEnum(allowed_values=[JobType.SPARK]) + conf = fields.Dict(keys=fields.Str(), values=fields.Raw()) + environment = EnvironmentField(allow_none=True) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/schedule.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/schedule.py new file mode 100644 index 00000000..fbde3e9b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/schedule.py @@ -0,0 +1,44 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields + +from azure.ai.ml._schema.core.fields import ArmStr, NestedField, UnionField +from azure.ai.ml._schema.core.resource import ResourceSchema +from azure.ai.ml._schema.job import CreationContextSchema +from azure.ai.ml._schema.schedule.create_job import ( + CommandCreateJobSchema, + CreateJobFileRefField, + PipelineCreateJobSchema, + SparkCreateJobSchema, +) +from azure.ai.ml._schema.schedule.trigger import CronTriggerSchema, RecurrenceTriggerSchema +from azure.ai.ml.constants._common import AzureMLResourceType + + +class ScheduleSchema(ResourceSchema): + name = fields.Str(attribute="name", required=True) + display_name = fields.Str(attribute="display_name") + trigger = UnionField( + [ + NestedField(CronTriggerSchema), + NestedField(RecurrenceTriggerSchema), + ], + ) + creation_context = NestedField(CreationContextSchema, dump_only=True) + is_enabled = fields.Boolean(dump_only=True) + provisioning_state = fields.Str(dump_only=True) + properties = fields.Dict(keys=fields.Str(), values=fields.Str(allow_none=True)) + + +class JobScheduleSchema(ScheduleSchema): + create_job = UnionField( + [ + ArmStr(azureml_type=AzureMLResourceType.JOB), + CreateJobFileRefField, + NestedField(PipelineCreateJobSchema), + NestedField(CommandCreateJobSchema), + NestedField(SparkCreateJobSchema), + ] + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/trigger.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/trigger.py new file mode 100644 index 00000000..37147d48 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/trigger.py @@ -0,0 +1,82 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields, post_dump, post_load + +from azure.ai.ml._restclient.v2022_10_01_preview.models import RecurrenceFrequency, TriggerType, WeekDay +from azure.ai.ml._schema.core.fields import ( + DateTimeStr, + DumpableIntegerField, + NestedField, + StringTransformedEnum, + UnionField, +) +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml.constants import TimeZone + + +class TriggerSchema(metaclass=PatchedSchemaMeta): + start_time = UnionField([fields.DateTime(), DateTimeStr()]) + end_time = UnionField([fields.DateTime(), DateTimeStr()]) + time_zone = fields.Str() + + @post_dump(pass_original=True) + # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype + def resolve_time_zone(self, data, original_data, **kwargs): # pylint: disable= unused-argument + """ + Auto-convert will get string like "TimeZone.UTC" for TimeZone enum object, + while the valid result should be "UTC" + """ + if isinstance(original_data.time_zone, TimeZone): + data["time_zone"] = original_data.time_zone.value + return data + + +class CronTriggerSchema(TriggerSchema): + type = StringTransformedEnum(allowed_values=TriggerType.CRON, required=True) + expression = fields.Str(required=True) + + @post_load + def make(self, data, **kwargs) -> "CronTrigger": # pylint: disable= unused-argument + from azure.ai.ml.entities import CronTrigger + + data.pop("type") + return CronTrigger(**data) + + +class RecurrencePatternSchema(metaclass=PatchedSchemaMeta): + hours = UnionField([DumpableIntegerField(), fields.List(fields.Int())], required=True) + minutes = UnionField([DumpableIntegerField(), fields.List(fields.Int())], required=True) + week_days = UnionField( + [ + StringTransformedEnum(allowed_values=[o.value for o in WeekDay]), + fields.List(StringTransformedEnum(allowed_values=[o.value for o in WeekDay])), + ] + ) + month_days = UnionField( + [ + fields.Int(), + fields.List(fields.Int()), + ] + ) + + @post_load + def make(self, data, **kwargs) -> "RecurrencePattern": # pylint: disable= unused-argument + from azure.ai.ml.entities import RecurrencePattern + + return RecurrencePattern(**data) + + +class RecurrenceTriggerSchema(TriggerSchema): + type = StringTransformedEnum(allowed_values=TriggerType.RECURRENCE, required=True) + frequency = StringTransformedEnum(allowed_values=[o.value for o in RecurrenceFrequency], required=True) + interval = fields.Int(required=True) + schedule = NestedField(RecurrencePatternSchema()) + + @post_load + def make(self, data, **kwargs) -> "RecurrenceTrigger": # pylint: disable= unused-argument + from azure.ai.ml.entities import RecurrenceTrigger + + data.pop("type") + return RecurrenceTrigger(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/spark_resource_configuration.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/spark_resource_configuration.py new file mode 100644 index 00000000..8571adf1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/spark_resource_configuration.py @@ -0,0 +1,52 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.fields import NumberVersionField, StringTransformedEnum +from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta + + +class SparkResourceConfigurationSchema(metaclass=PatchedSchemaMeta): + """Schema for SparkResourceConfiguration.""" + + instance_type = fields.Str(metadata={"description": "Optional type of VM used as supported by the compute target."}) + runtime_version = NumberVersionField() + + @post_load + def make(self, data, **kwargs): + """Construct a SparkResourceConfiguration object from the marshalled data. + + :param data: The marshalled data. + :type data: dict[str, str] + :return: A SparkResourceConfiguration object. + :rtype: ~azure.ai.ml.entities.SparkResourceConfiguration + """ + from azure.ai.ml.entities import SparkResourceConfiguration + + return SparkResourceConfiguration(**data) + + +class SparkResourceConfigurationForNodeSchema(SparkResourceConfigurationSchema): + """ + Schema for SparkResourceConfiguration, used for node configuration, where we need to move validation logic to + schema. + """ + + instance_type = StringTransformedEnum( + allowed_values=[ + "standard_e4s_v3", + "standard_e8s_v3", + "standard_e16s_v3", + "standard_e32s_v3", + "standard_e64s_v3", + ], + required=True, + metadata={"description": "Optional type of VM used as supported by the compute target."}, + ) + runtime_version = NumberVersionField( + required=True, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/__init__.py new file mode 100644 index 00000000..dc8b82e2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/__init__.py @@ -0,0 +1,11 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore + +from .workspace import WorkspaceSchema +from .ai_workspaces.project import ProjectSchema +from .ai_workspaces.hub import HubSchema + +__all__ = ["WorkspaceSchema", "ProjectSchema", "HubSchema"] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/ai_workspaces/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/ai_workspaces/__init__.py new file mode 100644 index 00000000..29a4fcd3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/ai_workspaces/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/ai_workspaces/capability_host.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/ai_workspaces/capability_host.py new file mode 100644 index 00000000..cdccb24c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/ai_workspaces/capability_host.py @@ -0,0 +1,18 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields + +from azure.ai.ml._schema.core.schema import PathAwareSchema +from azure.ai.ml._utils._experimental import experimental + + +@experimental +class CapabilityHostSchema(PathAwareSchema): + name = fields.Str() + description = fields.Str() + capability_host_kind = fields.Str() + vector_store_connections = fields.List(fields.Str(), required=False) + ai_services_connections = fields.List(fields.Str(), required=False) + storage_connections = fields.List(fields.Str(), required=False) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/ai_workspaces/hub.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/ai_workspaces/hub.py new file mode 100644 index 00000000..94a7c380 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/ai_workspaces/hub.py @@ -0,0 +1,18 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields + +from azure.ai.ml._schema import StringTransformedEnum +from azure.ai.ml._schema.workspace import WorkspaceSchema +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.constants import WorkspaceKind + + +@experimental +class HubSchema(WorkspaceSchema): + # additional_workspace_storage_accounts This field exists in the API, but is unused, and thus not surfaced yet. + kind = StringTransformedEnum(required=True, allowed_values=WorkspaceKind.HUB) + default_resource_group = fields.Str(required=False) + associated_workspaces = fields.List(fields.Str(), required=False, dump_only=True) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/ai_workspaces/project.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/ai_workspaces/project.py new file mode 100644 index 00000000..86daa735 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/ai_workspaces/project.py @@ -0,0 +1,16 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields + +from azure.ai.ml._schema import StringTransformedEnum +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml._schema.workspace import WorkspaceSchema +from azure.ai.ml.constants import WorkspaceKind + + +@experimental +class ProjectSchema(WorkspaceSchema): + kind = StringTransformedEnum(required=True, allowed_values=WorkspaceKind.PROJECT) + hub_id = fields.Str(required=True) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/__init__.py new file mode 100644 index 00000000..fa462cfb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/__init__.py @@ -0,0 +1,37 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore + +from .workspace_connection import WorkspaceConnectionSchema +from .connection_subtypes import ( + AzureBlobStoreConnectionSchema, + MicrosoftOneLakeConnectionSchema, + AzureOpenAIConnectionSchema, + AzureAIServicesConnectionSchema, + AzureAISearchConnectionSchema, + AzureContentSafetyConnectionSchema, + AzureSpeechServicesConnectionSchema, + APIKeyConnectionSchema, + OpenAIConnectionSchema, + SerpConnectionSchema, + ServerlessConnectionSchema, + OneLakeArtifactSchema, +) + +__all__ = [ + "WorkspaceConnectionSchema", + "AzureBlobStoreConnectionSchema", + "MicrosoftOneLakeConnectionSchema", + "AzureOpenAIConnectionSchema", + "AzureAIServicesConnectionSchema", + "AzureAISearchConnectionSchema", + "AzureContentSafetyConnectionSchema", + "AzureSpeechServicesConnectionSchema", + "APIKeyConnectionSchema", + "OpenAIConnectionSchema", + "SerpConnectionSchema", + "ServerlessConnectionSchema", + "OneLakeArtifactSchema", +] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/connection_subtypes.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/connection_subtypes.py new file mode 100644 index 00000000..d04b3e76 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/connection_subtypes.py @@ -0,0 +1,225 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load +from marshmallow.exceptions import ValidationError +from marshmallow.decorators import pre_load + +from azure.ai.ml._restclient.v2024_04_01_preview.models import ConnectionCategory +from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum, UnionField +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.constants._common import ConnectionTypes +from azure.ai.ml._schema.workspace.connections.one_lake_artifacts import OneLakeArtifactSchema +from azure.ai.ml._schema.workspace.connections.credentials import ( + SasTokenConfigurationSchema, + ServicePrincipalConfigurationSchema, + AccountKeyConfigurationSchema, + AadCredentialConfigurationSchema, +) +from azure.ai.ml.entities import AadCredentialConfiguration +from .workspace_connection import WorkspaceConnectionSchema + + +class AzureBlobStoreConnectionSchema(WorkspaceConnectionSchema): + # type and credentials limited + type = StringTransformedEnum( + allowed_values=ConnectionCategory.AZURE_BLOB, casing_transform=camel_to_snake, required=True + ) + credentials = UnionField( + [ + NestedField(SasTokenConfigurationSchema), + NestedField(AccountKeyConfigurationSchema), + NestedField(AadCredentialConfigurationSchema), + ], + required=False, + load_default=AadCredentialConfiguration(), + ) + + url = fields.Str() + + account_name = fields.Str(required=True, allow_none=False) + container_name = fields.Str(required=True, allow_none=False) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import AzureBlobStoreConnection + + return AzureBlobStoreConnection(**data) + + +class MicrosoftOneLakeConnectionSchema(WorkspaceConnectionSchema): + type = StringTransformedEnum( + allowed_values=ConnectionCategory.AZURE_ONE_LAKE, casing_transform=camel_to_snake, required=True + ) + credentials = UnionField( + [NestedField(ServicePrincipalConfigurationSchema), NestedField(AadCredentialConfigurationSchema)], + required=False, + load_default=AadCredentialConfiguration(), + ) + artifact = NestedField(OneLakeArtifactSchema, required=False, allow_none=True) + + endpoint = fields.Str(required=False) + one_lake_workspace_name = fields.Str(required=False) + + @pre_load + def check_for_target(self, data, **kwargs): + target = data.get("target", None) + artifact = data.get("artifact", None) + endpoint = data.get("endpoint", None) + one_lake_workspace_name = data.get("one_lake_workspace_name", None) + # If the user is using a target, then they don't need the artifact and one lake workspace name. + # This is distinct from when the user set's the 'endpoint' value, which is also used to construct + # the target. If the target is already present, then the loaded connection YAML was probably produced + # by dumping an extant connection. + if target is None: + if artifact is None: + raise ValidationError("If target is unset, then artifact must be set") + if endpoint is None: + raise ValidationError("If target is unset, then endpoint must be set") + if one_lake_workspace_name is None: + raise ValidationError("If target is unset, then one_lake_workspace_name must be set") + return data + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import MicrosoftOneLakeConnection + + return MicrosoftOneLakeConnection(**data) + + +class AzureOpenAIConnectionSchema(WorkspaceConnectionSchema): + # type and credentials limited + type = StringTransformedEnum( + allowed_values=ConnectionCategory.AZURE_OPEN_AI, casing_transform=camel_to_snake, required=True + ) + api_key = fields.Str(required=False, allow_none=True) + api_version = fields.Str(required=False, allow_none=True) + + azure_endpoint = fields.Str() + open_ai_resource_id = fields.Str(required=False, allow_none=True) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import AzureOpenAIConnection + + return AzureOpenAIConnection(**data) + + +class AzureAIServicesConnectionSchema(WorkspaceConnectionSchema): + # type and credentials limited + type = StringTransformedEnum( + allowed_values=ConnectionTypes.AZURE_AI_SERVICES, casing_transform=camel_to_snake, required=True + ) + api_key = fields.Str(required=False, allow_none=True) + endpoint = fields.Str() + ai_services_resource_id = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import AzureAIServicesConnection + + return AzureAIServicesConnection(**data) + + +class AzureAISearchConnectionSchema(WorkspaceConnectionSchema): + # type and credentials limited + type = StringTransformedEnum( + allowed_values=ConnectionTypes.AZURE_SEARCH, casing_transform=camel_to_snake, required=True + ) + api_key = fields.Str(required=False, allow_none=True) + endpoint = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import AzureAISearchConnection + + return AzureAISearchConnection(**data) + + +class AzureContentSafetyConnectionSchema(WorkspaceConnectionSchema): + # type and credentials limited + type = StringTransformedEnum( + allowed_values=ConnectionTypes.AZURE_CONTENT_SAFETY, casing_transform=camel_to_snake, required=True + ) + api_key = fields.Str(required=False, allow_none=True) + endpoint = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import AzureContentSafetyConnection + + return AzureContentSafetyConnection(**data) + + +class AzureSpeechServicesConnectionSchema(WorkspaceConnectionSchema): + # type and credentials limited + type = StringTransformedEnum( + allowed_values=ConnectionTypes.AZURE_SPEECH_SERVICES, casing_transform=camel_to_snake, required=True + ) + api_key = fields.Str(required=False, allow_none=True) + endpoint = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import AzureSpeechServicesConnection + + return AzureSpeechServicesConnection(**data) + + +class APIKeyConnectionSchema(WorkspaceConnectionSchema): + # type and credentials limited + type = StringTransformedEnum( + allowed_values=ConnectionCategory.API_KEY, casing_transform=camel_to_snake, required=True + ) + api_key = fields.Str(required=True) + api_base = fields.Str(required=True) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import APIKeyConnection + + return APIKeyConnection(**data) + + +class OpenAIConnectionSchema(WorkspaceConnectionSchema): + # type and credentials limited + type = StringTransformedEnum( + allowed_values=ConnectionCategory.OPEN_AI, casing_transform=camel_to_snake, required=True + ) + api_key = fields.Str(required=True) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import OpenAIConnection + + return OpenAIConnection(**data) + + +class SerpConnectionSchema(WorkspaceConnectionSchema): + # type and credentials limited + type = StringTransformedEnum(allowed_values=ConnectionCategory.SERP, casing_transform=camel_to_snake, required=True) + api_key = fields.Str(required=True) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import SerpConnection + + return SerpConnection(**data) + + +class ServerlessConnectionSchema(WorkspaceConnectionSchema): + # type and credentials limited + type = StringTransformedEnum( + allowed_values=ConnectionCategory.SERVERLESS, casing_transform=camel_to_snake, required=True + ) + api_key = fields.Str(required=True) + endpoint = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import ServerlessConnection + + return ServerlessConnection(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/credentials.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/credentials.py new file mode 100644 index 00000000..52213c08 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/credentials.py @@ -0,0 +1,178 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +##### DEV NOTE: For some reason, these schemas correlate to the classes defined in ~azure.ai.ml.entities._credentials. +# There used to be a credentials.py file in ~azure.ai.ml.entities.workspace.connections, +# but it was, as far as I could tell, never used. So I removed it and added this comment. + +from typing import Dict + +from marshmallow import fields, post_load + +from azure.ai.ml._restclient.v2024_04_01_preview.models import ConnectionAuthType +from azure.ai.ml._schema.core.fields import StringTransformedEnum +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.entities._credentials import ( + ManagedIdentityConfiguration, + PatTokenConfiguration, + SasTokenConfiguration, + ServicePrincipalConfiguration, + UsernamePasswordConfiguration, + AccessKeyConfiguration, + ApiKeyConfiguration, + AccountKeyConfiguration, + AadCredentialConfiguration, + NoneCredentialConfiguration, +) + + +class WorkspaceCredentialsSchema(metaclass=PatchedSchemaMeta): + type = fields.Str() + + +class PatTokenConfigurationSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum( + allowed_values=ConnectionAuthType.PAT, + casing_transform=camel_to_snake, + required=True, + ) + pat = fields.Str() + + @post_load + def make(self, data: Dict[str, str], **kwargs) -> PatTokenConfiguration: + data.pop("type") + return PatTokenConfiguration(**data) + + +class SasTokenConfigurationSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum( + allowed_values=ConnectionAuthType.SAS, + casing_transform=camel_to_snake, + required=True, + ) + sas_token = fields.Str() + + @post_load + def make(self, data: Dict[str, str], **kwargs) -> SasTokenConfiguration: + data.pop("type") + return SasTokenConfiguration(**data) + + +class UsernamePasswordConfigurationSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum( + allowed_values=ConnectionAuthType.USERNAME_PASSWORD, + casing_transform=camel_to_snake, + required=True, + ) + username = fields.Str() + password = fields.Str() + + @post_load + def make(self, data: Dict[str, str], **kwargs) -> UsernamePasswordConfiguration: + data.pop("type") + return UsernamePasswordConfiguration(**data) + + +class ManagedIdentityConfigurationSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum( + allowed_values=ConnectionAuthType.MANAGED_IDENTITY, + casing_transform=camel_to_snake, + required=True, + ) + client_id = fields.Str() + resource_id = fields.Str() + + @post_load + def make(self, data: Dict[str, str], **kwargs) -> ManagedIdentityConfiguration: + data.pop("type") + return ManagedIdentityConfiguration(**data) + + +class ServicePrincipalConfigurationSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum( + allowed_values=ConnectionAuthType.SERVICE_PRINCIPAL, + casing_transform=camel_to_snake, + required=True, + ) + + client_id = fields.Str() + client_secret = fields.Str() + tenant_id = fields.Str() + + @post_load + def make(self, data: Dict[str, str], **kwargs) -> ServicePrincipalConfiguration: + data.pop("type") + return ServicePrincipalConfiguration(**data) + + +class AccessKeyConfigurationSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum( + allowed_values=ConnectionAuthType.ACCESS_KEY, + casing_transform=camel_to_snake, + required=True, + ) + access_key_id = fields.Str() + secret_access_key = fields.Str() + + @post_load + def make(self, data: Dict[str, str], **kwargs) -> AccessKeyConfiguration: + data.pop("type") + return AccessKeyConfiguration(**data) + + +class ApiKeyConfigurationSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum( + allowed_values=ConnectionAuthType.API_KEY, + casing_transform=camel_to_snake, + required=True, + ) + key = fields.Str() + + @post_load + def make(self, data: Dict[str, str], **kwargs) -> ApiKeyConfiguration: + data.pop("type") + return ApiKeyConfiguration(**data) + + +class AccountKeyConfigurationSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum( + allowed_values=ConnectionAuthType.ACCOUNT_KEY, + casing_transform=camel_to_snake, + required=True, + ) + account_key = fields.Str() + + @post_load + def make(self, data: Dict[str, str], **kwargs) -> AccountKeyConfiguration: + data.pop("type") + return AccountKeyConfiguration(**data) + + +class AadCredentialConfigurationSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum( + allowed_values=ConnectionAuthType.AAD, + casing_transform=camel_to_snake, + required=True, + ) + + @post_load + def make(self, data: Dict[str, str], **kwargs) -> AadCredentialConfiguration: + data.pop("type") + return AadCredentialConfiguration(**data) + + +class NoneCredentialConfigurationSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum( + allowed_values=ConnectionAuthType.NONE, + casing_transform=camel_to_snake, + required=True, + ) + + @post_load + def make(self, data: Dict[str, str], **kwargs) -> NoneCredentialConfiguration: + data.pop("type") + return NoneCredentialConfiguration(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/one_lake_artifacts.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/one_lake_artifacts.py new file mode 100644 index 00000000..563a9359 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/one_lake_artifacts.py @@ -0,0 +1,26 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.fields import StringTransformedEnum +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.constants._common import OneLakeArtifactTypes + +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + + +class OneLakeArtifactSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum( + allowed_values=OneLakeArtifactTypes.ONE_LAKE, casing_transform=camel_to_snake, required=True + ) + name = fields.Str(required=True) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import OneLakeConnectionArtifact + + return OneLakeConnectionArtifact(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/workspace_connection.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/workspace_connection.py new file mode 100644 index 00000000..20863a5a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/workspace_connection.py @@ -0,0 +1,86 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load + +from azure.ai.ml._restclient.v2024_04_01_preview.models import ConnectionCategory +from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum, UnionField +from azure.ai.ml._schema.core.resource import ResourceSchema +from azure.ai.ml._schema.job import CreationContextSchema +from azure.ai.ml._schema.workspace.connections.credentials import ( + AccountKeyConfigurationSchema, + ManagedIdentityConfigurationSchema, + PatTokenConfigurationSchema, + SasTokenConfigurationSchema, + ServicePrincipalConfigurationSchema, + UsernamePasswordConfigurationSchema, + AccessKeyConfigurationSchema, + ApiKeyConfigurationSchema, + AadCredentialConfigurationSchema, + NoneCredentialConfigurationSchema, +) +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.constants._common import ConnectionTypes +from azure.ai.ml.entities import NoneCredentialConfiguration, AadCredentialConfiguration + + +class WorkspaceConnectionSchema(ResourceSchema): + # Inherits name, id, tags, and description fields from ResourceSchema + creation_context = NestedField(CreationContextSchema, dump_only=True) + type = StringTransformedEnum( + allowed_values=[ + ConnectionCategory.GIT, + ConnectionCategory.CONTAINER_REGISTRY, + ConnectionCategory.PYTHON_FEED, + ConnectionCategory.S3, + ConnectionCategory.SNOWFLAKE, + ConnectionCategory.AZURE_SQL_DB, + ConnectionCategory.AZURE_SYNAPSE_ANALYTICS, + ConnectionCategory.AZURE_MY_SQL_DB, + ConnectionCategory.AZURE_POSTGRES_DB, + ConnectionTypes.CUSTOM, + ConnectionTypes.AZURE_DATA_LAKE_GEN_2, + ], + casing_transform=camel_to_snake, + required=True, + ) + + # Sorta false, some connection types require this field, some don't. + # And some rename it... for client familiarity reasons. + target = fields.Str(required=False) + + credentials = UnionField( + [ + NestedField(PatTokenConfigurationSchema), + NestedField(SasTokenConfigurationSchema), + NestedField(UsernamePasswordConfigurationSchema), + NestedField(ManagedIdentityConfigurationSchema), + NestedField(ServicePrincipalConfigurationSchema), + NestedField(AccessKeyConfigurationSchema), + NestedField(ApiKeyConfigurationSchema), + NestedField(AccountKeyConfigurationSchema), + NestedField(AadCredentialConfigurationSchema), + NestedField(NoneCredentialConfigurationSchema), + ], + required=False, + load_default=NoneCredentialConfiguration(), + ) + + is_shared = fields.Bool(load_default=True) + metadata = fields.Dict(required=False) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import WorkspaceConnection + + # Most non-subclassed connections default to a none credential if none + # is provided. ALDS Gen 2 connections default to AAD with this code. + if ( + data.get("type", None) == ConnectionTypes.AZURE_DATA_LAKE_GEN_2 + and data.get("credentials", None) == NoneCredentialConfiguration() + ): + data["credentials"] = AadCredentialConfiguration() + return WorkspaceConnection(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/customer_managed_key.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/customer_managed_key.py new file mode 100644 index 00000000..459507fc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/customer_managed_key.py @@ -0,0 +1,23 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + + +class CustomerManagedKeySchema(metaclass=PatchedSchemaMeta): + key_vault = fields.Str() + key_uri = fields.Url() + cosmosdb_id = fields.Str() + storage_id = fields.Str() + search_id = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import CustomerManagedKey + + return CustomerManagedKey(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/endpoint_connection.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/endpoint_connection.py new file mode 100644 index 00000000..ba926d9e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/endpoint_connection.py @@ -0,0 +1,23 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + + +class EndpointConnectionSchema(metaclass=PatchedSchemaMeta): + subscription_id = fields.UUID() + resource_group = fields.Str() + location = fields.Str() + vnet_name = fields.Str() + subnet_name = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import EndpointConnection + + return EndpointConnection(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/identity.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/identity.py new file mode 100644 index 00000000..d0348c3b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/identity.py @@ -0,0 +1,79 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields +from marshmallow.decorators import post_load, pre_dump + +from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum +from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta +from azure.ai.ml._utils.utils import camel_to_snake, snake_to_camel +from azure.ai.ml.constants._workspace import ManagedServiceIdentityType +from azure.ai.ml.entities._credentials import IdentityConfiguration, ManagedIdentityConfiguration + + +class UserAssignedIdentitySchema(metaclass=PatchedSchemaMeta): + principal_id = fields.Str(required=False) + client_id = fields.Str(required=False) + resource_id = fields.Str(required=False) + + @post_load + def make(self, data, **kwargs): + return ManagedIdentityConfiguration(**data) + + +class IdentitySchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum( + allowed_values=[ + ManagedServiceIdentityType.SYSTEM_ASSIGNED, + ManagedServiceIdentityType.USER_ASSIGNED, + ManagedServiceIdentityType.NONE, + ManagedServiceIdentityType.SYSTEM_ASSIGNED_USER_ASSIGNED, + ], + casing_transform=camel_to_snake, + metadata={"description": "resource identity type."}, + ) + principal_id = fields.Str(required=False) + tenant_id = fields.Str(required=False) + user_assigned_identities = fields.Dict( + keys=fields.Str(required=True), values=NestedField(UserAssignedIdentitySchema, allow_none=True), allow_none=True + ) + + @pre_dump + def predump(self, data, **kwargs): + if data and isinstance(data, IdentityConfiguration): + data.user_assigned_identities = self.uai_list2dict(data.user_assigned_identities) + return data + + @post_load + def make(self, data, **kwargs): + if data.get("user_assigned_identities", False): + data["user_assigned_identities"] = self.uai_dict2list(data.pop("user_assigned_identities")) + data["type"] = snake_to_camel(data.pop("type")) + return IdentityConfiguration(**data) + + def uai_dict2list(self, uai_dict): + res = [] + for resource_id, meta in uai_dict.items(): + if not isinstance(meta, ManagedIdentityConfiguration): + continue + c_id = meta.client_id + p_id = meta.principal_id + res.append(ManagedIdentityConfiguration(resource_id=resource_id, client_id=c_id, principal_id=p_id)) + return res + + def uai_list2dict(self, uai_list): + res = {} + if uai_list and isinstance(uai_list, list): + for uai in uai_list: + if not isinstance(uai, ManagedIdentityConfiguration): + continue + meta = {} + if uai.client_id: + meta["client_id"] = uai.client_id + if uai.principal_id: + meta["principal_id"] = uai.principal_id + res[uai.resource_id] = meta + return res if res else None diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/network_acls.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/network_acls.py new file mode 100644 index 00000000..e9e5e8ec --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/network_acls.py @@ -0,0 +1,63 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import ValidationError, fields, post_load, validates_schema + +from azure.ai.ml._schema.core.schema import PathAwareSchema +from azure.ai.ml.entities._workspace.network_acls import DefaultActionType, IPRule, NetworkAcls + + +class IPRuleSchema(PathAwareSchema): + """Schema for IPRule.""" + + value = fields.Str(required=True) + + @post_load + def make(self, data, **kwargs): # pylint: disable=unused-argument + """Create an IPRule object from the marshmallow schema. + + :param data: The data from which the IPRule is being loaded. + :type data: OrderedDict[str, Any] + :returns: An IPRule object. + :rtype: azure.ai.ml.entities._workspace.network_acls.NetworkAcls.IPRule + """ + return IPRule(**data) + + +class NetworkAclsSchema(PathAwareSchema): + """Schema for NetworkAcls. + + :param default_action: Specifies the default action when no IP rules are matched. + :type default_action: str + :param ip_rules: Rules governing the accessibility of a resource from a specific IP address or IP range. + :type ip_rules: Optional[List[IPRule]] + """ + + default_action = fields.Str(required=True) + ip_rules = fields.List(fields.Nested(IPRuleSchema), allow_none=True) + + @post_load + def make(self, data, **kwargs): # pylint: disable=unused-argument + """Create a NetworkAcls object from the marshmallow schema. + + :param data: The data from which the NetworkAcls is being loaded. + :type data: OrderedDict[str, Any] + :returns: A NetworkAcls object. + :rtype: azure.ai.ml.entities._workspace.network_acls.NetworkAcls + """ + return NetworkAcls(**data) + + @validates_schema + def validate_schema(self, data, **kwargs): # pylint: disable=unused-argument + """Validate the NetworkAcls schema. + + :param data: The data to validate. + :type data: OrderedDict[str, Any] + :raises ValidationError: If the schema is invalid. + """ + if data["default_action"] not in set([DefaultActionType.DENY, DefaultActionType.ALLOW]): + raise ValidationError("Invalid value for default_action. Must be 'Deny' or 'Allow'.") + + if data["default_action"] == DefaultActionType.DENY and not data.get("ip_rules"): + raise ValidationError("ip_rules must be provided when default_action is 'Deny'.") diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/networking.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/networking.py new file mode 100644 index 00000000..f228ee3e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/networking.py @@ -0,0 +1,224 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument,no-else-return + +from marshmallow import EXCLUDE, fields +from marshmallow.decorators import post_load, pre_dump + +from azure.ai.ml._schema import ExperimentalField +from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum, UnionField +from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta +from azure.ai.ml._utils.utils import _snake_to_camel, camel_to_snake +from azure.ai.ml.constants._workspace import FirewallSku, IsolationMode, OutboundRuleCategory +from azure.ai.ml.entities._workspace.networking import ( + FqdnDestination, + ManagedNetwork, + PrivateEndpointDestination, + ServiceTagDestination, +) + + +class ManagedNetworkStatusSchema(metaclass=PatchedSchemaMeta): + spark_ready = fields.Bool(dump_only=True) + status = fields.Str(dump_only=True) + + +class FqdnOutboundRuleSchema(metaclass=PatchedSchemaMeta): + name = fields.Str(required=True) + parent_rule_names = fields.List(fields.Str(), dump_only=True) + type = fields.Constant("fqdn") + destination = fields.Str(required=True) + category = StringTransformedEnum( + allowed_values=[ + OutboundRuleCategory.REQUIRED, + OutboundRuleCategory.RECOMMENDED, + OutboundRuleCategory.USER_DEFINED, + ], + casing_transform=camel_to_snake, + metadata={"description": "outbound rule category."}, + dump_only=True, + ) + status = fields.Str(dump_only=True) + + @post_load + def createdestobject(self, data, **kwargs): + dest = data.get("destination") + category = data.get("category", OutboundRuleCategory.USER_DEFINED) + name = data.get("name") + status = data.get("status", None) + return FqdnDestination( + name=name, + destination=dest, + category=_snake_to_camel(category), + status=status, + ) + + +class ServiceTagDestinationSchema(metaclass=PatchedSchemaMeta): + service_tag = fields.Str(required=True) + protocol = fields.Str(required=True) + port_ranges = fields.Str(required=True) + address_prefixes = fields.List(fields.Str()) + + +class ServiceTagOutboundRuleSchema(metaclass=PatchedSchemaMeta): + name = fields.Str(required=True) + parent_rule_names = fields.List(fields.Str(), dump_only=True) + type = fields.Constant("service_tag") + destination = NestedField(ServiceTagDestinationSchema, required=True) + category = StringTransformedEnum( + allowed_values=[ + OutboundRuleCategory.REQUIRED, + OutboundRuleCategory.RECOMMENDED, + OutboundRuleCategory.USER_DEFINED, + ], + casing_transform=camel_to_snake, + metadata={"description": "outbound rule category."}, + dump_only=True, + ) + status = fields.Str(dump_only=True) + + @pre_dump + def predump(self, data, **kwargs): + data.destination = self.service_tag_dest2dict( + data.service_tag, data.protocol, data.port_ranges, data.address_prefixes + ) + return data + + @post_load + def createdestobject(self, data, **kwargs): + dest = data.get("destination") + category = data.get("category", OutboundRuleCategory.USER_DEFINED) + name = data.get("name") + status = data.get("status", None) + return ServiceTagDestination( + name=name, + service_tag=dest["service_tag"], + protocol=dest["protocol"], + port_ranges=dest["port_ranges"], + address_prefixes=dest.get("address_prefixes", None), + category=_snake_to_camel(category), + status=status, + ) + + def service_tag_dest2dict(self, service_tag, protocol, port_ranges, address_prefixes): + service_tag_dest = {} + service_tag_dest["service_tag"] = service_tag + service_tag_dest["protocol"] = protocol + service_tag_dest["port_ranges"] = port_ranges + service_tag_dest["address_prefixes"] = address_prefixes + return service_tag_dest + + +class PrivateEndpointDestinationSchema(metaclass=PatchedSchemaMeta): + service_resource_id = fields.Str(required=True) + subresource_target = fields.Str(required=True) + spark_enabled = fields.Bool(required=True) + + +class PrivateEndpointOutboundRuleSchema(metaclass=PatchedSchemaMeta): + name = fields.Str(required=True) + parent_rule_names = fields.List(fields.Str(), dump_only=True) + type = fields.Constant("private_endpoint") + destination = NestedField(PrivateEndpointDestinationSchema, required=True) + fqdns = fields.List(fields.Str()) + category = StringTransformedEnum( + allowed_values=[ + OutboundRuleCategory.REQUIRED, + OutboundRuleCategory.RECOMMENDED, + OutboundRuleCategory.USER_DEFINED, + OutboundRuleCategory.DEPENDENCY, + ], + casing_transform=camel_to_snake, + metadata={"description": "outbound rule category."}, + dump_only=True, + ) + status = fields.Str(dump_only=True) + + @pre_dump + def predump(self, data, **kwargs): + data.destination = self.pe_dest2dict(data.service_resource_id, data.subresource_target, data.spark_enabled) + return data + + @post_load + def createdestobject(self, data, **kwargs): + dest = data.get("destination") + category = data.get("category", OutboundRuleCategory.USER_DEFINED) + name = data.get("name") + status = data.get("status", None) + fqdns = data.get("fqdns", None) + return PrivateEndpointDestination( + name=name, + service_resource_id=dest["service_resource_id"], + subresource_target=dest["subresource_target"], + spark_enabled=dest["spark_enabled"], + category=_snake_to_camel(category), + status=status, + fqdns=fqdns, + ) + + def pe_dest2dict(self, service_resource_id, subresource_target, spark_enabled): + pedest = {} + pedest["service_resource_id"] = service_resource_id + pedest["subresource_target"] = subresource_target + pedest["spark_enabled"] = spark_enabled + return pedest + + +class ManagedNetworkSchema(metaclass=PatchedSchemaMeta): + isolation_mode = StringTransformedEnum( + allowed_values=[ + IsolationMode.DISABLED, + IsolationMode.ALLOW_INTERNET_OUTBOUND, + IsolationMode.ALLOW_ONLY_APPROVED_OUTBOUND, + ], + casing_transform=camel_to_snake, + metadata={"description": "isolation mode for the workspace managed network."}, + ) + outbound_rules = fields.List( + UnionField( + [ + NestedField(PrivateEndpointOutboundRuleSchema, allow_none=False, unknown=EXCLUDE), + NestedField(ServiceTagOutboundRuleSchema, allow_none=False, unknown=EXCLUDE), + NestedField( + FqdnOutboundRuleSchema, allow_none=False, unknown=EXCLUDE + ), # this needs to be last since otherwise union field with match destination as a string + ], + allow_none=False, + is_strict=True, + ), + allow_none=True, + ) + firewall_sku = ExperimentalField( + StringTransformedEnum( + allowed_values=[ + FirewallSku.STANDARD, + FirewallSku.BASIC, + ], + casing_transform=camel_to_snake, + metadata={"description": "Firewall sku for FQDN rules in AllowOnlyApprovedOutbound mode"}, + ) + ) + network_id = fields.Str(required=False, dump_only=True) + status = NestedField(ManagedNetworkStatusSchema, allow_none=False, unknown=EXCLUDE) + + @post_load + def make(self, data, **kwargs): + outbound_rules = data.get("outbound_rules", False) + + firewall_sku = data.get("firewall_sku", False) + firewall_sku_value = _snake_to_camel(data["firewall_sku"]) if firewall_sku else FirewallSku.STANDARD + + if outbound_rules: + return ManagedNetwork( + isolation_mode=_snake_to_camel(data["isolation_mode"]), + outbound_rules=outbound_rules, + firewall_sku=firewall_sku_value, + ) + else: + return ManagedNetwork( + isolation_mode=_snake_to_camel(data["isolation_mode"]), + firewall_sku=firewall_sku_value, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/private_endpoint.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/private_endpoint.py new file mode 100644 index 00000000..0235a4a0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/private_endpoint.py @@ -0,0 +1,23 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.fields import NestedField +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + +from .endpoint_connection import EndpointConnectionSchema + + +class PrivateEndpointSchema(metaclass=PatchedSchemaMeta): + approval_type = fields.Str() + connections = fields.Dict(keys=fields.Str(), values=NestedField(EndpointConnectionSchema)) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities import PrivateEndpoint + + return PrivateEndpoint(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/serverless_compute.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/serverless_compute.py new file mode 100644 index 00000000..5137e57f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/serverless_compute.py @@ -0,0 +1,52 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from marshmallow import fields +from marshmallow.decorators import post_load, validates + +from azure.ai.ml._schema._utils.utils import ArmId +from azure.ai.ml._schema.core.schema import PathAwareSchema +from azure.ai.ml.entities._workspace.serverless_compute import ServerlessComputeSettings + + +class ServerlessComputeSettingsSchema(PathAwareSchema): + """Schema for ServerlessComputeSettings. + + :param custom_subnet: The custom subnet to use for serverless computes created in the workspace. + :type custom_subnet: Optional[ArmId] + :param no_public_ip: Whether to disable public ip for the compute. Only valid if custom_subnet is defined. + :type no_public_ip: bool + """ + + custom_subnet = fields.Str(allow_none=True) + no_public_ip = fields.Bool(load_default=False) + + @post_load + def make(self, data, **_kwargs) -> ServerlessComputeSettings: + """Create a ServerlessComputeSettings object from the marshmallow schema. + + :param data: The data from which the ServerlessComputeSettings are being loaded. + :type data: OrderedDict[str, Any] + :returns: A ServerlessComputeSettings object. + :rtype: azure.ai.ml.entities._workspace.serverless_compute.ServerlessComputeSettings + """ + custom_subnet = data.pop("custom_subnet", None) + if custom_subnet == "None": + custom_subnet = None # For loading from YAML when the user wants to trigger a removal + no_public_ip = data.pop("no_public_ip", False) + return ServerlessComputeSettings(custom_subnet=custom_subnet, no_public_ip=no_public_ip) + + @validates("custom_subnet") + def validate_custom_subnet(self, data: str, **_kwargs): + """Validates the custom_subnet field matches the ARM ID format or is a None-recognizable value. + + :param data: The candidate custom_subnet to validate. + :type data: str + :raises ValidationError: If the custom_subnet is not formatted as an ARM ID. + """ + if data == "None" or data is None: + # If the string is literally "None", then it should be deserialized to None + pass + else: + # Verify that we can transform it to an ArmId if it is not None. + ArmId(data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/workspace.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/workspace.py new file mode 100644 index 00000000..1df06f97 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/workspace.py @@ -0,0 +1,49 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import EXCLUDE, fields + +from azure.ai.ml._schema._utils.utils import validate_arm_str +from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum +from azure.ai.ml._schema.core.schema import PathAwareSchema +from azure.ai.ml._schema.workspace.customer_managed_key import CustomerManagedKeySchema +from azure.ai.ml._schema.workspace.identity import IdentitySchema +from azure.ai.ml._schema.workspace.network_acls import NetworkAclsSchema +from azure.ai.ml._schema.workspace.networking import ManagedNetworkSchema +from azure.ai.ml._schema.workspace.serverless_compute import ServerlessComputeSettingsSchema +from azure.ai.ml._utils.utils import snake_to_pascal +from azure.ai.ml.constants._common import PublicNetworkAccess + + +class WorkspaceSchema(PathAwareSchema): + name = fields.Str(required=True) + location = fields.Str() + id = fields.Str(dump_only=True) + resource_group = fields.Str() + description = fields.Str() + discovery_url = fields.Str() + display_name = fields.Str() + hbi_workspace = fields.Bool() + storage_account = fields.Str(validate=validate_arm_str) + container_registry = fields.Str(validate=validate_arm_str) + key_vault = fields.Str(validate=validate_arm_str) + application_insights = fields.Str(validate=validate_arm_str) + customer_managed_key = NestedField(CustomerManagedKeySchema) + tags = fields.Dict(keys=fields.Str(), values=fields.Str()) + mlflow_tracking_uri = fields.Str(dump_only=True) + image_build_compute = fields.Str() + public_network_access = StringTransformedEnum( + allowed_values=[PublicNetworkAccess.DISABLED, PublicNetworkAccess.ENABLED], + casing_transform=snake_to_pascal, + ) + network_acls = NestedField(NetworkAclsSchema) + system_datastores_auth_mode = fields.Str() + identity = NestedField(IdentitySchema) + primary_user_assigned_identity = fields.Str() + workspace_hub = fields.Str(validate=validate_arm_str) + managed_network = NestedField(ManagedNetworkSchema, unknown=EXCLUDE) + provision_network_now = fields.Bool() + enable_data_isolation = fields.Bool() + allow_roleassignment_on_rg = fields.Bool() + serverless_compute = NestedField(ServerlessComputeSettingsSchema) |