aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/_schema
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/__init__.py60
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data/mltable_metadata_path_schemas.py30
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data/mltable_metadata_schema.py40
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data_import/__init__.py9
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data_import/data_import.py22
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data_import/schedule.py39
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/__init__.py30
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/_on_prem.py40
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/_on_prem_credentials.py53
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/adls_gen1.py41
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/azure_storage.py97
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/credentials.py99
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/one_lake.py49
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_deployment.py92
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_deployment_settings.py26
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_job.py132
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_pipeline_component_deployment_configurations_schema.py52
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/compute_binding.py36
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/job_definition_schema.py51
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/model_batch_deployment.py46
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/model_batch_deployment_settings.py56
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/pipeline_component_batch_deployment_schema.py70
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/run_settings_schema.py28
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/code_configuration_schema.py25
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/deployment.py48
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/data_asset_schema.py26
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/data_collector_schema.py39
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/deployment_collection_schema.py32
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/event_hub_schema.py31
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/liveness_probe.py28
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/online_deployment.py79
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/oversize_data_config_schema.py31
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/payload_response_schema.py24
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/request_logging_schema.py23
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/request_settings_schema.py26
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/resource_requirements_schema.py28
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/resource_settings_schema.py32
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/scale_settings_schema.py51
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/__init__.py17
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/distillation_job.py84
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/endpoint_request_settings.py27
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/prompt_settings.py29
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/teacher_model_settings.py29
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/__init__.py15
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/batch/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/batch/batch_endpoint.py27
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/batch/batch_endpoint_defaults.py28
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/endpoint.py41
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/online/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/online/online_endpoint.py66
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/__init__.py25
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/delay_metadata_schema.py21
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_schema.py29
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_set_backfill_schema.py22
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_set_schema.py27
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_set_specification_schema.py19
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_transformation_code_metadata_schema.py22
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_window_schema.py11
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/featureset_spec_metadata_schema.py33
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/featureset_spec_properties_schema.py55
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/materialization_settings_schema.py37
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/source_metadata_schema.py30
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/source_process_code_metadata_schema.py20
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/timestamp_column_metadata_schema.py20
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store/__init__.py15
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store/compute_runtime_schema.py19
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store/feature_store_schema.py43
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store/materialization_store_schema.py23
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store_entity/__init__.py13
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store_entity/data_column_schema.py26
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store_entity/feature_store_entity_schema.py26
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/__init__.py19
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/azure_openai_finetuning.py54
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/azure_openai_hyperparameters.py18
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/constants.py17
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/custom_model_finetuning.py35
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/finetuning_job.py21
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/finetuning_vertical.py73
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_notification/__init__.py11
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_notification/notification_schema.py24
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/__init__.py9
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/_constants.py6
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/parameterized_sweep.py30
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/__init__.py21
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/choice.py63
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/normal.py60
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/randint.py30
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/uniform.py62
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_fields_provider.py77
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_job.py18
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_objective.py31
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_sampling_algorithm.py103
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_termination.py95
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/data_binding_expression.py88
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/utils.py94
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/artifact.py24
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/asset.py42
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/code_asset.py47
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/data.py25
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/environment.py160
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/federated_learning_silo.py24
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/index.py30
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/model.py65
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/base_environment_source.py23
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/inference_server.py51
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/model_configuration.py30
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/model_package.py41
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/model_package_input.py81
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/online_inference_configuration.py30
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/route.py22
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/workspace_asset_reference.py27
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/__init__.py30
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/automl_job.py21
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/automl_vertical.py18
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/featurization_settings.py74
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/forecasting_settings.py66
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_classification.py66
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_limit_settings.py21
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_model_distribution_settings.py216
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_model_settings.py96
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_object_detection.py66
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_sweep_settings.py27
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_vertical.py19
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/nlp_fixed_parameters.py33
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/nlp_parameter_subspace.py106
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/nlp_sweep_settings.py27
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/nlp_vertical.py24
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/nlp_vertical_limit_settings.py23
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/text_classification.py36
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/text_classification_multilabel.py36
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/text_ner.py35
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/classification.py37
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/forecasting.py38
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/regression.py36
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/table_vertical.py29
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/table_vertical_limit_settings.py28
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/training_settings.py122
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/__init__.py48
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/automl_component.py23
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/command_component.py137
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/component.py143
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/data_transfer_component.py257
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/flow.py107
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/import_component.py74
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/input_output.py126
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_component.py108
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_task.py23
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/resource.py22
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/retry_settings.py13
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/spark_component.py79
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/aml_compute.py47
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/aml_compute_node_info.py15
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/attached_compute.py12
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/compute.py85
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/compute_instance.py83
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/custom_applications.py60
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/kubernetes_compute.py16
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/schedule.py118
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/setup_scripts.py33
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/synapsespark_compute.py49
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/usage.py42
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/virtual_machine_compute.py34
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/vm_size.py19
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/auto_delete_setting.py38
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/fields.py1029
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/intellectual_property.py38
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/resource.py51
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/schema.py123
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/schema_meta.py53
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/identity.py63
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/__init__.py28
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/base_job.py69
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/command_job.py23
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/creation_context.py16
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/data_transfer_job.py60
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/distribution.py104
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/identity.py67
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/import_job.py54
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/input_output_entry.py256
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/input_output_fields_provider.py50
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/input_port.py29
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/job_limits.py45
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/job_output.py18
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/parallel_job.py15
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/parameterized_command.py41
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/parameterized_parallel.py72
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/parameterized_spark.py151
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/services.py100
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/spark_job.py28
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job_resource_configuration.py38
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job_resources.py21
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/alert_notification.py19
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/compute.py23
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/input_data.py52
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/monitor_definition.py53
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/schedule.py11
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/signals.py348
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/target.py25
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/thresholds.py211
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/__init__.py17
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/automl_node.py148
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/component_job.py554
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/condition_node.py48
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/control_flow_job.py147
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_command_job.py31
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_component.py297
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_datatransfer_job.py55
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_import_job.py25
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_job.py76
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_job_io.py51
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_parallel_job.py40
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_spark_job.py29
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/settings.py42
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/queue_settings.py23
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/__init__.py9
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/registry.py53
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/registry_region_arm_details.py61
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/system_created_acr_account.py35
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/system_created_storage_account.py40
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/util.py15
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/resource_configuration.py21
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/create_job.py144
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/schedule.py44
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/trigger.py82
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/spark_resource_configuration.py52
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/__init__.py11
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/ai_workspaces/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/ai_workspaces/capability_host.py18
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/ai_workspaces/hub.py18
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/ai_workspaces/project.py16
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/__init__.py37
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/connection_subtypes.py225
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/credentials.py178
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/one_lake_artifacts.py26
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/workspace_connection.py86
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/customer_managed_key.py23
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/endpoint_connection.py23
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/identity.py79
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/network_acls.py63
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/networking.py224
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/private_endpoint.py23
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/serverless_compute.py52
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/workspace.py49
257 files changed, 14301 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/__init__.py
new file mode 100644
index 00000000..115a65bb
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/__init__.py
@@ -0,0 +1,60 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
+
+from ._data_import import DataImportSchema
+from ._sweep import SweepJobSchema
+from .assets.code_asset import AnonymousCodeAssetSchema, CodeAssetSchema
+from .assets.data import DataSchema
+from .assets.environment import AnonymousEnvironmentSchema, EnvironmentSchema
+from .assets.index import IndexAssetSchema
+from .assets.model import ModelSchema
+from .assets.workspace_asset_reference import WorkspaceAssetReferenceSchema
+from .component import CommandComponentSchema
+from .core.fields import (
+ ArmStr,
+ ArmVersionedStr,
+ ExperimentalField,
+ NestedField,
+ RegistryStr,
+ StringTransformedEnum,
+ UnionField,
+)
+from .core.schema import PathAwareSchema, YamlFileSchema
+from .core.schema_meta import PatchedSchemaMeta
+from .job import CommandJobSchema, ParallelJobSchema, SparkJobSchema
+
+# TODO: enable in PuP
+# from .job import ImportJobSchema
+# from .component import ImportComponentSchema
+
+__all__ = [
+ # "ImportJobSchema",
+ # "ImportComponentSchema",
+ "ArmStr",
+ "ArmVersionedStr",
+ "DataSchema",
+ "StringTransformedEnum",
+ "CodeAssetSchema",
+ "CommandJobSchema",
+ "SparkJobSchema",
+ "ParallelJobSchema",
+ "EnvironmentSchema",
+ "AnonymousEnvironmentSchema",
+ "NestedField",
+ "PatchedSchemaMeta",
+ "PathAwareSchema",
+ "ModelSchema",
+ "SweepJobSchema",
+ "UnionField",
+ "YamlFileSchema",
+ "CommandComponentSchema",
+ "AnonymousCodeAssetSchema",
+ "ExperimentalField",
+ "RegistryStr",
+ "WorkspaceAssetReferenceSchema",
+ "DataImportSchema",
+ "IndexAssetSchema",
+]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data/__init__.py
new file mode 100644
index 00000000..29a4fcd3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data/mltable_metadata_path_schemas.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data/mltable_metadata_path_schemas.py
new file mode 100644
index 00000000..0156743e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data/mltable_metadata_path_schemas.py
@@ -0,0 +1,30 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields
+
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+
+
+class MLTableMetadataPathFileSchema(metaclass=PatchedSchemaMeta):
+ file = fields.Str(
+ metadata={"description": "This specifies path of file containing data."},
+ required=True,
+ )
+
+
+class MLTableMetadataPathFolderSchema(metaclass=PatchedSchemaMeta):
+ folder = fields.Str(
+ metadata={"description": "This specifies path of folder containing data."},
+ required=True,
+ )
+
+
+class MLTableMetadataPathPatternSchema(metaclass=PatchedSchemaMeta):
+ pattern = fields.Str(
+ metadata={
+ "description": "This specifies a search pattern to allow globbing of files and folders containing data."
+ },
+ required=True,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data/mltable_metadata_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data/mltable_metadata_schema.py
new file mode 100644
index 00000000..99861bc3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data/mltable_metadata_schema.py
@@ -0,0 +1,40 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from typing import Dict
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.fields import NestedField, UnionField
+from azure.ai.ml._schema.core.schema import YamlFileSchema
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+
+from .mltable_metadata_path_schemas import (
+ MLTableMetadataPathFileSchema,
+ MLTableMetadataPathFolderSchema,
+ MLTableMetadataPathPatternSchema,
+)
+
+
+class MLTableMetadataSchema(YamlFileSchema):
+ paths = fields.List(
+ UnionField(
+ [
+ NestedField(MLTableMetadataPathFileSchema()),
+ NestedField(MLTableMetadataPathFolderSchema()),
+ NestedField(MLTableMetadataPathPatternSchema()),
+ ]
+ ),
+ required=True,
+ )
+ transformations = fields.List(fields.Raw(), required=False)
+
+ @post_load
+ def make(self, data: Dict, **kwargs):
+ from azure.ai.ml.entities._data.mltable_metadata import MLTableMetadata, MLTableMetadataPath
+
+ paths = [MLTableMetadataPath(pathDict=pathDict) for pathDict in data.pop("paths")]
+ return MLTableMetadata(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data, paths=paths)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data_import/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data_import/__init__.py
new file mode 100644
index 00000000..28719d1f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data_import/__init__.py
@@ -0,0 +1,9 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
+
+from .data_import import DataImportSchema
+
+__all__ = ["DataImportSchema"]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data_import/data_import.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data_import/data_import.py
new file mode 100644
index 00000000..a731e1da
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data_import/data_import.py
@@ -0,0 +1,22 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import post_load
+
+from azure.ai.ml._schema.core.fields import NestedField
+from azure.ai.ml._schema.job.input_output_entry import DatabaseSchema, FileSystemSchema
+from azure.ai.ml._utils._experimental import experimental
+from ..core.fields import UnionField
+from ..assets.data import DataSchema
+
+
+@experimental
+class DataImportSchema(DataSchema):
+ source = UnionField([NestedField(DatabaseSchema), NestedField(FileSystemSchema)], required=True, allow_none=False)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._data_import.data_import import DataImport
+
+ return DataImport(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data_import/schedule.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data_import/schedule.py
new file mode 100644
index 00000000..20a7e3d2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_data_import/schedule.py
@@ -0,0 +1,39 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import yaml
+
+from azure.ai.ml._schema.core.fields import NestedField, FileRefField
+from azure.ai.ml._schema.schedule.schedule import ScheduleSchema
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+from ..core.fields import UnionField
+from .data_import import DataImportSchema
+
+
+class ImportDataFileRefField(FileRefField):
+ def _deserialize(self, value, attr, data, **kwargs) -> "DataImport":
+ # Get component info from component yaml file.
+ data = super()._deserialize(value, attr, data, **kwargs)
+ data_import_dict = yaml.safe_load(data)
+
+ from azure.ai.ml.entities._data_import.data_import import DataImport
+
+ return DataImport._load(
+ data=data_import_dict,
+ yaml_path=self.context[BASE_PATH_CONTEXT_KEY] / value,
+ **kwargs,
+ )
+
+
+@experimental
+class ImportDataScheduleSchema(ScheduleSchema):
+ import_data = UnionField(
+ [
+ ImportDataFileRefField,
+ NestedField(DataImportSchema),
+ ]
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/__init__.py
new file mode 100644
index 00000000..18774380
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/__init__.py
@@ -0,0 +1,30 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore
+
+from .adls_gen1 import AzureDataLakeGen1Schema
+from .azure_storage import AzureBlobSchema, AzureDataLakeGen2Schema, AzureFileSchema, AzureStorageSchema
+from .credentials import (
+ AccountKeySchema,
+ BaseTenantCredentialSchema,
+ CertificateSchema,
+ NoneCredentialsSchema,
+ SasTokenSchema,
+ ServicePrincipalSchema,
+)
+
+__all__ = [
+ "AccountKeySchema",
+ "AzureBlobSchema",
+ "AzureDataLakeGen1Schema",
+ "AzureDataLakeGen2Schema",
+ "AzureFileSchema",
+ "AzureStorageSchema",
+ "BaseTenantCredentialSchema",
+ "CertificateSchema",
+ "NoneCredentialsSchema",
+ "SasTokenSchema",
+ "ServicePrincipalSchema",
+]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/_on_prem.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/_on_prem.py
new file mode 100644
index 00000000..1f0a9710
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/_on_prem.py
@@ -0,0 +1,40 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from typing import Any, Dict
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._restclient.v2022_10_01_preview.models import DatastoreType
+from azure.ai.ml._schema.core.fields import NestedField, PathAwareSchema, StringTransformedEnum, UnionField
+from azure.ai.ml._utils.utils import camel_to_snake
+
+from ._on_prem_credentials import KerberosKeytabSchema, KerberosPasswordSchema
+
+
+class HdfsSchema(PathAwareSchema):
+ name = fields.Str(required=True)
+ id = fields.Str(dump_only=True)
+ type = StringTransformedEnum(
+ allowed_values=DatastoreType.HDFS,
+ casing_transform=camel_to_snake,
+ required=True,
+ )
+ hdfs_server_certificate = fields.Str()
+ name_node_address = fields.Str(required=True)
+ protocol = fields.Str()
+ credentials = UnionField(
+ [NestedField(KerberosPasswordSchema), NestedField(KerberosKeytabSchema)],
+ required=True,
+ )
+ description = fields.Str()
+ tags = fields.Dict(keys=fields.Str(), values=fields.Dict())
+
+ @post_load
+ def make(self, data: Dict[str, Any], **kwargs) -> "HdfsDatastore":
+ from azure.ai.ml.entities._datastore._on_prem import HdfsDatastore
+
+ return HdfsDatastore(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/_on_prem_credentials.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/_on_prem_credentials.py
new file mode 100644
index 00000000..ada92afc
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/_on_prem_credentials.py
@@ -0,0 +1,53 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from typing import Dict
+
+from marshmallow import ValidationError, fields, post_load, pre_dump
+
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+
+
+class BaseKerberosCredentials(metaclass=PatchedSchemaMeta):
+ kerberos_realm = fields.Str(required=True)
+ kerberos_kdc_address = fields.Str(required=True)
+ kerberos_principal = fields.Str(required=True)
+
+
+class KerberosPasswordSchema(BaseKerberosCredentials):
+ kerberos_password = fields.Str(required=True)
+
+ @post_load
+ def make(self, data: Dict[str, str], **kwargs) -> "KerberosPasswordCredentials":
+ from azure.ai.ml.entities._datastore._on_prem_credentials import KerberosPasswordCredentials
+
+ return KerberosPasswordCredentials(**data)
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ from azure.ai.ml.entities._datastore._on_prem_credentials import KerberosPasswordCredentials
+
+ if not isinstance(data, KerberosPasswordCredentials):
+ raise ValidationError("Cannot dump non-KerberosPasswordCredentials object into KerberosPasswordCredentials")
+ return data
+
+
+class KerberosKeytabSchema(BaseKerberosCredentials):
+ kerberos_keytab = fields.Str(required=True)
+
+ @post_load
+ def make(self, data: Dict[str, str], **kwargs) -> "KerberosKeytabCredentials":
+ from azure.ai.ml.entities._datastore._on_prem_credentials import KerberosKeytabCredentials
+
+ return KerberosKeytabCredentials(**data)
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ from azure.ai.ml.entities._datastore._on_prem_credentials import KerberosKeytabCredentials
+
+ if not isinstance(data, KerberosKeytabCredentials):
+ raise ValidationError("Cannot dump non-KerberosKeytabCredentials object into KerberosKeytabCredentials")
+ return data
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/adls_gen1.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/adls_gen1.py
new file mode 100644
index 00000000..7a575fc6
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/adls_gen1.py
@@ -0,0 +1,41 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from typing import Any, Dict
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._restclient.v2022_10_01.models import DatastoreType
+from azure.ai.ml._schema.core.fields import NestedField, PathAwareSchema, StringTransformedEnum, UnionField
+from azure.ai.ml._utils.utils import camel_to_snake
+
+from .credentials import CertificateSchema, NoneCredentialsSchema, ServicePrincipalSchema
+
+
+class AzureDataLakeGen1Schema(PathAwareSchema):
+ name = fields.Str(required=True)
+ id = fields.Str(dump_only=True)
+ type = StringTransformedEnum(
+ allowed_values=DatastoreType.AZURE_DATA_LAKE_GEN1,
+ casing_transform=camel_to_snake,
+ required=True,
+ )
+ store_name = fields.Str(required=True)
+ credentials = UnionField(
+ [
+ NestedField(ServicePrincipalSchema),
+ NestedField(CertificateSchema),
+ NestedField(NoneCredentialsSchema),
+ ]
+ )
+ description = fields.Str()
+ tags = fields.Dict(keys=fields.Str(), values=fields.Dict())
+
+ @post_load
+ def make(self, data: Dict[str, Any], **kwargs) -> "AzureDataLakeGen1Datastore":
+ from azure.ai.ml.entities import AzureDataLakeGen1Datastore
+
+ return AzureDataLakeGen1Datastore(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/azure_storage.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/azure_storage.py
new file mode 100644
index 00000000..ffe8c61c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/azure_storage.py
@@ -0,0 +1,97 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from typing import Any, Dict
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._restclient.v2022_10_01.models import DatastoreType
+from azure.ai.ml._schema.core.fields import NestedField, PathAwareSchema, StringTransformedEnum, UnionField
+from azure.ai.ml._utils.utils import camel_to_snake
+
+from .credentials import (
+ AccountKeySchema,
+ CertificateSchema,
+ NoneCredentialsSchema,
+ SasTokenSchema,
+ ServicePrincipalSchema,
+)
+
+
+class AzureStorageSchema(PathAwareSchema):
+ name = fields.Str(required=True)
+ id = fields.Str(dump_only=True)
+ account_name = fields.Str(required=True)
+ endpoint = fields.Str()
+ protocol = fields.Str()
+ description = fields.Str()
+ tags = fields.Dict(keys=fields.Str(), values=fields.Str())
+
+
+class AzureFileSchema(AzureStorageSchema):
+ type = StringTransformedEnum(
+ allowed_values=DatastoreType.AZURE_FILE,
+ casing_transform=camel_to_snake,
+ required=True,
+ )
+ file_share_name = fields.Str(required=True)
+ credentials = UnionField(
+ [
+ NestedField(AccountKeySchema),
+ NestedField(SasTokenSchema),
+ NestedField(NoneCredentialsSchema),
+ ]
+ )
+
+ @post_load
+ def make(self, data: Dict[str, Any], **kwargs) -> "AzureFileDatastore": # type: ignore[name-defined]
+ from azure.ai.ml.entities import AzureFileDatastore
+
+ return AzureFileDatastore(**data)
+
+
+class AzureBlobSchema(AzureStorageSchema):
+ type = StringTransformedEnum(
+ allowed_values=DatastoreType.AZURE_BLOB,
+ casing_transform=camel_to_snake,
+ required=True,
+ )
+ container_name = fields.Str(required=True)
+ credentials = UnionField(
+ [
+ NestedField(AccountKeySchema),
+ NestedField(SasTokenSchema),
+ NestedField(NoneCredentialsSchema),
+ ],
+ )
+
+ @post_load
+ def make(self, data: Dict[str, Any], **kwargs) -> "AzureBlobDatastore": # type: ignore[name-defined]
+ from azure.ai.ml.entities import AzureBlobDatastore
+
+ return AzureBlobDatastore(**data)
+
+
+class AzureDataLakeGen2Schema(AzureStorageSchema):
+ type = StringTransformedEnum(
+ allowed_values=DatastoreType.AZURE_DATA_LAKE_GEN2,
+ casing_transform=camel_to_snake,
+ required=True,
+ )
+ filesystem = fields.Str(required=True)
+ credentials = UnionField(
+ [
+ NestedField(ServicePrincipalSchema),
+ NestedField(CertificateSchema),
+ NestedField(NoneCredentialsSchema),
+ ]
+ )
+
+ @post_load
+ def make(self, data: Dict[str, Any], **kwargs) -> "AzureDataLakeGen2Datastore":
+ from azure.ai.ml.entities import AzureDataLakeGen2Datastore
+
+ return AzureDataLakeGen2Datastore(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/credentials.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/credentials.py
new file mode 100644
index 00000000..a4b46aa0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/credentials.py
@@ -0,0 +1,99 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from typing import Any, Dict
+
+from marshmallow import ValidationError, fields, post_load, pre_dump, pre_load
+
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml.entities._credentials import (
+ AccountKeyConfiguration,
+ CertificateConfiguration,
+ NoneCredentialConfiguration,
+ SasTokenConfiguration,
+ ServicePrincipalConfiguration,
+)
+
+
+class NoneCredentialsSchema(metaclass=PatchedSchemaMeta):
+ @post_load
+ def make(self, data: Dict[str, str], **kwargs) -> NoneCredentialConfiguration:
+ return NoneCredentialConfiguration(**data)
+
+
+class AccountKeySchema(metaclass=PatchedSchemaMeta):
+ account_key = fields.Str(required=True)
+
+ @post_load
+ def make(self, data: Dict[str, str], **kwargs) -> AccountKeyConfiguration:
+ return AccountKeyConfiguration(**data)
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ if not isinstance(data, AccountKeyConfiguration):
+ raise ValidationError("Cannot dump non-AccountKeyCredentials object into AccountKeyCredentials")
+ return data
+
+
+class SasTokenSchema(metaclass=PatchedSchemaMeta):
+ sas_token = fields.Str(required=True)
+
+ @post_load
+ def make(self, data: Dict[str, str], **kwargs) -> SasTokenConfiguration:
+ return SasTokenConfiguration(**data)
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ if not isinstance(data, SasTokenConfiguration):
+ raise ValidationError("Cannot dump non-SasTokenCredentials object into SasTokenCredentials")
+ return data
+
+
+class BaseTenantCredentialSchema(metaclass=PatchedSchemaMeta):
+ authority_url = fields.Str()
+ resource_url = fields.Str()
+ tenant_id = fields.Str(required=True)
+ client_id = fields.Str(required=True)
+
+ @pre_load
+ def accept_backward_compatible_keys(self, data, **kwargs):
+ acceptable_keys = [key for key in data.keys() if key in ("authority_url", "authority_uri")]
+ if len(acceptable_keys) > 1:
+ raise ValidationError(
+ "Cannot specify both 'authority_url' and 'authority_uri'. Please use 'authority_url'."
+ )
+ if acceptable_keys:
+ data["authority_url"] = data.pop(acceptable_keys[0])
+ return data
+
+
+class ServicePrincipalSchema(BaseTenantCredentialSchema):
+ client_secret = fields.Str(required=True)
+
+ @post_load
+ def make(self, data: Dict[str, str], **kwargs) -> ServicePrincipalConfiguration:
+ return ServicePrincipalConfiguration(**data)
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ if not isinstance(data, ServicePrincipalConfiguration):
+ raise ValidationError("Cannot dump non-ServicePrincipalCredentials object into ServicePrincipalCredentials")
+ return data
+
+
+class CertificateSchema(BaseTenantCredentialSchema):
+ certificate = fields.Str()
+ thumbprint = fields.Str(required=True)
+
+ @post_load
+ def make(self, data: Dict[str, Any], **kwargs) -> CertificateConfiguration:
+ return CertificateConfiguration(**data)
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ if not isinstance(data, CertificateConfiguration):
+ raise ValidationError("Cannot dump non-CertificateCredentials object into CertificateCredentials")
+ return data
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/one_lake.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/one_lake.py
new file mode 100644
index 00000000..4b5e7b66
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_datastore/one_lake.py
@@ -0,0 +1,49 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from typing import Any, Dict
+
+from marshmallow import Schema, fields, post_load
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import DatastoreType, OneLakeArtifactType
+from azure.ai.ml._schema.core.fields import NestedField, PathAwareSchema, StringTransformedEnum, UnionField
+from azure.ai.ml._utils.utils import camel_to_snake
+
+from .credentials import NoneCredentialsSchema, ServicePrincipalSchema
+
+
+class OneLakeArtifactSchema(Schema):
+ name = fields.Str(required=True)
+ type = StringTransformedEnum(allowed_values=OneLakeArtifactType.LAKE_HOUSE, casing_transform=camel_to_snake)
+
+
+class OneLakeSchema(PathAwareSchema):
+ name = fields.Str(required=True)
+ id = fields.Str(dump_only=True)
+ type = StringTransformedEnum(
+ allowed_values=DatastoreType.ONE_LAKE,
+ casing_transform=camel_to_snake,
+ required=True,
+ )
+ # required fields for OneLake
+ one_lake_workspace_name = fields.Str(required=True)
+ endpoint = fields.Str(required=True)
+ artifact = NestedField(OneLakeArtifactSchema)
+ # ServicePrincipal and UserIdentity are the two supported credential types
+ credentials = UnionField(
+ [
+ NestedField(ServicePrincipalSchema),
+ NestedField(NoneCredentialsSchema),
+ ]
+ )
+ description = fields.Str()
+ tags = fields.Dict(keys=fields.Str(), values=fields.Str())
+
+ @post_load
+ def make(self, data: Dict[str, Any], **kwargs) -> "OneLakeDatastore":
+ from azure.ai.ml.entities import OneLakeDatastore
+
+ return OneLakeDatastore(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/__init__.py
new file mode 100644
index 00000000..fdf8caba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/__init__.py
new file mode 100644
index 00000000..fdf8caba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_deployment.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_deployment.py
new file mode 100644
index 00000000..7a69176b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_deployment.py
@@ -0,0 +1,92 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument,no-else-return
+
+import logging
+from typing import Any
+
+from marshmallow import fields, post_load
+from marshmallow.exceptions import ValidationError
+from azure.ai.ml._schema import (
+ UnionField,
+ ArmVersionedStr,
+ ArmStr,
+ RegistryStr,
+)
+from azure.ai.ml._schema._deployment.deployment import DeploymentSchema
+from azure.ai.ml._schema.core.fields import ComputeField, NestedField, StringTransformedEnum
+from azure.ai.ml._schema.job.creation_context import CreationContextSchema
+from azure.ai.ml._schema.pipeline.pipeline_component import PipelineComponentFileRefField
+from azure.ai.ml.constants._common import AzureMLResourceType
+from azure.ai.ml._schema.job_resource_configuration import JobResourceConfigurationSchema
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+from azure.ai.ml.constants._deployment import BatchDeploymentOutputAction, BatchDeploymentType
+
+from .batch_deployment_settings import BatchRetrySettingsSchema
+
+module_logger = logging.getLogger(__name__)
+
+
+class BatchDeploymentSchema(DeploymentSchema):
+ compute = ComputeField(required=False)
+ error_threshold = fields.Int(
+ metadata={
+ "description": """Error threshold, if the error count for the entire input goes above this value,\r\n
+ the batch inference will be aborted. Range is [-1, int.MaxValue].\r\n
+ For FileDataset, this value is the count of file failures.\r\n
+ For TabularDataset, this value is the count of record failures.\r\n
+ If set to -1 (the lower bound), all failures during batch inference will be ignored."""
+ }
+ )
+ retry_settings = NestedField(BatchRetrySettingsSchema)
+ mini_batch_size = fields.Int()
+ logging_level = fields.Str(
+ metadata={
+ "description": """A string of the logging level name, which is defined in 'logging'.
+ Possible values are 'warning', 'info', and 'debug'."""
+ }
+ )
+ output_action = StringTransformedEnum(
+ allowed_values=[
+ BatchDeploymentOutputAction.APPEND_ROW,
+ BatchDeploymentOutputAction.SUMMARY_ONLY,
+ ],
+ metadata={"description": "Indicates how batch inferencing will handle output."},
+ dump_default=BatchDeploymentOutputAction.APPEND_ROW,
+ )
+ output_file_name = fields.Str(metadata={"description": "Customized output file name for append_row output action."})
+ max_concurrency_per_instance = fields.Int(
+ metadata={"description": "Indicates maximum number of parallelism per instance."}
+ )
+ resources = NestedField(JobResourceConfigurationSchema)
+ type = StringTransformedEnum(
+ allowed_values=[BatchDeploymentType.PIPELINE, BatchDeploymentType.MODEL], required=False
+ )
+
+ job_definition = ArmStr(azureml_type=AzureMLResourceType.JOB)
+ component = UnionField(
+ [
+ RegistryStr(azureml_type=AzureMLResourceType.COMPONENT),
+ ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True),
+ PipelineComponentFileRefField(),
+ ]
+ )
+ creation_context = NestedField(CreationContextSchema, dump_only=True)
+ provisioning_state = fields.Str(dump_only=True)
+
+ @post_load
+ def make(self, data: Any, **kwargs: Any) -> Any:
+ from azure.ai.ml.entities import BatchDeployment, ModelBatchDeployment, PipelineComponentBatchDeployment
+
+ if "type" not in data:
+ return BatchDeployment(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data)
+ elif data["type"] == BatchDeploymentType.PIPELINE:
+ return PipelineComponentBatchDeployment(**data)
+ elif data["type"] == BatchDeploymentType.MODEL:
+ return ModelBatchDeployment(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data)
+ else:
+ raise ValidationError(
+ "Deployment type must be of type " + f"{BatchDeploymentType.PIPELINE} or {BatchDeploymentType.MODEL}."
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_deployment_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_deployment_settings.py
new file mode 100644
index 00000000..2a36352c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_deployment_settings.py
@@ -0,0 +1,26 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+from typing import Any
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml.entities._deployment.deployment_settings import BatchRetrySettings
+
+module_logger = logging.getLogger(__name__)
+
+
+class BatchRetrySettingsSchema(metaclass=PatchedSchemaMeta):
+ max_retries = fields.Int(
+ metadata={"description": "The number of maximum tries for a failed or timeout mini batch."},
+ )
+ timeout = fields.Int(metadata={"description": "The timeout for a mini batch."})
+
+ @post_load
+ def make(self, data: Any, **kwargs: Any) -> BatchRetrySettings:
+ return BatchRetrySettings(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_job.py
new file mode 100644
index 00000000..a1496f1e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_job.py
@@ -0,0 +1,132 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument,protected-access
+
+from typing import Any
+
+from marshmallow import fields
+from marshmallow.decorators import post_load
+
+from azure.ai.ml._restclient.v2020_09_01_dataplanepreview.models import (
+ BatchJob,
+ CustomModelJobInput,
+ CustomModelJobOutput,
+ DataVersion,
+ LiteralJobInput,
+ MLFlowModelJobInput,
+ MLFlowModelJobOutput,
+ MLTableJobInput,
+ MLTableJobOutput,
+ TritonModelJobInput,
+ TritonModelJobOutput,
+ UriFileJobInput,
+ UriFileJobOutput,
+ UriFolderJobInput,
+ UriFolderJobOutput,
+)
+from azure.ai.ml._schema.core.fields import ArmStr, NestedField
+from azure.ai.ml._schema.core.schema import PathAwareSchema
+from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
+from azure.ai.ml.constants import AssetTypes
+from azure.ai.ml.constants._common import AzureMLResourceType, InputTypes
+from azure.ai.ml.constants._endpoint import EndpointYamlFields
+from azure.ai.ml.entities import ComputeConfiguration
+from azure.ai.ml.entities._inputs_outputs import Input, Output
+
+from .batch_deployment_settings import BatchRetrySettingsSchema
+from .compute_binding import ComputeBindingSchema
+
+
+class OutputDataSchema(metaclass=PatchedSchemaMeta):
+ datastore_id = ArmStr(azureml_type=AzureMLResourceType.DATASTORE)
+ path = fields.Str()
+
+ @post_load
+ def make(self, data: Any, **kwargs: Any) -> Any:
+ return DataVersion(**data)
+
+
+class BatchJobSchema(PathAwareSchema):
+ compute = NestedField(ComputeBindingSchema)
+ dataset = fields.Str()
+ error_threshold = fields.Int()
+ input_data = fields.Dict()
+ mini_batch_size = fields.Int()
+ name = fields.Str(data_key="job_name")
+ output_data = fields.Dict()
+ output_dataset = NestedField(OutputDataSchema)
+ output_file_name = fields.Str()
+ retry_settings = NestedField(BatchRetrySettingsSchema)
+ properties = fields.Dict(data_key="properties")
+
+ @post_load
+ def make(self, data: Any, **kwargs: Any) -> Any: # pylint: disable=too-many-branches
+ if data.get(EndpointYamlFields.BATCH_JOB_INPUT_DATA, None):
+ for key, input_data in data[EndpointYamlFields.BATCH_JOB_INPUT_DATA].items():
+ if isinstance(input_data, Input):
+ if input_data.type == AssetTypes.URI_FILE:
+ data[EndpointYamlFields.BATCH_JOB_INPUT_DATA][key] = UriFileJobInput(uri=input_data.path)
+ if input_data.type == AssetTypes.URI_FOLDER:
+ data[EndpointYamlFields.BATCH_JOB_INPUT_DATA][key] = UriFolderJobInput(uri=input_data.path)
+ if input_data.type == AssetTypes.TRITON_MODEL:
+ data[EndpointYamlFields.BATCH_JOB_INPUT_DATA][key] = TritonModelJobInput(
+ mode=input_data.mode, uri=input_data.path
+ )
+ if input_data.type == AssetTypes.MLFLOW_MODEL:
+ data[EndpointYamlFields.BATCH_JOB_INPUT_DATA][key] = MLFlowModelJobInput(
+ mode=input_data.mode, uri=input_data.path
+ )
+ if input_data.type == AssetTypes.MLTABLE:
+ data[EndpointYamlFields.BATCH_JOB_INPUT_DATA][key] = MLTableJobInput(
+ mode=input_data.mode, uri=input_data.path
+ )
+ if input_data.type == AssetTypes.CUSTOM_MODEL:
+ data[EndpointYamlFields.BATCH_JOB_INPUT_DATA][key] = CustomModelJobInput(
+ mode=input_data.mode, uri=input_data.path
+ )
+ if input_data.type in {
+ InputTypes.INTEGER,
+ InputTypes.NUMBER,
+ InputTypes.STRING,
+ InputTypes.BOOLEAN,
+ }:
+ data[EndpointYamlFields.BATCH_JOB_INPUT_DATA][key] = LiteralJobInput(value=input_data.default)
+ if data.get(EndpointYamlFields.BATCH_JOB_OUTPUT_DATA, None):
+ for key, output_data in data[EndpointYamlFields.BATCH_JOB_OUTPUT_DATA].items():
+ if isinstance(output_data, Output):
+ if output_data.type == AssetTypes.URI_FILE:
+ data[EndpointYamlFields.BATCH_JOB_OUTPUT_DATA][key] = UriFileJobOutput(
+ mode=output_data.mode, uri=output_data.path
+ )
+ if output_data.type == AssetTypes.URI_FOLDER:
+ data[EndpointYamlFields.BATCH_JOB_OUTPUT_DATA][key] = UriFolderJobOutput(
+ mode=output_data.mode, uri=output_data.path
+ )
+ if output_data.type == AssetTypes.TRITON_MODEL:
+ data[EndpointYamlFields.BATCH_JOB_OUTPUT_DATA][key] = TritonModelJobOutput(
+ mode=output_data.mode, uri=output_data.path
+ )
+ if output_data.type == AssetTypes.MLFLOW_MODEL:
+ data[EndpointYamlFields.BATCH_JOB_OUTPUT_DATA][key] = MLFlowModelJobOutput(
+ mode=output_data.mode, uri=output_data.path
+ )
+ if output_data.type == AssetTypes.MLTABLE:
+ data[EndpointYamlFields.BATCH_JOB_OUTPUT_DATA][key] = MLTableJobOutput(
+ mode=output_data.mode, uri=output_data.path
+ )
+ if output_data.type == AssetTypes.CUSTOM_MODEL:
+ data[EndpointYamlFields.BATCH_JOB_OUTPUT_DATA][key] = CustomModelJobOutput(
+ mode=output_data.mode, uri=output_data.path
+ )
+
+ if data.get(EndpointYamlFields.COMPUTE, None):
+ data[EndpointYamlFields.COMPUTE] = ComputeConfiguration(
+ **data[EndpointYamlFields.COMPUTE]
+ )._to_rest_object()
+
+ if data.get(EndpointYamlFields.RETRY_SETTINGS, None):
+ data[EndpointYamlFields.RETRY_SETTINGS] = data[EndpointYamlFields.RETRY_SETTINGS]._to_rest_object()
+
+ return BatchJob(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_pipeline_component_deployment_configurations_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_pipeline_component_deployment_configurations_schema.py
new file mode 100644
index 00000000..f0b22fd7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_pipeline_component_deployment_configurations_schema.py
@@ -0,0 +1,52 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+
+import logging
+from typing import Any
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema import (
+ ArmVersionedStr,
+ PatchedSchemaMeta,
+ StringTransformedEnum,
+ UnionField,
+ ArmStr,
+ RegistryStr,
+)
+from azure.ai.ml._schema.pipeline.pipeline_component import PipelineComponentFileRefField
+from azure.ai.ml.constants._common import AzureMLResourceType
+from azure.ai.ml.constants._job.job import JobType
+
+module_logger = logging.getLogger(__name__)
+
+
+# pylint: disable-next=name-too-long
+class BatchPipelineComponentDeploymentConfiguarationsSchema(metaclass=PatchedSchemaMeta):
+ component_id = fields.Str()
+ job = UnionField(
+ [
+ ArmStr(azureml_type=AzureMLResourceType.JOB),
+ PipelineComponentFileRefField(),
+ ]
+ )
+ component = UnionField(
+ [
+ RegistryStr(azureml_type=AzureMLResourceType.COMPONENT),
+ ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True),
+ PipelineComponentFileRefField(),
+ ]
+ )
+ type = StringTransformedEnum(required=True, allowed_values=[JobType.PIPELINE])
+ settings = fields.Dict()
+ name = fields.Str()
+ description = fields.Str()
+ tags = fields.Dict()
+
+ @post_load
+ def make(self, data: Any, **kwargs: Any) -> Any: # pylint: disable=unused-argument
+ from azure.ai.ml.entities._deployment.job_definition import JobDefinition
+
+ return JobDefinition(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/compute_binding.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/compute_binding.py
new file mode 100644
index 00000000..2e4b0348
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/compute_binding.py
@@ -0,0 +1,36 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+from typing import Any
+
+from marshmallow import ValidationError, fields, validates_schema
+
+from azure.ai.ml._schema.core.fields import ArmStr, StringTransformedEnum, UnionField
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml.constants._common import LOCAL_COMPUTE_TARGET, AzureMLResourceType
+
+module_logger = logging.getLogger(__name__)
+
+
+class ComputeBindingSchema(metaclass=PatchedSchemaMeta):
+ target = UnionField(
+ [
+ StringTransformedEnum(allowed_values=[LOCAL_COMPUTE_TARGET]),
+ ArmStr(azureml_type=AzureMLResourceType.COMPUTE),
+ # Case for virtual clusters
+ ArmStr(azureml_type=AzureMLResourceType.VIRTUALCLUSTER),
+ ]
+ )
+ instance_count = fields.Integer()
+ instance_type = fields.Str(metadata={"description": "The instance type to make available to this job."})
+ location = fields.Str(metadata={"description": "The locations where this job may run."})
+ properties = fields.Dict(keys=fields.Str())
+
+ @validates_schema
+ def validate(self, data: Any, **kwargs):
+ if data.get("target") == LOCAL_COMPUTE_TARGET and data.get("instance_count", 1) != 1:
+ raise ValidationError("Local runs must have node count of 1.")
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/job_definition_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/job_definition_schema.py
new file mode 100644
index 00000000..269f1da7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/job_definition_schema.py
@@ -0,0 +1,51 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+
+import logging
+from typing import Any
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema import (
+ ArmVersionedStr,
+ PatchedSchemaMeta,
+ StringTransformedEnum,
+ UnionField,
+ ArmStr,
+ RegistryStr,
+)
+from azure.ai.ml._schema.pipeline.pipeline_component import PipelineComponentFileRefField
+from azure.ai.ml.constants._common import AzureMLResourceType
+from azure.ai.ml.constants._job.job import JobType
+
+module_logger = logging.getLogger(__name__)
+
+
+class JobDefinitionSchema(metaclass=PatchedSchemaMeta):
+ component_id = fields.Str()
+ job = UnionField(
+ [
+ ArmStr(azureml_type=AzureMLResourceType.JOB),
+ PipelineComponentFileRefField(),
+ ]
+ )
+ component = UnionField(
+ [
+ RegistryStr(azureml_type=AzureMLResourceType.COMPONENT),
+ ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True),
+ PipelineComponentFileRefField(),
+ ]
+ )
+ type = StringTransformedEnum(required=True, allowed_values=[JobType.PIPELINE])
+ settings = fields.Dict()
+ name = fields.Str()
+ description = fields.Str()
+ tags = fields.Dict()
+
+ @post_load
+ def make(self, data: Any, **kwargs: Any) -> Any: # pylint: disable=unused-argument
+ from azure.ai.ml.entities._deployment.job_definition import JobDefinition
+
+ return JobDefinition(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/model_batch_deployment.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/model_batch_deployment.py
new file mode 100644
index 00000000..0dbd8463
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/model_batch_deployment.py
@@ -0,0 +1,46 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+from typing import Any
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.fields import ComputeField, NestedField, StringTransformedEnum
+from azure.ai.ml._schema.job_resource_configuration import JobResourceConfigurationSchema
+from azure.ai.ml._schema._deployment.deployment import DeploymentSchema
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+from azure.ai.ml.constants._deployment import BatchDeploymentType
+from azure.ai.ml._schema import ExperimentalField
+from .model_batch_deployment_settings import ModelBatchDeploymentSettingsSchema
+
+
+module_logger = logging.getLogger(__name__)
+
+
+class ModelBatchDeploymentSchema(DeploymentSchema):
+ compute = ComputeField(required=True)
+ error_threshold = fields.Int(
+ metadata={
+ "description": """Error threshold, if the error count for the entire input goes above this value,\r\n
+ the batch inference will be aborted. Range is [-1, int.MaxValue].\r\n
+ For FileDataset, this value is the count of file failures.\r\n
+ For TabularDataset, this value is the count of record failures.\r\n
+ If set to -1 (the lower bound), all failures during batch inference will be ignored."""
+ }
+ )
+ resources = NestedField(JobResourceConfigurationSchema)
+ type = StringTransformedEnum(
+ allowed_values=[BatchDeploymentType.PIPELINE, BatchDeploymentType.MODEL], required=False
+ )
+
+ settings = ExperimentalField(NestedField(ModelBatchDeploymentSettingsSchema))
+
+ @post_load
+ def make(self, data: Any, **kwargs: Any) -> Any:
+ from azure.ai.ml.entities import ModelBatchDeployment
+
+ return ModelBatchDeployment(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/model_batch_deployment_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/model_batch_deployment_settings.py
new file mode 100644
index 00000000..e1945751
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/model_batch_deployment_settings.py
@@ -0,0 +1,56 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import logging
+from typing import Any
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema import PatchedSchemaMeta
+from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum
+from azure.ai.ml.constants._deployment import BatchDeploymentOutputAction
+
+from .batch_deployment_settings import BatchRetrySettingsSchema
+
+module_logger = logging.getLogger(__name__)
+
+
+class ModelBatchDeploymentSettingsSchema(metaclass=PatchedSchemaMeta):
+ error_threshold = fields.Int(
+ metadata={
+ "description": """Error threshold, if the error count for the entire input goes above this value,\r\n
+ the batch inference will be aborted. Range is [-1, int.MaxValue].\r\n
+ For FileDataset, this value is the count of file failures.\r\n
+ For TabularDataset, this value is the count of record failures.\r\n
+ If set to -1 (the lower bound), all failures during batch inference will be ignored."""
+ }
+ )
+ instance_count = fields.Int()
+ retry_settings = NestedField(BatchRetrySettingsSchema)
+ mini_batch_size = fields.Int()
+ logging_level = fields.Str(
+ metadata={
+ "description": """A string of the logging level name, which is defined in 'logging'.
+ Possible values are 'warning', 'info', and 'debug'."""
+ }
+ )
+ output_action = StringTransformedEnum(
+ allowed_values=[
+ BatchDeploymentOutputAction.APPEND_ROW,
+ BatchDeploymentOutputAction.SUMMARY_ONLY,
+ ],
+ metadata={"description": "Indicates how batch inferencing will handle output."},
+ dump_default=BatchDeploymentOutputAction.APPEND_ROW,
+ )
+ output_file_name = fields.Str(metadata={"description": "Customized output file name for append_row output action."})
+ max_concurrency_per_instance = fields.Int(
+ metadata={"description": "Indicates maximum number of parallelism per instance."}
+ )
+ environment_variables = fields.Dict()
+
+ @post_load
+ def make(self, data: Any, **kwargs: Any) -> Any: # pylint: disable=unused-argument
+ from azure.ai.ml.entities import ModelBatchDeploymentSettings
+
+ return ModelBatchDeploymentSettings(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/pipeline_component_batch_deployment_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/pipeline_component_batch_deployment_schema.py
new file mode 100644
index 00000000..4bc884b0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/pipeline_component_batch_deployment_schema.py
@@ -0,0 +1,70 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+
+import logging
+from typing import Any
+
+from marshmallow import INCLUDE, fields, post_load
+
+from azure.ai.ml._schema import (
+ ArmVersionedStr,
+ ArmStr,
+ UnionField,
+ RegistryStr,
+ NestedField,
+)
+from azure.ai.ml._schema.core.fields import PipelineNodeNameStr, TypeSensitiveUnionField, PathAwareSchema
+from azure.ai.ml._schema.pipeline.pipeline_component import PipelineComponentFileRefField
+from azure.ai.ml.constants._common import AzureMLResourceType
+from azure.ai.ml.constants._component import NodeType
+
+module_logger = logging.getLogger(__name__)
+
+
+class PipelineComponentBatchDeploymentSchema(PathAwareSchema):
+ name = fields.Str()
+ endpoint_name = fields.Str()
+ component = UnionField(
+ [
+ RegistryStr(azureml_type=AzureMLResourceType.COMPONENT),
+ ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True),
+ PipelineComponentFileRefField(),
+ ]
+ )
+ settings = fields.Dict()
+ name = fields.Str()
+ type = fields.Str()
+ job_definition = UnionField(
+ [
+ ArmStr(azureml_type=AzureMLResourceType.JOB),
+ NestedField("PipelineSchema", unknown=INCLUDE),
+ ]
+ )
+ tags = fields.Dict()
+ description = fields.Str(metadata={"description": "Description of the endpoint deployment."})
+
+ @post_load
+ def make(self, data: Any, **kwargs: Any) -> Any: # pylint: disable=unused-argument
+ from azure.ai.ml.entities._deployment.pipeline_component_batch_deployment import (
+ PipelineComponentBatchDeployment,
+ )
+
+ return PipelineComponentBatchDeployment(**data)
+
+
+class NodeNameStr(PipelineNodeNameStr):
+ def _get_field_name(self) -> str:
+ return "Pipeline node"
+
+
+def PipelineJobsField():
+ pipeline_enable_job_type = {NodeType.PIPELINE: [NestedField("PipelineSchema", unknown=INCLUDE)]}
+
+ pipeline_job_field = fields.Dict(
+ keys=NodeNameStr(),
+ values=TypeSensitiveUnionField(pipeline_enable_job_type),
+ )
+
+ return pipeline_job_field
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/run_settings_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/run_settings_schema.py
new file mode 100644
index 00000000..54661ada
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/run_settings_schema.py
@@ -0,0 +1,28 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+
+import logging
+from typing import Any
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema import PatchedSchemaMeta
+
+module_logger = logging.getLogger(__name__)
+
+
+class RunSettingsSchema(metaclass=PatchedSchemaMeta):
+ name = fields.Str()
+ display_name = fields.Str()
+ experiment_name = fields.Str()
+ description = fields.Str()
+ tags = fields.Dict()
+ settings = fields.Dict()
+
+ @post_load
+ def make(self, data: Any, **kwargs: Any) -> Any: # pylint: disable=unused-argument
+ from azure.ai.ml.entities._deployment.run_settings import RunSettings
+
+ return RunSettings(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/code_configuration_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/code_configuration_schema.py
new file mode 100644
index 00000000..e9b3eac4
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/code_configuration_schema.py
@@ -0,0 +1,25 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+from typing import Any
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.schema import PathAwareSchema
+
+module_logger = logging.getLogger(__name__)
+
+
+class CodeConfigurationSchema(PathAwareSchema):
+ code = fields.Str()
+ scoring_script = fields.Str()
+
+ @post_load
+ def make(self, data: Any, **kwargs: Any) -> Any:
+ from azure.ai.ml.entities import CodeConfiguration
+
+ return CodeConfiguration(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/deployment.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/deployment.py
new file mode 100644
index 00000000..669a96ad
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/deployment.py
@@ -0,0 +1,48 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import logging
+
+from marshmallow import fields
+
+from azure.ai.ml._schema.assets.environment import AnonymousEnvironmentSchema, EnvironmentSchema
+from azure.ai.ml._schema.assets.model import AnonymousModelSchema
+from azure.ai.ml._schema.core.fields import ArmVersionedStr, NestedField, PathAwareSchema, RegistryStr, UnionField
+from azure.ai.ml.constants._common import AzureMLResourceType
+
+from .code_configuration_schema import CodeConfigurationSchema
+
+module_logger = logging.getLogger(__name__)
+
+
+class DeploymentSchema(PathAwareSchema):
+ name = fields.Str(required=True)
+ endpoint_name = fields.Str(required=True)
+ description = fields.Str(metadata={"description": "Description of the endpoint deployment."})
+ id = fields.Str()
+ tags = fields.Dict()
+ properties = fields.Dict()
+ model = UnionField(
+ [
+ RegistryStr(azureml_type=AzureMLResourceType.MODEL),
+ ArmVersionedStr(azureml_type=AzureMLResourceType.MODEL, allow_default_version=True),
+ NestedField(AnonymousModelSchema),
+ ],
+ metadata={"description": "Reference to the model asset for the endpoint deployment."},
+ )
+ code_configuration = NestedField(
+ CodeConfigurationSchema,
+ metadata={"description": "Code configuration for the endpoint deployment."},
+ )
+ environment = UnionField(
+ [
+ RegistryStr(azureml_type=AzureMLResourceType.ENVIRONMENT),
+ ArmVersionedStr(azureml_type=AzureMLResourceType.ENVIRONMENT, allow_default_version=True),
+ NestedField(EnvironmentSchema),
+ NestedField(AnonymousEnvironmentSchema),
+ ]
+ )
+ environment_variables = fields.Dict(
+ metadata={"description": "Environment variables configuration for the deployment."}
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/__init__.py
new file mode 100644
index 00000000..fdf8caba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/data_asset_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/data_asset_schema.py
new file mode 100644
index 00000000..84bd37e3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/data_asset_schema.py
@@ -0,0 +1,26 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+
+import logging
+from typing import Any
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema import PatchedSchemaMeta
+
+module_logger = logging.getLogger(__name__)
+
+
+class DataAssetSchema(metaclass=PatchedSchemaMeta):
+ name = fields.Str()
+ path = fields.Str()
+ version = fields.Str()
+ data_id = fields.Str()
+
+ @post_load
+ def make(self, data: Any, **kwargs: Any) -> Any: # pylint: disable=unused-argument
+ from azure.ai.ml.entities._deployment.data_asset import DataAsset
+
+ return DataAsset(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/data_collector_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/data_collector_schema.py
new file mode 100644
index 00000000..633f96fc
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/data_collector_schema.py
@@ -0,0 +1,39 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+
+import logging
+from typing import Any
+
+from marshmallow import fields, post_load, validates, ValidationError
+
+from azure.ai.ml._schema import NestedField, PatchedSchemaMeta, StringTransformedEnum
+from azure.ai.ml._schema._deployment.online.request_logging_schema import RequestLoggingSchema
+from azure.ai.ml._schema._deployment.online.deployment_collection_schema import DeploymentCollectionSchema
+
+from azure.ai.ml.constants._common import RollingRate
+
+module_logger = logging.getLogger(__name__)
+
+
+class DataCollectorSchema(metaclass=PatchedSchemaMeta):
+ collections = fields.Dict(keys=fields.Str, values=NestedField(DeploymentCollectionSchema))
+ rolling_rate = StringTransformedEnum(
+ required=False,
+ allowed_values=[RollingRate.MINUTE, RollingRate.DAY, RollingRate.HOUR],
+ )
+ sampling_rate = fields.Float() # Should be copied to each of the collections
+ request_logging = NestedField(RequestLoggingSchema)
+
+ # pylint: disable=unused-argument
+ @validates("sampling_rate")
+ def validate_sampling_rate(self, value, **kwargs):
+ if value > 1.0 or value < 0.0:
+ raise ValidationError("Sampling rate must be an number in range (0.0-1.0)")
+
+ @post_load
+ def make(self, data: Any, **kwargs: Any) -> Any: # pylint: disable=unused-argument
+ from azure.ai.ml.entities._deployment.data_collector import DataCollector
+
+ return DataCollector(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/deployment_collection_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/deployment_collection_schema.py
new file mode 100644
index 00000000..4be4a9cc
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/deployment_collection_schema.py
@@ -0,0 +1,32 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import logging
+from typing import Any
+
+from marshmallow import post_load, fields
+
+from azure.ai.ml._schema import PatchedSchemaMeta, StringTransformedEnum, NestedField, UnionField
+from azure.ai.ml._schema._deployment.online.data_asset_schema import DataAssetSchema
+from azure.ai.ml.constants._common import Boolean
+
+module_logger = logging.getLogger(__name__)
+
+
+class DeploymentCollectionSchema(metaclass=PatchedSchemaMeta):
+ enabled = StringTransformedEnum(required=True, allowed_values=[Boolean.TRUE, Boolean.FALSE])
+ data = UnionField(
+ [
+ NestedField(DataAssetSchema),
+ fields.Str(),
+ ]
+ )
+ client_id = fields.Str()
+
+ # pylint: disable=unused-argument
+ @post_load
+ def make(self, data: Any, **kwargs: Any) -> Any:
+ from azure.ai.ml.entities._deployment.deployment_collection import DeploymentCollection
+
+ return DeploymentCollection(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/event_hub_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/event_hub_schema.py
new file mode 100644
index 00000000..27b603de
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/event_hub_schema.py
@@ -0,0 +1,31 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+from typing import Any
+
+from marshmallow import ValidationError, fields, post_load, validates
+
+from azure.ai.ml._schema import NestedField, PatchedSchemaMeta
+from azure.ai.ml._schema._deployment.online.oversize_data_config_schema import OversizeDataConfigSchema
+
+module_logger = logging.getLogger(__name__)
+
+
+class EventHubSchema(metaclass=PatchedSchemaMeta):
+ namespace = fields.Str()
+ oversize_data_config = NestedField(OversizeDataConfigSchema)
+
+ @validates("namespace")
+ def validate_namespace(self, value, **kwargs):
+ if len(value.split(".")) != 2:
+ raise ValidationError("Namespace must follow format of {namespace}.{name}")
+
+ @post_load
+ def make(self, data: Any, **kwargs: Any) -> Any:
+ from azure.ai.ml.entities._deployment.event_hub import EventHub
+
+ return EventHub(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/liveness_probe.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/liveness_probe.py
new file mode 100644
index 00000000..d1008b8b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/liveness_probe.py
@@ -0,0 +1,28 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+from typing import Any
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+
+module_logger = logging.getLogger(__name__)
+
+
+class LivenessProbeSchema(metaclass=PatchedSchemaMeta):
+ period = fields.Int()
+ initial_delay = fields.Int()
+ timeout = fields.Int()
+ success_threshold = fields.Int()
+ failure_threshold = fields.Int()
+
+ @post_load
+ def make(self, data: Any, **kwargs: Any) -> Any:
+ from azure.ai.ml.entities import ProbeSettings
+
+ return ProbeSettings(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/online_deployment.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/online_deployment.py
new file mode 100644
index 00000000..7f0760fe
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/online_deployment.py
@@ -0,0 +1,79 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+from typing import Any
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._restclient.v2022_02_01_preview.models import EndpointComputeType
+from azure.ai.ml._schema._deployment.deployment import DeploymentSchema
+from azure.ai.ml._schema._utils.utils import exit_if_registry_assets
+from azure.ai.ml._schema.core.fields import ExperimentalField, NestedField, StringTransformedEnum, UnionField
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PublicNetworkAccess
+from azure.ai.ml._schema.job.creation_context import CreationContextSchema
+
+from .data_collector_schema import DataCollectorSchema
+from .liveness_probe import LivenessProbeSchema
+from .request_settings_schema import RequestSettingsSchema
+from .resource_requirements_schema import ResourceRequirementsSchema
+from .scale_settings_schema import DefaultScaleSettingsSchema, TargetUtilizationScaleSettingsSchema
+
+module_logger = logging.getLogger(__name__)
+
+
+class OnlineDeploymentSchema(DeploymentSchema):
+ app_insights_enabled = fields.Bool()
+ scale_settings = UnionField(
+ [
+ NestedField(DefaultScaleSettingsSchema),
+ NestedField(TargetUtilizationScaleSettingsSchema),
+ ]
+ )
+ request_settings = NestedField(RequestSettingsSchema)
+ liveness_probe = NestedField(LivenessProbeSchema)
+ readiness_probe = NestedField(LivenessProbeSchema)
+ provisioning_state = fields.Str()
+ instance_count = fields.Int()
+ type = StringTransformedEnum(
+ required=False,
+ allowed_values=[
+ EndpointComputeType.MANAGED.value,
+ EndpointComputeType.KUBERNETES.value,
+ ],
+ casing_transform=camel_to_snake,
+ )
+ model_mount_path = fields.Str()
+ instance_type = fields.Str()
+ data_collector = ExperimentalField(NestedField(DataCollectorSchema))
+
+
+class KubernetesOnlineDeploymentSchema(OnlineDeploymentSchema):
+ resources = NestedField(ResourceRequirementsSchema)
+
+ @post_load
+ def make(self, data: Any, **kwargs: Any) -> Any:
+ from azure.ai.ml.entities import KubernetesOnlineDeployment
+
+ exit_if_registry_assets(data=data, caller="K8SDeployment")
+ return KubernetesOnlineDeployment(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data)
+
+
+class ManagedOnlineDeploymentSchema(OnlineDeploymentSchema):
+ instance_type = fields.Str(required=True)
+ egress_public_network_access = StringTransformedEnum(
+ allowed_values=[PublicNetworkAccess.ENABLED, PublicNetworkAccess.DISABLED]
+ )
+ private_network_connection = ExperimentalField(fields.Bool())
+ data_collector = NestedField(DataCollectorSchema)
+ creation_context = NestedField(CreationContextSchema, dump_only=True)
+
+ @post_load
+ def make(self, data: Any, **kwargs: Any) -> Any:
+ from azure.ai.ml.entities import ManagedOnlineDeployment
+
+ return ManagedOnlineDeployment(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/oversize_data_config_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/oversize_data_config_schema.py
new file mode 100644
index 00000000..8103681a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/oversize_data_config_schema.py
@@ -0,0 +1,31 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import logging
+from typing import Any
+
+from marshmallow import ValidationError, fields, post_load, validates
+
+from azure.ai.ml._schema import PatchedSchemaMeta
+from azure.ai.ml._utils._storage_utils import AzureMLDatastorePathUri
+
+module_logger = logging.getLogger(__name__)
+
+
+class OversizeDataConfigSchema(metaclass=PatchedSchemaMeta):
+ path = fields.Str()
+
+ # pylint: disable=unused-argument
+ @validates("path")
+ def validate_path(self, value, **kwargs):
+ datastore_path = AzureMLDatastorePathUri(value)
+ if datastore_path.uri_type != "Datastore":
+ raise ValidationError(f"Path '{value}' is not a properly formatted datastore path.")
+
+ # pylint: disable=unused-argument
+ @post_load
+ def make(self, data: Any, **kwargs: Any) -> Any:
+ from azure.ai.ml.entities._deployment.oversize_data_config import OversizeDataConfig
+
+ return OversizeDataConfig(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/payload_response_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/payload_response_schema.py
new file mode 100644
index 00000000..172af4f1
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/payload_response_schema.py
@@ -0,0 +1,24 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import logging
+from typing import Any
+
+from marshmallow import post_load
+
+from azure.ai.ml._schema import PatchedSchemaMeta, StringTransformedEnum
+from azure.ai.ml.constants._common import Boolean
+
+module_logger = logging.getLogger(__name__)
+
+
+class PayloadResponseSchema(metaclass=PatchedSchemaMeta):
+ enabled = StringTransformedEnum(required=True, allowed_values=[Boolean.TRUE, Boolean.FALSE])
+
+ # pylint: disable=unused-argument
+ @post_load
+ def make(self, data: Any, **kwargs: Any) -> Any:
+ from azure.ai.ml.entities._deployment.payload_response import PayloadResponse
+
+ return PayloadResponse(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/request_logging_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/request_logging_schema.py
new file mode 100644
index 00000000..4ac0b466
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/request_logging_schema.py
@@ -0,0 +1,23 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import logging
+from typing import Any
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema import PatchedSchemaMeta
+
+module_logger = logging.getLogger(__name__)
+
+
+class RequestLoggingSchema(metaclass=PatchedSchemaMeta):
+ capture_headers = fields.List(fields.Str())
+
+ # pylint: disable=unused-argument
+ @post_load
+ def make(self, data: Any, **kwargs: Any) -> Any:
+ from azure.ai.ml.entities._deployment.request_logging import RequestLogging
+
+ return RequestLogging(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/request_settings_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/request_settings_schema.py
new file mode 100644
index 00000000..887a71c5
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/request_settings_schema.py
@@ -0,0 +1,26 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+from typing import Any
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+
+module_logger = logging.getLogger(__name__)
+
+
+class RequestSettingsSchema(metaclass=PatchedSchemaMeta):
+ request_timeout_ms = fields.Int(required=False)
+ max_concurrent_requests_per_instance = fields.Int(required=False)
+ max_queue_wait_ms = fields.Int(required=False)
+
+ @post_load
+ def make(self, data: Any, **kwargs: Any) -> Any:
+ from azure.ai.ml.entities import OnlineRequestSettings
+
+ return OnlineRequestSettings(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/resource_requirements_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/resource_requirements_schema.py
new file mode 100644
index 00000000..7f43d91f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/resource_requirements_schema.py
@@ -0,0 +1,28 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+from typing import Any
+
+from marshmallow import post_load
+
+from azure.ai.ml._schema.core.fields import NestedField
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+
+from .resource_settings_schema import ResourceSettingsSchema
+
+module_logger = logging.getLogger(__name__)
+
+
+class ResourceRequirementsSchema(metaclass=PatchedSchemaMeta):
+ requests = NestedField(ResourceSettingsSchema)
+ limits = NestedField(ResourceSettingsSchema)
+
+ @post_load
+ def make(self, data: Any, **kwargs: Any) -> "ResourceRequirementsSettings":
+ from azure.ai.ml.entities import ResourceRequirementsSettings
+
+ return ResourceRequirementsSettings(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/resource_settings_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/resource_settings_schema.py
new file mode 100644
index 00000000..21a229ad
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/resource_settings_schema.py
@@ -0,0 +1,32 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+from typing import Any
+
+from marshmallow import fields, post_load, pre_load
+
+from azure.ai.ml._schema._utils.utils import replace_key_in_odict
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+
+module_logger = logging.getLogger(__name__)
+
+
+class ResourceSettingsSchema(metaclass=PatchedSchemaMeta):
+ cpu = fields.String()
+ memory = fields.String()
+ gpu = fields.String()
+
+ @pre_load
+ def conversion(self, data: Any, **kwargs) -> Any:
+ data = replace_key_in_odict(data, "nvidia.com/gpu", "gpu")
+ return data
+
+ @post_load
+ def make(self, data: Any, **kwargs: Any) -> Any:
+ from azure.ai.ml.entities import ResourceSettings
+
+ return ResourceSettings(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/scale_settings_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/scale_settings_schema.py
new file mode 100644
index 00000000..6c5c5283
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/online/scale_settings_schema.py
@@ -0,0 +1,51 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+from typing import Any
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._restclient.v2022_10_01.models import ScaleType
+from azure.ai.ml._schema.core.fields import StringTransformedEnum
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml._utils.utils import camel_to_snake
+
+module_logger = logging.getLogger(__name__)
+
+
+class DefaultScaleSettingsSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(
+ required=True,
+ allowed_values=ScaleType.DEFAULT,
+ casing_transform=camel_to_snake,
+ data_key="type",
+ )
+
+ @post_load
+ def make(self, data: Any, **kwargs: Any) -> "DefaultScaleSettings":
+ from azure.ai.ml.entities import DefaultScaleSettings
+
+ return DefaultScaleSettings(**data)
+
+
+class TargetUtilizationScaleSettingsSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(
+ required=True,
+ allowed_values=ScaleType.TARGET_UTILIZATION,
+ casing_transform=camel_to_snake,
+ data_key="type",
+ )
+ polling_interval = fields.Int()
+ target_utilization_percentage = fields.Int()
+ min_instances = fields.Int()
+ max_instances = fields.Int()
+
+ @post_load
+ def make(self, data: Any, **kwargs: Any) -> "TargetUtilizationScaleSettings":
+ from azure.ai.ml.entities import TargetUtilizationScaleSettings
+
+ return TargetUtilizationScaleSettings(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/__init__.py
new file mode 100644
index 00000000..437d8743
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/__init__.py
@@ -0,0 +1,17 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
+
+from .distillation_job import DistillationJobSchema
+from .endpoint_request_settings import EndpointRequestSettingsSchema
+from .prompt_settings import PromptSettingsSchema
+from .teacher_model_settings import TeacherModelSettingsSchema
+
+__all__ = [
+ "DistillationJobSchema",
+ "PromptSettingsSchema",
+ "EndpointRequestSettingsSchema",
+ "TeacherModelSettingsSchema",
+]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/distillation_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/distillation_job.py
new file mode 100644
index 00000000..d72f2457
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/distillation_job.py
@@ -0,0 +1,84 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields
+
+from azure.ai.ml._schema._distillation.prompt_settings import PromptSettingsSchema
+from azure.ai.ml._schema._distillation.teacher_model_settings import TeacherModelSettingsSchema
+from azure.ai.ml._schema.core.fields import (
+ ArmVersionedStr,
+ LocalPathField,
+ NestedField,
+ RegistryStr,
+ StringTransformedEnum,
+ UnionField,
+)
+from azure.ai.ml._schema.job import BaseJobSchema
+from azure.ai.ml._schema.job.input_output_entry import DataInputSchema, ModelInputSchema
+from azure.ai.ml._schema.job.input_output_fields_provider import OutputsField
+from azure.ai.ml._schema.job_resource_configuration import ResourceConfigurationSchema
+from azure.ai.ml._schema.workspace.connections import ServerlessConnectionSchema, WorkspaceConnectionSchema
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants import DataGenerationTaskType, DataGenerationType, JobType
+from azure.ai.ml.constants._common import AzureMLResourceType
+
+
+@experimental
+class DistillationJobSchema(BaseJobSchema):
+ type = StringTransformedEnum(required=True, allowed_values=JobType.DISTILLATION)
+ data_generation_type = StringTransformedEnum(
+ allowed_values=[DataGenerationType.LABEL_GENERATION, DataGenerationType.DATA_GENERATION],
+ required=True,
+ )
+ data_generation_task_type = StringTransformedEnum(
+ allowed_values=[
+ DataGenerationTaskType.NLI,
+ DataGenerationTaskType.NLU_QA,
+ DataGenerationTaskType.CONVERSATION,
+ DataGenerationTaskType.MATH,
+ DataGenerationTaskType.SUMMARIZATION,
+ ],
+ casing_transform=str.upper,
+ required=True,
+ )
+ teacher_model_endpoint_connection = UnionField(
+ [NestedField(WorkspaceConnectionSchema), NestedField(ServerlessConnectionSchema)], required=True
+ )
+ student_model = UnionField(
+ [
+ NestedField(ModelInputSchema),
+ RegistryStr(azureml_type=AzureMLResourceType.MODEL),
+ ArmVersionedStr(azureml_type=AzureMLResourceType.MODEL, allow_default_version=True),
+ ],
+ required=True,
+ )
+ training_data = UnionField(
+ [
+ NestedField(DataInputSchema),
+ ArmVersionedStr(azureml_type=AzureMLResourceType.DATA),
+ fields.Str(metadata={"pattern": r"^(http(s)?):.*"}),
+ fields.Str(metadata={"pattern": r"^(wasb(s)?):.*"}),
+ LocalPathField(pattern=r"^file:.*"),
+ LocalPathField(
+ pattern=r"^(?!(azureml|http(s)?|wasb(s)?|file):).*",
+ ),
+ ]
+ )
+ validation_data = UnionField(
+ [
+ NestedField(DataInputSchema),
+ ArmVersionedStr(azureml_type=AzureMLResourceType.DATA),
+ fields.Str(metadata={"pattern": r"^(http(s)?):.*"}),
+ fields.Str(metadata={"pattern": r"^(wasb(s)?):.*"}),
+ LocalPathField(pattern=r"^file:.*"),
+ LocalPathField(
+ pattern=r"^(?!(azureml|http(s)?|wasb(s)?|file):).*",
+ ),
+ ]
+ )
+ teacher_model_settings = NestedField(TeacherModelSettingsSchema)
+ prompt_settings = NestedField(PromptSettingsSchema)
+ hyperparameters = fields.Dict(keys=fields.Str(), values=fields.Str(allow_none=True))
+ resources = NestedField(ResourceConfigurationSchema)
+ outputs = OutputsField()
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/endpoint_request_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/endpoint_request_settings.py
new file mode 100644
index 00000000..960e7d2a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/endpoint_request_settings.py
@@ -0,0 +1,27 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml._utils._experimental import experimental
+
+
+@experimental
+class EndpointRequestSettingsSchema(metaclass=PatchedSchemaMeta):
+ request_batch_size = fields.Int()
+ min_endpoint_success_ratio = fields.Number()
+
+ @post_load
+ def make(self, data, **kwargs): # pylint: disable=unused-argument
+ """Post-load processing of the schema data
+
+ :param data: Dictionary of parsed values from the yaml.
+ :type data: typing.Dict
+ :return: EndpointRequestSettings made from the yaml
+ :rtype: EndpointRequestSettings
+ """
+ from azure.ai.ml.entities._job.distillation.endpoint_request_settings import EndpointRequestSettings
+
+ return EndpointRequestSettings(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/prompt_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/prompt_settings.py
new file mode 100644
index 00000000..3b21908a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/prompt_settings.py
@@ -0,0 +1,29 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml._utils._experimental import experimental
+
+
+@experimental
+class PromptSettingsSchema(metaclass=PatchedSchemaMeta):
+ enable_chain_of_thought = fields.Bool()
+ enable_chain_of_density = fields.Bool()
+ max_len_summary = fields.Int()
+ # custom_prompt = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs): # pylint: disable=unused-argument
+ """Post-load processing of the schema data
+
+ :param data: Dictionary of parsed values from the yaml.
+ :type data: typing.Dict
+ :return: PromptSettings made from the yaml
+ :rtype: PromptSettings
+ """
+ from azure.ai.ml.entities._job.distillation.prompt_settings import PromptSettings
+
+ return PromptSettings(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/teacher_model_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/teacher_model_settings.py
new file mode 100644
index 00000000..ecf32047
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/teacher_model_settings.py
@@ -0,0 +1,29 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema._distillation.endpoint_request_settings import EndpointRequestSettingsSchema
+from azure.ai.ml._schema.core.fields import NestedField
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml._utils._experimental import experimental
+
+
+@experimental
+class TeacherModelSettingsSchema(metaclass=PatchedSchemaMeta):
+ inference_parameters = fields.Dict(keys=fields.Str(), values=fields.Raw())
+ endpoint_request_settings = NestedField(EndpointRequestSettingsSchema)
+
+ @post_load
+ def make(self, data, **kwargs): # pylint: disable=unused-argument
+ """Post-load processing of the schema data
+
+ :param data: Dictionary of parsed values from the yaml.
+ :type data: typing.Dict
+ :return: TeacherModelSettings made from the yaml
+ :rtype: TeacherModelSettings
+ """
+ from azure.ai.ml.entities._job.distillation.teacher_model_settings import TeacherModelSettings
+
+ return TeacherModelSettings(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/__init__.py
new file mode 100644
index 00000000..e9538cbb
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/__init__.py
@@ -0,0 +1,15 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
+
+
+from .batch.batch_endpoint import BatchEndpointSchema
+from .online.online_endpoint import KubernetesOnlineEndpointSchema, ManagedOnlineEndpointSchema
+
+__all__ = [
+ "BatchEndpointSchema",
+ "KubernetesOnlineEndpointSchema",
+ "ManagedOnlineEndpointSchema",
+]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/batch/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/batch/__init__.py
new file mode 100644
index 00000000..fdf8caba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/batch/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/batch/batch_endpoint.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/batch/batch_endpoint.py
new file mode 100644
index 00000000..0bee2493
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/batch/batch_endpoint.py
@@ -0,0 +1,27 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+from typing import Any
+
+from marshmallow import post_load
+
+from azure.ai.ml._schema._endpoint.batch.batch_endpoint_defaults import BatchEndpointsDefaultsSchema
+from azure.ai.ml._schema._endpoint.endpoint import EndpointSchema
+from azure.ai.ml._schema.core.fields import NestedField
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+
+module_logger = logging.getLogger(__name__)
+
+
+class BatchEndpointSchema(EndpointSchema):
+ defaults = NestedField(BatchEndpointsDefaultsSchema)
+
+ @post_load
+ def make(self, data: Any, **kwargs: Any) -> Any:
+ from azure.ai.ml.entities import BatchEndpoint
+
+ return BatchEndpoint(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/batch/batch_endpoint_defaults.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/batch/batch_endpoint_defaults.py
new file mode 100644
index 00000000..49699bb0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/batch/batch_endpoint_defaults.py
@@ -0,0 +1,28 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+from typing import Any
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._restclient.v2023_10_01.models import BatchEndpointDefaults
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+
+module_logger = logging.getLogger(__name__)
+
+
+class BatchEndpointsDefaultsSchema(metaclass=PatchedSchemaMeta):
+ deployment_name = fields.Str(
+ metadata={
+ "description": """Name of the deployment that will be default for the endpoint.
+ This deployment will end up getting 100% traffic when the endpoint scoring URL is invoked."""
+ }
+ )
+
+ @post_load
+ def make(self, data: Any, **kwargs: Any) -> Any:
+ return BatchEndpointDefaults(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/endpoint.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/endpoint.py
new file mode 100644
index 00000000..1ff43338
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/endpoint.py
@@ -0,0 +1,41 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import logging
+
+from marshmallow import fields, validate
+
+from azure.ai.ml._restclient.v2022_02_01_preview.models import EndpointAuthMode
+from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum
+from azure.ai.ml._schema.core.schema import PathAwareSchema
+from azure.ai.ml._schema.identity import IdentitySchema
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.constants._endpoint import EndpointConfigurations
+
+module_logger = logging.getLogger(__name__)
+
+
+class EndpointSchema(PathAwareSchema):
+ id = fields.Str()
+ name = fields.Str(required=True, validate=validate.Regexp(EndpointConfigurations.NAME_REGEX_PATTERN))
+ description = fields.Str(metadata={"description": "Description of the inference endpoint."})
+ tags = fields.Dict()
+ provisioning_state = fields.Str(metadata={"description": "Provisioning state for the endpoint."})
+ properties = fields.Dict()
+ auth_mode = StringTransformedEnum(
+ allowed_values=[
+ EndpointAuthMode.AML_TOKEN,
+ EndpointAuthMode.KEY,
+ EndpointAuthMode.AAD_TOKEN,
+ ],
+ casing_transform=camel_to_snake,
+ metadata={
+ "description": """authentication method: no auth, key based or azure ml token based.
+ aad_token is only valid for batch endpoint."""
+ },
+ )
+ scoring_uri = fields.Str(metadata={"description": "The endpoint uri that can be used for scoring"})
+ location = fields.Str()
+ openapi_uri = fields.Str(metadata={"description": "Endpoint Open API URI."})
+ identity = NestedField(IdentitySchema)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/online/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/online/__init__.py
new file mode 100644
index 00000000..fdf8caba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/online/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/online/online_endpoint.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/online/online_endpoint.py
new file mode 100644
index 00000000..84b34636
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_endpoint/online/online_endpoint.py
@@ -0,0 +1,66 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+from typing import Any
+
+from marshmallow import ValidationError, fields, post_load, validates
+
+from azure.ai.ml._schema._endpoint.endpoint import EndpointSchema
+from azure.ai.ml._schema.core.fields import ArmStr, StringTransformedEnum
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, AzureMLResourceType, PublicNetworkAccess
+
+module_logger = logging.getLogger(__name__)
+
+
+class OnlineEndpointSchema(EndpointSchema):
+ traffic = fields.Dict(
+ keys=fields.Str(),
+ values=fields.Int(),
+ metadata={
+ "description": """a dict with key as deployment name and value as traffic percentage.
+ The values need to sum to 100 """
+ },
+ )
+ kind = fields.Str(dump_only=True)
+
+ mirror_traffic = fields.Dict(
+ keys=fields.Str(),
+ values=fields.Int(),
+ metadata={
+ "description": """a dict with key as deployment name and value as traffic percentage.
+ Only one key will be accepted and value needs to be less than or equal to 50%"""
+ },
+ )
+
+ @validates("traffic")
+ def validate_traffic(self, data, **kwargs):
+ if sum(data.values()) > 100:
+ raise ValidationError("Traffic rule percentages must sum to less than or equal to 100%")
+
+
+class KubernetesOnlineEndpointSchema(OnlineEndpointSchema):
+ provisioning_state = fields.Str(metadata={"description": "status of the deployment provisioning operation"})
+ compute = ArmStr(azureml_type=AzureMLResourceType.COMPUTE)
+
+ @post_load
+ def make(self, data: Any, **kwargs: Any) -> Any:
+ from azure.ai.ml.entities import KubernetesOnlineEndpoint
+
+ return KubernetesOnlineEndpoint(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data)
+
+
+class ManagedOnlineEndpointSchema(OnlineEndpointSchema):
+ provisioning_state = fields.Str()
+ public_network_access = StringTransformedEnum(
+ allowed_values=[PublicNetworkAccess.ENABLED, PublicNetworkAccess.DISABLED]
+ )
+
+ @post_load
+ def make(self, data: Any, **kwargs: Any) -> Any:
+ from azure.ai.ml.entities import ManagedOnlineEndpoint
+
+ return ManagedOnlineEndpoint(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/__init__.py
new file mode 100644
index 00000000..69c1cdbd
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/__init__.py
@@ -0,0 +1,25 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
+
+from .delay_metadata_schema import DelayMetadataSchema
+from .feature_schema import FeatureSchema
+from .feature_set_schema import FeatureSetSchema
+from .featureset_spec_metadata_schema import FeaturesetSpecMetadataSchema
+from .feature_set_specification_schema import FeatureSetSpecificationSchema
+from .materialization_settings_schema import MaterializationSettingsSchema
+from .source_metadata_schema import SourceMetadataSchema
+from .timestamp_column_metadata_schema import TimestampColumnMetadataSchema
+
+__all__ = [
+ "DelayMetadataSchema",
+ "FeatureSchema",
+ "FeatureSetSchema",
+ "FeaturesetSpecMetadataSchema",
+ "FeatureSetSpecificationSchema",
+ "MaterializationSettingsSchema",
+ "SourceMetadataSchema",
+ "TimestampColumnMetadataSchema",
+]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/delay_metadata_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/delay_metadata_schema.py
new file mode 100644
index 00000000..5ad78a7a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/delay_metadata_schema.py
@@ -0,0 +1,21 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+
+
+class DelayMetadataSchema(metaclass=PatchedSchemaMeta):
+ days = fields.Int(required=False)
+ hours = fields.Int(required=False)
+ minutes = fields.Int(required=False)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._feature_set.delay_metadata import DelayMetadata
+
+ return DelayMetadata(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_schema.py
new file mode 100644
index 00000000..6d248270
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_schema.py
@@ -0,0 +1,29 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+
+
+class FeatureSchema(metaclass=PatchedSchemaMeta):
+ name = fields.Str(
+ required=True,
+ allow_none=False,
+ )
+ data_type = fields.Str(
+ required=True,
+ allow_none=False,
+ data_key="type",
+ )
+ description = fields.Str(required=False)
+ tags = fields.Dict(keys=fields.Str(), values=fields.Str(), required=False)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._feature_set.feature import Feature
+
+ return Feature(description=data.pop("description", None), **data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_set_backfill_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_set_backfill_schema.py
new file mode 100644
index 00000000..0ee5af8e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_set_backfill_schema.py
@@ -0,0 +1,22 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields
+
+from azure.ai.ml._schema._feature_set.feature_window_schema import FeatureWindowSchema
+from azure.ai.ml._schema._feature_set.materialization_settings_schema import MaterializationComputeResourceSchema
+from azure.ai.ml._schema.core.fields import NestedField
+from azure.ai.ml._schema.core.schema import YamlFileSchema
+
+
+class FeatureSetBackfillSchema(YamlFileSchema):
+ name = fields.Str(required=True)
+ version = fields.Str(required=True)
+ feature_window = NestedField(FeatureWindowSchema)
+ description = fields.Str()
+ tags = fields.Dict()
+ resource = NestedField(MaterializationComputeResourceSchema)
+ spark_configuration = fields.Dict()
+ data_status = fields.List(fields.Str())
+ job_id = fields.Str()
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_set_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_set_schema.py
new file mode 100644
index 00000000..08722402
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_set_schema.py
@@ -0,0 +1,27 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields, post_dump, validate
+
+from azure.ai.ml._schema import NestedField
+from azure.ai.ml._schema.core.schema import YamlFileSchema
+
+from .feature_set_specification_schema import FeatureSetSpecificationSchema
+from .materialization_settings_schema import MaterializationSettingsSchema
+
+
+class FeatureSetSchema(YamlFileSchema):
+ name = fields.Str(required=True, allow_none=False)
+ version = fields.Str(required=True, allow_none=False)
+ latest_version = fields.Str(dump_only=True)
+ specification = NestedField(FeatureSetSpecificationSchema, required=True, allow_none=False)
+ entities = fields.List(fields.Str, required=True, allow_none=False)
+ stage = fields.Str(validate=validate.OneOf(["Development", "Production", "Archived"]), dump_default="Development")
+ description = fields.Str()
+ tags = fields.Dict(keys=fields.Str(), values=fields.Str())
+ materialization_settings = NestedField(MaterializationSettingsSchema)
+
+ @post_dump
+ def remove_empty_values(self, data, **kwargs): # pylint: disable=unused-argument
+ return {key: value for key, value in data.items() if value}
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_set_specification_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_set_specification_schema.py
new file mode 100644
index 00000000..64b399fb
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_set_specification_schema.py
@@ -0,0 +1,19 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+
+
+class FeatureSetSpecificationSchema(metaclass=PatchedSchemaMeta):
+ path = fields.Str(required=True, allow_none=False)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._feature_set.feature_set_specification import FeatureSetSpecification
+
+ return FeatureSetSpecification(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_transformation_code_metadata_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_transformation_code_metadata_schema.py
new file mode 100644
index 00000000..8b173865
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_transformation_code_metadata_schema.py
@@ -0,0 +1,22 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+
+
+class FeatureTransformationCodeMetadataSchema(metaclass=PatchedSchemaMeta):
+ path = fields.Str(required=False)
+ transformer_class = fields.Str(required=False)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._feature_set.feature_transformation_code_metadata import (
+ FeatureTransformationCodeMetadata,
+ )
+
+ return FeatureTransformationCodeMetadata(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_window_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_window_schema.py
new file mode 100644
index 00000000..d114c731
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/feature_window_schema.py
@@ -0,0 +1,11 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields
+from azure.ai.ml._schema.core.schema import YamlFileSchema
+
+
+class FeatureWindowSchema(YamlFileSchema):
+ feature_window_end = fields.Str()
+ feature_window_start = fields.Str()
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/featureset_spec_metadata_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/featureset_spec_metadata_schema.py
new file mode 100644
index 00000000..251ccd6e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/featureset_spec_metadata_schema.py
@@ -0,0 +1,33 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from typing import Dict
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.fields import NestedField
+from azure.ai.ml._schema.core.schema import YamlFileSchema
+from azure.ai.ml._schema._feature_store_entity.data_column_schema import DataColumnSchema
+
+from .source_metadata_schema import SourceMetadataSchema
+from .delay_metadata_schema import DelayMetadataSchema
+from .feature_schema import FeatureSchema
+from .feature_transformation_code_metadata_schema import FeatureTransformationCodeMetadataSchema
+
+
+class FeaturesetSpecMetadataSchema(YamlFileSchema):
+ source = fields.Nested(SourceMetadataSchema, required=True)
+ feature_transformation_code = fields.Nested(FeatureTransformationCodeMetadataSchema, required=False)
+ features = fields.List(NestedField(FeatureSchema), required=True, allow_none=False)
+ index_columns = fields.List(NestedField(DataColumnSchema), required=False)
+ source_lookback = fields.Nested(DelayMetadataSchema, required=False)
+ temporal_join_lookback = fields.Nested(DelayMetadataSchema, required=False)
+
+ @post_load
+ def make(self, data: Dict, **kwargs):
+ from azure.ai.ml.entities._feature_set.featureset_spec_metadata import FeaturesetSpecMetadata
+
+ return FeaturesetSpecMetadata(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/featureset_spec_properties_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/featureset_spec_properties_schema.py
new file mode 100644
index 00000000..e3a56542
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/featureset_spec_properties_schema.py
@@ -0,0 +1,55 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields
+
+from azure.ai.ml._schema.core.fields import NestedField
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta, YamlFileSchema
+
+from .source_process_code_metadata_schema import SourceProcessCodeSchema
+from .timestamp_column_metadata_schema import TimestampColumnMetadataSchema
+
+
+# pylint: disable-next=name-too-long
+class FeatureTransformationCodePropertiesSchema(metaclass=PatchedSchemaMeta):
+ path = fields.Str(data_key="Path")
+ transformer_class = fields.Str(data_key="TransformerClass")
+
+
+class DelayMetadataPropertiesSchema(metaclass=PatchedSchemaMeta):
+ days = fields.Int(data_key="Days")
+ hours = fields.Int(data_key="Hours")
+ minutes = fields.Int(data_key="Minutes")
+
+
+class FeaturePropertiesSchema(metaclass=PatchedSchemaMeta):
+ name = fields.Str(data_key="FeatureName")
+ data_type = fields.Str(data_key="DataType")
+ description = fields.Str(data_key="Description")
+ tags = fields.Dict(keys=fields.Str(), values=fields.Str(), data_key="Tags")
+
+
+class ColumnPropertiesSchema(metaclass=PatchedSchemaMeta):
+ name = fields.Str(data_key="ColumnName")
+ type = fields.Str(data_key="DataType")
+
+
+class SourcePropertiesSchema(metaclass=PatchedSchemaMeta):
+ type = fields.Str(required=True)
+ path = fields.Str(required=False)
+ timestamp_column = fields.Nested(TimestampColumnMetadataSchema, data_key="timestampColumn")
+ source_delay = fields.Nested(DelayMetadataPropertiesSchema, data_key="sourceDelay")
+ source_process_code = fields.Nested(SourceProcessCodeSchema)
+ dict = fields.Dict(keys=fields.Str(), values=fields.Str(), data_key="kwargs")
+
+
+class FeaturesetSpecPropertiesSchema(YamlFileSchema):
+ source = fields.Nested(SourcePropertiesSchema, data_key="source")
+ feature_transformation_code = fields.Nested(
+ FeatureTransformationCodePropertiesSchema, data_key="featureTransformationCode"
+ )
+ features = fields.List(NestedField(FeaturePropertiesSchema), data_key="features")
+ index_columns = fields.List(NestedField(ColumnPropertiesSchema), data_key="indexColumns")
+ source_lookback = fields.Nested(DelayMetadataPropertiesSchema, data_key="sourceLookback")
+ temporal_join_lookback = fields.Nested(DelayMetadataPropertiesSchema, data_key="temporalJoinLookback")
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/materialization_settings_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/materialization_settings_schema.py
new file mode 100644
index 00000000..8cf68b67
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/materialization_settings_schema.py
@@ -0,0 +1,37 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema import NestedField
+from azure.ai.ml._schema._notification.notification_schema import NotificationSchema
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml._schema.schedule.trigger import RecurrenceTriggerSchema
+
+
+class MaterializationComputeResourceSchema(metaclass=PatchedSchemaMeta):
+ instance_type = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._feature_set.materialization_compute_resource import MaterializationComputeResource
+
+ return MaterializationComputeResource(instance_type=data.pop("instance_type"), **data)
+
+
+class MaterializationSettingsSchema(metaclass=PatchedSchemaMeta):
+ schedule = NestedField(RecurrenceTriggerSchema)
+ notification = NestedField(NotificationSchema)
+ resource = NestedField(MaterializationComputeResourceSchema)
+ spark_configuration = fields.Dict()
+ offline_enabled = fields.Boolean()
+ online_enabled = fields.Boolean()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._feature_set.materialization_settings import MaterializationSettings
+
+ return MaterializationSettings(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/source_metadata_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/source_metadata_schema.py
new file mode 100644
index 00000000..345c9084
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/source_metadata_schema.py
@@ -0,0 +1,30 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from typing import Dict
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+
+from .delay_metadata_schema import DelayMetadataSchema
+from .source_process_code_metadata_schema import SourceProcessCodeSchema
+from .timestamp_column_metadata_schema import TimestampColumnMetadataSchema
+
+
+class SourceMetadataSchema(metaclass=PatchedSchemaMeta):
+ type = fields.Str(required=True)
+ path = fields.Str(required=False)
+ timestamp_column = fields.Nested(TimestampColumnMetadataSchema, required=False)
+ source_delay = fields.Nested(DelayMetadataSchema, required=False)
+ source_process_code = fields.Nested(SourceProcessCodeSchema, load_only=True, required=False)
+ dict = fields.Dict(keys=fields.Str(), values=fields.Str(), data_key="kwargs", load_only=True, required=False)
+
+ @post_load
+ def make(self, data: Dict, **kwargs):
+ from azure.ai.ml.entities._feature_set.source_metadata import SourceMetadata
+
+ return SourceMetadata(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/source_process_code_metadata_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/source_process_code_metadata_schema.py
new file mode 100644
index 00000000..b8b93739
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/source_process_code_metadata_schema.py
@@ -0,0 +1,20 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+
+
+class SourceProcessCodeSchema(metaclass=PatchedSchemaMeta):
+ path = fields.Str(required=True, allow_none=False)
+ process_class = fields.Str(required=True, allow_none=False)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._feature_set.source_process_code_metadata import SourceProcessCodeMetadata
+
+ return SourceProcessCodeMetadata(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/timestamp_column_metadata_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/timestamp_column_metadata_schema.py
new file mode 100644
index 00000000..6d7982be
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_set/timestamp_column_metadata_schema.py
@@ -0,0 +1,20 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+
+
+class TimestampColumnMetadataSchema(metaclass=PatchedSchemaMeta):
+ name = fields.Str(required=True)
+ format = fields.Str(required=False)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._feature_set.timestamp_column_metadata import TimestampColumnMetadata
+
+ return TimestampColumnMetadata(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store/__init__.py
new file mode 100644
index 00000000..5e7d7822
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store/__init__.py
@@ -0,0 +1,15 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
+
+from .compute_runtime_schema import ComputeRuntimeSchema
+from .feature_store_schema import FeatureStoreSchema
+from .materialization_store_schema import MaterializationStoreSchema
+
+__all__ = [
+ "ComputeRuntimeSchema",
+ "FeatureStoreSchema",
+ "MaterializationStoreSchema",
+]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store/compute_runtime_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store/compute_runtime_schema.py
new file mode 100644
index 00000000..48db586f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store/compute_runtime_schema.py
@@ -0,0 +1,19 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+
+
+class ComputeRuntimeSchema(metaclass=PatchedSchemaMeta):
+ spark_runtime_version = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._workspace.compute_runtime import ComputeRuntime
+
+ return ComputeRuntime(spark_runtime_version=data.pop("spark_runtime_version"))
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store/feature_store_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store/feature_store_schema.py
new file mode 100644
index 00000000..78fb0642
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store/feature_store_schema.py
@@ -0,0 +1,43 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields, EXCLUDE
+
+from azure.ai.ml._schema._utils.utils import validate_arm_str
+from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum
+from azure.ai.ml._schema.core.schema import PathAwareSchema
+from azure.ai.ml._schema.workspace.customer_managed_key import CustomerManagedKeySchema
+from azure.ai.ml._schema.workspace.identity import IdentitySchema, UserAssignedIdentitySchema
+from azure.ai.ml._utils.utils import snake_to_pascal
+from azure.ai.ml.constants._common import PublicNetworkAccess
+from azure.ai.ml._schema.workspace.networking import ManagedNetworkSchema
+from .compute_runtime_schema import ComputeRuntimeSchema
+from .materialization_store_schema import MaterializationStoreSchema
+
+
+class FeatureStoreSchema(PathAwareSchema):
+ name = fields.Str(required=True)
+ compute_runtime = NestedField(ComputeRuntimeSchema)
+ offline_store = NestedField(MaterializationStoreSchema)
+ online_store = NestedField(MaterializationStoreSchema)
+ materialization_identity = NestedField(UserAssignedIdentitySchema)
+ description = fields.Str()
+ tags = fields.Dict(keys=fields.Str(), values=fields.Str())
+ display_name = fields.Str()
+ location = fields.Str()
+ resource_group = fields.Str()
+ hbi_workspace = fields.Bool()
+ storage_account = fields.Str(validate=validate_arm_str)
+ container_registry = fields.Str(validate=validate_arm_str)
+ key_vault = fields.Str(validate=validate_arm_str)
+ application_insights = fields.Str(validate=validate_arm_str)
+ customer_managed_key = NestedField(CustomerManagedKeySchema)
+ image_build_compute = fields.Str()
+ public_network_access = StringTransformedEnum(
+ allowed_values=[PublicNetworkAccess.DISABLED, PublicNetworkAccess.ENABLED],
+ casing_transform=snake_to_pascal,
+ )
+ identity = NestedField(IdentitySchema)
+ primary_user_assigned_identity = fields.Str()
+ managed_network = NestedField(ManagedNetworkSchema, unknown=EXCLUDE)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store/materialization_store_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store/materialization_store_schema.py
new file mode 100644
index 00000000..091cd4eb
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store/materialization_store_schema.py
@@ -0,0 +1,23 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+
+
+class MaterializationStoreSchema(metaclass=PatchedSchemaMeta):
+ type = fields.Str(required=True, allow_none=False)
+ target = fields.Str(required=True, allow_none=False)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._feature_store.materialization_store import MaterializationStore
+
+ return MaterializationStore(
+ type=data.pop("type"),
+ target=data.pop("target"),
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store_entity/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store_entity/__init__.py
new file mode 100644
index 00000000..8fec3153
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store_entity/__init__.py
@@ -0,0 +1,13 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
+
+from .data_column_schema import DataColumnSchema
+from .feature_store_entity_schema import FeatureStoreEntitySchema
+
+__all__ = [
+ "DataColumnSchema",
+ "FeatureStoreEntitySchema",
+]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store_entity/data_column_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store_entity/data_column_schema.py
new file mode 100644
index 00000000..9fffc055
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store_entity/data_column_schema.py
@@ -0,0 +1,26 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+
+
+class DataColumnSchema(metaclass=PatchedSchemaMeta):
+ name = fields.Str(
+ required=True,
+ allow_none=False,
+ )
+ type = fields.Str(
+ required=True,
+ allow_none=False,
+ )
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._feature_store_entity.data_column import DataColumn
+
+ return DataColumn(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store_entity/feature_store_entity_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store_entity/feature_store_entity_schema.py
new file mode 100644
index 00000000..51505430
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_feature_store_entity/feature_store_entity_schema.py
@@ -0,0 +1,26 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+
+from marshmallow import fields, post_dump, validate
+
+from azure.ai.ml._schema import NestedField
+from azure.ai.ml._schema.core.schema import YamlFileSchema
+
+from .data_column_schema import DataColumnSchema
+
+
+class FeatureStoreEntitySchema(YamlFileSchema):
+ name = fields.Str(required=True, allow_none=False)
+ version = fields.Str(required=True, allow_none=False)
+ latest_version = fields.Str(dump_only=True)
+ index_columns = fields.List(NestedField(DataColumnSchema), required=True, allow_none=False)
+ stage = fields.Str(validate=validate.OneOf(["Development", "Production", "Archived"]), dump_default="Development")
+ description = fields.Str()
+ tags = fields.Dict(keys=fields.Str(), values=fields.Str())
+ properties = fields.Dict(keys=fields.Str(), values=fields.Str())
+
+ @post_dump
+ def remove_empty_values(self, data, **kwargs): # pylint: disable=unused-argument
+ return {key: value for key, value in data.items() if value}
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/__init__.py
new file mode 100644
index 00000000..e47aa230
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/__init__.py
@@ -0,0 +1,19 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
+
+from .azure_openai_finetuning import AzureOpenAIFineTuningSchema
+from .azure_openai_hyperparameters import AzureOpenAIHyperparametersSchema
+from .custom_model_finetuning import CustomModelFineTuningSchema
+from .finetuning_job import FineTuningJobSchema
+from .finetuning_vertical import FineTuningVerticalSchema
+
+__all__ = [
+ "AzureOpenAIFineTuningSchema",
+ "AzureOpenAIHyperparametersSchema",
+ "CustomModelFineTuningSchema",
+ "FineTuningJobSchema",
+ "FineTuningVerticalSchema",
+]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/azure_openai_finetuning.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/azure_openai_finetuning.py
new file mode 100644
index 00000000..f6d2a58d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/azure_openai_finetuning.py
@@ -0,0 +1,54 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from typing import Any, Dict
+from marshmallow import post_load
+
+
+from azure.ai.ml._schema.core.fields import StringTransformedEnum
+from azure.ai.ml._restclient.v2024_01_01_preview.models import ModelProvider
+from azure.ai.ml._schema._finetuning.azure_openai_hyperparameters import AzureOpenAIHyperparametersSchema
+from azure.ai.ml._schema._finetuning.finetuning_vertical import FineTuningVerticalSchema
+from azure.ai.ml.entities._job.finetuning.azure_openai_hyperparameters import AzureOpenAIHyperparameters
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml._schema.core.fields import NestedField
+from azure.ai.ml.constants._job.finetuning import FineTuningConstants
+from azure.ai.ml._utils._experimental import experimental
+
+
+@experimental
+class AzureOpenAIFineTuningSchema(FineTuningVerticalSchema):
+ # This is meant to match the yaml definition NOT the models defined in _restclient
+
+ model_provider = StringTransformedEnum(
+ required=True, allowed_values=ModelProvider.AZURE_OPEN_AI, casing_transform=camel_to_snake
+ )
+ hyperparameters = NestedField(AzureOpenAIHyperparametersSchema(), data_key=FineTuningConstants.HyperParameters)
+
+ @post_load
+ def post_load_processing(self, data: Dict, **kwargs) -> Dict[str, Any]:
+ """Post load processing for the schema.
+
+ :param data: Dictionary of parsed values from the yaml.
+ :type data: typing.Dict
+
+ :return Dictionary of parsed values from the yaml.
+ :rtype Dict[str, Any]
+ """
+ data.pop("model_provider")
+ hyperaparameters = data.pop("hyperparameters", None)
+
+ if hyperaparameters and not isinstance(hyperaparameters, AzureOpenAIHyperparameters):
+ hyperaparameters_dict = {}
+ for key, value in hyperaparameters.items():
+ hyperaparameters_dict[key] = value
+ azure_openai_hyperparameters = AzureOpenAIHyperparameters(
+ batch_size=hyperaparameters_dict.get("batch_size", None),
+ learning_rate_multiplier=hyperaparameters_dict.get("learning_rate_multiplier", None),
+ n_epochs=hyperaparameters_dict.get("n_epochs", None),
+ )
+ data["hyperparameters"] = azure_openai_hyperparameters
+ return data
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/azure_openai_hyperparameters.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/azure_openai_hyperparameters.py
new file mode 100644
index 00000000..f421188d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/azure_openai_hyperparameters.py
@@ -0,0 +1,18 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml._utils._experimental import experimental
+
+
+@experimental
+class AzureOpenAIHyperparametersSchema(metaclass=PatchedSchemaMeta):
+ n_epochs = fields.Int()
+ learning_rate_multiplier = fields.Float()
+ batch_size = fields.Int()
+ # TODO: Should be dict<string,string>, check schema for the same.
+ # For now not exposing as we dont have REST layer representation exposed.
+ # Need to check with the team.
+ # additional_parameters = fields.Dict()
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/constants.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/constants.py
new file mode 100644
index 00000000..3e14dca4
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/constants.py
@@ -0,0 +1,17 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+
+class SnakeCaseFineTuningTaskTypes:
+ CHAT_COMPLETION = "chat_completion"
+ TEXT_COMPLETION = "text_completion"
+ TEXT_CLASSIFICATION = "text_classification"
+ QUESTION_ANSWERING = "question_answering"
+ TEXT_SUMMARIZATION = "text_summarization"
+ TOKEN_CLASSIFICATION = "token_classification"
+ TEXT_TRANSLATION = "text_translation"
+ IMAGE_CLASSIFICATION = "image_classification"
+ IMAGE_INSTANCE_SEGMENTATION = "image_instance_segmentation"
+ IMAGE_OBJECT_DETECTION = "image_object_detection"
+ VIDEO_MULTI_OBJECT_TRACKING = "video_multi_object_tracking"
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/custom_model_finetuning.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/custom_model_finetuning.py
new file mode 100644
index 00000000..9d5b22a7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/custom_model_finetuning.py
@@ -0,0 +1,35 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from typing import Any, Dict
+from marshmallow import fields, post_load
+
+from azure.ai.ml._restclient.v2024_01_01_preview.models import ModelProvider
+from azure.ai.ml._schema._finetuning.finetuning_vertical import FineTuningVerticalSchema
+from azure.ai.ml._schema.core.fields import StringTransformedEnum
+from azure.ai.ml._utils._experimental import experimental
+
+
+@experimental
+class CustomModelFineTuningSchema(FineTuningVerticalSchema):
+ # This is meant to match the yaml definition NOT the models defined in _restclient
+
+ model_provider = StringTransformedEnum(required=True, allowed_values=ModelProvider.CUSTOM)
+ hyperparameters = fields.Dict(keys=fields.Str(), values=fields.Str(allow_none=True))
+
+ @post_load
+ def post_load_processing(self, data: Dict, **kwargs) -> Dict[str, Any]:
+ """Post-load processing for the schema.
+
+ :param data: Dictionary of parsed values from the yaml.
+ :type data: typing.Dict
+
+ :return Dictionary of parsed values from the yaml.
+ :rtype Dict[str, Any]
+ """
+
+ data.pop("model_provider")
+ return data
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/finetuning_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/finetuning_job.py
new file mode 100644
index 00000000..e1b2270e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/finetuning_job.py
@@ -0,0 +1,21 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from azure.ai.ml._schema.job import BaseJobSchema
+from azure.ai.ml._schema.job.input_output_fields_provider import OutputsField
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml._schema.core.fields import (
+ NestedField,
+)
+from ..queue_settings import QueueSettingsSchema
+from ..job_resources import JobResourcesSchema
+
+# This is meant to match the yaml definition NOT the models defined in _restclient
+
+
+@experimental
+class FineTuningJobSchema(BaseJobSchema):
+ outputs = OutputsField()
+ queue_settings = NestedField(QueueSettingsSchema)
+ resources = NestedField(JobResourcesSchema)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/finetuning_vertical.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/finetuning_vertical.py
new file mode 100644
index 00000000..10ac51ff
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/finetuning_vertical.py
@@ -0,0 +1,73 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields
+
+from azure.ai.ml._schema._finetuning.finetuning_job import FineTuningJobSchema
+from azure.ai.ml._schema._finetuning.constants import SnakeCaseFineTuningTaskTypes
+from azure.ai.ml._schema.core.fields import (
+ ArmVersionedStr,
+ LocalPathField,
+ NestedField,
+ StringTransformedEnum,
+ UnionField,
+)
+from azure.ai.ml.constants import JobType
+from azure.ai.ml._utils.utils import snake_to_camel
+from azure.ai.ml._schema.job.input_output_entry import DataInputSchema, ModelInputSchema
+from azure.ai.ml.constants._job.finetuning import FineTuningConstants
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants._common import AzureMLResourceType
+
+
+# This is meant to match the yaml definition NOT the models defined in _restclient
+
+
+@experimental
+class FineTuningVerticalSchema(FineTuningJobSchema):
+ type = StringTransformedEnum(required=True, allowed_values=JobType.FINE_TUNING)
+ model = NestedField(ModelInputSchema, required=True)
+ training_data = UnionField(
+ [
+ NestedField(DataInputSchema),
+ ArmVersionedStr(azureml_type=AzureMLResourceType.DATA),
+ fields.Str(metadata={"pattern": r"^(http(s)?):.*"}),
+ fields.Str(metadata={"pattern": r"^(wasb(s)?):.*"}),
+ LocalPathField(pattern=r"^file:.*"),
+ LocalPathField(
+ pattern=r"^(?!(azureml|http(s)?|wasb(s)?|file):).*",
+ ),
+ ]
+ )
+ validation_data = UnionField(
+ [
+ NestedField(DataInputSchema),
+ ArmVersionedStr(azureml_type=AzureMLResourceType.DATA),
+ fields.Str(metadata={"pattern": r"^(http(s)?):.*"}),
+ fields.Str(metadata={"pattern": r"^(wasb(s)?):.*"}),
+ LocalPathField(pattern=r"^file:.*"),
+ LocalPathField(
+ pattern=r"^(?!(azureml|http(s)?|wasb(s)?|file):).*",
+ ),
+ ]
+ )
+
+ task = StringTransformedEnum(
+ allowed_values=[
+ SnakeCaseFineTuningTaskTypes.CHAT_COMPLETION,
+ SnakeCaseFineTuningTaskTypes.TEXT_COMPLETION,
+ SnakeCaseFineTuningTaskTypes.TEXT_CLASSIFICATION,
+ SnakeCaseFineTuningTaskTypes.QUESTION_ANSWERING,
+ SnakeCaseFineTuningTaskTypes.TEXT_SUMMARIZATION,
+ SnakeCaseFineTuningTaskTypes.TOKEN_CLASSIFICATION,
+ SnakeCaseFineTuningTaskTypes.TEXT_TRANSLATION,
+ SnakeCaseFineTuningTaskTypes.IMAGE_CLASSIFICATION,
+ SnakeCaseFineTuningTaskTypes.IMAGE_INSTANCE_SEGMENTATION,
+ SnakeCaseFineTuningTaskTypes.IMAGE_OBJECT_DETECTION,
+ SnakeCaseFineTuningTaskTypes.VIDEO_MULTI_OBJECT_TRACKING,
+ ],
+ casing_transform=snake_to_camel,
+ data_key=FineTuningConstants.TaskType,
+ required=True,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_notification/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_notification/__init__.py
new file mode 100644
index 00000000..b95c2d6d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_notification/__init__.py
@@ -0,0 +1,11 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
+
+from .notification_schema import NotificationSchema
+
+__all__ = [
+ "NotificationSchema",
+]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_notification/notification_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_notification/notification_schema.py
new file mode 100644
index 00000000..21245bc9
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_notification/notification_schema.py
@@ -0,0 +1,24 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, validate, post_load
+
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+
+
+class NotificationSchema(metaclass=PatchedSchemaMeta):
+ email_on = fields.List(
+ fields.Str(validate=validate.OneOf(["JobCompleted", "JobFailed", "JobCancelled"])),
+ required=True,
+ allow_none=False,
+ )
+ emails = fields.List(fields.Str, required=True, allow_none=False)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._notification.notification import Notification
+
+ return Notification(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/__init__.py
new file mode 100644
index 00000000..1d08c92a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/__init__.py
@@ -0,0 +1,9 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
+
+from .sweep_job import SweepJobSchema
+
+__all__ = ["SweepJobSchema"]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/_constants.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/_constants.py
new file mode 100644
index 00000000..644c3046
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/_constants.py
@@ -0,0 +1,6 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+
+BASE_ERROR_MESSAGE = "Search space type not one of: "
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/parameterized_sweep.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/parameterized_sweep.py
new file mode 100644
index 00000000..e48c9637
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/parameterized_sweep.py
@@ -0,0 +1,30 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from azure.ai.ml._schema.core.fields import ExperimentalField, NestedField, PathAwareSchema
+from azure.ai.ml._schema.job_resource_configuration import JobResourceConfigurationSchema
+
+from ..job.job_limits import SweepJobLimitsSchema
+from ..queue_settings import QueueSettingsSchema
+from .sweep_fields_provider import EarlyTerminationField, SamplingAlgorithmField, SearchSpaceField
+from .sweep_objective import SweepObjectiveSchema
+
+
+class ParameterizedSweepSchema(PathAwareSchema):
+ """Shared schema for standalone and pipeline sweep job."""
+
+ sampling_algorithm = SamplingAlgorithmField()
+ search_space = SearchSpaceField()
+ objective = NestedField(
+ SweepObjectiveSchema,
+ required=True,
+ metadata={"description": "The name and optimization goal of the primary metric."},
+ )
+ early_termination = EarlyTerminationField()
+ limits = NestedField(
+ SweepJobLimitsSchema,
+ required=True,
+ )
+ queue_settings = ExperimentalField(NestedField(QueueSettingsSchema))
+ resources = NestedField(JobResourceConfigurationSchema)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/__init__.py
new file mode 100644
index 00000000..d206a9b6
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/__init__.py
@@ -0,0 +1,21 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
+
+from .choice import ChoiceSchema
+from .normal import IntegerQNormalSchema, NormalSchema, QNormalSchema
+from .randint import RandintSchema
+from .uniform import IntegerQUniformSchema, QUniformSchema, UniformSchema
+
+__all__ = [
+ "ChoiceSchema",
+ "NormalSchema",
+ "QNormalSchema",
+ "RandintSchema",
+ "UniformSchema",
+ "QUniformSchema",
+ "IntegerQUniformSchema",
+ "IntegerQNormalSchema",
+]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/choice.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/choice.py
new file mode 100644
index 00000000..7e6b5a76
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/choice.py
@@ -0,0 +1,63 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import ValidationError, fields, post_load, pre_dump
+
+from azure.ai.ml._schema._sweep.search_space.normal import NormalSchema, QNormalSchema
+from azure.ai.ml._schema._sweep.search_space.randint import RandintSchema
+from azure.ai.ml._schema._sweep.search_space.uniform import QUniformSchema, UniformSchema
+from azure.ai.ml._schema.core.fields import (
+ DumpableIntegerField,
+ DumpableStringField,
+ NestedField,
+ StringTransformedEnum,
+ UnionField,
+)
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml.constants._job.sweep import SearchSpace
+
+
+class ChoiceSchema(metaclass=PatchedSchemaMeta):
+ values = fields.List(
+ UnionField(
+ [
+ DumpableIntegerField(strict=True),
+ DumpableStringField(),
+ fields.Float(),
+ fields.Dict(
+ keys=fields.Str(),
+ values=UnionField(
+ [
+ NestedField("ChoiceSchema"),
+ NestedField(NormalSchema()),
+ NestedField(QNormalSchema()),
+ NestedField(RandintSchema()),
+ NestedField(UniformSchema()),
+ NestedField(QUniformSchema()),
+ DumpableIntegerField(strict=True),
+ fields.Float(),
+ fields.Str(),
+ ]
+ ),
+ ),
+ ]
+ )
+ )
+ type = StringTransformedEnum(required=True, allowed_values=SearchSpace.CHOICE)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.sweep import Choice
+
+ return Choice(**data)
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ from azure.ai.ml.sweep import Choice
+
+ if not isinstance(data, Choice):
+ raise ValidationError("Cannot dump non-Choice object into ChoiceSchema")
+ return data
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/normal.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/normal.py
new file mode 100644
index 00000000..b29f175e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/normal.py
@@ -0,0 +1,60 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import ValidationError, fields, post_load
+from marshmallow.decorators import pre_dump
+
+from azure.ai.ml._schema.core.fields import DumpableIntegerField, StringTransformedEnum, UnionField
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml.constants._common import TYPE
+from azure.ai.ml.constants._job.sweep import SearchSpace
+
+
+class NormalSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(required=True, allowed_values=SearchSpace.NORMAL_LOGNORMAL)
+ mu = UnionField([DumpableIntegerField(strict=True), fields.Float()], required=True)
+ sigma = UnionField([DumpableIntegerField(strict=True), fields.Float()], required=True)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.sweep import LogNormal, Normal
+
+ return Normal(**data) if data[TYPE] == SearchSpace.NORMAL else LogNormal(**data)
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ from azure.ai.ml.sweep import Normal
+
+ if not isinstance(data, Normal):
+ raise ValidationError("Cannot dump non-Normal object into NormalSchema")
+ return data
+
+
+class QNormalSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(required=True, allowed_values=SearchSpace.QNORMAL_QLOGNORMAL)
+ mu = UnionField([DumpableIntegerField(strict=True), fields.Float()], required=True)
+ sigma = UnionField([DumpableIntegerField(strict=True), fields.Float()], required=True)
+ q = UnionField([DumpableIntegerField(strict=True), fields.Float()], required=True)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.sweep import QLogNormal, QNormal
+
+ return QNormal(**data) if data[TYPE] == SearchSpace.QNORMAL else QLogNormal(**data)
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ from azure.ai.ml.sweep import QLogNormal, QNormal
+
+ if not isinstance(data, (QNormal, QLogNormal)):
+ raise ValidationError("Cannot dump non-QNormal or non-QLogNormal object into QNormalSchema")
+ return data
+
+
+class IntegerQNormalSchema(QNormalSchema):
+ mu = DumpableIntegerField(strict=True, required=True)
+ sigma = DumpableIntegerField(strict=True, required=True)
+ q = DumpableIntegerField(strict=True, required=True)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/randint.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/randint.py
new file mode 100644
index 00000000..8df0d4f5
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/randint.py
@@ -0,0 +1,30 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import ValidationError, fields, post_load, pre_dump
+
+from azure.ai.ml._schema.core.fields import StringTransformedEnum
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml.constants._job.sweep import SearchSpace
+
+
+class RandintSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(required=True, allowed_values=SearchSpace.RANDINT)
+ upper = fields.Integer(required=True)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.sweep import Randint
+
+ return Randint(**data)
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ from azure.ai.ml.sweep import Randint
+
+ if not isinstance(data, Randint):
+ raise ValidationError("Cannot dump non-Randint object into RandintSchema")
+ return data
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/uniform.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/uniform.py
new file mode 100644
index 00000000..2eb1d98f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/search_space/uniform.py
@@ -0,0 +1,62 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import ValidationError, fields, post_load, pre_dump
+
+from azure.ai.ml._schema._sweep._constants import BASE_ERROR_MESSAGE
+from azure.ai.ml._schema.core.fields import DumpableIntegerField, StringTransformedEnum, UnionField
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml.constants._common import TYPE
+from azure.ai.ml.constants._job.sweep import SearchSpace
+
+
+class UniformSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(required=True, allowed_values=SearchSpace.UNIFORM_LOGUNIFORM)
+ min_value = UnionField([DumpableIntegerField(strict=True), fields.Float()], required=True)
+ max_value = UnionField([DumpableIntegerField(strict=True), fields.Float()], required=True)
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ from azure.ai.ml.sweep import LogUniform, Uniform
+
+ if not isinstance(data, (Uniform, LogUniform)):
+ raise ValidationError("Cannot dump non-Uniform or non-LogUniform object into UniformSchema")
+ if data.type.lower() not in SearchSpace.UNIFORM_LOGUNIFORM:
+ raise ValidationError(BASE_ERROR_MESSAGE + str(SearchSpace.UNIFORM_LOGUNIFORM))
+ return data
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.sweep import LogUniform, Uniform
+
+ return Uniform(**data) if data[TYPE] == SearchSpace.UNIFORM else LogUniform(**data)
+
+
+class QUniformSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(required=True, allowed_values=SearchSpace.QUNIFORM_QLOGUNIFORM)
+ min_value = UnionField([DumpableIntegerField(strict=True), fields.Float()], required=True)
+ max_value = UnionField([DumpableIntegerField(strict=True), fields.Float()], required=True)
+ q = UnionField([DumpableIntegerField(strict=True), fields.Float()], required=True)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.sweep import QLogUniform, QUniform
+
+ return QUniform(**data) if data[TYPE] == SearchSpace.QUNIFORM else QLogUniform(**data)
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ from azure.ai.ml.sweep import QLogUniform, QUniform
+
+ if not isinstance(data, (QUniform, QLogUniform)):
+ raise ValidationError("Cannot dump non-QUniform or non-QLogUniform object into UniformSchema")
+ return data
+
+
+class IntegerQUniformSchema(QUniformSchema):
+ min_value = DumpableIntegerField(strict=True, required=True)
+ max_value = DumpableIntegerField(strict=True, required=True)
+ q = DumpableIntegerField(strict=True, required=True)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_fields_provider.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_fields_provider.py
new file mode 100644
index 00000000..e96d4fa2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_fields_provider.py
@@ -0,0 +1,77 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields
+
+from azure.ai.ml._restclient.v2022_02_01_preview.models import SamplingAlgorithmType
+from azure.ai.ml._schema._sweep.search_space import (
+ ChoiceSchema,
+ NormalSchema,
+ QNormalSchema,
+ QUniformSchema,
+ RandintSchema,
+ UniformSchema,
+)
+from azure.ai.ml._schema._sweep.sweep_sampling_algorithm import (
+ BayesianSamplingAlgorithmSchema,
+ GridSamplingAlgorithmSchema,
+ RandomSamplingAlgorithmSchema,
+)
+from azure.ai.ml._schema._sweep.sweep_termination import (
+ BanditPolicySchema,
+ MedianStoppingPolicySchema,
+ TruncationSelectionPolicySchema,
+)
+from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum, UnionField
+
+
+def SamplingAlgorithmField():
+ return UnionField(
+ [
+ SamplingAlgorithmTypeField(),
+ NestedField(RandomSamplingAlgorithmSchema()),
+ NestedField(GridSamplingAlgorithmSchema()),
+ NestedField(BayesianSamplingAlgorithmSchema()),
+ ]
+ )
+
+
+def SamplingAlgorithmTypeField():
+ return StringTransformedEnum(
+ required=True,
+ allowed_values=[
+ SamplingAlgorithmType.BAYESIAN,
+ SamplingAlgorithmType.GRID,
+ SamplingAlgorithmType.RANDOM,
+ ],
+ metadata={"description": "The sampling algorithm to use for the hyperparameter sweep."},
+ )
+
+
+def SearchSpaceField():
+ return fields.Dict(
+ keys=fields.Str(),
+ values=UnionField(
+ [
+ NestedField(ChoiceSchema()),
+ NestedField(UniformSchema()),
+ NestedField(QUniformSchema()),
+ NestedField(NormalSchema()),
+ NestedField(QNormalSchema()),
+ NestedField(RandintSchema()),
+ ]
+ ),
+ metadata={"description": "The parameters to sweep over the trial."},
+ )
+
+
+def EarlyTerminationField():
+ return UnionField(
+ [
+ NestedField(BanditPolicySchema()),
+ NestedField(MedianStoppingPolicySchema()),
+ NestedField(TruncationSelectionPolicySchema()),
+ ],
+ metadata={"description": "The early termination policy to be applied to the Sweep runs."},
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_job.py
new file mode 100644
index 00000000..f835ed0a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_job.py
@@ -0,0 +1,18 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from azure.ai.ml._schema._sweep.parameterized_sweep import ParameterizedSweepSchema
+from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum
+from azure.ai.ml._schema.job import BaseJobSchema, ParameterizedCommandSchema
+from azure.ai.ml._schema.job.input_output_fields_provider import InputsField, OutputsField
+from azure.ai.ml.constants import JobType
+
+# This is meant to match the yaml definition NOT the models defined in _restclient
+
+
+class SweepJobSchema(BaseJobSchema, ParameterizedSweepSchema):
+ type = StringTransformedEnum(required=True, allowed_values=JobType.SWEEP)
+ trial = NestedField(ParameterizedCommandSchema, required=True)
+ inputs = InputsField()
+ outputs = OutputsField()
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_objective.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_objective.py
new file mode 100644
index 00000000..fdc24fdf
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_objective.py
@@ -0,0 +1,31 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._restclient.v2022_10_01.models import Goal
+from azure.ai.ml._schema.core.fields import StringTransformedEnum
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml._utils.utils import camel_to_snake
+
+module_logger = logging.getLogger(__name__)
+
+
+class SweepObjectiveSchema(metaclass=PatchedSchemaMeta):
+ goal = StringTransformedEnum(
+ required=True,
+ allowed_values=[Goal.MINIMIZE, Goal.MAXIMIZE],
+ casing_transform=camel_to_snake,
+ )
+ primary_metric = fields.Str(required=True)
+
+ @post_load
+ def make(self, data, **kwargs) -> "Objective":
+ from azure.ai.ml.entities._job.sweep.objective import Objective
+
+ return Objective(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_sampling_algorithm.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_sampling_algorithm.py
new file mode 100644
index 00000000..2b8137b4
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_sampling_algorithm.py
@@ -0,0 +1,103 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+
+from marshmallow import ValidationError, fields, post_load, pre_dump
+
+from azure.ai.ml._restclient.v2023_02_01_preview.models import RandomSamplingAlgorithmRule, SamplingAlgorithmType
+from azure.ai.ml._schema.core.fields import StringTransformedEnum, UnionField
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml._utils.utils import camel_to_snake
+
+module_logger = logging.getLogger(__name__)
+
+
+class RandomSamplingAlgorithmSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(
+ required=True,
+ allowed_values=SamplingAlgorithmType.RANDOM,
+ casing_transform=camel_to_snake,
+ )
+
+ seed = fields.Int()
+
+ logbase = UnionField(
+ [
+ fields.Number(),
+ fields.Str(),
+ ],
+ data_key="logbase",
+ )
+
+ rule = StringTransformedEnum(
+ allowed_values=[
+ RandomSamplingAlgorithmRule.RANDOM,
+ RandomSamplingAlgorithmRule.SOBOL,
+ ],
+ casing_transform=camel_to_snake,
+ )
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.sweep import RandomSamplingAlgorithm
+
+ data.pop("type")
+ return RandomSamplingAlgorithm(**data)
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ from azure.ai.ml.sweep import RandomSamplingAlgorithm
+
+ if not isinstance(data, RandomSamplingAlgorithm):
+ raise ValidationError("Cannot dump non-RandomSamplingAlgorithm object into RandomSamplingAlgorithm")
+ return data
+
+
+class GridSamplingAlgorithmSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(
+ required=True,
+ allowed_values=SamplingAlgorithmType.GRID,
+ casing_transform=camel_to_snake,
+ )
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.sweep import GridSamplingAlgorithm
+
+ data.pop("type")
+ return GridSamplingAlgorithm(**data)
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ from azure.ai.ml.sweep import GridSamplingAlgorithm
+
+ if not isinstance(data, GridSamplingAlgorithm):
+ raise ValidationError("Cannot dump non-GridSamplingAlgorithm object into GridSamplingAlgorithm")
+ return data
+
+
+class BayesianSamplingAlgorithmSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(
+ required=True,
+ allowed_values=SamplingAlgorithmType.BAYESIAN,
+ casing_transform=camel_to_snake,
+ )
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.sweep import BayesianSamplingAlgorithm
+
+ data.pop("type")
+ return BayesianSamplingAlgorithm(**data)
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ from azure.ai.ml.sweep import BayesianSamplingAlgorithm
+
+ if not isinstance(data, BayesianSamplingAlgorithm):
+ raise ValidationError("Cannot dump non-BayesianSamplingAlgorithm object into BayesianSamplingAlgorithm")
+ return data
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_termination.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_termination.py
new file mode 100644
index 00000000..08fa9145
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_sweep/sweep_termination.py
@@ -0,0 +1,95 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+
+from marshmallow import ValidationError, fields, post_load, pre_dump
+
+from azure.ai.ml._restclient.v2022_02_01_preview.models import EarlyTerminationPolicyType
+from azure.ai.ml._schema.core.fields import StringTransformedEnum
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml._utils.utils import camel_to_snake
+
+module_logger = logging.getLogger(__name__)
+
+
+class EarlyTerminationPolicySchema(metaclass=PatchedSchemaMeta):
+ evaluation_interval = fields.Int(allow_none=True)
+ delay_evaluation = fields.Int(allow_none=True)
+
+
+class BanditPolicySchema(EarlyTerminationPolicySchema):
+ type = StringTransformedEnum(
+ required=True,
+ allowed_values=EarlyTerminationPolicyType.BANDIT,
+ casing_transform=camel_to_snake,
+ )
+ slack_factor = fields.Float(allow_none=True)
+ slack_amount = fields.Float(allow_none=True)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.sweep import BanditPolicy
+
+ data.pop("type", None)
+ return BanditPolicy(**data)
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ from azure.ai.ml.sweep import BanditPolicy
+
+ if not isinstance(data, BanditPolicy):
+ raise ValidationError("Cannot dump non-BanditPolicy object into BanditPolicySchema")
+ return data
+
+
+class MedianStoppingPolicySchema(EarlyTerminationPolicySchema):
+ type = StringTransformedEnum(
+ required=True,
+ allowed_values=EarlyTerminationPolicyType.MEDIAN_STOPPING,
+ casing_transform=camel_to_snake,
+ )
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.sweep import MedianStoppingPolicy
+
+ data.pop("type", None)
+ return MedianStoppingPolicy(**data)
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ from azure.ai.ml.sweep import MedianStoppingPolicy
+
+ if not isinstance(data, MedianStoppingPolicy):
+ raise ValidationError("Cannot dump non-MedicanStoppingPolicy object into MedianStoppingPolicySchema")
+ return data
+
+
+class TruncationSelectionPolicySchema(EarlyTerminationPolicySchema):
+ type = StringTransformedEnum(
+ required=True,
+ allowed_values=EarlyTerminationPolicyType.TRUNCATION_SELECTION,
+ casing_transform=camel_to_snake,
+ )
+ truncation_percentage = fields.Int(required=True)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.sweep import TruncationSelectionPolicy
+
+ data.pop("type", None)
+ return TruncationSelectionPolicy(**data)
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ from azure.ai.ml.sweep import TruncationSelectionPolicy
+
+ if not isinstance(data, TruncationSelectionPolicy):
+ raise ValidationError(
+ "Cannot dump non-TruncationSelectionPolicy object into TruncationSelectionPolicySchema"
+ )
+ return data
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/__init__.py
new file mode 100644
index 00000000..29a4fcd3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/data_binding_expression.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/data_binding_expression.py
new file mode 100644
index 00000000..611c80a2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/data_binding_expression.py
@@ -0,0 +1,88 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from typing import Union
+
+from marshmallow import Schema, fields
+
+from azure.ai.ml._schema.core.fields import DataBindingStr, ExperimentalField, NestedField, UnionField
+from azure.ai.ml._schema.core.schema import PathAwareSchema
+
+DATA_BINDING_SUPPORTED_KEY = "_data_binding_supported"
+
+
+def _is_literal(field):
+ return not isinstance(field, (NestedField, fields.List, fields.Dict, UnionField))
+
+
+def _add_data_binding_to_field(field, attrs_to_skip, schema_stack):
+ if hasattr(field, DATA_BINDING_SUPPORTED_KEY) and getattr(field, DATA_BINDING_SUPPORTED_KEY):
+ return field
+ data_binding_field = DataBindingStr()
+ if isinstance(field, UnionField):
+ for field_obj in field.union_fields:
+ if not _is_literal(field_obj):
+ _add_data_binding_to_field(field_obj, attrs_to_skip, schema_stack=schema_stack)
+ field.insert_union_field(data_binding_field)
+ elif isinstance(field, fields.Dict):
+ # handle dict, dict value can be None
+ if field.value_field is not None:
+ field.value_field = _add_data_binding_to_field(field.value_field, attrs_to_skip, schema_stack=schema_stack)
+ elif isinstance(field, fields.List):
+ # handle list
+ field.inner = _add_data_binding_to_field(field.inner, attrs_to_skip, schema_stack=schema_stack)
+ elif isinstance(field, ExperimentalField):
+ field = ExperimentalField(
+ _add_data_binding_to_field(field.experimental_field, attrs_to_skip, schema_stack=schema_stack),
+ data_key=field.data_key,
+ attribute=field.attribute,
+ dump_only=field.dump_only,
+ required=field.required,
+ allow_none=field.allow_none,
+ )
+ elif isinstance(field, NestedField):
+ # handle nested field
+ support_data_binding_expression_for_fields(field.schema, attrs_to_skip, schema_stack=schema_stack)
+ else:
+ # change basic fields to union
+ field = UnionField(
+ [data_binding_field, field],
+ data_key=field.data_key,
+ attribute=field.attribute,
+ dump_only=field.dump_only,
+ required=field.required,
+ allow_none=field.allow_none,
+ )
+
+ setattr(field, DATA_BINDING_SUPPORTED_KEY, True)
+ return field
+
+
+# pylint: disable-next=docstring-missing-param
+def support_data_binding_expression_for_fields( # pylint: disable=name-too-long
+ schema: Union[PathAwareSchema, Schema], attrs_to_skip=None, schema_stack=None
+):
+ """Update fields inside schema to support data binding string.
+
+ Only first layer of recursive schema is supported now.
+ """
+ if hasattr(schema, DATA_BINDING_SUPPORTED_KEY) and getattr(schema, DATA_BINDING_SUPPORTED_KEY):
+ return
+
+ setattr(schema, DATA_BINDING_SUPPORTED_KEY, True)
+
+ if attrs_to_skip is None:
+ attrs_to_skip = []
+ if schema_stack is None:
+ schema_stack = []
+ schema_type_name = type(schema).__name__
+ if schema_type_name in schema_stack:
+ return
+ schema_stack.append(schema_type_name)
+ for attr, field_obj in schema.load_fields.items():
+ if attr not in attrs_to_skip:
+ schema.load_fields[attr] = _add_data_binding_to_field(field_obj, attrs_to_skip, schema_stack=schema_stack)
+ for attr, field_obj in schema.dump_fields.items():
+ if attr not in attrs_to_skip:
+ schema.dump_fields[attr] = _add_data_binding_to_field(field_obj, attrs_to_skip, schema_stack=schema_stack)
+ schema_stack.pop()
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/utils.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/utils.py
new file mode 100644
index 00000000..c1ee3568
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/utils.py
@@ -0,0 +1,94 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+import copy
+import logging
+import re
+from collections import OrderedDict
+from typing import Any, Dict, Optional, Union
+
+from marshmallow.exceptions import ValidationError
+
+module_logger = logging.getLogger(__name__)
+
+
+class ArmId(str):
+ def __new__(cls, content):
+ validate_arm_str(content)
+ return str.__new__(cls, content)
+
+
+def validate_arm_str(arm_str: Union[ArmId, str]) -> bool:
+ """Validate whether the given string is in fact in the format of an ARM ID.
+
+ :param arm_str: The string to validate.
+ :type arm_str: Either a string (in case of incorrect formatting) or ArmID (in case of correct formatting).
+ :returns: True if the string is correctly formatted, False otherwise.
+ :rtype: bool
+ """
+ reg_str = (
+ r"/subscriptions/[0-9a-f]{8}-([0-9a-f]{4}-){3}[0-9a-f]{12}?/resourcegroups/.*/providers/[a-z.a-z]*/[a-z]*/.*"
+ )
+ lowered = arm_str.lower()
+ match = re.match(reg_str, lowered)
+ if match and match.group() == lowered:
+ return True
+ raise ValidationError(f"ARM string {arm_str} is not formatted correctly.")
+
+
+def get_subnet_str(vnet_name: str, subnet: str, sub_id: Optional[str] = None, rg: Optional[str] = None) -> str:
+ if vnet_name and not subnet:
+ raise ValidationError("Subnet is required when vnet name is specified.")
+ try:
+ validate_arm_str(subnet)
+ return subnet
+ except ValidationError:
+ return (
+ f"/subscriptions/{sub_id}/resourceGroups/{rg}/"
+ f"providers/Microsoft.Network/virtualNetworks/{vnet_name}/subnets/{subnet}"
+ )
+
+
+def replace_key_in_odict(odict: OrderedDict, old_key: Any, new_key: Any):
+ if not odict or old_key not in odict:
+ return odict
+ return OrderedDict([(new_key, v) if k == old_key else (k, v) for k, v in odict.items()])
+
+
+# This is temporary until deployments(batch/K8S) support registry references
+def exit_if_registry_assets(data: Dict, caller: str) -> None:
+ startswith = "azureml://registries/"
+ if (
+ "environment" in data
+ and data["environment"]
+ and isinstance(data["environment"], str)
+ and data["environment"].startswith(startswith)
+ ):
+ raise ValidationError(f"Registry reference for environments is not supported for {caller}")
+ if "model" in data and data["model"] and isinstance(data["model"], str) and data["model"].startswith(startswith):
+ raise ValidationError(f"Registry reference for models is not supported for {caller}")
+ if (
+ "code_configuration" in data
+ and data["code_configuration"].code
+ and isinstance(data["code_configuration"].code, str)
+ and data["code_configuration"].code.startswith(startswith)
+ ):
+ raise ValidationError(f"Registry reference for code_configuration.code is not supported for {caller}")
+
+
+def _resolve_group_inputs_for_component(component, **kwargs): # pylint: disable=unused-argument
+ # Try resolve object's inputs & outputs and return a resolved new object
+ from azure.ai.ml.entities._inputs_outputs import GroupInput
+
+ result = copy.copy(component)
+
+ flatten_inputs = {}
+ for key, val in result.inputs.items():
+ if isinstance(val, GroupInput):
+ flatten_inputs.update(val.flatten(group_parameter_name=key))
+ continue
+ flatten_inputs[key] = val
+
+ # Flatten group inputs
+ result._inputs = flatten_inputs # pylint: disable=protected-access
+ return result
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/__init__.py
new file mode 100644
index 00000000..29a4fcd3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/artifact.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/artifact.py
new file mode 100644
index 00000000..fc107a78
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/artifact.py
@@ -0,0 +1,24 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+
+from .asset import AssetSchema
+
+module_logger = logging.getLogger(__name__)
+
+
+class ArtifactSchema(AssetSchema):
+ datastore = fields.Str(metadata={"description": "Name of the datastore to upload to."}, required=False)
+
+ @post_load
+ def make(self, data, **kwargs):
+ data[BASE_PATH_CONTEXT_KEY] = self.context[BASE_PATH_CONTEXT_KEY]
+ return data
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/asset.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/asset.py
new file mode 100644
index 00000000..09edb115
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/asset.py
@@ -0,0 +1,42 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+
+from marshmallow import ValidationError, fields
+from marshmallow.decorators import pre_load
+
+from azure.ai.ml._schema.core.auto_delete_setting import AutoDeleteSettingSchema
+from azure.ai.ml._schema.core.fields import NestedField, VersionField, ExperimentalField
+from azure.ai.ml._schema.job.creation_context import CreationContextSchema
+
+from ..core.resource import ResourceSchema
+
+module_logger = logging.getLogger(__name__)
+
+
+class AssetSchema(ResourceSchema):
+ version = VersionField()
+ creation_context = NestedField(CreationContextSchema, dump_only=True)
+ latest_version = fields.Str(dump_only=True)
+ auto_delete_setting = ExperimentalField(NestedField(AutoDeleteSettingSchema))
+
+
+class AnonymousAssetSchema(AssetSchema):
+ version = VersionField(dump_only=True)
+ name = fields.Str(dump_only=True)
+
+ @pre_load
+ def warn_if_named(self, data, **kwargs):
+ if isinstance(data, str):
+ raise ValidationError("Anonymous assets must be defined inline")
+ name = data.pop("name", None)
+ data.pop("version", None)
+ if name is not None:
+ module_logger.warning(
+ "Warning: the provided asset name '%s' will not be used for anonymous registration.", name
+ )
+ return data
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/code_asset.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/code_asset.py
new file mode 100644
index 00000000..0610caff
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/code_asset.py
@@ -0,0 +1,47 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+
+from marshmallow import ValidationError, fields, post_load, pre_dump
+
+from azure.ai.ml._schema.core.fields import ArmStr
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, AzureMLResourceType
+
+from .artifact import ArtifactSchema
+from .asset import AnonymousAssetSchema
+
+module_logger = logging.getLogger(__name__)
+
+
+class CodeAssetSchema(ArtifactSchema):
+ id = ArmStr(azureml_type=AzureMLResourceType.CODE, dump_only=True)
+ path = fields.Str(
+ metadata={
+ "description": "A local path or a Blob URI pointing to a file or directory where code asset is located."
+ }
+ )
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._assets import Code
+
+ return Code(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data)
+
+
+class AnonymousCodeAssetSchema(CodeAssetSchema, AnonymousAssetSchema):
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._assets import Code
+
+ return Code(is_anonymous=True, base_path=self.context[BASE_PATH_CONTEXT_KEY], **data)
+
+ @pre_dump
+ def validate(self, data, **kwargs):
+ # AnonymousCodeAssetSchema does not support None or arm string(fall back to ArmVersionedStr)
+ if data is None or not hasattr(data, "get"):
+ raise ValidationError("Code cannot be None")
+ return data
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/data.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/data.py
new file mode 100644
index 00000000..e14afd9b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/data.py
@@ -0,0 +1,25 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields, validate
+
+from azure.ai.ml.constants._common import AssetTypes
+
+from .artifact import ArtifactSchema
+from .asset import AnonymousAssetSchema
+
+
+class DataSchema(ArtifactSchema):
+ path = fields.Str(metadata={"description": "URI pointing to a file or folder."}, required=True)
+ properties = fields.Dict(dump_only=True)
+ type = fields.Str(
+ metadata={"description": "the type of data. Valid values are uri_file, uri_folder, or mltable."},
+ validate=validate.OneOf([AssetTypes.URI_FILE, AssetTypes.URI_FOLDER, AssetTypes.MLTABLE]),
+ dump_default=AssetTypes.URI_FOLDER,
+ error_messages={"validator_failed": "value must be uri_file, uri_folder, or mltable."},
+ )
+
+
+class AnonymousDataSchema(DataSchema, AnonymousAssetSchema):
+ pass
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/environment.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/environment.py
new file mode 100644
index 00000000..3ca5333f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/environment.py
@@ -0,0 +1,160 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+
+from marshmallow import ValidationError, fields, post_load, pre_dump, pre_load
+
+from azure.ai.ml._restclient.v2022_05_01.models import (
+ InferenceContainerProperties,
+ OperatingSystemType,
+ Route,
+)
+from azure.ai.ml._schema.core.fields import ExperimentalField, NestedField, UnionField, LocalPathField
+from azure.ai.ml._schema.core.intellectual_property import IntellectualPropertySchema
+
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml.constants._common import (
+ ANONYMOUS_ENV_NAME,
+ BASE_PATH_CONTEXT_KEY,
+ CREATE_ENVIRONMENT_ERROR_MESSAGE,
+ AzureMLResourceType,
+ YAMLRefDocLinks,
+)
+
+from ..core.fields import ArmStr, RegistryStr, StringTransformedEnum, VersionField
+from .asset import AnonymousAssetSchema, AssetSchema
+
+module_logger = logging.getLogger(__name__)
+
+
+class BuildContextSchema(metaclass=PatchedSchemaMeta):
+ dockerfile_path = fields.Str()
+ path = UnionField(
+ [
+ LocalPathField(),
+ # build context also support http url
+ fields.URL(),
+ ]
+ )
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._assets.environment import BuildContext
+
+ return BuildContext(**data)
+
+
+class RouteSchema(metaclass=PatchedSchemaMeta):
+ port = fields.Int(required=True)
+ path = fields.Str(required=True)
+
+ @post_load
+ def make(self, data, **kwargs):
+ return Route(**data)
+
+
+class InferenceConfigSchema(metaclass=PatchedSchemaMeta):
+ liveness_route = NestedField(RouteSchema, required=True)
+ scoring_route = NestedField(RouteSchema, required=True)
+ readiness_route = NestedField(RouteSchema, required=True)
+
+ @post_load
+ def make(self, data, **kwargs):
+ return InferenceContainerProperties(**data)
+
+
+class _BaseEnvironmentSchema(AssetSchema):
+ id = UnionField(
+ [
+ RegistryStr(dump_only=True),
+ ArmStr(azureml_type=AzureMLResourceType.ENVIRONMENT, dump_only=True),
+ ]
+ )
+ build = NestedField(
+ BuildContextSchema,
+ metadata={"description": "Docker build context to create the environment. Mutually exclusive with image"},
+ )
+ image = fields.Str()
+ conda_file = UnionField([fields.Raw(), fields.Str()])
+ inference_config = NestedField(InferenceConfigSchema)
+ os_type = StringTransformedEnum(
+ allowed_values=[OperatingSystemType.Linux, OperatingSystemType.Windows],
+ required=False,
+ )
+ datastore = fields.Str(
+ metadata={
+ "description": "Name of the datastore to upload to.",
+ "arm_type": AzureMLResourceType.DATASTORE,
+ },
+ required=False,
+ )
+ intellectual_property = ExperimentalField(NestedField(IntellectualPropertySchema), dump_only=True)
+
+ @pre_load
+ def pre_load(self, data, **kwargs):
+ if isinstance(data, str):
+ raise ValidationError("Environment schema data cannot be a string")
+ # validates that "channels" and "dependencies" are not included in the data creation.
+ # These properties should only be on environment conda files not in the environment creation file
+ if "channels" in data or "dependencies" in data:
+ environmentMessage = CREATE_ENVIRONMENT_ERROR_MESSAGE.format(YAMLRefDocLinks.ENVIRONMENT)
+ raise ValidationError(environmentMessage)
+ return data
+
+ @pre_dump
+ def validate(self, data, **kwargs):
+ from azure.ai.ml.entities._assets import Environment
+
+ if isinstance(data, Environment):
+ if data._intellectual_property: # pylint: disable=protected-access
+ ipp_field = data._intellectual_property # pylint: disable=protected-access
+ if ipp_field:
+ setattr(data, "intellectual_property", ipp_field)
+ return data
+ if data is None or not hasattr(data, "get"):
+ raise ValidationError("Environment cannot be None")
+ return data
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._assets import Environment
+
+ try:
+ obj = Environment(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data)
+ except FileNotFoundError as e:
+ # Environment.__init__() will raise FileNotFoundError if build.path is not found when trying to calculate
+ # the hash for anonymous. Raise ValidationError instead to collect all errors in schema validation.
+ raise ValidationError("Environment file not found: {}".format(e)) from e
+ return obj
+
+
+class EnvironmentSchema(_BaseEnvironmentSchema):
+ name = fields.Str(required=True)
+ version = VersionField()
+
+
+class AnonymousEnvironmentSchema(_BaseEnvironmentSchema, AnonymousAssetSchema):
+ @pre_load
+ # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype
+ def trim_dump_only(self, data, **kwargs):
+ """trim_dump_only in PathAwareSchema removes all properties which are dump only.
+
+ By the time we reach this schema name and version properties are removed so no warning is shown. This method
+ overrides trim_dump_only in PathAwareSchema to check for name and version and raise warning if present. And then
+ calls the it
+ """
+ if isinstance(data, str) or data is None:
+ return data
+ name = data.pop("name", None)
+ data.pop("version", None)
+ # CliV2AnonymousEnvironment is a default name for anonymous environment
+ if name is not None and name != ANONYMOUS_ENV_NAME:
+ module_logger.warning(
+ "Warning: the provided asset name '%s' will not be used for anonymous registration",
+ name,
+ )
+ return super(AnonymousEnvironmentSchema, self).trim_dump_only(data, **kwargs)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/federated_learning_silo.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/federated_learning_silo.py
new file mode 100644
index 00000000..80c4ba7e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/federated_learning_silo.py
@@ -0,0 +1,24 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+# # TODO determine where this file should live.
+from marshmallow import fields
+
+from azure.ai.ml._schema.core.resource import YamlFileSchema
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml._schema.job.input_output_fields_provider import InputsField
+
+
+# Inherits from YamlFileSchema instead of something for specific because
+# this does not represent a server-side resource.
+@experimental
+class FederatedLearningSiloSchema(YamlFileSchema):
+ """The YAML definition of a silo for describing a federated learning data target.
+ Unlike most SDK/CLI schemas, this schema does not represent an AML resource;
+ it is merely used to simplify the loading and validation of silos which are used
+ to create FL pipeline nodes.
+ """
+
+ compute = fields.Str()
+ datastore = fields.Str()
+ inputs = InputsField()
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/index.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/index.py
new file mode 100644
index 00000000..4a97c0ab
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/index.py
@@ -0,0 +1,30 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.fields import ArmStr
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, AzureMLResourceType
+
+from .artifact import ArtifactSchema
+
+
+class IndexAssetSchema(ArtifactSchema):
+ name = fields.Str(required=True, allow_none=False)
+ id = ArmStr(azureml_type=AzureMLResourceType.INDEX, dump_only=True)
+ stage = fields.Str(default="Development")
+ path = fields.Str(
+ required=True,
+ metadata={
+ "description": "A local path or a Blob URI pointing to a file or directory where index files are located."
+ },
+ )
+ properties = fields.Dict(keys=fields.Str(), values=fields.Str())
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._assets import Index
+
+ return Index(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/model.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/model.py
new file mode 100644
index 00000000..60c17f63
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/model.py
@@ -0,0 +1,65 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+
+from marshmallow import fields, post_load, pre_dump
+
+from azure.ai.ml._schema.core.fields import ExperimentalField, NestedField
+from azure.ai.ml._schema.core.intellectual_property import IntellectualPropertySchema
+from azure.ai.ml._schema.core.schema import PathAwareSchema
+from azure.ai.ml._schema.job import CreationContextSchema
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, AssetTypes, AzureMLResourceType
+
+from ..core.fields import ArmVersionedStr, StringTransformedEnum, VersionField
+
+module_logger = logging.getLogger(__name__)
+
+
+class ModelSchema(PathAwareSchema):
+ name = fields.Str(required=True)
+ id = ArmVersionedStr(azureml_type=AzureMLResourceType.MODEL, dump_only=True)
+ type = StringTransformedEnum(
+ allowed_values=[
+ AssetTypes.CUSTOM_MODEL,
+ AssetTypes.MLFLOW_MODEL,
+ AssetTypes.TRITON_MODEL,
+ ],
+ metadata={"description": "The storage format for this entity. Used for NCD."},
+ )
+ path = fields.Str()
+ version = VersionField()
+ description = fields.Str()
+ properties = fields.Dict()
+ tags = fields.Dict()
+ stage = fields.Str()
+ utc_time_created = fields.DateTime(format="iso", dump_only=True)
+ flavors = fields.Dict()
+ creation_context = NestedField(CreationContextSchema, dump_only=True)
+ job_name = fields.Str(dump_only=True)
+ latest_version = fields.Str(dump_only=True)
+ datastore = fields.Str(metadata={"description": "Name of the datastore to upload to."}, required=False)
+ intellectual_property = ExperimentalField(NestedField(IntellectualPropertySchema, required=False), dump_only=True)
+ system_metadata = fields.Dict()
+
+ @pre_dump
+ def validate(self, data, **kwargs):
+ if data._intellectual_property: # pylint: disable=protected-access
+ ipp_field = data._intellectual_property # pylint: disable=protected-access
+ if ipp_field:
+ setattr(data, "intellectual_property", ipp_field)
+ return data
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._assets import Model
+
+ return Model(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data)
+
+
+class AnonymousModelSchema(ModelSchema):
+ name = fields.Str()
+ version = VersionField()
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/__init__.py
new file mode 100644
index 00000000..29a4fcd3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/base_environment_source.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/base_environment_source.py
new file mode 100644
index 00000000..09e0a56c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/base_environment_source.py
@@ -0,0 +1,23 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+from marshmallow import fields, post_load
+from azure.ai.ml._schema.core.schema import PathAwareSchema
+
+
+module_logger = logging.getLogger(__name__)
+
+
+class BaseEnvironmentSourceSchema(PathAwareSchema):
+ type = fields.Str()
+ resource_id = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import BaseEnvironment
+
+ return BaseEnvironment(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/inference_server.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/inference_server.py
new file mode 100644
index 00000000..c6e38331
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/inference_server.py
@@ -0,0 +1,51 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument,no-else-return
+
+import logging
+
+from marshmallow import post_load
+from azure.ai.ml._schema._deployment.code_configuration_schema import CodeConfigurationSchema
+from azure.ai.ml._schema.core.fields import StringTransformedEnum, NestedField
+from azure.ai.ml._schema.core.schema import PathAwareSchema
+from azure.ai.ml.constants._common import InferenceServerType
+from .online_inference_configuration import OnlineInferenceConfigurationSchema
+
+
+module_logger = logging.getLogger(__name__)
+
+
+class InferenceServerSchema(PathAwareSchema):
+ type = StringTransformedEnum(
+ allowed_values=[
+ InferenceServerType.AZUREML_ONLINE,
+ InferenceServerType.AZUREML_BATCH,
+ InferenceServerType.CUSTOM,
+ InferenceServerType.TRITON,
+ ],
+ required=True,
+ )
+ code_configuration = NestedField(CodeConfigurationSchema) # required for batch and online
+ inference_configuration = NestedField(OnlineInferenceConfigurationSchema) # required for custom and Triton
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import (
+ AzureMLOnlineInferencingServer,
+ AzureMLBatchInferencingServer,
+ CustomInferencingServer,
+ TritonInferencingServer,
+ )
+
+ if data["type"] == InferenceServerType.AZUREML_ONLINE:
+ return AzureMLOnlineInferencingServer(**data)
+ elif data["type"] == InferenceServerType.AZUREML_BATCH:
+ return AzureMLBatchInferencingServer(**data)
+ elif data["type"] == InferenceServerType.CUSTOM:
+ return CustomInferencingServer(**data)
+ elif data["type"] == InferenceServerType.TRITON:
+ return TritonInferencingServer(**data)
+ else:
+ return None
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/model_configuration.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/model_configuration.py
new file mode 100644
index 00000000..0e5a54a5
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/model_configuration.py
@@ -0,0 +1,30 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+
+from marshmallow import fields, post_load
+from azure.ai.ml._schema.core.schema import PathAwareSchema
+from azure.ai.ml._schema.core.fields import StringTransformedEnum
+
+
+module_logger = logging.getLogger(__name__)
+
+
+class ModelConfigurationSchema(PathAwareSchema):
+ mode = StringTransformedEnum(
+ allowed_values=[
+ "copy",
+ "download",
+ ]
+ )
+ mount_path = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import ModelConfiguration
+
+ return ModelConfiguration(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/model_package.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/model_package.py
new file mode 100644
index 00000000..142c85c8
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/model_package.py
@@ -0,0 +1,41 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.schema import PathAwareSchema
+from azure.ai.ml._schema.core.fields import UnionField, NestedField, StringTransformedEnum
+from .inference_server import InferenceServerSchema
+from .model_configuration import ModelConfigurationSchema
+from .model_package_input import ModelPackageInputSchema
+from .base_environment_source import BaseEnvironmentSourceSchema
+
+module_logger = logging.getLogger(__name__)
+
+
+class ModelPackageSchema(PathAwareSchema):
+ target_environment = UnionField(
+ union_fields=[
+ fields.Dict(keys=StringTransformedEnum(allowed_values=["name"]), values=fields.Str()),
+ fields.Str(required=True),
+ ]
+ )
+ base_environment_source = NestedField(BaseEnvironmentSourceSchema)
+ inferencing_server = NestedField(InferenceServerSchema)
+ model_configuration = NestedField(ModelConfigurationSchema)
+ inputs = fields.List(NestedField(ModelPackageInputSchema))
+ tags = fields.Dict()
+ environment_variables = fields.Dict(
+ metadata={"description": "Environment variables configuration for the model package."}
+ )
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import ModelPackage
+
+ return ModelPackage(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/model_package_input.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/model_package_input.py
new file mode 100644
index 00000000..a1a1dd8b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/model_package_input.py
@@ -0,0 +1,81 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+
+from marshmallow import fields, post_load
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml._schema.core.schema import PathAwareSchema
+from azure.ai.ml._schema.core.fields import StringTransformedEnum, UnionField, NestedField
+
+module_logger = logging.getLogger(__name__)
+
+
+class PathBaseSchema(PathAwareSchema):
+ input_path_type = StringTransformedEnum(
+ allowed_values=[
+ "path_id",
+ "url",
+ "path_version",
+ ],
+ casing_transform=camel_to_snake,
+ )
+
+
+class PackageInputPathIdSchema(PathBaseSchema):
+ resource_id = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._assets._artifacts._package.model_package import PackageInputPathId
+
+ return PackageInputPathId(**data)
+
+
+class PackageInputPathUrlSchema(PathBaseSchema):
+ url = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._assets._artifacts._package.model_package import PackageInputPathUrl
+
+ return PackageInputPathUrl(**data)
+
+
+class PackageInputPathSchema(PathBaseSchema):
+ resource_name = fields.Str()
+ resource_version = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._assets._artifacts._package.model_package import PackageInputPathVersion
+
+ return PackageInputPathVersion(**data)
+
+
+class ModelPackageInputSchema(PathAwareSchema):
+ type = StringTransformedEnum(allowed_values=["uri_file", "uri_folder"], casing_transform=camel_to_snake)
+ mode = StringTransformedEnum(
+ allowed_values=[
+ "read_only_mount",
+ "download",
+ ],
+ casing_transform=camel_to_snake,
+ )
+ path = UnionField(
+ [
+ NestedField(PackageInputPathIdSchema),
+ NestedField(PackageInputPathUrlSchema),
+ NestedField(PackageInputPathSchema),
+ ]
+ )
+ mount_path = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._assets._artifacts._package.model_package import ModelPackageInput
+
+ return ModelPackageInput(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/online_inference_configuration.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/online_inference_configuration.py
new file mode 100644
index 00000000..b5c313ed
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/online_inference_configuration.py
@@ -0,0 +1,30 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+from marshmallow import fields, post_load
+from azure.ai.ml._schema.core.fields import NestedField
+from azure.ai.ml._schema.core.schema import PathAwareSchema
+from .route import RouteSchema
+
+
+module_logger = logging.getLogger(__name__)
+
+
+class OnlineInferenceConfigurationSchema(PathAwareSchema):
+ liveness_route = NestedField(RouteSchema)
+ readiness_route = NestedField(RouteSchema)
+ scoring_route = NestedField(RouteSchema)
+ entry_script = fields.Str()
+ configuration = fields.Dict()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._assets._artifacts._package.inferencing_server import (
+ OnlineInferenceConfiguration,
+ )
+
+ return OnlineInferenceConfiguration(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/route.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/route.py
new file mode 100644
index 00000000..86f37e06
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/package/route.py
@@ -0,0 +1,22 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument,bad-mcs-method-argument
+
+import logging
+from marshmallow import fields, post_load
+from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
+
+module_logger = logging.getLogger(__name__)
+
+
+class RouteSchema(PatchedSchemaMeta):
+ port = fields.Str()
+ path = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._assets._artifacts._package.inferencing_server import Route
+
+ return Route(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/workspace_asset_reference.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/workspace_asset_reference.py
new file mode 100644
index 00000000..83d6d793
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/assets/workspace_asset_reference.py
@@ -0,0 +1,27 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+
+from .asset import AssetSchema
+
+module_logger = logging.getLogger(__name__)
+
+
+class WorkspaceAssetReferenceSchema(AssetSchema):
+ destination_name = fields.Str()
+ destination_version = fields.Str()
+ source_asset_id = fields.Str(required=True)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._assets.workspace_asset_reference import WorkspaceAssetReference
+
+ return WorkspaceAssetReference(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/__init__.py
new file mode 100644
index 00000000..36befc7c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/__init__.py
@@ -0,0 +1,30 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
+
+from .automl_job import AutoMLJobSchema
+from .automl_vertical import AutoMLVerticalSchema
+from .featurization_settings import FeaturizationSettingsSchema, TableFeaturizationSettingsSchema
+from .forecasting_settings import ForecastingSettingsSchema
+from .table_vertical.classification import AutoMLClassificationSchema
+from .table_vertical.forecasting import AutoMLForecastingSchema
+from .table_vertical.regression import AutoMLRegressionSchema
+from .table_vertical.table_vertical import AutoMLTableVerticalSchema
+from .table_vertical.table_vertical_limit_settings import AutoMLTableLimitsSchema
+from .training_settings import TrainingSettingsSchema
+
+__all__ = [
+ "AutoMLJobSchema",
+ "AutoMLVerticalSchema",
+ "FeaturizationSettingsSchema",
+ "TableFeaturizationSettingsSchema",
+ "ForecastingSettingsSchema",
+ "AutoMLClassificationSchema",
+ "AutoMLForecastingSchema",
+ "AutoMLRegressionSchema",
+ "AutoMLTableVerticalSchema",
+ "AutoMLTableLimitsSchema",
+ "TrainingSettingsSchema",
+]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/automl_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/automl_job.py
new file mode 100644
index 00000000..ebec82c7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/automl_job.py
@@ -0,0 +1,21 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields
+
+from azure.ai.ml._schema.core.fields import ExperimentalField, NestedField, StringTransformedEnum
+from azure.ai.ml._schema.job import BaseJobSchema
+from azure.ai.ml._schema.job.input_output_fields_provider import OutputsField
+from azure.ai.ml._schema.job_resource_configuration import JobResourceConfigurationSchema
+from azure.ai.ml._schema.queue_settings import QueueSettingsSchema
+from azure.ai.ml.constants import JobType
+
+
+class AutoMLJobSchema(BaseJobSchema):
+ type = StringTransformedEnum(required=True, allowed_values=JobType.AUTOML)
+ environment_id = fields.Str()
+ environment_variables = fields.Dict(keys=fields.Str(), values=fields.Str())
+ outputs = OutputsField()
+ resources = NestedField(JobResourceConfigurationSchema())
+ queue_settings = ExperimentalField(NestedField(QueueSettingsSchema))
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/automl_vertical.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/automl_vertical.py
new file mode 100644
index 00000000..2cf3bb83
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/automl_vertical.py
@@ -0,0 +1,18 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import LogVerbosity
+from azure.ai.ml._schema.automl.automl_job import AutoMLJobSchema
+from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum, UnionField
+from azure.ai.ml._schema.job.input_output_entry import MLTableInputSchema
+from azure.ai.ml._utils.utils import camel_to_snake
+
+
+class AutoMLVerticalSchema(AutoMLJobSchema):
+ log_verbosity = StringTransformedEnum(
+ allowed_values=[o.value for o in LogVerbosity],
+ casing_transform=camel_to_snake,
+ load_default=LogVerbosity.INFO,
+ )
+ training_data = UnionField([NestedField(MLTableInputSchema)])
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/featurization_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/featurization_settings.py
new file mode 100644
index 00000000..19998e45
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/featurization_settings.py
@@ -0,0 +1,74 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields as flds
+from marshmallow import post_load
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import BlockedTransformers
+from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum, UnionField
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.constants._job.automl import AutoMLConstants, AutoMLTransformerParameterKeys
+
+
+class ColumnTransformerSchema(metaclass=PatchedSchemaMeta):
+ fields = flds.List(flds.Str())
+ parameters = flds.Dict(
+ keys=flds.Str(),
+ values=UnionField([flds.Float(), flds.Str()], allow_none=True, load_default=None),
+ )
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.automl import ColumnTransformer
+
+ return ColumnTransformer(**data)
+
+
+class FeaturizationSettingsSchema(metaclass=PatchedSchemaMeta):
+ dataset_language = flds.Str()
+
+
+class NlpFeaturizationSettingsSchema(FeaturizationSettingsSchema):
+ dataset_language = flds.Str()
+
+ @post_load
+ def make(self, data, **kwargs) -> "NlpFeaturizationSettings":
+ from azure.ai.ml.automl import NlpFeaturizationSettings
+
+ return NlpFeaturizationSettings(**data)
+
+
+class TableFeaturizationSettingsSchema(FeaturizationSettingsSchema):
+ mode = StringTransformedEnum(
+ allowed_values=[
+ AutoMLConstants.AUTO,
+ AutoMLConstants.OFF,
+ AutoMLConstants.CUSTOM,
+ ],
+ load_default=AutoMLConstants.AUTO,
+ )
+ blocked_transformers = flds.List(
+ StringTransformedEnum(
+ allowed_values=[o.value for o in BlockedTransformers],
+ casing_transform=camel_to_snake,
+ )
+ )
+ column_name_and_types = flds.Dict(keys=flds.Str(), values=flds.Str())
+ transformer_params = flds.Dict(
+ keys=StringTransformedEnum(
+ allowed_values=[o.value for o in AutoMLTransformerParameterKeys],
+ casing_transform=camel_to_snake,
+ ),
+ values=flds.List(NestedField(ColumnTransformerSchema())),
+ )
+ enable_dnn_featurization = flds.Bool()
+
+ @post_load
+ def make(self, data, **kwargs) -> "TabularFeaturizationSettings":
+ from azure.ai.ml.automl import TabularFeaturizationSettings
+
+ return TabularFeaturizationSettings(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/forecasting_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/forecasting_settings.py
new file mode 100644
index 00000000..56033e14
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/forecasting_settings.py
@@ -0,0 +1,66 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import FeatureLags as FeatureLagsMode
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ ForecastHorizonMode,
+ SeasonalityMode,
+ ShortSeriesHandlingConfiguration,
+ TargetAggregationFunction,
+ TargetLagsMode,
+ TargetRollingWindowSizeMode,
+)
+from azure.ai.ml._restclient.v2023_04_01_preview.models import UseStl as STLMode
+from azure.ai.ml._schema.core.fields import StringTransformedEnum, UnionField
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+
+
+class ForecastingSettingsSchema(metaclass=PatchedSchemaMeta):
+ country_or_region_for_holidays = fields.Str()
+ cv_step_size = fields.Int()
+ forecast_horizon = UnionField(
+ [
+ StringTransformedEnum(allowed_values=[ForecastHorizonMode.AUTO]),
+ fields.Int(),
+ ]
+ )
+ target_lags = UnionField(
+ [
+ StringTransformedEnum(allowed_values=[TargetLagsMode.AUTO]),
+ fields.Int(),
+ fields.List(fields.Int()),
+ ]
+ )
+ target_rolling_window_size = UnionField(
+ [
+ StringTransformedEnum(allowed_values=[TargetRollingWindowSizeMode.AUTO]),
+ fields.Int(),
+ ]
+ )
+ time_column_name = fields.Str()
+ time_series_id_column_names = UnionField([fields.Str(), fields.List(fields.Str())])
+ frequency = fields.Str()
+ feature_lags = StringTransformedEnum(allowed_values=[FeatureLagsMode.NONE, FeatureLagsMode.AUTO])
+ seasonality = UnionField(
+ [
+ StringTransformedEnum(allowed_values=[SeasonalityMode.AUTO]),
+ fields.Int(),
+ ]
+ )
+ short_series_handling_config = StringTransformedEnum(
+ allowed_values=[o.value for o in ShortSeriesHandlingConfiguration]
+ )
+ use_stl = StringTransformedEnum(allowed_values=[STLMode.NONE, STLMode.SEASON, STLMode.SEASON_TREND])
+ target_aggregate_function = StringTransformedEnum(allowed_values=[o.value for o in TargetAggregationFunction])
+ features_unknown_at_forecast_time = UnionField([fields.Str(), fields.List(fields.Str())])
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._job.automl.tabular.forecasting_settings import ForecastingSettings
+
+ return ForecastingSettings(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/__init__.py
new file mode 100644
index 00000000..29a4fcd3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_classification.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_classification.py
new file mode 100644
index 00000000..c539f037
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_classification.py
@@ -0,0 +1,66 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from typing import Any, Dict
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ ClassificationMultilabelPrimaryMetrics,
+ ClassificationPrimaryMetrics,
+ TaskType,
+)
+from azure.ai.ml._schema.automl.image_vertical.image_model_distribution_settings import (
+ ImageModelDistributionSettingsClassificationSchema,
+)
+from azure.ai.ml._schema.automl.image_vertical.image_model_settings import ImageModelSettingsClassificationSchema
+from azure.ai.ml._schema.automl.image_vertical.image_vertical import ImageVerticalSchema
+from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.constants._job.automl import AutoMLConstants
+
+
+class ImageClassificationBaseSchema(ImageVerticalSchema):
+ training_parameters = NestedField(ImageModelSettingsClassificationSchema())
+ search_space = fields.List(NestedField(ImageModelDistributionSettingsClassificationSchema()))
+
+
+class ImageClassificationSchema(ImageClassificationBaseSchema):
+ task_type = StringTransformedEnum(
+ allowed_values=TaskType.IMAGE_CLASSIFICATION,
+ casing_transform=camel_to_snake,
+ data_key=AutoMLConstants.TASK_TYPE_YAML,
+ required=True,
+ )
+ primary_metric = StringTransformedEnum(
+ allowed_values=[o.value for o in ClassificationPrimaryMetrics],
+ casing_transform=camel_to_snake,
+ load_default=camel_to_snake(ClassificationPrimaryMetrics.Accuracy),
+ )
+
+ @post_load
+ def make(self, data, **kwargs) -> Dict[str, Any]:
+ data.pop("task_type")
+ return data
+
+
+class ImageClassificationMultilabelSchema(ImageClassificationBaseSchema):
+ task_type = StringTransformedEnum(
+ allowed_values=TaskType.IMAGE_CLASSIFICATION_MULTILABEL,
+ casing_transform=camel_to_snake,
+ data_key=AutoMLConstants.TASK_TYPE_YAML,
+ required=True,
+ )
+ primary_metric = StringTransformedEnum(
+ allowed_values=[o.value for o in ClassificationMultilabelPrimaryMetrics],
+ casing_transform=camel_to_snake,
+ load_default=camel_to_snake(ClassificationMultilabelPrimaryMetrics.IOU),
+ )
+
+ @post_load
+ def make(self, data, **kwargs) -> Dict[str, Any]:
+ data.pop("task_type")
+ return data
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_limit_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_limit_settings.py
new file mode 100644
index 00000000..3f5c73e8
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_limit_settings.py
@@ -0,0 +1,21 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+
+
+class ImageLimitsSchema(metaclass=PatchedSchemaMeta):
+ max_concurrent_trials = fields.Int()
+ max_trials = fields.Int()
+ timeout_minutes = fields.Int() # type duration
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.automl import ImageLimitSettings
+
+ return ImageLimitSettings(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_model_distribution_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_model_distribution_settings.py
new file mode 100644
index 00000000..9f784038
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_model_distribution_settings.py
@@ -0,0 +1,216 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_dump, post_load, pre_load
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ LearningRateScheduler,
+ ModelSize,
+ StochasticOptimizer,
+ ValidationMetricType,
+)
+from azure.ai.ml._schema._sweep.search_space import (
+ ChoiceSchema,
+ IntegerQNormalSchema,
+ IntegerQUniformSchema,
+ NormalSchema,
+ QNormalSchema,
+ QUniformSchema,
+ RandintSchema,
+ UniformSchema,
+)
+from azure.ai.ml._schema.core.fields import (
+ DumpableIntegerField,
+ DumpableStringField,
+ NestedField,
+ StringTransformedEnum,
+ UnionField,
+)
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml._utils.utils import camel_to_snake
+
+
+def choice_schema_of_type(cls, **kwargs):
+ class CustomChoiceSchema(ChoiceSchema):
+ values = fields.List(cls(**kwargs))
+
+ return CustomChoiceSchema()
+
+
+def choice_and_single_value_schema_of_type(cls, **kwargs):
+ # Reshuffling the order of fields for allowing choice of booleans.
+ # The reason is, while dumping [Bool, Choice[Bool]] is parsing even dict as True.
+ # Since all unionFields are parsed sequentially, to avoid this, we are giving the "type" field at the end.
+ return UnionField([NestedField(choice_schema_of_type(cls, **kwargs)), cls(**kwargs)])
+
+
+FLOAT_SEARCH_SPACE_DISTRIBUTION_FIELD = UnionField(
+ [
+ fields.Float(),
+ DumpableIntegerField(strict=True),
+ NestedField(choice_schema_of_type(DumpableIntegerField, strict=True)),
+ NestedField(choice_schema_of_type(fields.Float)),
+ NestedField(UniformSchema()),
+ NestedField(QUniformSchema()),
+ NestedField(NormalSchema()),
+ NestedField(QNormalSchema()),
+ NestedField(RandintSchema()),
+ ]
+)
+
+INT_SEARCH_SPACE_DISTRIBUTION_FIELD = UnionField(
+ [
+ DumpableIntegerField(strict=True),
+ NestedField(choice_schema_of_type(DumpableIntegerField, strict=True)),
+ NestedField(RandintSchema()),
+ NestedField(IntegerQUniformSchema()),
+ NestedField(IntegerQNormalSchema()),
+ ]
+)
+
+STRING_SEARCH_SPACE_DISTRIBUTION_FIELD = choice_and_single_value_schema_of_type(DumpableStringField)
+BOOL_SEARCH_SPACE_DISTRIBUTION_FIELD = choice_and_single_value_schema_of_type(fields.Bool)
+
+model_size_enum_args = {"allowed_values": [o.value for o in ModelSize], "casing_transform": camel_to_snake}
+learning_rate_scheduler_enum_args = {
+ "allowed_values": [o.value for o in LearningRateScheduler],
+ "casing_transform": camel_to_snake,
+}
+optimizer_enum_args = {"allowed_values": [o.value for o in StochasticOptimizer], "casing_transform": camel_to_snake}
+validation_metric_enum_args = {
+ "allowed_values": [o.value for o in ValidationMetricType],
+ "casing_transform": camel_to_snake,
+}
+
+
+MODEL_SIZE_DISTRIBUTION_FIELD = choice_and_single_value_schema_of_type(StringTransformedEnum, **model_size_enum_args)
+LEARNING_RATE_SCHEDULER_DISTRIBUTION_FIELD = choice_and_single_value_schema_of_type(
+ StringTransformedEnum, **learning_rate_scheduler_enum_args
+)
+OPTIMIZER_DISTRIBUTION_FIELD = choice_and_single_value_schema_of_type(StringTransformedEnum, **optimizer_enum_args)
+VALIDATION_METRIC_DISTRIBUTION_FIELD = choice_and_single_value_schema_of_type(
+ StringTransformedEnum, **validation_metric_enum_args
+)
+
+
+class ImageModelDistributionSettingsSchema(metaclass=PatchedSchemaMeta):
+ ams_gradient = BOOL_SEARCH_SPACE_DISTRIBUTION_FIELD
+ augmentations = STRING_SEARCH_SPACE_DISTRIBUTION_FIELD
+ beta1 = FLOAT_SEARCH_SPACE_DISTRIBUTION_FIELD
+ beta2 = FLOAT_SEARCH_SPACE_DISTRIBUTION_FIELD
+ distributed = BOOL_SEARCH_SPACE_DISTRIBUTION_FIELD
+ early_stopping = BOOL_SEARCH_SPACE_DISTRIBUTION_FIELD
+ early_stopping_delay = INT_SEARCH_SPACE_DISTRIBUTION_FIELD
+ early_stopping_patience = INT_SEARCH_SPACE_DISTRIBUTION_FIELD
+ evaluation_frequency = INT_SEARCH_SPACE_DISTRIBUTION_FIELD
+ enable_onnx_normalization = BOOL_SEARCH_SPACE_DISTRIBUTION_FIELD
+ gradient_accumulation_step = INT_SEARCH_SPACE_DISTRIBUTION_FIELD
+ layers_to_freeze = INT_SEARCH_SPACE_DISTRIBUTION_FIELD
+ learning_rate = FLOAT_SEARCH_SPACE_DISTRIBUTION_FIELD
+ learning_rate_scheduler = LEARNING_RATE_SCHEDULER_DISTRIBUTION_FIELD
+ momentum = FLOAT_SEARCH_SPACE_DISTRIBUTION_FIELD
+ nesterov = BOOL_SEARCH_SPACE_DISTRIBUTION_FIELD
+ number_of_epochs = INT_SEARCH_SPACE_DISTRIBUTION_FIELD
+ number_of_workers = INT_SEARCH_SPACE_DISTRIBUTION_FIELD
+ optimizer = OPTIMIZER_DISTRIBUTION_FIELD
+ random_seed = INT_SEARCH_SPACE_DISTRIBUTION_FIELD
+ step_lr_gamma = FLOAT_SEARCH_SPACE_DISTRIBUTION_FIELD
+ step_lr_step_size = INT_SEARCH_SPACE_DISTRIBUTION_FIELD
+ training_batch_size = INT_SEARCH_SPACE_DISTRIBUTION_FIELD
+ validation_batch_size = INT_SEARCH_SPACE_DISTRIBUTION_FIELD
+ warmup_cosine_lr_cycles = FLOAT_SEARCH_SPACE_DISTRIBUTION_FIELD
+ warmup_cosine_lr_warmup_epochs = INT_SEARCH_SPACE_DISTRIBUTION_FIELD
+ weight_decay = FLOAT_SEARCH_SPACE_DISTRIBUTION_FIELD
+
+
+# pylint: disable-next=name-too-long
+class ImageModelDistributionSettingsClassificationSchema(ImageModelDistributionSettingsSchema):
+ model_name = STRING_SEARCH_SPACE_DISTRIBUTION_FIELD
+ training_crop_size = INT_SEARCH_SPACE_DISTRIBUTION_FIELD
+ validation_crop_size = INT_SEARCH_SPACE_DISTRIBUTION_FIELD
+ validation_resize_size = INT_SEARCH_SPACE_DISTRIBUTION_FIELD
+ weighted_loss = INT_SEARCH_SPACE_DISTRIBUTION_FIELD
+
+ @post_dump
+ def conversion(self, data, **kwargs):
+ if self.context.get("inside_pipeline", False): # pylint: disable=no-member
+ # AutoML job inside pipeline does load(dump) instead of calling to_rest_object
+ # explicitly for creating the autoRest Object from sdk job.
+ # Hence for pipeline job, we explicitly convert Sweep Distribution dict to str after dump in this method.
+ # For standalone automl job, same conversion happens in image_classification_job._to_rest_object()
+ from azure.ai.ml.entities._job.automl.search_space_utils import _convert_sweep_dist_dict_to_str_dict
+
+ data = _convert_sweep_dist_dict_to_str_dict(data)
+ return data
+
+ @pre_load
+ def before_make(self, data, **kwargs):
+ if self.context.get("inside_pipeline", False): # pylint: disable=no-member
+ from azure.ai.ml.entities._job.automl.search_space_utils import _convert_sweep_dist_str_to_dict
+
+ # Converting Sweep Distribution str to Sweep Distribution dict for complying with search_space schema.
+ data = _convert_sweep_dist_str_to_dict(data)
+ return data
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.automl import ImageClassificationSearchSpace
+
+ return ImageClassificationSearchSpace(**data)
+
+
+# pylint: disable-next=name-too-long
+class ImageModelDistributionSettingsDetectionCommonSchema(ImageModelDistributionSettingsSchema):
+ box_detections_per_image = INT_SEARCH_SPACE_DISTRIBUTION_FIELD
+ box_score_threshold = FLOAT_SEARCH_SPACE_DISTRIBUTION_FIELD
+ image_size = INT_SEARCH_SPACE_DISTRIBUTION_FIELD
+ max_size = INT_SEARCH_SPACE_DISTRIBUTION_FIELD
+ min_size = INT_SEARCH_SPACE_DISTRIBUTION_FIELD
+ model_size = MODEL_SIZE_DISTRIBUTION_FIELD
+ multi_scale = BOOL_SEARCH_SPACE_DISTRIBUTION_FIELD
+ nms_iou_threshold = FLOAT_SEARCH_SPACE_DISTRIBUTION_FIELD
+ tile_grid_size = STRING_SEARCH_SPACE_DISTRIBUTION_FIELD
+ tile_overlap_ratio = FLOAT_SEARCH_SPACE_DISTRIBUTION_FIELD
+ tile_predictions_nms_threshold = FLOAT_SEARCH_SPACE_DISTRIBUTION_FIELD
+ validation_iou_threshold = FLOAT_SEARCH_SPACE_DISTRIBUTION_FIELD
+ validation_metric_type = VALIDATION_METRIC_DISTRIBUTION_FIELD
+
+ @post_dump
+ def conversion(self, data, **kwargs):
+ if self.context.get("inside_pipeline", False): # pylint: disable=no-member
+ # AutoML job inside pipeline does load(dump) instead of calling to_rest_object
+ # explicitly for creating the autoRest Object from sdk job object.
+ # Hence for pipeline job, we explicitly convert Sweep Distribution dict to str after dump in this method.
+ # For standalone automl job, same conversion happens in image_object_detection_job._to_rest_object()
+ from azure.ai.ml.entities._job.automl.search_space_utils import _convert_sweep_dist_dict_to_str_dict
+
+ data = _convert_sweep_dist_dict_to_str_dict(data)
+ return data
+
+ @pre_load
+ def before_make(self, data, **kwargs):
+ if self.context.get("inside_pipeline", False): # pylint: disable=no-member
+ from azure.ai.ml.entities._job.automl.search_space_utils import _convert_sweep_dist_str_to_dict
+
+ # Converting Sweep Distribution str to Sweep Distribution dict for complying with search_space schema.
+ data = _convert_sweep_dist_str_to_dict(data)
+ return data
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.automl import ImageObjectDetectionSearchSpace
+
+ return ImageObjectDetectionSearchSpace(**data)
+
+
+# pylint: disable-next=name-too-long
+class ImageModelDistributionSettingsObjectDetectionSchema(ImageModelDistributionSettingsDetectionCommonSchema):
+ model_name = STRING_SEARCH_SPACE_DISTRIBUTION_FIELD
+
+
+# pylint: disable-next=name-too-long
+class ImageModelDistributionSettingsInstanceSegmentationSchema(ImageModelDistributionSettingsObjectDetectionSchema):
+ model_name = STRING_SEARCH_SPACE_DISTRIBUTION_FIELD
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_model_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_model_settings.py
new file mode 100644
index 00000000..7c88e628
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_model_settings.py
@@ -0,0 +1,96 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ LearningRateScheduler,
+ ModelSize,
+ StochasticOptimizer,
+ ValidationMetricType,
+)
+from azure.ai.ml._schema.core.fields import StringTransformedEnum
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml._utils.utils import camel_to_snake
+
+
+class ImageModelSettingsSchema(metaclass=PatchedSchemaMeta):
+ ams_gradient = fields.Bool()
+ advanced_settings = fields.Str()
+ beta1 = fields.Float()
+ beta2 = fields.Float()
+ checkpoint_frequency = fields.Int()
+ checkpoint_run_id = fields.Str()
+ distributed = fields.Bool()
+ early_stopping = fields.Bool()
+ early_stopping_delay = fields.Int()
+ early_stopping_patience = fields.Int()
+ evaluation_frequency = fields.Int()
+ enable_onnx_normalization = fields.Bool()
+ gradient_accumulation_step = fields.Int()
+ layers_to_freeze = fields.Int()
+ learning_rate = fields.Float()
+ learning_rate_scheduler = StringTransformedEnum(
+ allowed_values=[o.value for o in LearningRateScheduler],
+ casing_transform=camel_to_snake,
+ )
+ model_name = fields.Str()
+ momentum = fields.Float()
+ nesterov = fields.Bool()
+ number_of_epochs = fields.Int()
+ number_of_workers = fields.Int()
+ optimizer = StringTransformedEnum(
+ allowed_values=[o.value for o in StochasticOptimizer],
+ casing_transform=camel_to_snake,
+ )
+ random_seed = fields.Int()
+ step_lr_gamma = fields.Float()
+ step_lr_step_size = fields.Int()
+ training_batch_size = fields.Int()
+ validation_batch_size = fields.Int()
+ warmup_cosine_lr_cycles = fields.Float()
+ warmup_cosine_lr_warmup_epochs = fields.Int()
+ weight_decay = fields.Float()
+
+
+class ImageModelSettingsClassificationSchema(ImageModelSettingsSchema):
+ training_crop_size = fields.Int()
+ validation_crop_size = fields.Int()
+ validation_resize_size = fields.Int()
+ weighted_loss = fields.Int()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._job.automl.image.image_model_settings import ImageModelSettingsClassification
+
+ return ImageModelSettingsClassification(**data)
+
+
+class ImageModelSettingsObjectDetectionSchema(ImageModelSettingsSchema):
+ box_detections_per_image = fields.Int()
+ box_score_threshold = fields.Float()
+ image_size = fields.Int()
+ max_size = fields.Int()
+ min_size = fields.Int()
+ model_size = StringTransformedEnum(allowed_values=[o.value for o in ModelSize], casing_transform=camel_to_snake)
+ multi_scale = fields.Bool()
+ nms_iou_threshold = fields.Float()
+ tile_grid_size = fields.Str()
+ tile_overlap_ratio = fields.Float()
+ tile_predictions_nms_threshold = fields.Float()
+ validation_iou_threshold = fields.Float()
+ validation_metric_type = StringTransformedEnum(
+ allowed_values=[o.value for o in ValidationMetricType],
+ casing_transform=camel_to_snake,
+ )
+ log_training_metrics = fields.Str()
+ log_validation_loss = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._job.automl.image.image_model_settings import ImageModelSettingsObjectDetection
+
+ return ImageModelSettingsObjectDetection(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_object_detection.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_object_detection.py
new file mode 100644
index 00000000..cb753882
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_object_detection.py
@@ -0,0 +1,66 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from typing import Any, Dict
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ InstanceSegmentationPrimaryMetrics,
+ ObjectDetectionPrimaryMetrics,
+ TaskType,
+)
+from azure.ai.ml._schema.automl.image_vertical.image_model_distribution_settings import (
+ ImageModelDistributionSettingsInstanceSegmentationSchema,
+ ImageModelDistributionSettingsObjectDetectionSchema,
+)
+from azure.ai.ml._schema.automl.image_vertical.image_model_settings import ImageModelSettingsObjectDetectionSchema
+from azure.ai.ml._schema.automl.image_vertical.image_vertical import ImageVerticalSchema
+from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.constants._job.automl import AutoMLConstants
+
+
+class ImageObjectDetectionSchema(ImageVerticalSchema):
+ task_type = StringTransformedEnum(
+ allowed_values=TaskType.IMAGE_OBJECT_DETECTION,
+ casing_transform=camel_to_snake,
+ data_key=AutoMLConstants.TASK_TYPE_YAML,
+ required=True,
+ )
+ primary_metric = StringTransformedEnum(
+ allowed_values=ObjectDetectionPrimaryMetrics.MEAN_AVERAGE_PRECISION,
+ casing_transform=camel_to_snake,
+ load_default=camel_to_snake(ObjectDetectionPrimaryMetrics.MEAN_AVERAGE_PRECISION),
+ )
+ training_parameters = NestedField(ImageModelSettingsObjectDetectionSchema())
+ search_space = fields.List(NestedField(ImageModelDistributionSettingsObjectDetectionSchema()))
+
+ @post_load
+ def make(self, data, **kwargs) -> Dict[str, Any]:
+ data.pop("task_type")
+ return data
+
+
+class ImageInstanceSegmentationSchema(ImageVerticalSchema):
+ task_type = StringTransformedEnum(
+ allowed_values=TaskType.IMAGE_INSTANCE_SEGMENTATION,
+ casing_transform=camel_to_snake,
+ data_key=AutoMLConstants.TASK_TYPE_YAML,
+ required=True,
+ )
+ primary_metric = StringTransformedEnum(
+ allowed_values=[InstanceSegmentationPrimaryMetrics.MEAN_AVERAGE_PRECISION],
+ casing_transform=camel_to_snake,
+ load_default=camel_to_snake(InstanceSegmentationPrimaryMetrics.MEAN_AVERAGE_PRECISION),
+ )
+ training_parameters = NestedField(ImageModelSettingsObjectDetectionSchema())
+ search_space = fields.List(NestedField(ImageModelDistributionSettingsInstanceSegmentationSchema()))
+
+ @post_load
+ def make(self, data, **kwargs) -> Dict[str, Any]:
+ data.pop("task_type")
+ return data
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_sweep_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_sweep_settings.py
new file mode 100644
index 00000000..66dfd7ae
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_sweep_settings.py
@@ -0,0 +1,27 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument,protected-access
+
+from marshmallow import post_load, pre_dump
+
+from azure.ai.ml._schema._sweep.sweep_fields_provider import EarlyTerminationField, SamplingAlgorithmField
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+
+
+class ImageSweepSettingsSchema(metaclass=PatchedSchemaMeta):
+ sampling_algorithm = SamplingAlgorithmField()
+ early_termination = EarlyTerminationField()
+
+ @pre_dump
+ def conversion(self, data, **kwargs):
+ rest_obj = data._to_rest_object()
+ rest_obj.early_termination = data.early_termination
+ return rest_obj
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.automl import ImageSweepSettings
+
+ return ImageSweepSettings(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_vertical.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_vertical.py
new file mode 100644
index 00000000..fdfaa79f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/image_vertical/image_vertical.py
@@ -0,0 +1,19 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from azure.ai.ml._schema.automl.automl_vertical import AutoMLVerticalSchema
+from azure.ai.ml._schema.automl.image_vertical.image_limit_settings import ImageLimitsSchema
+from azure.ai.ml._schema.automl.image_vertical.image_sweep_settings import ImageSweepSettingsSchema
+from azure.ai.ml._schema.core.fields import NestedField, UnionField, fields
+from azure.ai.ml._schema.job.input_output_entry import MLTableInputSchema
+
+
+class ImageVerticalSchema(AutoMLVerticalSchema):
+ limits = NestedField(ImageLimitsSchema())
+ sweep = NestedField(ImageSweepSettingsSchema())
+ target_column_name = fields.Str(required=True)
+ test_data = UnionField([NestedField(MLTableInputSchema)])
+ test_data_size = fields.Float()
+ validation_data = UnionField([NestedField(MLTableInputSchema)])
+ validation_data_size = fields.Float()
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/__init__.py
new file mode 100644
index 00000000..29a4fcd3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/nlp_fixed_parameters.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/nlp_fixed_parameters.py
new file mode 100644
index 00000000..2a5cb336
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/nlp_fixed_parameters.py
@@ -0,0 +1,33 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import NlpLearningRateScheduler
+from azure.ai.ml._schema.core.fields import StringTransformedEnum
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml._utils.utils import camel_to_snake
+
+
+class NlpFixedParametersSchema(metaclass=PatchedSchemaMeta):
+ gradient_accumulation_steps = fields.Int()
+ learning_rate = fields.Float()
+ learning_rate_scheduler = StringTransformedEnum(
+ allowed_values=[obj.value for obj in NlpLearningRateScheduler],
+ casing_transform=camel_to_snake,
+ )
+ model_name = fields.Str()
+ number_of_epochs = fields.Int()
+ training_batch_size = fields.Int()
+ validation_batch_size = fields.Int()
+ warmup_ratio = fields.Float()
+ weight_decay = fields.Float()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.automl import NlpFixedParameters
+
+ return NlpFixedParameters(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/nlp_parameter_subspace.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/nlp_parameter_subspace.py
new file mode 100644
index 00000000..de963478
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/nlp_parameter_subspace.py
@@ -0,0 +1,106 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_dump, post_load, pre_load
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import NlpLearningRateScheduler
+from azure.ai.ml._schema._sweep.search_space import (
+ ChoiceSchema,
+ NormalSchema,
+ QNormalSchema,
+ QUniformSchema,
+ RandintSchema,
+ UniformSchema,
+)
+from azure.ai.ml._schema.core.fields import (
+ DumpableIntegerField,
+ DumpableStringField,
+ NestedField,
+ StringTransformedEnum,
+ UnionField,
+)
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml._utils.utils import camel_to_snake
+
+
+def choice_schema_of_type(cls, **kwargs):
+ class CustomChoiceSchema(ChoiceSchema):
+ values = fields.List(cls(**kwargs))
+
+ return CustomChoiceSchema()
+
+
+def choice_and_single_value_schema_of_type(cls, **kwargs):
+ return UnionField([cls(**kwargs), NestedField(choice_schema_of_type(cls, **kwargs))])
+
+
+FLOAT_SEARCH_SPACE_DISTRIBUTION_FIELD = UnionField(
+ [
+ fields.Float(),
+ DumpableIntegerField(strict=True),
+ NestedField(choice_schema_of_type(DumpableIntegerField, strict=True)),
+ NestedField(choice_schema_of_type(fields.Float)),
+ NestedField(UniformSchema()),
+ NestedField(QUniformSchema()),
+ NestedField(NormalSchema()),
+ NestedField(QNormalSchema()),
+ NestedField(RandintSchema()),
+ ]
+)
+
+INT_SEARCH_SPACE_DISTRIBUTION_FIELD = UnionField(
+ [
+ DumpableIntegerField(strict=True),
+ NestedField(choice_schema_of_type(DumpableIntegerField, strict=True)),
+ NestedField(RandintSchema()),
+ ]
+)
+
+STRING_SEARCH_SPACE_DISTRIBUTION_FIELD = choice_and_single_value_schema_of_type(DumpableStringField)
+BOOL_SEARCH_SPACE_DISTRIBUTION_FIELD = choice_and_single_value_schema_of_type(fields.Bool)
+
+
+class NlpParameterSubspaceSchema(metaclass=PatchedSchemaMeta):
+ gradient_accumulation_steps = INT_SEARCH_SPACE_DISTRIBUTION_FIELD
+ learning_rate = FLOAT_SEARCH_SPACE_DISTRIBUTION_FIELD
+ learning_rate_scheduler = choice_and_single_value_schema_of_type(
+ StringTransformedEnum,
+ allowed_values=[obj.value for obj in NlpLearningRateScheduler],
+ casing_transform=camel_to_snake,
+ )
+ model_name = STRING_SEARCH_SPACE_DISTRIBUTION_FIELD
+ number_of_epochs = INT_SEARCH_SPACE_DISTRIBUTION_FIELD
+ training_batch_size = INT_SEARCH_SPACE_DISTRIBUTION_FIELD
+ validation_batch_size = INT_SEARCH_SPACE_DISTRIBUTION_FIELD
+ warmup_ratio = FLOAT_SEARCH_SPACE_DISTRIBUTION_FIELD
+ weight_decay = FLOAT_SEARCH_SPACE_DISTRIBUTION_FIELD
+
+ @post_dump
+ def conversion(self, data, **kwargs):
+ if self.context.get("inside_pipeline", False): # pylint: disable=no-member
+ # AutoML job inside pipeline does load(dump) instead of calling to_rest_object
+ # explicitly for creating the autoRest Object from sdk job.
+ # Hence for pipeline job, we explicitly convert Sweep Distribution dict to str after dump in this method.
+ # For standalone automl job, same conversion happens in text_classification_job._to_rest_object()
+ from azure.ai.ml.entities._job.automl.search_space_utils import _convert_sweep_dist_dict_to_str_dict
+
+ data = _convert_sweep_dist_dict_to_str_dict(data)
+ return data
+
+ @pre_load
+ def before_make(self, data, **kwargs):
+ if self.context.get("inside_pipeline", False): # pylint: disable=no-member
+ from azure.ai.ml.entities._job.automl.search_space_utils import _convert_sweep_dist_str_to_dict
+
+ # Converting Sweep Distribution str to Sweep Distribution dict for complying with search_space schema.
+ data = _convert_sweep_dist_str_to_dict(data)
+ return data
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.automl import NlpSearchSpace
+
+ return NlpSearchSpace(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/nlp_sweep_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/nlp_sweep_settings.py
new file mode 100644
index 00000000..ab9b5ec3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/nlp_sweep_settings.py
@@ -0,0 +1,27 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument,protected-access
+
+from marshmallow import post_load, pre_dump
+
+from azure.ai.ml._schema._sweep.sweep_fields_provider import EarlyTerminationField, SamplingAlgorithmField
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+
+
+class NlpSweepSettingsSchema(metaclass=PatchedSchemaMeta):
+ sampling_algorithm = SamplingAlgorithmField()
+ early_termination = EarlyTerminationField()
+
+ @pre_dump
+ def conversion(self, data, **kwargs):
+ rest_obj = data._to_rest_object()
+ rest_obj.early_termination = data.early_termination
+ return rest_obj
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.automl import NlpSweepSettings
+
+ return NlpSweepSettings(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/nlp_vertical.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/nlp_vertical.py
new file mode 100644
index 00000000..f701ce95
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/nlp_vertical.py
@@ -0,0 +1,24 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields
+
+from azure.ai.ml._schema.automl.automl_vertical import AutoMLVerticalSchema
+from azure.ai.ml._schema.automl.featurization_settings import NlpFeaturizationSettingsSchema
+from azure.ai.ml._schema.automl.nlp_vertical.nlp_fixed_parameters import NlpFixedParametersSchema
+from azure.ai.ml._schema.automl.nlp_vertical.nlp_parameter_subspace import NlpParameterSubspaceSchema
+from azure.ai.ml._schema.automl.nlp_vertical.nlp_sweep_settings import NlpSweepSettingsSchema
+from azure.ai.ml._schema.automl.nlp_vertical.nlp_vertical_limit_settings import NlpLimitsSchema
+from azure.ai.ml._schema.core.fields import NestedField, UnionField
+from azure.ai.ml._schema.job.input_output_entry import MLTableInputSchema
+from azure.ai.ml.constants._job.automl import AutoMLConstants
+
+
+class NlpVerticalSchema(AutoMLVerticalSchema):
+ limits = NestedField(NlpLimitsSchema())
+ sweep = NestedField(NlpSweepSettingsSchema())
+ training_parameters = NestedField(NlpFixedParametersSchema())
+ search_space = fields.List(NestedField(NlpParameterSubspaceSchema()))
+ featurization = NestedField(NlpFeaturizationSettingsSchema(), data_key=AutoMLConstants.FEATURIZATION_YAML)
+ validation_data = UnionField([NestedField(MLTableInputSchema)])
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/nlp_vertical_limit_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/nlp_vertical_limit_settings.py
new file mode 100644
index 00000000..fe054f38
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/nlp_vertical_limit_settings.py
@@ -0,0 +1,23 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+
+
+class NlpLimitsSchema(metaclass=PatchedSchemaMeta):
+ max_concurrent_trials = fields.Int()
+ max_trials = fields.Int()
+ max_nodes = fields.Int()
+ timeout_minutes = fields.Int() # type duration
+ trial_timeout_minutes = fields.Int() # type duration
+
+ @post_load
+ def make(self, data, **kwargs) -> "NlpLimitSettings":
+ from azure.ai.ml.automl import NlpLimitSettings
+
+ return NlpLimitSettings(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/text_classification.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/text_classification.py
new file mode 100644
index 00000000..14e0b7d6
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/text_classification.py
@@ -0,0 +1,36 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from typing import Any, Dict
+
+from marshmallow import post_load
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import ClassificationPrimaryMetrics, TaskType
+from azure.ai.ml._schema.automl.nlp_vertical.nlp_vertical import NlpVerticalSchema
+from azure.ai.ml._schema.core.fields import StringTransformedEnum, fields
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.constants._job.automl import AutoMLConstants
+
+
+class TextClassificationSchema(NlpVerticalSchema):
+ task_type = StringTransformedEnum(
+ allowed_values=TaskType.TEXT_CLASSIFICATION,
+ casing_transform=camel_to_snake,
+ data_key=AutoMLConstants.TASK_TYPE_YAML,
+ required=True,
+ )
+ primary_metric = StringTransformedEnum(
+ allowed_values=[o.value for o in ClassificationPrimaryMetrics],
+ casing_transform=camel_to_snake,
+ load_default=camel_to_snake(ClassificationPrimaryMetrics.ACCURACY),
+ )
+ # added here as for text_ner target_column_name is optional
+ target_column_name = fields.Str(required=True)
+
+ @post_load
+ def make(self, data, **kwargs) -> Dict[str, Any]:
+ data.pop("task_type")
+ return data
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/text_classification_multilabel.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/text_classification_multilabel.py
new file mode 100644
index 00000000..56cd5bc1
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/text_classification_multilabel.py
@@ -0,0 +1,36 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from typing import Any, Dict
+
+from marshmallow import post_load
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import ClassificationMultilabelPrimaryMetrics, TaskType
+from azure.ai.ml._schema.automl.nlp_vertical.nlp_vertical import NlpVerticalSchema
+from azure.ai.ml._schema.core.fields import StringTransformedEnum, fields
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.constants._job.automl import AutoMLConstants
+
+
+class TextClassificationMultilabelSchema(NlpVerticalSchema):
+ task_type = StringTransformedEnum(
+ allowed_values=TaskType.TEXT_CLASSIFICATION_MULTILABEL,
+ casing_transform=camel_to_snake,
+ data_key=AutoMLConstants.TASK_TYPE_YAML,
+ required=True,
+ )
+ primary_metric = StringTransformedEnum(
+ allowed_values=ClassificationMultilabelPrimaryMetrics.ACCURACY,
+ casing_transform=camel_to_snake,
+ load_default=camel_to_snake(ClassificationMultilabelPrimaryMetrics.ACCURACY),
+ )
+ # added here as for text_ner target_column_name is optional
+ target_column_name = fields.Str(required=True)
+
+ @post_load
+ def make(self, data, **kwargs) -> Dict[str, Any]:
+ data.pop("task_type")
+ return data
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/text_ner.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/text_ner.py
new file mode 100644
index 00000000..3609b1d0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/nlp_vertical/text_ner.py
@@ -0,0 +1,35 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from typing import Any, Dict
+
+from marshmallow import post_load
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import ClassificationPrimaryMetrics, TaskType
+from azure.ai.ml._schema.automl.nlp_vertical.nlp_vertical import NlpVerticalSchema
+from azure.ai.ml._schema.core.fields import StringTransformedEnum, fields
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.constants._job.automl import AutoMLConstants
+
+
+class TextNerSchema(NlpVerticalSchema):
+ task_type = StringTransformedEnum(
+ allowed_values=TaskType.TEXT_NER,
+ casing_transform=camel_to_snake,
+ data_key=AutoMLConstants.TASK_TYPE_YAML,
+ required=True,
+ )
+ primary_metric = StringTransformedEnum(
+ allowed_values=ClassificationPrimaryMetrics.ACCURACY,
+ casing_transform=camel_to_snake,
+ load_default=camel_to_snake(ClassificationPrimaryMetrics.ACCURACY),
+ )
+ target_column_name = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs) -> Dict[str, Any]:
+ data.pop("task_type")
+ return data
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/__init__.py
new file mode 100644
index 00000000..29a4fcd3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/classification.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/classification.py
new file mode 100644
index 00000000..f9ce7b8b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/classification.py
@@ -0,0 +1,37 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from typing import Any, Dict
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import ClassificationPrimaryMetrics, TaskType
+from azure.ai.ml._schema.automl.table_vertical.table_vertical import AutoMLTableVerticalSchema
+from azure.ai.ml._schema.automl.training_settings import ClassificationTrainingSettingsSchema
+from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.constants._job.automl import AutoMLConstants
+
+
+class AutoMLClassificationSchema(AutoMLTableVerticalSchema):
+ task_type = StringTransformedEnum(
+ allowed_values=TaskType.CLASSIFICATION,
+ casing_transform=camel_to_snake,
+ data_key=AutoMLConstants.TASK_TYPE_YAML,
+ required=True,
+ )
+ primary_metric = StringTransformedEnum(
+ allowed_values=[o.value for o in ClassificationPrimaryMetrics],
+ casing_transform=camel_to_snake,
+ load_default=camel_to_snake(ClassificationPrimaryMetrics.AUC_WEIGHTED),
+ )
+ positive_label = fields.Str()
+ training = NestedField(ClassificationTrainingSettingsSchema(), data_key=AutoMLConstants.TRAINING_YAML)
+
+ @post_load
+ def make(self, data, **kwargs) -> Dict[str, Any]:
+ data.pop("task_type")
+ return data
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/forecasting.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/forecasting.py
new file mode 100644
index 00000000..7f302c97
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/forecasting.py
@@ -0,0 +1,38 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from typing import Any, Dict
+
+from marshmallow import post_load
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import ForecastingPrimaryMetrics, TaskType
+from azure.ai.ml._schema.automl.forecasting_settings import ForecastingSettingsSchema
+from azure.ai.ml._schema.automl.table_vertical.table_vertical import AutoMLTableVerticalSchema
+from azure.ai.ml._schema.automl.training_settings import ForecastingTrainingSettingsSchema
+from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.constants._job.automl import AutoMLConstants
+
+
+class AutoMLForecastingSchema(AutoMLTableVerticalSchema):
+ task_type = StringTransformedEnum(
+ allowed_values=TaskType.FORECASTING,
+ casing_transform=camel_to_snake,
+ data_key=AutoMLConstants.TASK_TYPE_YAML,
+ required=True,
+ )
+ primary_metric = StringTransformedEnum(
+ allowed_values=[o.value for o in ForecastingPrimaryMetrics],
+ casing_transform=camel_to_snake,
+ load_default=camel_to_snake(ForecastingPrimaryMetrics.NORMALIZED_ROOT_MEAN_SQUARED_ERROR),
+ )
+ training = NestedField(ForecastingTrainingSettingsSchema(), data_key=AutoMLConstants.TRAINING_YAML)
+ forecasting_settings = NestedField(ForecastingSettingsSchema(), data_key=AutoMLConstants.FORECASTING_YAML)
+
+ @post_load
+ def make(self, data, **kwargs) -> Dict[str, Any]:
+ data.pop("task_type")
+ return data
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/regression.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/regression.py
new file mode 100644
index 00000000..fc1e3900
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/regression.py
@@ -0,0 +1,36 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from typing import Any, Dict
+
+from marshmallow import post_load
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import RegressionPrimaryMetrics, TaskType
+from azure.ai.ml._schema.automl.table_vertical.table_vertical import AutoMLTableVerticalSchema
+from azure.ai.ml._schema.automl.training_settings import RegressionTrainingSettingsSchema
+from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.constants._job.automl import AutoMLConstants
+
+
+class AutoMLRegressionSchema(AutoMLTableVerticalSchema):
+ task_type = StringTransformedEnum(
+ allowed_values=TaskType.REGRESSION,
+ casing_transform=camel_to_snake,
+ data_key=AutoMLConstants.TASK_TYPE_YAML,
+ required=True,
+ )
+ primary_metric = StringTransformedEnum(
+ allowed_values=[o.value for o in RegressionPrimaryMetrics],
+ casing_transform=camel_to_snake,
+ load_default=camel_to_snake(RegressionPrimaryMetrics.NORMALIZED_ROOT_MEAN_SQUARED_ERROR),
+ )
+ training = NestedField(RegressionTrainingSettingsSchema(), data_key=AutoMLConstants.TRAINING_YAML)
+
+ @post_load
+ def make(self, data, **kwargs) -> Dict[str, Any]:
+ data.pop("task_type")
+ return data
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/table_vertical.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/table_vertical.py
new file mode 100644
index 00000000..e98d7066
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/table_vertical.py
@@ -0,0 +1,29 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import NCrossValidationsMode
+from azure.ai.ml._schema.automl.automl_vertical import AutoMLVerticalSchema
+from azure.ai.ml._schema.automl.featurization_settings import TableFeaturizationSettingsSchema
+from azure.ai.ml._schema.automl.table_vertical.table_vertical_limit_settings import AutoMLTableLimitsSchema
+from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum, UnionField, fields
+from azure.ai.ml._schema.job.input_output_entry import MLTableInputSchema
+from azure.ai.ml.constants._job.automl import AutoMLConstants
+
+
+class AutoMLTableVerticalSchema(AutoMLVerticalSchema):
+ limits = NestedField(AutoMLTableLimitsSchema(), data_key=AutoMLConstants.LIMITS_YAML)
+ featurization = NestedField(TableFeaturizationSettingsSchema(), data_key=AutoMLConstants.FEATURIZATION_YAML)
+ target_column_name = fields.Str(required=True)
+ validation_data = UnionField([NestedField(MLTableInputSchema)])
+ validation_data_size = fields.Float()
+ cv_split_column_names = fields.List(fields.Str())
+ n_cross_validations = UnionField(
+ [
+ StringTransformedEnum(allowed_values=[NCrossValidationsMode.AUTO]),
+ fields.Int(),
+ ],
+ )
+ weight_column_name = fields.Str()
+ test_data = UnionField([NestedField(MLTableInputSchema)])
+ test_data_size = fields.Float()
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/table_vertical_limit_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/table_vertical_limit_settings.py
new file mode 100644
index 00000000..122774a6
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/table_vertical/table_vertical_limit_settings.py
@@ -0,0 +1,28 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema import ExperimentalField
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml.constants._job.automl import AutoMLConstants
+
+
+class AutoMLTableLimitsSchema(metaclass=PatchedSchemaMeta):
+ enable_early_termination = fields.Bool()
+ exit_score = fields.Float()
+ max_concurrent_trials = fields.Int()
+ max_cores_per_trial = fields.Int()
+ max_nodes = ExperimentalField(fields.Int())
+ max_trials = fields.Int(data_key=AutoMLConstants.MAX_TRIALS_YAML)
+ timeout_minutes = fields.Int() # type duration
+ trial_timeout_minutes = fields.Int() # type duration
+
+ @post_load
+ def make(self, data, **kwargs) -> "TabularLimitSettings":
+ from azure.ai.ml.automl import TabularLimitSettings
+
+ return TabularLimitSettings(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/training_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/training_settings.py
new file mode 100644
index 00000000..57a76892
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/automl/training_settings.py
@@ -0,0 +1,122 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ ClassificationModels,
+ ForecastingModels,
+ RegressionModels,
+ StackMetaLearnerType,
+)
+from azure.ai.ml.constants import TabularTrainingMode
+from azure.ai.ml._schema import ExperimentalField
+from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.constants._job.automl import AutoMLConstants
+from azure.ai.ml.entities._job.automl.training_settings import (
+ ClassificationTrainingSettings,
+ ForecastingTrainingSettings,
+ RegressionTrainingSettings,
+)
+
+
+class StackEnsembleSettingsSchema(metaclass=PatchedSchemaMeta):
+ stack_meta_learner_kwargs = fields.Dict()
+ stack_meta_learner_train_percentage = fields.Float()
+ stack_meta_learner_type = StringTransformedEnum(
+ allowed_values=[o.value for o in StackMetaLearnerType],
+ casing_transform=camel_to_snake,
+ )
+
+ @post_load
+ def make(self, data, **kwargs):
+ # Converting it here, as there is no corresponding entity class
+ stack_meta_learner_type = data.pop("stack_meta_learner_type")
+ stack_meta_learner_type = StackMetaLearnerType[stack_meta_learner_type.upper()]
+ from azure.ai.ml.entities._job.automl.stack_ensemble_settings import StackEnsembleSettings
+
+ return StackEnsembleSettings(stack_meta_learner_type=stack_meta_learner_type, **data)
+
+
+class TrainingSettingsSchema(metaclass=PatchedSchemaMeta):
+ enable_dnn_training = fields.Bool()
+ enable_model_explainability = fields.Bool()
+ enable_onnx_compatible_models = fields.Bool()
+ enable_stack_ensemble = fields.Bool()
+ enable_vote_ensemble = fields.Bool()
+ ensemble_model_download_timeout = fields.Int(data_key=AutoMLConstants.ENSEMBLE_MODEL_DOWNLOAD_TIMEOUT_YAML)
+ stack_ensemble_settings = NestedField(StackEnsembleSettingsSchema())
+ training_mode = ExperimentalField(
+ StringTransformedEnum(
+ allowed_values=[o.value for o in TabularTrainingMode],
+ casing_transform=camel_to_snake,
+ )
+ )
+
+
+class ClassificationTrainingSettingsSchema(TrainingSettingsSchema):
+ allowed_training_algorithms = fields.List(
+ StringTransformedEnum(
+ allowed_values=[o.value for o in ClassificationModels],
+ casing_transform=camel_to_snake,
+ ),
+ data_key=AutoMLConstants.ALLOWED_ALGORITHMS_YAML,
+ )
+ blocked_training_algorithms = fields.List(
+ StringTransformedEnum(
+ allowed_values=[o.value for o in ClassificationModels],
+ casing_transform=camel_to_snake,
+ ),
+ data_key=AutoMLConstants.BLOCKED_ALGORITHMS_YAML,
+ )
+
+ @post_load
+ def make(self, data, **kwargs) -> "ClassificationTrainingSettings":
+ return ClassificationTrainingSettings(**data)
+
+
+class ForecastingTrainingSettingsSchema(TrainingSettingsSchema):
+ allowed_training_algorithms = fields.List(
+ StringTransformedEnum(
+ allowed_values=[o.value for o in ForecastingModels],
+ casing_transform=camel_to_snake,
+ ),
+ data_key=AutoMLConstants.ALLOWED_ALGORITHMS_YAML,
+ )
+ blocked_training_algorithms = fields.List(
+ StringTransformedEnum(
+ allowed_values=[o.value for o in ForecastingModels],
+ casing_transform=camel_to_snake,
+ ),
+ data_key=AutoMLConstants.BLOCKED_ALGORITHMS_YAML,
+ )
+
+ @post_load
+ def make(self, data, **kwargs) -> "ForecastingTrainingSettings":
+ return ForecastingTrainingSettings(**data)
+
+
+class RegressionTrainingSettingsSchema(TrainingSettingsSchema):
+ allowed_training_algorithms = fields.List(
+ StringTransformedEnum(
+ allowed_values=[o.value for o in RegressionModels],
+ casing_transform=camel_to_snake,
+ ),
+ data_key=AutoMLConstants.ALLOWED_ALGORITHMS_YAML,
+ )
+ blocked_training_algorithms = fields.List(
+ StringTransformedEnum(
+ allowed_values=[o.value for o in RegressionModels],
+ casing_transform=camel_to_snake,
+ ),
+ data_key=AutoMLConstants.BLOCKED_ALGORITHMS_YAML,
+ )
+
+ @post_load
+ def make(self, data, **kwargs) -> "RegressionTrainingSettings":
+ return RegressionTrainingSettings(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/__init__.py
new file mode 100644
index 00000000..1b92f18e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/__init__.py
@@ -0,0 +1,48 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore
+
+from .command_component import AnonymousCommandComponentSchema, CommandComponentSchema, ComponentFileRefField
+from .component import ComponentSchema, ComponentYamlRefField
+from .data_transfer_component import (
+ AnonymousDataTransferCopyComponentSchema,
+ AnonymousDataTransferExportComponentSchema,
+ AnonymousDataTransferImportComponentSchema,
+ DataTransferCopyComponentFileRefField,
+ DataTransferCopyComponentSchema,
+ DataTransferExportComponentFileRefField,
+ DataTransferExportComponentSchema,
+ DataTransferImportComponentFileRefField,
+ DataTransferImportComponentSchema,
+)
+from .import_component import AnonymousImportComponentSchema, ImportComponentFileRefField, ImportComponentSchema
+from .parallel_component import AnonymousParallelComponentSchema, ParallelComponentFileRefField, ParallelComponentSchema
+from .spark_component import AnonymousSparkComponentSchema, SparkComponentFileRefField, SparkComponentSchema
+
+__all__ = [
+ "ComponentSchema",
+ "CommandComponentSchema",
+ "AnonymousCommandComponentSchema",
+ "ComponentFileRefField",
+ "ParallelComponentSchema",
+ "AnonymousParallelComponentSchema",
+ "ParallelComponentFileRefField",
+ "ImportComponentSchema",
+ "AnonymousImportComponentSchema",
+ "ImportComponentFileRefField",
+ "AnonymousSparkComponentSchema",
+ "SparkComponentFileRefField",
+ "SparkComponentSchema",
+ "AnonymousDataTransferCopyComponentSchema",
+ "DataTransferCopyComponentFileRefField",
+ "DataTransferCopyComponentSchema",
+ "AnonymousDataTransferImportComponentSchema",
+ "DataTransferImportComponentFileRefField",
+ "DataTransferImportComponentSchema",
+ "AnonymousDataTransferExportComponentSchema",
+ "DataTransferExportComponentFileRefField",
+ "DataTransferExportComponentSchema",
+ "ComponentYamlRefField",
+]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/automl_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/automl_component.py
new file mode 100644
index 00000000..aef98cca
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/automl_component.py
@@ -0,0 +1,23 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from azure.ai.ml._restclient.v2022_10_01_preview.models import TaskType
+from azure.ai.ml._schema.component.component import ComponentSchema
+from azure.ai.ml._schema.core.fields import StringTransformedEnum
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.constants import JobType
+
+
+class AutoMLComponentSchema(ComponentSchema):
+ """AutoMl component schema.
+
+ Only has type & task property with basic component properties. No inputs & outputs are allowed.
+ """
+
+ type = StringTransformedEnum(required=True, allowed_values=JobType.AUTOML)
+ task = StringTransformedEnum(
+ # TODO: verify if this works
+ allowed_values=[t for t in TaskType], # pylint: disable=unnecessary-comprehension
+ casing_transform=camel_to_snake,
+ required=True,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/command_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/command_component.py
new file mode 100644
index 00000000..9d688ee0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/command_component.py
@@ -0,0 +1,137 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument,protected-access
+from copy import deepcopy
+
+import yaml
+from marshmallow import INCLUDE, fields, post_dump, post_load
+
+from azure.ai.ml._schema.assets.asset import AnonymousAssetSchema
+from azure.ai.ml._schema.component.component import ComponentSchema
+from azure.ai.ml._schema.component.input_output import (
+ OutputPortSchema,
+ PrimitiveOutputSchema,
+)
+from azure.ai.ml._schema.component.resource import ComponentResourceSchema
+from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
+from azure.ai.ml._schema.core.fields import (
+ ExperimentalField,
+ FileRefField,
+ NestedField,
+ StringTransformedEnum,
+ UnionField,
+)
+from azure.ai.ml._schema.job.distribution import (
+ MPIDistributionSchema,
+ PyTorchDistributionSchema,
+ TensorFlowDistributionSchema,
+ RayDistributionSchema,
+)
+from azure.ai.ml._schema.job.parameterized_command import ParameterizedCommandSchema
+from azure.ai.ml._utils.utils import is_private_preview_enabled
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, AzureDevopsArtifactsType
+from azure.ai.ml.constants._component import ComponentSource, NodeType
+
+
+class AzureDevopsArtifactsSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(allowed_values=[AzureDevopsArtifactsType.ARTIFACT])
+ feed = fields.Str()
+ name = fields.Str()
+ version = fields.Str()
+ scope = fields.Str()
+ organization = fields.Str()
+ project = fields.Str()
+
+
+class CommandComponentSchema(ComponentSchema, ParameterizedCommandSchema):
+ class Meta:
+ exclude = ["environment_variables"] # component doesn't have environment variables
+
+ type = StringTransformedEnum(allowed_values=[NodeType.COMMAND])
+ resources = NestedField(ComponentResourceSchema, unknown=INCLUDE)
+ distribution = UnionField(
+ [
+ NestedField(MPIDistributionSchema, unknown=INCLUDE),
+ NestedField(TensorFlowDistributionSchema, unknown=INCLUDE),
+ NestedField(PyTorchDistributionSchema, unknown=INCLUDE),
+ ExperimentalField(NestedField(RayDistributionSchema, unknown=INCLUDE)),
+ ],
+ metadata={"description": "Provides the configuration for a distributed run."},
+ )
+ # primitive output is only supported for command component & pipeline component
+ outputs = fields.Dict(
+ keys=fields.Str(),
+ values=UnionField(
+ [
+ NestedField(OutputPortSchema),
+ NestedField(PrimitiveOutputSchema, unknown=INCLUDE),
+ ]
+ ),
+ )
+ properties = fields.Dict(keys=fields.Str(), values=fields.Raw())
+
+ # Note: AzureDevopsArtifactsSchema only available when private preview flag opened before init of command component
+ # schema class.
+ if is_private_preview_enabled():
+ additional_includes = fields.List(UnionField([fields.Str(), NestedField(AzureDevopsArtifactsSchema)]))
+ else:
+ additional_includes = fields.List(fields.Str())
+
+ @post_dump
+ def remove_unnecessary_fields(self, component_schema_dict, **kwargs):
+ # remove empty properties to keep the component spec unchanged
+ if not component_schema_dict.get("properties"):
+ component_schema_dict.pop("properties", None)
+ if (
+ component_schema_dict.get("additional_includes") is not None
+ and len(component_schema_dict["additional_includes"]) == 0
+ ):
+ component_schema_dict.pop("additional_includes")
+ return component_schema_dict
+
+
+class RestCommandComponentSchema(CommandComponentSchema):
+ """When component load from rest, won't validate on name since there might be existing component with invalid
+ name."""
+
+ name = fields.Str(required=True)
+
+
+class AnonymousCommandComponentSchema(AnonymousAssetSchema, CommandComponentSchema):
+ """Anonymous command component schema.
+
+ Note inheritance follows order: AnonymousAssetSchema, CommandComponentSchema because we need name and version to be
+ dump_only(marshmallow collects fields follows method resolution order).
+ """
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import CommandComponent
+
+ # Inline component will have source=YAML.JOB
+ # As we only regard full separate component file as YAML.COMPONENT
+ return CommandComponent(
+ base_path=self.context[BASE_PATH_CONTEXT_KEY],
+ _source=ComponentSource.YAML_JOB,
+ **data,
+ )
+
+
+class ComponentFileRefField(FileRefField):
+ def _deserialize(self, value, attr, data, **kwargs):
+ # Get component info from component yaml file.
+ data = super()._deserialize(value, attr, data, **kwargs)
+ component_dict = yaml.safe_load(data)
+ source_path = self.context[BASE_PATH_CONTEXT_KEY] / value
+
+ # Update base_path to parent path of component file.
+ component_schema_context = deepcopy(self.context)
+ component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent
+ component = AnonymousCommandComponentSchema(context=component_schema_context).load(
+ component_dict, unknown=INCLUDE
+ )
+ component._source_path = source_path
+ component._source = ComponentSource.YAML_COMPONENT
+ return component
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/component.py
new file mode 100644
index 00000000..5772a607
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/component.py
@@ -0,0 +1,143 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from pathlib import Path
+
+from marshmallow import ValidationError, fields, post_dump, pre_dump, pre_load
+from marshmallow.fields import Field
+
+from azure.ai.ml._schema.component.input_output import InputPortSchema, OutputPortSchema, ParameterSchema
+from azure.ai.ml._schema.core.fields import (
+ ArmVersionedStr,
+ ExperimentalField,
+ NestedField,
+ PythonFuncNameStr,
+ UnionField,
+)
+from azure.ai.ml._schema.core.intellectual_property import IntellectualPropertySchema
+from azure.ai.ml._utils.utils import is_private_preview_enabled, load_yaml
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, AzureMLResourceType
+
+from .._utils.utils import _resolve_group_inputs_for_component
+from ..assets.asset import AssetSchema
+from ..core.fields import RegistryStr
+
+
+class ComponentNameStr(PythonFuncNameStr):
+ def _get_field_name(self):
+ return "Component"
+
+
+class ComponentYamlRefField(Field):
+ """Allows you to nest a :class:`Schema <marshmallow.Schema>`
+ inside a yaml ref field.
+ """
+
+ def _jsonschema_type_mapping(self):
+ schema = {"type": "string"}
+ if self.name is not None:
+ schema["title"] = self.name
+ if self.dump_only:
+ schema["readonly"] = True
+ return schema
+
+ def _deserialize(self, value, attr, data, **kwargs):
+ if not isinstance(value, str):
+ raise ValidationError(f"Nested yaml ref field expected a string but got {type(value)}.")
+
+ base_path = Path(self.context[BASE_PATH_CONTEXT_KEY])
+
+ source_path = Path(value)
+ # raise if the string is not a valid path, like "azureml:xxx"
+ try:
+ source_path.resolve()
+ except OSError as ex:
+ raise ValidationError(f"Nested file ref field expected a local path but got {value}.") from ex
+
+ if not source_path.is_absolute():
+ source_path = base_path / source_path
+
+ if not source_path.is_file():
+ raise ValidationError(
+ f"Nested yaml ref field expected a local path but can't find {value} based on {base_path.as_posix()}."
+ )
+
+ loaded_value = load_yaml(source_path)
+
+ # local import to avoid circular import
+ from azure.ai.ml.entities import Component
+
+ component = Component._load(data=loaded_value, yaml_path=source_path) # pylint: disable=protected-access
+ return component
+
+ def _serialize(self, value, attr, obj, **kwargs):
+ raise ValidationError("Serialize on RefField is not supported.")
+
+
+class ComponentSchema(AssetSchema):
+ schema = fields.Str(data_key="$schema", attribute="_schema")
+ name = ComponentNameStr(required=True)
+ id = UnionField(
+ [
+ RegistryStr(dump_only=True),
+ ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, dump_only=True),
+ ]
+ )
+ display_name = fields.Str()
+ description = fields.Str()
+ tags = fields.Dict(keys=fields.Str(), values=fields.Str())
+ is_deterministic = fields.Bool()
+ inputs = fields.Dict(
+ keys=fields.Str(),
+ values=UnionField(
+ [
+ NestedField(ParameterSchema),
+ NestedField(InputPortSchema),
+ ]
+ ),
+ )
+ outputs = fields.Dict(
+ keys=fields.Str(),
+ values=NestedField(OutputPortSchema),
+ )
+ # hide in private preview
+ if is_private_preview_enabled():
+ intellectual_property = ExperimentalField(NestedField(IntellectualPropertySchema))
+
+ def __init__(self, *args, **kwargs):
+ # Remove schema_ignored to enable serialize and deserialize schema.
+ self._declared_fields.pop("schema_ignored", None)
+ super().__init__(*args, **kwargs)
+
+ @pre_load
+ def convert_version_to_str(self, data, **kwargs): # pylint: disable=unused-argument
+ if isinstance(data, dict) and data.get("version", None):
+ data["version"] = str(data["version"])
+ return data
+
+ @pre_dump
+ def add_private_fields_to_dump(self, data, **kwargs): # pylint: disable=unused-argument
+ # The ipp field is set on the component object as "_intellectual_property".
+ # We need to set it as "intellectual_property" before dumping so that Marshmallow
+ # can pick up the field correctly on dump and show it back to the user.
+ ipp_field = data._intellectual_property # pylint: disable=protected-access
+ if ipp_field:
+ setattr(data, "intellectual_property", ipp_field)
+ return data
+
+ @post_dump
+ def convert_input_value_to_str(self, data, **kwargs): # pylint:disable=unused-argument
+ if isinstance(data, dict) and data.get("inputs", None):
+ input_dict = data["inputs"]
+ for input_value in input_dict.values():
+ input_type = input_value.get("type", None)
+ if isinstance(input_type, str) and input_type.lower() == "float":
+ # Convert number to string to avoid precision issue
+ for key in ["default", "min", "max"]:
+ if input_value.get(key, None) is not None:
+ input_value[key] = str(input_value[key])
+ return data
+
+ @pre_dump
+ def flatten_group_inputs(self, data, **kwargs): # pylint: disable=unused-argument
+ return _resolve_group_inputs_for_component(data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/data_transfer_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/data_transfer_component.py
new file mode 100644
index 00000000..70035d57
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/data_transfer_component.py
@@ -0,0 +1,257 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from copy import deepcopy
+
+import yaml
+from marshmallow import INCLUDE, fields, post_load, validates, ValidationError
+
+from azure.ai.ml._schema.assets.asset import AnonymousAssetSchema
+from azure.ai.ml._schema.component.component import ComponentSchema
+from azure.ai.ml._schema.component.input_output import InputPortSchema
+from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
+from azure.ai.ml._schema.core.fields import FileRefField, StringTransformedEnum, NestedField
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, AssetTypes
+from azure.ai.ml.constants._component import (
+ ComponentSource,
+ NodeType,
+ DataTransferTaskType,
+ DataCopyMode,
+ ExternalDataType,
+)
+
+
+class DataTransferComponentSchemaMixin(ComponentSchema):
+ type = StringTransformedEnum(allowed_values=[NodeType.DATA_TRANSFER])
+
+
+class DataTransferCopyComponentSchema(DataTransferComponentSchemaMixin):
+ task = StringTransformedEnum(allowed_values=[DataTransferTaskType.COPY_DATA], required=True)
+ data_copy_mode = StringTransformedEnum(
+ allowed_values=[DataCopyMode.MERGE_WITH_OVERWRITE, DataCopyMode.FAIL_IF_CONFLICT]
+ )
+ inputs = fields.Dict(
+ keys=fields.Str(),
+ values=NestedField(InputPortSchema),
+ )
+
+ @validates("outputs")
+ def outputs_key(self, value):
+ outputs_count = len(value)
+ if outputs_count != 1:
+ msg = "Only support single output in {}, but there're {} outputs."
+ raise ValidationError(
+ message=msg.format(DataTransferTaskType.COPY_DATA, outputs_count), field_name="outputs"
+ )
+
+
+class SinkSourceSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(
+ allowed_values=[ExternalDataType.FILE_SYSTEM, ExternalDataType.DATABASE], required=True
+ )
+
+
+class SourceInputsSchema(metaclass=PatchedSchemaMeta):
+ """
+ For export task in DataTransfer, inputs type only support uri_file for database and uri_folder for filesystem.
+ """
+
+ type = StringTransformedEnum(allowed_values=[AssetTypes.URI_FOLDER, AssetTypes.URI_FILE], required=True)
+
+
+class SinkOutputsSchema(metaclass=PatchedSchemaMeta):
+ """
+ For import task in DataTransfer, outputs type only support mltable for database and uri_folder for filesystem;
+ """
+
+ type = StringTransformedEnum(allowed_values=[AssetTypes.MLTABLE, AssetTypes.URI_FOLDER], required=True)
+
+
+class DataTransferImportComponentSchema(DataTransferComponentSchemaMixin):
+ task = StringTransformedEnum(allowed_values=[DataTransferTaskType.IMPORT_DATA], required=True)
+ source = NestedField(SinkSourceSchema, required=True)
+ outputs = fields.Dict(
+ keys=fields.Str(),
+ values=NestedField(SinkOutputsSchema),
+ )
+
+ @validates("inputs")
+ def inputs_key(self, value):
+ raise ValidationError(f"inputs field is not a valid filed in task type " f"{DataTransferTaskType.IMPORT_DATA}.")
+
+ @validates("outputs")
+ def outputs_key(self, value):
+ if len(value) != 1 or value and list(value.keys())[0] != "sink":
+ raise ValidationError(
+ f"outputs field only support one output called sink in task type "
+ f"{DataTransferTaskType.IMPORT_DATA}."
+ )
+
+
+class DataTransferExportComponentSchema(DataTransferComponentSchemaMixin):
+ task = StringTransformedEnum(allowed_values=[DataTransferTaskType.EXPORT_DATA], required=True)
+ inputs = fields.Dict(
+ keys=fields.Str(),
+ values=NestedField(SourceInputsSchema),
+ )
+ sink = NestedField(SinkSourceSchema(), required=True)
+
+ @validates("inputs")
+ def inputs_key(self, value):
+ if len(value) != 1 or value and list(value.keys())[0] != "source":
+ raise ValidationError(
+ f"inputs field only support one input called source in task type "
+ f"{DataTransferTaskType.EXPORT_DATA}."
+ )
+
+ @validates("outputs")
+ def outputs_key(self, value):
+ raise ValidationError(
+ f"outputs field is not a valid filed in task type " f"{DataTransferTaskType.EXPORT_DATA}."
+ )
+
+
+class RestDataTransferCopyComponentSchema(DataTransferCopyComponentSchema):
+ """When component load from rest, won't validate on name since there might
+ be existing component with invalid name."""
+
+ name = fields.Str(required=True)
+
+
+class RestDataTransferImportComponentSchema(DataTransferImportComponentSchema):
+ """When component load from rest, won't validate on name since there might
+ be existing component with invalid name."""
+
+ name = fields.Str(required=True)
+
+
+class RestDataTransferExportComponentSchema(DataTransferExportComponentSchema):
+ """When component load from rest, won't validate on name since there might
+ be existing component with invalid name."""
+
+ name = fields.Str(required=True)
+
+
+class AnonymousDataTransferCopyComponentSchema(AnonymousAssetSchema, DataTransferCopyComponentSchema):
+ """Anonymous data transfer copy component schema.
+
+ Note inheritance follows order: AnonymousAssetSchema,
+ AnonymousDataTransferCopyComponentSchema because we need name and version to be
+ dump_only(marshmallow collects fields follows method resolution
+ order).
+ """
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._component.datatransfer_component import DataTransferCopyComponent
+
+ # Inline component will have source=YAML.JOB
+ # As we only regard full separate component file as YAML.COMPONENT
+ return DataTransferCopyComponent(
+ base_path=self.context[BASE_PATH_CONTEXT_KEY],
+ _source=kwargs.pop("_source", ComponentSource.YAML_JOB),
+ **data,
+ )
+
+
+# pylint: disable-next=name-too-long
+class AnonymousDataTransferImportComponentSchema(AnonymousAssetSchema, DataTransferImportComponentSchema):
+ """Anonymous data transfer import component schema.
+
+ Note inheritance follows order: AnonymousAssetSchema,
+ DataTransferImportComponentSchema because we need name and version to be
+ dump_only(marshmallow collects fields follows method resolution
+ order).
+ """
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._component.datatransfer_component import DataTransferImportComponent
+
+ # Inline component will have source=YAML.JOB
+ # As we only regard full separate component file as YAML.COMPONENT
+ return DataTransferImportComponent(
+ base_path=self.context[BASE_PATH_CONTEXT_KEY],
+ _source=kwargs.pop("_source", ComponentSource.YAML_JOB),
+ **data,
+ )
+
+
+# pylint: disable-next=name-too-long
+class AnonymousDataTransferExportComponentSchema(AnonymousAssetSchema, DataTransferExportComponentSchema):
+ """Anonymous data transfer export component schema.
+
+ Note inheritance follows order: AnonymousAssetSchema,
+ DataTransferExportComponentSchema because we need name and version to be
+ dump_only(marshmallow collects fields follows method resolution
+ order).
+ """
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._component.datatransfer_component import DataTransferExportComponent
+
+ # Inline component will have source=YAML.JOB
+ # As we only regard full separate component file as YAML.COMPONENT
+ return DataTransferExportComponent(
+ base_path=self.context[BASE_PATH_CONTEXT_KEY],
+ _source=kwargs.pop("_source", ComponentSource.YAML_JOB),
+ **data,
+ )
+
+
+class DataTransferCopyComponentFileRefField(FileRefField):
+ def _deserialize(self, value, attr, data, **kwargs):
+ # Get component info from component yaml file.
+ data = super()._deserialize(value, attr, data, **kwargs)
+ component_dict = yaml.safe_load(data)
+ source_path = self.context[BASE_PATH_CONTEXT_KEY] / value
+
+ # Update base_path to parent path of component file.
+ component_schema_context = deepcopy(self.context)
+ component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent
+ component = AnonymousDataTransferCopyComponentSchema(context=component_schema_context).load(
+ component_dict, unknown=INCLUDE
+ )
+ component._source_path = source_path
+ component._source = ComponentSource.YAML_COMPONENT
+ return component
+
+
+class DataTransferImportComponentFileRefField(FileRefField):
+ def _deserialize(self, value, attr, data, **kwargs):
+ # Get component info from component yaml file.
+ data = super()._deserialize(value, attr, data, **kwargs)
+ component_dict = yaml.safe_load(data)
+ source_path = self.context[BASE_PATH_CONTEXT_KEY] / value
+
+ # Update base_path to parent path of component file.
+ component_schema_context = deepcopy(self.context)
+ component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent
+ component = AnonymousDataTransferImportComponentSchema(context=component_schema_context).load(
+ component_dict, unknown=INCLUDE
+ )
+ component._source_path = source_path
+ component._source = ComponentSource.YAML_COMPONENT
+ return component
+
+
+class DataTransferExportComponentFileRefField(FileRefField):
+ def _deserialize(self, value, attr, data, **kwargs):
+ # Get component info from component yaml file.
+ data = super()._deserialize(value, attr, data, **kwargs)
+ component_dict = yaml.safe_load(data)
+ source_path = self.context[BASE_PATH_CONTEXT_KEY] / value
+
+ # Update base_path to parent path of component file.
+ component_schema_context = deepcopy(self.context)
+ component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent
+ component = AnonymousDataTransferExportComponentSchema(context=component_schema_context).load(
+ component_dict, unknown=INCLUDE
+ )
+ component._source_path = source_path
+ component._source = ComponentSource.YAML_COMPONENT
+ return component
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/flow.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/flow.py
new file mode 100644
index 00000000..848220d3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/flow.py
@@ -0,0 +1,107 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields
+
+from azure.ai.ml._schema import YamlFileSchema
+from azure.ai.ml._schema.component import ComponentSchema
+from azure.ai.ml._schema.component.component import ComponentNameStr
+from azure.ai.ml._schema.core.fields import (
+ ArmVersionedStr,
+ EnvironmentField,
+ LocalPathField,
+ NestedField,
+ StringTransformedEnum,
+ UnionField,
+)
+from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
+from azure.ai.ml.constants._common import AzureMLResourceType
+from azure.ai.ml.constants._component import NodeType
+
+
+class _ComponentMetadataSchema(metaclass=PatchedSchemaMeta):
+ """Schema to recognize metadata of a flow as a component."""
+
+ name = ComponentNameStr()
+ version = fields.Str()
+ display_name = fields.Str()
+ description = fields.Str()
+ tags = fields.Dict(keys=fields.Str(), values=fields.Str())
+
+
+class _FlowAttributesSchema(metaclass=PatchedSchemaMeta):
+ """Schema to recognize attributes of a flow."""
+
+ variant = fields.Str()
+ column_mappings = fields.Dict(
+ fields.Str(),
+ fields.Str(),
+ )
+ connections = fields.Dict(
+ keys=fields.Str(),
+ values=fields.Dict(
+ keys=fields.Str(),
+ values=fields.Str(),
+ ),
+ )
+ environment_variables = fields.Dict(
+ fields.Str(),
+ fields.Str(),
+ )
+
+
+class _FLowComponentOverridesSchema(metaclass=PatchedSchemaMeta):
+ environment = EnvironmentField()
+ is_deterministic = fields.Bool()
+
+
+class _FlowComponentOverridableSchema(metaclass=PatchedSchemaMeta):
+ # the field name must be the same as azure.ai.ml.constants._common.PROMPTFLOW_AZUREML_OVERRIDE_KEY
+ azureml = NestedField(_FLowComponentOverridesSchema)
+
+
+class FlowSchema(YamlFileSchema, _ComponentMetadataSchema, _FlowComponentOverridableSchema):
+ """Schema for flow.dag.yaml file."""
+
+ environment_variables = fields.Dict(
+ fields.Str(),
+ fields.Str(),
+ )
+ additional_includes = fields.List(LocalPathField())
+
+
+class RunSchema(YamlFileSchema, _ComponentMetadataSchema, _FlowAttributesSchema, _FlowComponentOverridableSchema):
+ """Schema for run.yaml file."""
+
+ flow = LocalPathField(required=True)
+
+
+class FlowComponentSchema(ComponentSchema, _FlowAttributesSchema, _FLowComponentOverridesSchema):
+ """FlowSchema and FlowRunSchema are used to load flow while FlowComponentSchema is used to dump flow."""
+
+ class Meta:
+ """Override this to exclude inputs & outputs as component doesn't have them."""
+
+ exclude = ["inputs", "outputs"] # component doesn't have inputs & outputs
+
+ # TODO: name should be required?
+ name = ComponentNameStr()
+
+ type = StringTransformedEnum(allowed_values=[NodeType.FLOW_PARALLEL], required=True)
+
+ # name, version, tags, display_name and is_deterministic are inherited from ComponentSchema
+ properties = fields.Dict(
+ fields.Str(),
+ fields.Str(),
+ )
+
+ # this is different from regular CodeField
+ code = UnionField(
+ [
+ LocalPathField(),
+ ArmVersionedStr(azureml_type=AzureMLResourceType.CODE),
+ ],
+ metadata={"description": "A local path or http:, https:, azureml: url pointing to a remote location."},
+ )
+ additional_includes = fields.List(LocalPathField(), load_only=True)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/import_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/import_component.py
new file mode 100644
index 00000000..b0ec14ea
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/import_component.py
@@ -0,0 +1,74 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+# pylint: disable=protected-access
+from copy import deepcopy
+
+import yaml
+from marshmallow import INCLUDE, fields, post_load, validate
+
+from azure.ai.ml._schema.assets.asset import AnonymousAssetSchema
+from azure.ai.ml._schema.component.component import ComponentSchema
+from azure.ai.ml._schema.component.input_output import OutputPortSchema, ParameterSchema
+from azure.ai.ml._schema.core.fields import FileRefField, NestedField, StringTransformedEnum
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+from azure.ai.ml.constants._component import ComponentSource, NodeType
+
+
+class ImportComponentSchema(ComponentSchema):
+ class Meta:
+ exclude = ["inputs", "outputs"] # inputs or outputs property not applicable to import job
+
+ type = StringTransformedEnum(allowed_values=[NodeType.IMPORT])
+ source = fields.Dict(
+ keys=fields.Str(validate=validate.OneOf(["type", "connection", "query", "path"])),
+ values=NestedField(ParameterSchema),
+ required=True,
+ )
+
+ output = NestedField(OutputPortSchema, required=True)
+
+
+class RestCommandComponentSchema(ImportComponentSchema):
+ """When component load from rest, won't validate on name since there might be existing component with invalid
+ name."""
+
+ name = fields.Str(required=True)
+
+
+class AnonymousImportComponentSchema(AnonymousAssetSchema, ImportComponentSchema):
+ """Anonymous command component schema.
+
+ Note inheritance follows order: AnonymousAssetSchema, CommandComponentSchema because we need name and version to be
+ dump_only(marshmallow collects fields follows method resolution order).
+ """
+
+ @post_load
+ def make(self, data, **kwargs): # pylint: disable=unused-argument
+ from azure.ai.ml.entities._component.import_component import ImportComponent
+
+ # Inline component will have source=YAML.JOB
+ # As we only regard full separate component file as YAML.COMPONENT
+ return ImportComponent(
+ base_path=self.context[BASE_PATH_CONTEXT_KEY],
+ _source=ComponentSource.YAML_JOB,
+ **data,
+ )
+
+
+class ImportComponentFileRefField(FileRefField):
+ def _deserialize(self, value, attr, data, **kwargs):
+ # Get component info from component yaml file.
+ data = super()._deserialize(value, attr, data, **kwargs)
+ component_dict = yaml.safe_load(data)
+ source_path = self.context[BASE_PATH_CONTEXT_KEY] / value
+
+ # Update base_path to parent path of component file.
+ component_schema_context = deepcopy(self.context)
+ component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent
+ component = AnonymousImportComponentSchema(context=component_schema_context).load(
+ component_dict, unknown=INCLUDE
+ )
+ component._source_path = source_path
+ component._source = ComponentSource.YAML_COMPONENT
+ return component
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/input_output.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/input_output.py
new file mode 100644
index 00000000..9fef9489
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/input_output.py
@@ -0,0 +1,126 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import INCLUDE, fields, pre_dump
+
+from azure.ai.ml._schema.core.fields import DumpableEnumField, ExperimentalField, NestedField, UnionField
+from azure.ai.ml._schema.core.intellectual_property import ProtectionLevelSchema
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml._utils.utils import is_private_preview_enabled
+from azure.ai.ml.constants._common import AssetTypes, InputOutputModes, LegacyAssetTypes
+from azure.ai.ml.constants._component import ComponentParameterTypes
+
+# Here we use an adhoc way to collect all class constant attributes by checking if it's upper letter
+# because making those constants enum will fail in string serialization in marshmallow
+asset_type_obj = AssetTypes()
+SUPPORTED_PORT_TYPES = [LegacyAssetTypes.PATH] + [
+ getattr(asset_type_obj, k) for k in dir(asset_type_obj) if k.isupper()
+]
+param_obj = ComponentParameterTypes()
+SUPPORTED_PARAM_TYPES = [getattr(param_obj, k) for k in dir(param_obj) if k.isupper()]
+
+input_output_type_obj = InputOutputModes()
+# Link mode is only supported in component level currently
+SUPPORTED_INPUT_OUTPUT_MODES = [
+ getattr(input_output_type_obj, k) for k in dir(input_output_type_obj) if k.isupper()
+] + ["link"]
+
+
+class InputPortSchema(metaclass=PatchedSchemaMeta):
+ type = DumpableEnumField(
+ allowed_values=SUPPORTED_PORT_TYPES,
+ required=True,
+ )
+ description = fields.Str()
+ optional = fields.Bool()
+ default = fields.Str()
+ mode = DumpableEnumField(
+ allowed_values=SUPPORTED_INPUT_OUTPUT_MODES,
+ )
+ # hide in private preview
+ if is_private_preview_enabled():
+ # only protection_level is allowed for inputs
+ intellectual_property = ExperimentalField(NestedField(ProtectionLevelSchema))
+
+ @pre_dump
+ def add_private_fields_to_dump(self, data, **kwargs): # pylint: disable=unused-argument
+ # The ipp field is set on the output object as "_intellectual_property".
+ # We need to set it as "intellectual_property" before dumping so that Marshmallow
+ # can pick up the field correctly on dump and show it back to the user.
+ if hasattr(data, "_intellectual_property"):
+ ipp_field = data._intellectual_property # pylint: disable=protected-access
+ if ipp_field:
+ setattr(data, "intellectual_property", ipp_field)
+ return data
+
+
+class OutputPortSchema(metaclass=PatchedSchemaMeta):
+ type = DumpableEnumField(
+ allowed_values=SUPPORTED_PORT_TYPES,
+ required=True,
+ )
+ description = fields.Str()
+ mode = DumpableEnumField(
+ allowed_values=SUPPORTED_INPUT_OUTPUT_MODES,
+ )
+ # hide in private preview
+ if is_private_preview_enabled():
+ # only protection_level is allowed for outputs
+ intellectual_property = ExperimentalField(NestedField(ProtectionLevelSchema))
+
+ @pre_dump
+ def add_private_fields_to_dump(self, data, **kwargs): # pylint: disable=unused-argument
+ # The ipp field is set on the output object as "_intellectual_property".
+ # We need to set it as "intellectual_property" before dumping so that Marshmallow
+ # can pick up the field correctly on dump and show it back to the user.
+ if hasattr(data, "_intellectual_property"):
+ ipp_field = data._intellectual_property # pylint: disable=protected-access
+ if ipp_field:
+ setattr(data, "intellectual_property", ipp_field)
+ return data
+
+
+class PrimitiveOutputSchema(OutputPortSchema):
+ # Note: according to marshmallow doc on Handling Unknown Fields:
+ # https://marshmallow.readthedocs.io/en/stable/quickstart.html#handling-unknown-fields
+ # specify unknown at instantiation time will not take effect;
+ # still add here just for explicitly declare this behavior:
+ # primitive type output used in environment that private preview flag is not enabled.
+ class Meta:
+ unknown = INCLUDE
+
+ type = DumpableEnumField(
+ allowed_values=SUPPORTED_PARAM_TYPES,
+ required=True,
+ )
+ # hide early_available in spec
+ if is_private_preview_enabled():
+ early_available = fields.Bool()
+
+ # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype
+ def _serialize(self, obj, *, many: bool = False):
+ """Override to add private preview hidden fields
+
+ :keyword many: Whether obj is a collection of objects.
+ :paramtype many: bool
+ """
+ from azure.ai.ml.entities._job.pipeline._attr_dict import has_attr_safe
+
+ ret = super()._serialize(obj, many=many) # pylint: disable=no-member
+ if has_attr_safe(obj, "early_available") and obj.early_available is not None and "early_available" not in ret:
+ ret["early_available"] = obj.early_available
+ return ret
+
+
+class ParameterSchema(metaclass=PatchedSchemaMeta):
+ type = DumpableEnumField(
+ allowed_values=SUPPORTED_PARAM_TYPES,
+ required=True,
+ )
+ optional = fields.Bool()
+ default = UnionField([fields.Str(), fields.Number(), fields.Bool()])
+ description = fields.Str()
+ max = UnionField([fields.Str(), fields.Number()])
+ min = UnionField([fields.Str(), fields.Number()])
+ enum = fields.List(fields.Str())
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_component.py
new file mode 100644
index 00000000..70f286a9
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_component.py
@@ -0,0 +1,108 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from copy import deepcopy
+
+import yaml
+from marshmallow import INCLUDE, fields, post_load
+
+from azure.ai.ml._schema.assets.asset import AnonymousAssetSchema
+from azure.ai.ml._schema.component.component import ComponentSchema
+from azure.ai.ml._schema.component.parallel_task import ComponentParallelTaskSchema
+from azure.ai.ml._schema.component.resource import ComponentResourceSchema
+from azure.ai.ml._schema.component.retry_settings import RetrySettingsSchema
+from azure.ai.ml._schema.core.fields import DumpableEnumField, FileRefField, NestedField, StringTransformedEnum
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, LoggingLevel
+from azure.ai.ml.constants._component import ComponentSource, NodeType
+
+
+class ParallelComponentSchema(ComponentSchema):
+ type = StringTransformedEnum(allowed_values=[NodeType.PARALLEL], required=True)
+ resources = NestedField(ComponentResourceSchema, unknown=INCLUDE)
+ logging_level = DumpableEnumField(
+ allowed_values=[LoggingLevel.DEBUG, LoggingLevel.INFO, LoggingLevel.WARN],
+ dump_default=LoggingLevel.INFO,
+ metadata={
+ "description": "A string of the logging level name, which is defined in 'logging'. \
+ Possible values are 'WARNING', 'INFO', and 'DEBUG'."
+ },
+ )
+ task = NestedField(ComponentParallelTaskSchema, unknown=INCLUDE)
+ mini_batch_size = fields.Str(
+ metadata={"description": "The The batch size of current job."},
+ )
+ partition_keys = fields.List(
+ fields.Str(), metadata={"description": "The keys used to partition input data into mini-batches"}
+ )
+
+ input_data = fields.Str()
+ retry_settings = NestedField(RetrySettingsSchema, unknown=INCLUDE)
+ max_concurrency_per_instance = fields.Integer(
+ dump_default=1,
+ metadata={"description": "The max parallellism that each compute instance has."},
+ )
+ error_threshold = fields.Integer(
+ dump_default=-1,
+ metadata={
+ "description": "The number of item processing failures should be ignored. \
+ If the error_threshold is reached, the job terminates. \
+ For a list of files as inputs, one item means one file reference. \
+ This setting doesn't apply to command parallelization."
+ },
+ )
+ mini_batch_error_threshold = fields.Integer(
+ dump_default=-1,
+ metadata={
+ "description": "The number of mini batch processing failures should be ignored. \
+ If the mini_batch_error_threshold is reached, the job terminates. \
+ For a list of files as inputs, one item means one file reference. \
+ This setting can be used by either command or python function parallelization. \
+ Only one error_threshold setting can be used in one job."
+ },
+ )
+
+
+class RestParallelComponentSchema(ParallelComponentSchema):
+ """When component load from rest, won't validate on name since there might be existing component with invalid
+ name."""
+
+ name = fields.Str(required=True)
+
+
+class AnonymousParallelComponentSchema(AnonymousAssetSchema, ParallelComponentSchema):
+ """Anonymous parallel component schema.
+
+ Note inheritance follows order: AnonymousAssetSchema, ParallelComponentSchema because we need name and version to be
+ dump_only(marshmallow collects fields follows method resolution order).
+ """
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._component.parallel_component import ParallelComponent
+
+ return ParallelComponent(
+ base_path=self.context[BASE_PATH_CONTEXT_KEY],
+ _source=kwargs.pop("_source", ComponentSource.YAML_JOB),
+ **data,
+ )
+
+
+class ParallelComponentFileRefField(FileRefField):
+ def _deserialize(self, value, attr, data, **kwargs):
+ # Get component info from component yaml file.
+ data = super()._deserialize(value, attr, data, **kwargs)
+ component_dict = yaml.safe_load(data)
+ source_path = self.context[BASE_PATH_CONTEXT_KEY] / value
+
+ # Update base_path to parent path of component file.
+ component_schema_context = deepcopy(self.context)
+ component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent
+ component = AnonymousParallelComponentSchema(context=component_schema_context).load(
+ component_dict, unknown=INCLUDE
+ )
+ component._source_path = source_path
+ component._source = ComponentSource.YAML_COMPONENT
+ return component
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_task.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_task.py
new file mode 100644
index 00000000..390a6683
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_task.py
@@ -0,0 +1,23 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+
+from marshmallow import fields
+
+from azure.ai.ml._schema.core.fields import CodeField, EnvironmentField, StringTransformedEnum
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml.constants import ParallelTaskType
+
+
+class ComponentParallelTaskSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(
+ allowed_values=[ParallelTaskType.RUN_FUNCTION, ParallelTaskType.MODEL, ParallelTaskType.FUNCTION],
+ required=True,
+ )
+ code = CodeField()
+ entry_script = fields.Str()
+ program_arguments = fields.Str()
+ model = fields.Str()
+ append_row_to = fields.Str()
+ environment = EnvironmentField(required=True)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/resource.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/resource.py
new file mode 100644
index 00000000..592d740c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/resource.py
@@ -0,0 +1,22 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import INCLUDE, post_dump, post_load
+
+from azure.ai.ml._schema.job_resource_configuration import JobResourceConfigurationSchema
+
+
+class ComponentResourceSchema(JobResourceConfigurationSchema):
+ class Meta:
+ unknown = INCLUDE
+
+ @post_load
+ def make(self, data, **kwargs):
+ return data
+
+ @post_dump(pass_original=True)
+ def dump_override(self, data, original, **kwargs):
+ return original
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/retry_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/retry_settings.py
new file mode 100644
index 00000000..bac2c54d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/retry_settings.py
@@ -0,0 +1,13 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields
+
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml._schema.core.fields import DataBindingStr, UnionField
+
+
+class RetrySettingsSchema(metaclass=PatchedSchemaMeta):
+ timeout = UnionField([fields.Int(), DataBindingStr])
+ max_retries = UnionField([fields.Int(), DataBindingStr])
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/spark_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/spark_component.py
new file mode 100644
index 00000000..445481ec
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/spark_component.py
@@ -0,0 +1,79 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument,protected-access
+
+from copy import deepcopy
+
+import yaml
+from marshmallow import INCLUDE, fields, post_dump, post_load
+
+from azure.ai.ml._schema.assets.asset import AnonymousAssetSchema
+from azure.ai.ml._schema.component.component import ComponentSchema
+from azure.ai.ml._schema.core.fields import FileRefField, StringTransformedEnum
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+from azure.ai.ml.constants._component import ComponentSource, NodeType
+
+from ..job.parameterized_spark import ParameterizedSparkSchema
+
+
+class SparkComponentSchema(ComponentSchema, ParameterizedSparkSchema):
+ type = StringTransformedEnum(allowed_values=[NodeType.SPARK])
+ additional_includes = fields.List(fields.Str())
+
+ @post_dump
+ def remove_unnecessary_fields(self, component_schema_dict, **kwargs):
+ if (
+ component_schema_dict.get("additional_includes") is not None
+ and len(component_schema_dict["additional_includes"]) == 0
+ ):
+ component_schema_dict.pop("additional_includes")
+ return component_schema_dict
+
+
+class RestSparkComponentSchema(SparkComponentSchema):
+ """When component load from rest, won't validate on name since there might
+ be existing component with invalid name."""
+
+ name = fields.Str(required=True)
+
+
+class AnonymousSparkComponentSchema(AnonymousAssetSchema, SparkComponentSchema):
+ """Anonymous spark component schema.
+
+ Note inheritance follows order: AnonymousAssetSchema,
+ SparkComponentSchema because we need name and version to be
+ dump_only(marshmallow collects fields follows method resolution
+ order).
+ """
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._component.spark_component import SparkComponent
+
+ # Inline component will have source=YAML.JOB
+ # As we only regard full separate component file as YAML.COMPONENT
+ return SparkComponent(
+ base_path=self.context[BASE_PATH_CONTEXT_KEY],
+ _source=kwargs.pop("_source", ComponentSource.YAML_JOB),
+ **data,
+ )
+
+
+class SparkComponentFileRefField(FileRefField):
+ def _deserialize(self, value, attr, data, **kwargs):
+ # Get component info from component yaml file.
+ data = super()._deserialize(value, attr, data, **kwargs)
+ component_dict = yaml.safe_load(data)
+ source_path = self.context[BASE_PATH_CONTEXT_KEY] / value
+
+ # Update base_path to parent path of component file.
+ component_schema_context = deepcopy(self.context)
+ component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent
+ component = AnonymousSparkComponentSchema(context=component_schema_context).load(
+ component_dict, unknown=INCLUDE
+ )
+ component._source_path = source_path
+ component._source = ComponentSource.YAML_COMPONENT
+ return component
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/__init__.py
new file mode 100644
index 00000000..29a4fcd3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/aml_compute.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/aml_compute.py
new file mode 100644
index 00000000..304b0eae
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/aml_compute.py
@@ -0,0 +1,47 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields
+from marshmallow.decorators import post_load
+
+from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
+from azure.ai.ml.constants._compute import ComputeTier, ComputeType, ComputeSizeTier
+
+from ..core.fields import NestedField, StringTransformedEnum, UnionField
+from .compute import ComputeSchema, IdentitySchema, NetworkSettingsSchema
+
+
+class AmlComputeSshSettingsSchema(metaclass=PatchedSchemaMeta):
+ admin_username = fields.Str()
+ admin_password = fields.Str()
+ ssh_key_value = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import AmlComputeSshSettings
+
+ return AmlComputeSshSettings(**data)
+
+
+class AmlComputeSchema(ComputeSchema):
+ type = StringTransformedEnum(allowed_values=[ComputeType.AMLCOMPUTE], required=True)
+ size = UnionField(
+ union_fields=[
+ fields.Str(metadata={"arm_type": ComputeSizeTier.AML_COMPUTE_DEDICATED, "tier": ComputeTier.DEDICATED}),
+ fields.Str(metadata={"arm_type": ComputeSizeTier.AML_COMPUTE_LOWPRIORITY, "tier": ComputeTier.LOWPRIORITY}),
+ ],
+ )
+ tier = StringTransformedEnum(allowed_values=[ComputeTier.LOWPRIORITY, ComputeTier.DEDICATED])
+ min_instances = fields.Int()
+ max_instances = fields.Int()
+ idle_time_before_scale_down = fields.Int()
+ ssh_public_access_enabled = fields.Bool()
+ ssh_settings = NestedField(AmlComputeSshSettingsSchema)
+ network_settings = NestedField(NetworkSettingsSchema)
+ identity = NestedField(IdentitySchema)
+ enable_node_public_ip = fields.Bool(
+ metadata={"description": "Enable or disable node public IP address provisioning."}
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/aml_compute_node_info.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/aml_compute_node_info.py
new file mode 100644
index 00000000..983f76f6
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/aml_compute_node_info.py
@@ -0,0 +1,15 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from marshmallow import fields
+
+from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
+
+
+class AmlComputeNodeInfoSchema(metaclass=PatchedSchemaMeta):
+ node_id = fields.Str()
+ private_ip_address = fields.Str()
+ public_ip_address = fields.Str()
+ port = fields.Str()
+ node_state = fields.Str()
+ current_job_name = fields.Str()
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/attached_compute.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/attached_compute.py
new file mode 100644
index 00000000..2ac4ce9e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/attached_compute.py
@@ -0,0 +1,12 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from marshmallow import fields
+
+from .compute import ComputeSchema
+
+
+class AttachedComputeSchema(ComputeSchema):
+ resource_id = fields.Str(required=True)
+ ssh_port = fields.Int()
+ compute_location = fields.Str()
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/compute.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/compute.py
new file mode 100644
index 00000000..4488b53d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/compute.py
@@ -0,0 +1,85 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+from marshmallow import fields
+from marshmallow.decorators import post_load
+
+from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml._vendor.azure_resources.models._resource_management_client_enums import ResourceIdentityType
+from azure.ai.ml.entities._credentials import ManagedIdentityConfiguration
+
+from ..core.schema import PathAwareSchema
+
+
+class ComputeSchema(PathAwareSchema):
+ name = fields.Str(required=True)
+ id = fields.Str(dump_only=True)
+ type = fields.Str()
+ location = fields.Str()
+ description = fields.Str()
+ provisioning_errors = fields.Str(dump_only=True)
+ created_on = fields.Str(dump_only=True)
+ provisioning_state = fields.Str(dump_only=True)
+ resource_id = fields.Str()
+ tags = fields.Dict(keys=fields.Str(), values=fields.Str())
+
+
+class NetworkSettingsSchema(PathAwareSchema):
+ vnet_name = fields.Str()
+ subnet = fields.Str()
+ public_ip_address = fields.Str(dump_only=True)
+ private_ip_address = fields.Str(dump_only=True)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import NetworkSettings
+
+ return NetworkSettings(**data)
+
+
+class UserAssignedIdentitySchema(PathAwareSchema):
+ resource_id = fields.Str()
+ principal_id = fields.Str(dump_only=True)
+ client_id = fields.Str(dump_only=True)
+ tenant_id = fields.Str(dump_only=True)
+
+ @post_load
+ def make(self, data, **kwargs):
+ return ManagedIdentityConfiguration(**data)
+
+
+class IdentitySchema(PathAwareSchema):
+ type = StringTransformedEnum(
+ allowed_values=[
+ ResourceIdentityType.SYSTEM_ASSIGNED,
+ ResourceIdentityType.USER_ASSIGNED,
+ ResourceIdentityType.NONE,
+ ResourceIdentityType.SYSTEM_ASSIGNED_USER_ASSIGNED,
+ ],
+ casing_transform=camel_to_snake,
+ metadata={"description": "resource identity type."},
+ )
+ user_assigned_identities = fields.List(NestedField(UserAssignedIdentitySchema))
+ principal_id = fields.Str(dump_only=True)
+ tenant_id = fields.Str(dump_only=True)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import IdentityConfiguration
+
+ user_assigned_identities_list = []
+ user_assigned_identities = data.pop("user_assigned_identities", None)
+ if user_assigned_identities:
+ for identity in user_assigned_identities:
+ user_assigned_identities_list.append(
+ ManagedIdentityConfiguration(
+ resource_id=identity.get("resource_id", None),
+ client_id=identity.get("client_id", None),
+ object_id=identity.get("object_id", None),
+ )
+ )
+ data["user_assigned_identities"] = user_assigned_identities_list
+ return IdentityConfiguration(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/compute_instance.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/compute_instance.py
new file mode 100644
index 00000000..c72e06bb
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/compute_instance.py
@@ -0,0 +1,83 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields
+from marshmallow.decorators import post_load
+
+# pylint: disable=unused-argument
+from azure.ai.ml._schema import PathAwareSchema
+from azure.ai.ml.constants._compute import ComputeType, ComputeSizeTier
+
+from ..core.fields import ExperimentalField, NestedField, StringTransformedEnum
+from .compute import ComputeSchema, IdentitySchema, NetworkSettingsSchema
+from .schedule import ComputeSchedulesSchema
+from .setup_scripts import SetupScriptsSchema
+from .custom_applications import CustomApplicationsSchema
+
+
+class ComputeInstanceSshSettingsSchema(PathAwareSchema):
+ admin_username = fields.Str(dump_only=True)
+ ssh_port = fields.Str(dump_only=True)
+ ssh_key_value = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import ComputeInstanceSshSettings
+
+ return ComputeInstanceSshSettings(**data)
+
+
+class CreateOnBehalfOfSchema(PathAwareSchema):
+ user_tenant_id = fields.Str()
+ user_object_id = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import AssignedUserConfiguration
+
+ return AssignedUserConfiguration(**data)
+
+
+class OsImageMetadataSchema(PathAwareSchema):
+ is_latest_os_image_version = fields.Bool(dump_only=True)
+ current_image_version = fields.Str(dump_only=True)
+ latest_image_version = fields.Str(dump_only=True)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import ImageMetadata
+
+ return ImageMetadata(**data)
+
+
+class ComputeInstanceSchema(ComputeSchema):
+ type = StringTransformedEnum(allowed_values=[ComputeType.COMPUTEINSTANCE], required=True)
+ size = fields.Str(metadata={"arm_type": ComputeSizeTier.COMPUTE_INSTANCE})
+ network_settings = NestedField(NetworkSettingsSchema)
+ create_on_behalf_of = NestedField(CreateOnBehalfOfSchema)
+ ssh_settings = NestedField(ComputeInstanceSshSettingsSchema)
+ ssh_public_access_enabled = fields.Bool(dump_default=None)
+ state = fields.Str(dump_only=True)
+ last_operation = fields.Dict(keys=fields.Str(), values=fields.Str(), dump_only=True)
+ services = fields.List(fields.Dict(keys=fields.Str(), values=fields.Str()), dump_only=True)
+ schedules = NestedField(ComputeSchedulesSchema)
+ identity = ExperimentalField(NestedField(IdentitySchema))
+ idle_time_before_shutdown = fields.Str()
+ idle_time_before_shutdown_minutes = fields.Int()
+ custom_applications = fields.List(NestedField(CustomApplicationsSchema))
+ setup_scripts = NestedField(SetupScriptsSchema)
+ os_image_metadata = NestedField(OsImageMetadataSchema, dump_only=True)
+ enable_node_public_ip = fields.Bool(
+ metadata={"description": "Enable or disable node public IP address provisioning."}
+ )
+ enable_sso = fields.Bool(metadata={"description": "Enable or disable single sign-on for the compute instance."})
+ enable_root_access = fields.Bool(
+ metadata={"description": "Enable or disable root access for the compute instance."}
+ )
+ release_quota_on_stop = fields.Bool(
+ metadata={"description": "Release quota on stop for the compute instance. Defaults to False."}
+ )
+ enable_os_patching = fields.Bool(
+ metadata={"description": "Enable or disable OS patching for the compute instance. Defaults to False."}
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/custom_applications.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/custom_applications.py
new file mode 100644
index 00000000..66fa587c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/custom_applications.py
@@ -0,0 +1,60 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+from marshmallow import fields
+from marshmallow.decorators import post_load
+
+from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum
+from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
+from azure.ai.ml.constants._compute import CustomApplicationDefaults
+
+
+class ImageSettingsSchema(metaclass=PatchedSchemaMeta):
+ reference = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._compute._custom_applications import ImageSettings
+
+ return ImageSettings(**data)
+
+
+class EndpointsSettingsSchema(metaclass=PatchedSchemaMeta):
+ target = fields.Int()
+ published = fields.Int()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._compute._custom_applications import EndpointsSettings
+
+ return EndpointsSettings(**data)
+
+
+class VolumeSettingsSchema(metaclass=PatchedSchemaMeta):
+ source = fields.Str()
+ target = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._compute._custom_applications import VolumeSettings
+
+ return VolumeSettings(**data)
+
+
+class CustomApplicationsSchema(metaclass=PatchedSchemaMeta):
+ name = fields.Str(required=True)
+ type = StringTransformedEnum(allowed_values=[CustomApplicationDefaults.DOCKER])
+ image = NestedField(ImageSettingsSchema)
+ endpoints = fields.List(NestedField(EndpointsSettingsSchema))
+ environment_variables = fields.Dict()
+ bind_mounts = fields.List(NestedField(VolumeSettingsSchema))
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._compute._custom_applications import (
+ CustomApplications,
+ )
+
+ return CustomApplications(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/kubernetes_compute.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/kubernetes_compute.py
new file mode 100644
index 00000000..a84102ca
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/kubernetes_compute.py
@@ -0,0 +1,16 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from marshmallow import fields
+
+from azure.ai.ml.constants._compute import ComputeType
+
+from ..core.fields import NestedField, StringTransformedEnum
+from .compute import ComputeSchema, IdentitySchema
+
+
+class KubernetesComputeSchema(ComputeSchema):
+ type = StringTransformedEnum(allowed_values=[ComputeType.KUBERNETES], required=True)
+ namespace = fields.Str(required=True, dump_default="default")
+ properties = fields.Dict()
+ identity = NestedField(IdentitySchema)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/schedule.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/schedule.py
new file mode 100644
index 00000000..49f41edf
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/schedule.py
@@ -0,0 +1,118 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields
+from marshmallow.decorators import post_load
+
+from azure.ai.ml._restclient.v2022_10_01_preview.models import ComputePowerAction, RecurrenceFrequency
+from azure.ai.ml._restclient.v2022_10_01_preview.models import ScheduleStatus as ScheduleState
+from azure.ai.ml._restclient.v2022_10_01_preview.models import TriggerType, WeekDay
+from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum, UnionField
+from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
+
+
+class BaseTriggerSchema(metaclass=PatchedSchemaMeta):
+ start_time = fields.Str()
+ time_zone = fields.Str()
+
+
+class CronTriggerSchema(BaseTriggerSchema):
+ type = StringTransformedEnum(required=True, allowed_values=TriggerType.CRON)
+ expression = fields.Str(required=True)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import CronTrigger
+
+ data.pop("type")
+ return CronTrigger(**data)
+
+
+class RecurrenceScheduleSchema(metaclass=PatchedSchemaMeta):
+ week_days = fields.List(
+ StringTransformedEnum(
+ allowed_values=[
+ WeekDay.SUNDAY,
+ WeekDay.MONDAY,
+ WeekDay.TUESDAY,
+ WeekDay.WEDNESDAY,
+ WeekDay.THURSDAY,
+ WeekDay.FRIDAY,
+ WeekDay.SATURDAY,
+ ],
+ )
+ )
+ hours = fields.List(fields.Int())
+ minutes = fields.List(fields.Int())
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import RecurrencePattern
+
+ return RecurrencePattern(**data)
+
+
+class RecurrenceTriggerSchema(BaseTriggerSchema):
+ type = StringTransformedEnum(required=True, allowed_values=TriggerType.RECURRENCE)
+ frequency = StringTransformedEnum(
+ required=True,
+ allowed_values=[
+ RecurrenceFrequency.MINUTE,
+ RecurrenceFrequency.HOUR,
+ RecurrenceFrequency.DAY,
+ RecurrenceFrequency.WEEK,
+ RecurrenceFrequency.MONTH,
+ ],
+ )
+ interval = fields.Int()
+ schedule = NestedField(RecurrenceScheduleSchema)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import RecurrenceTrigger
+
+ data.pop("type")
+ return RecurrenceTrigger(**data)
+
+
+class ComputeStartStopScheduleSchema(metaclass=PatchedSchemaMeta):
+ trigger = UnionField(
+ [
+ NestedField(CronTriggerSchema()),
+ NestedField(RecurrenceTriggerSchema()),
+ ],
+ )
+ action = StringTransformedEnum(
+ required=True,
+ allowed_values=[
+ ComputePowerAction.START,
+ ComputePowerAction.STOP,
+ ],
+ )
+ state = StringTransformedEnum(
+ allowed_values=[
+ ScheduleState.ENABLED,
+ ScheduleState.DISABLED,
+ ],
+ )
+ schedule_id = fields.Str(dump_only=True)
+ provisioning_state = fields.Str(dump_only=True)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import ComputeStartStopSchedule
+
+ return ComputeStartStopSchedule(**data)
+
+
+class ComputeSchedulesSchema(metaclass=PatchedSchemaMeta):
+ compute_start_stop = fields.List(NestedField(ComputeStartStopScheduleSchema))
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import ComputeSchedules
+
+ return ComputeSchedules(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/setup_scripts.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/setup_scripts.py
new file mode 100644
index 00000000..da3f3c14
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/setup_scripts.py
@@ -0,0 +1,33 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+from marshmallow import fields
+from marshmallow.decorators import post_load
+
+from azure.ai.ml._schema.core.fields import NestedField
+from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
+
+
+class ScriptReferenceSchema(metaclass=PatchedSchemaMeta):
+ path = fields.Str()
+ command = fields.Str()
+ timeout_minutes = fields.Int()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._compute._setup_scripts import ScriptReference
+
+ return ScriptReference(**data)
+
+
+class SetupScriptsSchema(metaclass=PatchedSchemaMeta):
+ creation_script = NestedField(ScriptReferenceSchema())
+ startup_script = NestedField(ScriptReferenceSchema())
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._compute._setup_scripts import SetupScripts
+
+ return SetupScripts(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/synapsespark_compute.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/synapsespark_compute.py
new file mode 100644
index 00000000..11760186
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/synapsespark_compute.py
@@ -0,0 +1,49 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields
+from marshmallow.decorators import post_load
+
+from azure.ai.ml.constants._compute import ComputeType
+
+from ..core.fields import NestedField, StringTransformedEnum
+from ..core.schema import PathAwareSchema
+from .compute import ComputeSchema, IdentitySchema
+
+
+class AutoScaleSettingsSchema(PathAwareSchema):
+ min_node_count = fields.Int(dump_only=True)
+ max_node_count = fields.Int(dump_only=True)
+ auto_scale_enabled = fields.Bool(dump_only=True)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import AutoScaleSettings
+
+ return AutoScaleSettings(**data)
+
+
+class AutoPauseSettingsSchema(PathAwareSchema):
+ delay_in_minutes = fields.Int(dump_only=True)
+ auto_pause_enabled = fields.Bool(dump_only=True)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import AutoPauseSettings
+
+ return AutoPauseSettings(**data)
+
+
+class SynapseSparkComputeSchema(ComputeSchema):
+ type = StringTransformedEnum(allowed_values=[ComputeType.SYNAPSESPARK], required=True)
+ resource_id = fields.Str(required=True)
+ identity = NestedField(IdentitySchema)
+ node_family = fields.Str(dump_only=True)
+ node_size = fields.Str(dump_only=True)
+ node_count = fields.Int(dump_only=True)
+ spark_version = fields.Str(dump_only=True)
+ scale_settings = NestedField(AutoScaleSettingsSchema)
+ auto_pause_settings = NestedField(AutoPauseSettingsSchema)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/usage.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/usage.py
new file mode 100644
index 00000000..4860946b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/usage.py
@@ -0,0 +1,42 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields
+from marshmallow.decorators import post_load
+
+from azure.ai.ml._restclient.v2022_10_01_preview.models import UsageUnit
+from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum, UnionField
+from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
+from azure.ai.ml._utils.utils import camel_to_snake
+
+
+class UsageNameSchema(metaclass=PatchedSchemaMeta):
+ value = fields.Str()
+ localized_value = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import UsageName
+
+ return UsageName(**data)
+
+
+class UsageSchema(metaclass=PatchedSchemaMeta):
+ id = fields.Str()
+ aml_workspace_location = fields.Str()
+ type = fields.Str()
+ unit = UnionField(
+ [
+ fields.Str(),
+ StringTransformedEnum(
+ allowed_values=UsageUnit.COUNT,
+ casing_transform=camel_to_snake,
+ ),
+ ]
+ )
+ current_value = fields.Int()
+ limit = fields.Int()
+ name = NestedField(UsageNameSchema)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/virtual_machine_compute.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/virtual_machine_compute.py
new file mode 100644
index 00000000..deb92d3c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/virtual_machine_compute.py
@@ -0,0 +1,34 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields
+from marshmallow.decorators import post_load
+
+from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
+from azure.ai.ml.constants._compute import ComputeType
+
+from ..core.fields import NestedField, StringTransformedEnum
+from .compute import ComputeSchema
+
+
+class VirtualMachineSshSettingsSchema(metaclass=PatchedSchemaMeta):
+ admin_username = fields.Str()
+ admin_password = fields.Str()
+ ssh_port = fields.Int()
+ ssh_private_key_file = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import VirtualMachineSshSettings
+
+ return VirtualMachineSshSettings(**data)
+
+
+class VirtualMachineComputeSchema(ComputeSchema):
+ type = StringTransformedEnum(allowed_values=[ComputeType.VIRTUALMACHINE], required=True)
+ resource_id = fields.Str(required=True)
+ compute_location = fields.Str(dump_only=True)
+ ssh_settings = NestedField(VirtualMachineSshSettingsSchema)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/vm_size.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/vm_size.py
new file mode 100644
index 00000000..79ee8ea7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/compute/vm_size.py
@@ -0,0 +1,19 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from marshmallow import fields
+
+from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
+
+
+class VmSizeSchema(metaclass=PatchedSchemaMeta):
+ name = fields.Str()
+ family = fields.Str()
+ v_cp_us = fields.Int()
+ gpus = fields.Int()
+ os_vhd_size_mb = fields.Int()
+ max_resource_volume_mb = fields.Int()
+ memory_gb = fields.Float()
+ low_priority_capable = fields.Bool()
+ premium_io = fields.Bool()
+ supported_compute_types = fields.Str()
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/__init__.py
new file mode 100644
index 00000000..29a4fcd3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/auto_delete_setting.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/auto_delete_setting.py
new file mode 100644
index 00000000..ca2bd2e1
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/auto_delete_setting.py
@@ -0,0 +1,38 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+from azure.ai.ml._schema.core.fields import StringTransformedEnum
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.constants._common import AutoDeleteCondition
+from azure.ai.ml.entities._assets.auto_delete_setting import AutoDeleteSetting
+
+
+@experimental
+class BaseAutoDeleteSettingSchema(metaclass=PatchedSchemaMeta):
+ @post_load
+ def make(self, data, **kwargs) -> "AutoDeleteSetting":
+ return AutoDeleteSetting(**data)
+
+
+@experimental
+class AutoDeleteConditionSchema(BaseAutoDeleteSettingSchema):
+ condition = StringTransformedEnum(
+ allowed_values=[condition.name for condition in AutoDeleteCondition],
+ casing_transform=camel_to_snake,
+ )
+
+
+@experimental
+class ValueSchema(BaseAutoDeleteSettingSchema):
+ value = fields.Str()
+
+
+@experimental
+class AutoDeleteSettingSchema(AutoDeleteConditionSchema, ValueSchema):
+ pass
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/fields.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/fields.py
new file mode 100644
index 00000000..fd7956b8
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/fields.py
@@ -0,0 +1,1029 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access,too-many-lines
+
+import copy
+import logging
+import os
+import re
+import traceback
+import typing
+from abc import abstractmethod
+from pathlib import Path
+from typing import List, Optional, Union
+
+from marshmallow import RAISE, fields
+from marshmallow.exceptions import ValidationError
+from marshmallow.fields import Field, Nested
+from marshmallow.utils import FieldInstanceResolutionError, from_iso_datetime, resolve_field_instance
+
+from ..._utils._arm_id_utils import AMLVersionedArmId, is_ARM_id_for_resource, parse_name_label, parse_name_version
+from ..._utils._experimental import _is_warning_cached
+from ..._utils.utils import is_data_binding_expression, is_valid_node_name, load_file, load_yaml
+from ...constants._common import (
+ ARM_ID_PREFIX,
+ AZUREML_RESOURCE_PROVIDER,
+ BASE_PATH_CONTEXT_KEY,
+ CONDA_FILE,
+ DOCKER_FILE_NAME,
+ EXPERIMENTAL_FIELD_MESSAGE,
+ EXPERIMENTAL_LINK_MESSAGE,
+ FILE_PREFIX,
+ INTERNAL_REGISTRY_URI_FORMAT,
+ LOCAL_COMPUTE_TARGET,
+ LOCAL_PATH,
+ REGISTRY_URI_FORMAT,
+ RESOURCE_ID_FORMAT,
+ AzureMLResourceType,
+ DefaultOpenEncoding,
+)
+from ...entities._job.pipeline._attr_dict import try_get_non_arbitrary_attr
+from ...exceptions import MlException, ValidationException
+from ..core.schema import PathAwareSchema
+
+module_logger = logging.getLogger(__name__)
+T = typing.TypeVar("T")
+
+
+class StringTransformedEnum(Field):
+ def __init__(self, **kwargs):
+ # pop marshmallow unknown args to avoid warnings
+ self.allowed_values = kwargs.pop("allowed_values", None)
+ self.casing_transform = kwargs.pop("casing_transform", lambda x: x.lower())
+ self.pass_original = kwargs.pop("pass_original", False)
+ super().__init__(**kwargs)
+ if isinstance(self.allowed_values, str):
+ self.allowed_values = [self.allowed_values]
+ self.allowed_values = [self.casing_transform(x) for x in self.allowed_values]
+
+ def _jsonschema_type_mapping(self):
+ schema = {"type": "string", "enum": self.allowed_values}
+ if self.name is not None:
+ schema["title"] = self.name
+ if self.dump_only:
+ schema["readonly"] = True
+ return schema
+
+ def _serialize(self, value, attr, obj, **kwargs):
+ if not value:
+ return None
+ if isinstance(value, str) and self.casing_transform(value) in self.allowed_values:
+ return value if self.pass_original else self.casing_transform(value)
+ raise ValidationError(f"Value {value!r} passed is not in set {self.allowed_values}")
+
+ def _deserialize(self, value, attr, data, **kwargs):
+ if isinstance(value, str) and self.casing_transform(value) in self.allowed_values:
+ return value if self.pass_original else self.casing_transform(value)
+ raise ValidationError(f"Value {value!r} passed is not in set {self.allowed_values}")
+
+
+class DumpableEnumField(StringTransformedEnum):
+ def __init__(self, **kwargs):
+ """Enum field that will raise exception when dumping."""
+ kwargs.pop("casing_transform", None)
+ super(DumpableEnumField, self).__init__(casing_transform=lambda x: x, **kwargs)
+
+
+class LocalPathField(fields.Str):
+ """A field that validates that the input is a local path.
+
+ Can only be used as fields of PathAwareSchema.
+ """
+
+ default_error_messages = {
+ "invalid_path": "The filename, directory name, or volume label syntax is incorrect.",
+ "path_not_exist": "Can't find {allow_type} in resolved absolute path: {path}.",
+ }
+
+ def __init__(self, allow_dir=True, allow_file=True, **kwargs):
+ self._allow_dir = allow_dir
+ self._allow_file = allow_file
+ self._pattern = kwargs.get("pattern", None)
+ super().__init__()
+
+ def _jsonschema_type_mapping(self):
+ schema = {"type": "string", "arm_type": LOCAL_PATH}
+ if self.name is not None:
+ schema["title"] = self.name
+ if self.dump_only:
+ schema["readonly"] = True
+ if self._pattern:
+ schema["pattern"] = self._pattern
+ return schema
+
+ # pylint: disable-next=docstring-missing-param
+ def _resolve_path(self, value: Union[str, os.PathLike]) -> Path:
+ """Resolve path to absolute path based on base_path in context.
+
+ Will resolve the path if it's already an absolute path.
+
+ :return: The resolved path
+ :rtype: Path
+ """
+ try:
+ result = Path(value)
+ base_path = Path(self.context[BASE_PATH_CONTEXT_KEY])
+ if not result.is_absolute():
+ result = base_path / result
+
+ # for non-path string like "azureml:/xxx", OSError can be raised in either
+ # resolve() or is_dir() or is_file()
+ result = result.resolve()
+ if (self._allow_dir and result.is_dir()) or (self._allow_file and result.is_file()):
+ return result
+ except OSError as e:
+ raise self.make_error("invalid_path") from e
+ raise self.make_error("path_not_exist", path=result.as_posix(), allow_type=self.allowed_path_type)
+
+ @property
+ def allowed_path_type(self) -> str:
+ if self._allow_dir and self._allow_file:
+ return "directory or file"
+ if self._allow_dir:
+ return "directory"
+ return "file"
+
+ def _validate(self, value):
+ # inherited validations like required, allow_none, etc.
+ super(LocalPathField, self)._validate(value)
+
+ if value is None:
+ return
+ self._resolve_path(value)
+
+ def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[str]:
+ # do not block serializing None even if required or not allow_none.
+ if value is None:
+ return None
+ # always dump path as absolute path in string as base_path will be dropped after serialization
+ return super(LocalPathField, self)._serialize(self._resolve_path(value).as_posix(), attr, obj, **kwargs)
+
+
+class SerializeValidatedUrl(fields.Url):
+ """This field will validate if value is an url during serialization, so that only valid urls can be serialized as
+ this schema.
+
+ Use this schema instead of fields.Url when unioned with ArmStr or its subclasses like ArmVersionedStr, so that the
+ field can be serialized correctly after deserialization. azureml:xxx => xxx => azureml:xxx e.g. The field will still
+ always be serializable as any string can be serialized as an ArmStr.
+ """
+
+ def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[str]:
+ if value is None:
+ return None
+ self._validate(value)
+ return super(SerializeValidatedUrl, self)._serialize(value, attr, obj, **kwargs)
+
+
+class DataBindingStr(fields.Str):
+ """A string represents a binding to some data in pipeline job, e.g.: parent.jobs.inputs.input1,
+ parent.jobs.node1.outputs.output1."""
+
+ def _jsonschema_type_mapping(self):
+ schema = {"type": "string", "pattern": r"\$\{\{\s*(\S*)\s*\}\}"}
+ if self.name is not None:
+ schema["title"] = self.name
+ if self.dump_only:
+ schema["readonly"] = True
+ return schema
+
+ def _serialize(self, value, attr, obj, **kwargs):
+ # None value handling logic is inside _serialize but outside _validate/_deserialize
+ if value is None:
+ return None
+
+ from azure.ai.ml.entities._job.pipeline._io import InputOutputBase
+
+ if isinstance(value, InputOutputBase):
+ value = str(value)
+
+ self._validate(value)
+ return super(DataBindingStr, self)._serialize(value, attr, obj, **kwargs)
+
+ def _validate(self, value):
+ if is_data_binding_expression(value, is_singular=False):
+ return super(DataBindingStr, self)._validate(value)
+ raise ValidationError(f"Value passed is not a data binding string: {value}")
+
+
+class NodeBindingStr(DataBindingStr):
+ """A string represents a binding to some node in pipeline job, e.g.: parent.jobs.node1."""
+
+ def _serialize(self, value, attr, obj, **kwargs):
+ # None value handling logic is inside _serialize but outside _validate/_deserialize
+ if value is None:
+ return None
+
+ from azure.ai.ml.entities._builders import BaseNode
+
+ if isinstance(value, BaseNode):
+ value = f"${{{{parent.jobs.{value.name}}}}}"
+
+ self._validate(value)
+ return super(NodeBindingStr, self)._serialize(value, attr, obj, **kwargs)
+
+ def _validate(self, value):
+ if is_data_binding_expression(value, is_singular=True):
+ return super(NodeBindingStr, self)._validate(value)
+ raise ValidationError(f"Value passed is not a node binding string: {value}")
+
+
+class DateTimeStr(fields.Str):
+ """A string represents a datetime in ISO8601 format."""
+
+ def _jsonschema_type_mapping(self):
+ schema = {"type": "string"}
+ if self.name is not None:
+ schema["title"] = self.name
+ if self.dump_only:
+ schema["readonly"] = True
+ return schema
+
+ def _serialize(self, value, attr, obj, **kwargs):
+ if value is None:
+ return None
+ self._validate(value)
+ return super(DateTimeStr, self)._serialize(value, attr, obj, **kwargs)
+
+ def _validate(self, value):
+ try:
+ from_iso_datetime(value)
+ except Exception as e:
+ raise ValidationError(f"Not a valid ISO8601-formatted datetime string: {value}") from e
+
+
+class ArmStr(Field):
+ """A string represents an ARM ID for some AzureML resource."""
+
+ def __init__(self, **kwargs):
+ self.azureml_type = kwargs.pop("azureml_type", None)
+ self.pattern = kwargs.pop("pattern", r"^azureml:.+")
+ super().__init__(**kwargs)
+
+ def _jsonschema_type_mapping(self):
+ schema = {
+ "type": "string",
+ "pattern": self.pattern,
+ "arm_type": self.azureml_type,
+ }
+ if self.name is not None:
+ schema["title"] = self.name
+ if self.dump_only:
+ schema["readonly"] = True
+ return schema
+
+ def _serialize(self, value, attr, obj, **kwargs):
+ if isinstance(value, str):
+ serialized_value = value if value.startswith(ARM_ID_PREFIX) else f"{ARM_ID_PREFIX}{value}"
+ return serialized_value
+ if value is None and not self.required:
+ return None
+ raise ValidationError(f"Non-string passed to ArmStr for {attr}")
+
+ def _deserialize(self, value, attr, data, **kwargs):
+ if isinstance(value, str) and value.startswith(ARM_ID_PREFIX):
+ name = value[len(ARM_ID_PREFIX) :]
+ return name
+ formatted_resource_id = RESOURCE_ID_FORMAT.format(
+ "<subscription_id>",
+ "<resource_group>",
+ AZUREML_RESOURCE_PROVIDER,
+ "<workspace_name>/",
+ )
+ if self.azureml_type is not None:
+ azureml_type_suffix = self.azureml_type
+ else:
+ azureml_type_suffix = "<asset_type>" + "/<resource_name>/<version-if applicable>)"
+ raise ValidationError(
+ f"In order to specify an existing {self.azureml_type if self.azureml_type is not None else 'asset'}, "
+ "please provide either of the following prefixed with 'azureml:':\n"
+ "1. The full ARM ID for the resource, e.g."
+ f"azureml:{formatted_resource_id + azureml_type_suffix}\n"
+ "2. The short-hand name of the resource registered in the workspace, "
+ "eg: azureml:<short-hand-name>:<version-if applicable>. "
+ "For example, version 1 of the environment registered as "
+ "'my-env' in the workspace can be referenced as 'azureml:my-env:1'"
+ )
+
+
+class ArmVersionedStr(ArmStr):
+ """A string represents an ARM ID for some AzureML resource with version."""
+
+ def __init__(self, **kwargs):
+ self.allow_default_version = kwargs.pop("allow_default_version", False)
+ super().__init__(**kwargs)
+
+ def _deserialize(self, value, attr, data, **kwargs):
+ arm_id = super()._deserialize(value, attr, data, **kwargs)
+ try:
+ AMLVersionedArmId(arm_id)
+ return arm_id
+ except ValidationException:
+ pass
+
+ if is_ARM_id_for_resource(name=arm_id, resource_type=self.azureml_type):
+ msg = "id for {} is invalid"
+ raise ValidationError(message=msg.format(attr))
+
+ try:
+ name, label = parse_name_label(arm_id)
+ except ValidationException as e:
+ # Schema will try to deserialize the value with all possible Schema & catch ValidationError
+ # So raise ValidationError instead of ValidationException
+ raise ValidationError(e.message) from e
+
+ version = None
+ if not label:
+ name, version = parse_name_version(arm_id)
+
+ if not (label or version):
+ if self.allow_default_version:
+ return name
+ raise ValidationError(f"Either version or label is not provided for {attr} or the id is not valid.")
+
+ if version:
+ return f"{name}:{version}"
+ return f"{name}@{label}"
+
+
+class FileRefField(Field):
+ """A string represents a file reference in pipeline job, e.g.: file:./my_file.txt, file:../my_file.txt,"""
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ def _jsonschema_type_mapping(self):
+ schema = {"type": "string"}
+ if self.name is not None:
+ schema["title"] = self.name
+ if self.dump_only:
+ schema["readonly"] = True
+ return schema
+
+ def _deserialize(self, value, attr, data, **kwargs):
+ if isinstance(value, str) and not value.startswith(FILE_PREFIX):
+ base_path = Path(self.context[BASE_PATH_CONTEXT_KEY])
+ path = Path(value)
+ if not path.is_absolute():
+ path = base_path / path
+ path.resolve()
+ data = load_file(path)
+ return data
+ raise ValidationError(f"Not supporting non file for {attr}")
+
+ def _serialize(self, value: typing.Any, attr: str, obj: typing.Any, **kwargs):
+ raise ValidationError("Serialize on FileRefField is not supported.")
+
+
+class RefField(Field):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ def _jsonschema_type_mapping(self):
+ schema = {"type": "string"}
+ if self.name is not None:
+ schema["title"] = self.name
+ if self.dump_only:
+ schema["readonly"] = True
+ return schema
+
+ def _deserialize(self, value, attr, data, **kwargs):
+ if isinstance(value, str) and (
+ value.startswith(FILE_PREFIX)
+ or (os.path.isdir(value) or os.path.isfile(value))
+ or value == DOCKER_FILE_NAME
+ ): # "Dockerfile" w/o file: prefix doesn't register as a path
+ if value.startswith(FILE_PREFIX):
+ value = value[len(FILE_PREFIX) :]
+ base_path = Path(self.context[BASE_PATH_CONTEXT_KEY])
+
+ path = Path(value)
+ if not path.is_absolute():
+ path = base_path / path
+ path.resolve()
+ if attr == CONDA_FILE: # conda files should be loaded as dictionaries
+ data = load_yaml(path)
+ else:
+ data = load_file(path)
+ return data
+ raise ValidationError(f"Not supporting non file for {attr}")
+
+ def _serialize(self, value: typing.Any, attr: str, obj: typing.Any, **kwargs):
+ raise ValidationError("Serialize on RefField is not supported.")
+
+
+class NestedField(Nested):
+ """anticipates the default coming in next marshmallow version, unknown=True."""
+
+ def __init__(self, *args, **kwargs):
+ if kwargs.get("unknown") is None:
+ kwargs["unknown"] = RAISE
+ super().__init__(*args, **kwargs)
+
+
+# Note: Currently contains a bug where the order in which fields are inputted can potentially cause a bug
+# Example, the first line below works, but the second one fails upon calling load_from_dict
+# with the error " AttributeError: 'list' object has no attribute 'get'"
+# inputs = UnionField([fields.List(NestedField(DataSchema)), NestedField(DataSchema)])
+# inputs = UnionField([NestedField(DataSchema), fields.List(NestedField(DataSchema))])
+class UnionField(fields.Field):
+ """A field that can be one of multiple types."""
+
+ def __init__(self, union_fields: List[fields.Field], is_strict=False, **kwargs):
+ super().__init__(**kwargs)
+ try:
+ # add the validation and make sure union_fields must be subclasses or instances of
+ # marshmallow.base.FieldABC
+ self._union_fields = [resolve_field_instance(cls_or_instance) for cls_or_instance in union_fields]
+ # TODO: make serialization/de-serialization work in the same way as json schema when is_strict is True
+ self.is_strict = is_strict # S\When True, combine fields with oneOf instead of anyOf at schema generation
+ except FieldInstanceResolutionError as error:
+ raise ValueError(
+ 'Elements of "union_fields" must be subclasses or instances of marshmallow.base.FieldABC.'
+ ) from error
+
+ @property
+ def union_fields(self):
+ return iter(self._union_fields)
+
+ def insert_union_field(self, field):
+ self._union_fields.insert(0, field)
+
+ # This sets the parent for the schema and also handles nesting.
+ def _bind_to_schema(self, field_name, schema):
+ super()._bind_to_schema(field_name, schema)
+ self._union_fields = self._create_bind_fields(self._union_fields, field_name)
+
+ def _create_bind_fields(self, _fields, field_name):
+ new_union_fields = []
+ for field in _fields:
+ field = copy.deepcopy(field)
+ field._bind_to_schema(field_name, self)
+ new_union_fields.append(field)
+ return new_union_fields
+
+ def _serialize(self, value, attr, obj, **kwargs):
+ if value is None:
+ return None
+ errors = []
+ for field in self._union_fields:
+ try:
+ return field._serialize(value, attr, obj, **kwargs)
+
+ except ValidationError as e:
+ errors.extend(e.messages)
+ except (TypeError, ValueError, AttributeError, ValidationException) as e:
+ errors.extend([str(e)])
+ raise ValidationError(message=errors, field_name=attr)
+
+ def _deserialize(self, value, attr, data, **kwargs):
+ errors = []
+ for schema in self._union_fields:
+ try:
+ return schema.deserialize(value, attr, data, **kwargs)
+ except ValidationError as e:
+ errors.append(e.normalized_messages())
+ except ValidationException as e:
+ # ValidationException is explicitly raised in project code so usually easy to locate with error message
+ errors.append([str(e)])
+ except (FileNotFoundError, TypeError) as e:
+ # FileNotFoundError and TypeError can be raised in system code, so we need to add more information
+ # TODO: consider if it's possible to handle those errors in their directly relative
+ # code instead of in UnionField
+ trace = traceback.format_exc().splitlines()
+ if len(trace) >= 3:
+ errors.append([f"{trace[-1]} from {trace[-3]} {trace[-2]}"])
+ else:
+ errors.append([f"{e.__class__.__name__}: {e}"])
+ finally:
+ # Revert base path to original path when job schema fail to deserialize job. For example, when load
+ # parallel job with component file reference starting with FILE prefix, maybe first CommandSchema will
+ # load component yaml according to AnonymousCommandComponentSchema, and YamlFileSchema will update base
+ # path. When CommandSchema fail to load, then Parallelschema will load component yaml according to
+ # AnonymousParallelComponentSchema, but base path now is incorrect, and will raise path not found error
+ # when load component yaml file.
+ if (
+ hasattr(schema, "name")
+ and schema.name == "jobs"
+ and hasattr(schema, "schema")
+ and isinstance(schema.schema, PathAwareSchema)
+ ):
+ # use old base path to recover original base path
+ schema.schema.context[BASE_PATH_CONTEXT_KEY] = schema.schema.old_base_path
+ # recover base path of parent schema
+ schema.context[BASE_PATH_CONTEXT_KEY] = schema.schema.context[BASE_PATH_CONTEXT_KEY]
+ raise ValidationError(errors, field_name=attr)
+
+
+class TypeSensitiveUnionField(UnionField):
+ """Union field which will try to simplify error messages based on type field in failed
+ serialization/deserialization.
+
+ If value doesn't have type, will skip error messages from fields with type field If value has type & its type
+ doesn't match any allowed types, raise "Value {} not in set {}" If value has type & its type matches at least 1
+ allowed value, it will raise the first matched error.
+ """
+
+ def __init__(
+ self,
+ type_sensitive_fields_dict: typing.Dict[str, List[fields.Field]],
+ *,
+ plain_union_fields: Optional[List[fields.Field]] = None,
+ allow_load_from_file: bool = True,
+ type_field_name="type",
+ **kwargs,
+ ):
+ """param type_sensitive_fields_dict: a dict of type name to list of
+ type sensitive fields param plain_union_fields: list of fields that
+ will be used if value doesn't have type field type plain_union_fields:
+ List[fields.Field] param allow_load_from_file: whether to allow load
+ from file, default to True type allow_load_from_file: bool param
+ type_field_name: field name of type field, default value is "type" type
+ type_field_name: str."""
+ self._type_sensitive_fields_dict = {}
+ self._allow_load_from_yaml = allow_load_from_file
+
+ union_fields = plain_union_fields or []
+ for type_name, type_sensitive_fields in type_sensitive_fields_dict.items():
+ union_fields.extend(type_sensitive_fields)
+ self._type_sensitive_fields_dict[type_name] = [
+ resolve_field_instance(cls_or_instance) for cls_or_instance in type_sensitive_fields
+ ]
+
+ super(TypeSensitiveUnionField, self).__init__(union_fields, **kwargs)
+ self._type_field_name = type_field_name
+
+ def _bind_to_schema(self, field_name, schema):
+ super()._bind_to_schema(field_name, schema)
+ for (
+ type_name,
+ type_sensitive_fields,
+ ) in self._type_sensitive_fields_dict.items():
+ self._type_sensitive_fields_dict[type_name] = self._create_bind_fields(type_sensitive_fields, field_name)
+
+ @property
+ def type_field_name(self) -> str:
+ return self._type_field_name
+
+ @property
+ def allowed_types(self) -> List[str]:
+ return list(self._type_sensitive_fields_dict.keys())
+
+ # pylint: disable-next=docstring-missing-param
+ def insert_type_sensitive_field(self, type_name, field):
+ """Insert a new type sensitive field for a specific type."""
+ if type_name not in self._type_sensitive_fields_dict:
+ self._type_sensitive_fields_dict[type_name] = []
+ self._type_sensitive_fields_dict[type_name].insert(0, field)
+ self.insert_union_field(field)
+
+ # pylint: disable-next=docstring-missing-param
+ def _simplified_error_base_on_type(self, e, value, attr) -> Exception:
+ """Returns a simplified error based on value type
+
+ :return: Returns
+ * e if value doesn't havetype
+ * ValidationError("Value {} not in set {}") if value type not in allowed types
+ * First Matched Error message if value has type and type matches atleast one field
+ :rtype: Exception
+ """
+ value_type = try_get_non_arbitrary_attr(value, self.type_field_name)
+ if value_type is None:
+ # if value has no type field, raise original error
+ return e
+ if value_type not in self.allowed_types:
+ # if value has type field but its value doesn't match any allowed value, raise ValidationError directly
+ return ValidationError(
+ message={self.type_field_name: f"Value {value_type!r} passed is not in set {self.allowed_types}"},
+ field_name=attr,
+ )
+ filtered_messages = []
+ # if value has type field and its value match at least 1 allowed value, raise first matched
+ for error in e.messages:
+ # for non-nested schema, their error message will be {"_schema": ["xxx"]}
+ if len(error) == 1 and "_schema" in error:
+ continue
+ # for nested schema, type field won't be within error only if type field value is matched
+ # then return first matched error message
+ if self.type_field_name in error:
+ continue
+ filtered_messages.append(error)
+
+ if len(filtered_messages) == 0:
+ # shouldn't happen
+ return e
+ # TODO: consider if we should keep all filtered messages
+ return ValidationError(message=filtered_messages[0], field_name=attr)
+
+ def _serialize(self, value, attr, obj, **kwargs):
+ union_fields = self._union_fields[:]
+ value_type = try_get_non_arbitrary_attr(value, self.type_field_name)
+ if value_type is not None and value_type in self.allowed_types:
+ target_fields = self._type_sensitive_fields_dict[value_type]
+ if len(target_fields) == 1:
+ return target_fields[0]._serialize(value, attr, obj, **kwargs)
+ self._union_fields = target_fields
+
+ try:
+ return super(TypeSensitiveUnionField, self)._serialize(value, attr, obj, **kwargs)
+ except ValidationError as e:
+ raise self._simplified_error_base_on_type(e, value, attr)
+ finally:
+ self._union_fields = union_fields
+
+ def _try_load_from_yaml(self, value):
+ target_path = value
+ if target_path.startswith(FILE_PREFIX):
+ target_path = target_path[len(FILE_PREFIX) :]
+ try:
+ import yaml
+
+ base_path = Path(self.context[BASE_PATH_CONTEXT_KEY])
+ target_path = Path(target_path)
+ if not target_path.is_absolute():
+ target_path = base_path / target_path
+ target_path.resolve()
+ if target_path.is_file():
+ self.context[BASE_PATH_CONTEXT_KEY] = target_path.parent
+ with target_path.open(encoding=DefaultOpenEncoding.READ) as f:
+ return yaml.safe_load(f)
+ except Exception: # pylint: disable=W0718
+ pass
+ return value
+
+ def _deserialize(self, value, attr, data, **kwargs):
+ try:
+ return super(TypeSensitiveUnionField, self)._deserialize(value, attr, data, **kwargs)
+ except ValidationError as e:
+ if isinstance(value, str) and self._allow_load_from_yaml:
+ value = self._try_load_from_yaml(value)
+ raise self._simplified_error_base_on_type(e, value, attr)
+
+
+def ComputeField(**kwargs) -> Field:
+ """
+ :return: The compute field
+ :rtype: Field
+ """
+ return UnionField(
+ [
+ StringTransformedEnum(allowed_values=[LOCAL_COMPUTE_TARGET]),
+ ArmStr(azureml_type=AzureMLResourceType.COMPUTE),
+ # Case for virtual clusters
+ fields.Str(),
+ ],
+ metadata={"description": "The compute resource."},
+ **kwargs,
+ )
+
+
+def CodeField(**kwargs) -> Field:
+ """
+ :return: The code field
+ :rtype: Field
+ """
+ return UnionField(
+ [
+ LocalPathField(),
+ SerializeValidatedUrl(),
+ GitStr(),
+ RegistryStr(azureml_type=AzureMLResourceType.CODE),
+ InternalRegistryStr(azureml_type=AzureMLResourceType.CODE),
+ # put arm versioned string at last order as it can deserialize any string into "azureml:<origin>"
+ ArmVersionedStr(azureml_type=AzureMLResourceType.CODE),
+ ],
+ metadata={"description": "A local path or http:, https:, azureml: url pointing to a remote location."},
+ **kwargs,
+ )
+
+
+def EnvironmentField(*, extra_fields: List[Field] = None, **kwargs):
+ """Function to return a union field for environment.
+
+ :keyword extra_fields: Extra fields to be added to the union field
+ :paramtype extra_fields: List[Field]
+ :return: The environment field
+ :rtype: Field
+ """
+ extra_fields = extra_fields or []
+ # local import to avoid circular dependency
+ from azure.ai.ml._schema.assets.environment import AnonymousEnvironmentSchema
+
+ return UnionField(
+ [
+ NestedField(AnonymousEnvironmentSchema),
+ RegistryStr(azureml_type=AzureMLResourceType.ENVIRONMENT),
+ ArmVersionedStr(azureml_type=AzureMLResourceType.ENVIRONMENT, allow_default_version=True),
+ ]
+ + extra_fields,
+ **kwargs,
+ )
+
+
+def DistributionField(**kwargs):
+ """Function to return a union field for distribution.
+
+ :return: The distribution field
+ :rtype: Field
+ """
+ from azure.ai.ml._schema.job.distribution import (
+ MPIDistributionSchema,
+ PyTorchDistributionSchema,
+ RayDistributionSchema,
+ TensorFlowDistributionSchema,
+ )
+
+ return UnionField(
+ [
+ NestedField(PyTorchDistributionSchema, **kwargs),
+ NestedField(TensorFlowDistributionSchema, **kwargs),
+ NestedField(MPIDistributionSchema, **kwargs),
+ ExperimentalField(NestedField(RayDistributionSchema, **kwargs)),
+ ]
+ )
+
+
+def PrimitiveValueField(**kwargs):
+ """Function to return a union field for primitive value.
+
+ :return: The primitive value field
+ :rtype: Field
+ """
+ return UnionField(
+ [
+ # Note: order matters here - to make sure value parsed correctly.
+ # By default when strict is false, marshmallow downcasts float to int.
+ # Setting it to true will throw a validation error when loading a float to int.
+ # https://github.com/marshmallow-code/marshmallow/pull/755
+ # Use DumpableIntegerField to make sure there will be validation error when
+ # loading/dumping a float to int.
+ # note that this field can serialize bool instance but cannot deserialize bool instance.
+ DumpableIntegerField(strict=True),
+ # Use DumpableFloatField with strict of True to avoid '1'(str) serialized to 1.0(float)
+ DumpableFloatField(strict=True),
+ # put string schema after Int and Float to make sure they won't dump to string
+ fields.Str(),
+ # fields.Bool comes last since it'll parse anything non-falsy to True
+ fields.Bool(),
+ ],
+ **kwargs,
+ )
+
+
+class VersionField(Field):
+ """A string represents a version, e.g.: 1, 1.0, 1.0.0.
+ Will always convert to string to ensure that "1.0" won't be converted to 1.
+ """
+
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+
+ def _jsonschema_type_mapping(self):
+ schema = {"anyOf": [{"type": "string"}, {"type": "integer"}]}
+ if self.name is not None:
+ schema["title"] = self.name
+ if self.dump_only:
+ schema["readonly"] = True
+ return schema
+
+ def _deserialize(self, value, attr, data, **kwargs) -> str:
+ if isinstance(value, str):
+ return value
+ if isinstance(value, (int, float)):
+ return str(value)
+ msg = f"Type {type(value)} is not supported for version."
+ raise MlException(message=msg, no_personal_data_message=msg)
+
+
+class NumberVersionField(VersionField):
+ """A string represents a version, e.g.: 1, 1.0, 1.0.0.
+ Will always convert to string to ensure that "1.0" won't be converted to 1.
+ """
+
+ default_error_messages = {
+ "max_version": "Version {input} is greater than or equal to upper bound {bound}.",
+ "min_version": "Version {input} is smaller than lower bound {bound}.",
+ "invalid": "Number version must be integers concatenated by '.', like 1.0.1.",
+ }
+
+ def __init__(self, *args, upper_bound: Optional[str] = None, lower_bound: Optional[str] = None, **kwargs) -> None:
+ self._upper = None if upper_bound is None else self._version_to_tuple(upper_bound)
+ self._lower = None if lower_bound is None else self._version_to_tuple(lower_bound)
+ super().__init__(*args, **kwargs)
+
+ def _version_to_tuple(self, value: str):
+ try:
+ return tuple(int(v) for v in str(value).split("."))
+ except ValueError as e:
+ raise self.make_error("invalid") from e
+
+ def _validate(self, value):
+ super()._validate(value)
+ value_tuple = self._version_to_tuple(value)
+ if self._upper is not None and value_tuple >= self._upper:
+ raise self.make_error("max_version", input=value, bound=self._upper)
+ if self._lower is not None and value_tuple < self._lower:
+ raise self.make_error("min_version", input=value, bound=self._lower)
+
+
+class DumpableIntegerField(fields.Integer):
+ """A int field that cannot serialize other type of values to int if self.strict."""
+
+ def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[typing.Union[str, T]]:
+ if self.strict and not isinstance(value, int):
+ # this implementation can serialize bool to bool
+ raise self.make_error("invalid", input=value)
+ return super()._serialize(value, attr, obj, **kwargs)
+
+
+class DumpableFloatField(fields.Float):
+ """A float field that cannot serialize other type of values to float if self.strict."""
+
+ def __init__(
+ self,
+ *,
+ strict: bool = False,
+ allow_nan: bool = False,
+ as_string: bool = False,
+ **kwargs,
+ ):
+ self.strict = strict
+ super().__init__(allow_nan=allow_nan, as_string=as_string, **kwargs)
+
+ def _validated(self, value):
+ if self.strict and not isinstance(value, float):
+ raise self.make_error("invalid", input=value)
+ return super()._validated(value)
+
+ def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[typing.Union[str, T]]:
+ return super()._serialize(self._validated(value), attr, obj, **kwargs)
+
+
+class DumpableStringField(fields.String):
+ """A string field that cannot serialize other type of values to string if self.strict."""
+
+ def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[typing.Union[str, T]]:
+ if not isinstance(value, str):
+ raise ValidationError("Given value is not a string")
+ return super()._serialize(value, attr, obj, **kwargs)
+
+
+class ExperimentalField(fields.Field):
+ def __init__(self, experimental_field: fields.Field, **kwargs):
+ super().__init__(**kwargs)
+ try:
+ self._experimental_field = resolve_field_instance(experimental_field)
+ self.required = experimental_field.required
+ except FieldInstanceResolutionError as error:
+ raise ValueError(
+ '"experimental_field" must be subclasses or instances of marshmallow.base.FieldABC.'
+ ) from error
+
+ @property
+ def experimental_field(self):
+ return self._experimental_field
+
+ # This sets the parent for the schema and also handles nesting.
+ def _bind_to_schema(self, field_name, schema):
+ super()._bind_to_schema(field_name, schema)
+ self._experimental_field._bind_to_schema(field_name, schema)
+
+ def _serialize(self, value, attr, obj, **kwargs):
+ if value is None:
+ return None
+ return self._experimental_field._serialize(value, attr, obj, **kwargs)
+
+ def _deserialize(self, value, attr, data, **kwargs):
+ if value is not None:
+ message = "Field '{0}': {1} {2}".format(attr, EXPERIMENTAL_FIELD_MESSAGE, EXPERIMENTAL_LINK_MESSAGE)
+ if not _is_warning_cached(message):
+ module_logger.warning(message)
+
+ return self._experimental_field._deserialize(value, attr, data, **kwargs)
+
+
+class RegistryStr(Field):
+ """A string represents a registry ID for some AzureML resource."""
+
+ def __init__(self, **kwargs):
+ self.azureml_type = kwargs.pop("azureml_type", None)
+ super().__init__(**kwargs)
+
+ def _jsonschema_type_mapping(self):
+ schema = {
+ "type": "string",
+ "pattern": "^azureml://registries/.*",
+ "arm_type": self.azureml_type,
+ }
+ if self.name is not None:
+ schema["title"] = self.name
+ if self.dump_only:
+ schema["readonly"] = True
+ return schema
+
+ def _serialize(self, value, attr, obj, **kwargs):
+ if isinstance(value, str) and value.startswith(REGISTRY_URI_FORMAT):
+ return f"{value}"
+ if value is None and not self.required:
+ return None
+ raise ValidationError(f"Non-string passed to RegistryStr for {attr}")
+
+ def _deserialize(self, value, attr, data, **kwargs):
+ if isinstance(value, str) and value.startswith(REGISTRY_URI_FORMAT):
+ return value
+ raise ValidationError(
+ f"In order to specify an existing {self.azureml_type}, "
+ "please provide the correct registry path prefixed with 'azureml://':\n"
+ )
+
+
+class InternalRegistryStr(RegistryStr):
+ """A string represents a registry ID for some internal AzureML resource."""
+
+ def _jsonschema_type_mapping(self):
+ schema = super()._jsonschema_type_mapping()
+ schema["pattern"] = "^azureml://feeds/.*"
+ return schema
+
+ def _deserialize(self, value, attr, data, **kwargs):
+ if isinstance(value, str) and value.startswith(INTERNAL_REGISTRY_URI_FORMAT):
+ value = value.replace(INTERNAL_REGISTRY_URI_FORMAT, REGISTRY_URI_FORMAT, 1)
+ return super()._deserialize(value, attr, data, **kwargs)
+
+
+class PythonFuncNameStr(fields.Str):
+ """A string represents a python function name."""
+
+ @abstractmethod
+ def _get_field_name(self) -> str:
+ """Returns field name, used for error message."""
+
+ # pylint: disable-next=docstring-missing-param
+ def _deserialize(self, value, attr, data, **kwargs) -> str:
+ """Validate component name.
+
+ :return: The component name
+ :rtype: str
+ """
+ name = super()._deserialize(value, attr, data, **kwargs)
+ pattern = r"^[a-z][a-z\d_]*$"
+ if not re.match(pattern, name):
+ raise ValidationError(
+ f"{self._get_field_name()} name should only contain "
+ "lower letter, number, underscore and start with a lower letter. "
+ f"Currently got {name}."
+ )
+ return name
+
+
+class PipelineNodeNameStr(fields.Str):
+ """A string represents a pipeline node name."""
+
+ @abstractmethod
+ def _get_field_name(self) -> str:
+ """Returns field name, used for error message."""
+
+ # pylint: disable-next=docstring-missing-param
+ def _deserialize(self, value, attr, data, **kwargs) -> str:
+ """Validate component name.
+
+ :return: The component name
+ :rtype: str
+ """
+ name = super()._deserialize(value, attr, data, **kwargs)
+ if not is_valid_node_name(name):
+ raise ValidationError(
+ f"{self._get_field_name()} name should be a valid python identifier"
+ "(lower letters, numbers, underscore and start with a letter or underscore). "
+ "Currently got {name}."
+ )
+ return name
+
+
+class GitStr(fields.Str):
+ """A string represents a git path."""
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ def _jsonschema_type_mapping(self):
+ schema = {"type": "string", "pattern": "^git+"}
+ if self.name is not None:
+ schema["title"] = self.name
+ if self.dump_only:
+ schema["readonly"] = True
+ return schema
+
+ def _serialize(self, value, attr, obj, **kwargs):
+ if isinstance(value, str) and value.startswith("git+"):
+ return f"{value}"
+ if value is None and not self.required:
+ return None
+ raise ValidationError(f"Non-string passed to GitStr for {attr}")
+
+ def _deserialize(self, value, attr, data, **kwargs):
+ if isinstance(value, str) and value.startswith("git+"):
+ return value
+ raise ValidationError("In order to specify a git path, please provide the correct path prefixed with 'git+\n")
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/intellectual_property.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/intellectual_property.py
new file mode 100644
index 00000000..2ae47130
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/intellectual_property.py
@@ -0,0 +1,38 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+from azure.ai.ml._schema.core.fields import StringTransformedEnum
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.constants._assets import IPProtectionLevel
+from azure.ai.ml.entities._assets.intellectual_property import IntellectualProperty
+
+
+@experimental
+class BaseIntellectualPropertySchema(metaclass=PatchedSchemaMeta):
+ @post_load
+ def make(self, data, **kwargs) -> "IntellectualProperty":
+ return IntellectualProperty(**data)
+
+
+@experimental
+class ProtectionLevelSchema(BaseIntellectualPropertySchema):
+ protection_level = StringTransformedEnum(
+ allowed_values=[level.name for level in IPProtectionLevel],
+ casing_transform=camel_to_snake,
+ )
+
+
+@experimental
+class PublisherSchema(BaseIntellectualPropertySchema):
+ publisher = fields.Str()
+
+
+@experimental
+class IntellectualPropertySchema(ProtectionLevelSchema, PublisherSchema):
+ pass
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/resource.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/resource.py
new file mode 100644
index 00000000..dbbc6f63
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/resource.py
@@ -0,0 +1,51 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument,protected-access
+
+import logging
+
+from marshmallow import fields, post_dump, post_load, pre_dump
+
+from ...constants._common import BASE_PATH_CONTEXT_KEY
+from .schema import YamlFileSchema
+
+module_logger = logging.getLogger(__name__)
+
+
+class ResourceSchema(YamlFileSchema):
+ name = fields.Str(attribute="name")
+ id = fields.Str(attribute="id")
+ description = fields.Str(attribute="description")
+ tags = fields.Dict(keys=fields.Str, attribute="tags")
+
+ @post_load(pass_original=True)
+ def pass_source_path(self, data, original, **kwargs):
+ path = self._resolve_path(original, base_path=self._previous_base_path)
+ if path is not None:
+ from ...entities import Resource
+
+ if isinstance(data, dict):
+ # data will be used in Resource.__init__
+ data["source_path"] = path.as_posix()
+ elif isinstance(data, Resource):
+ # some resource will make dict into object in their post_load
+ # not sure if it's a better way to unify them
+ data._source_path = path
+ return data
+
+ @pre_dump
+ def update_base_path_pre_dump(self, data, **kwargs):
+ # inherit from parent if base_path is not set
+ if data.base_path:
+ self._previous_base_path = self.context[BASE_PATH_CONTEXT_KEY]
+ self.context[BASE_PATH_CONTEXT_KEY] = data.base_path
+ return data
+
+ @post_dump
+ def reset_base_path_post_dump(self, data, **kwargs):
+ if self._previous_base_path is not None:
+ # pop state
+ self.context[BASE_PATH_CONTEXT_KEY] = self._previous_base_path
+ return data
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/schema.py
new file mode 100644
index 00000000..062575bc
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/schema.py
@@ -0,0 +1,123 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import copy
+import logging
+from pathlib import Path
+from typing import Optional
+
+from marshmallow import fields, post_load, pre_load
+from pydash import objects
+
+from azure.ai.ml._schema.core.schema_meta import PatchedBaseSchema, PatchedSchemaMeta
+from azure.ai.ml._utils.utils import load_yaml
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, FILE_PREFIX, PARAMS_OVERRIDE_KEY
+from azure.ai.ml.exceptions import MlException
+
+module_logger = logging.getLogger(__name__)
+
+
+class PathAwareSchema(PatchedBaseSchema, metaclass=PatchedSchemaMeta):
+ schema_ignored = fields.Str(data_key="$schema", dump_only=True)
+
+ def __init__(self, *args, **kwargs):
+ # this will make context of all PathAwareSchema child class point to one object
+ self.context = kwargs.get("context", None)
+ if self.context is None or self.context.get(BASE_PATH_CONTEXT_KEY, None) is None:
+ msg = "Base path for reading files is required when building PathAwareSchema"
+ raise MlException(message=msg, no_personal_data_message=msg)
+ # set old base path, note it's an Path object and point to the same object with
+ # self.context.get(BASE_PATH_CONTEXT_KEY)
+ self.old_base_path = self.context.get(BASE_PATH_CONTEXT_KEY)
+ super().__init__(*args, **kwargs)
+
+ @pre_load
+ def add_param_overrides(self, data, **kwargs):
+ # Removing params override from context so that overriding is done once on the yaml
+ # child schema should not override the params.
+ params_override = self.context.pop(PARAMS_OVERRIDE_KEY, None)
+ if params_override is not None:
+ for override in params_override:
+ for param, val in override.items():
+ # Check that none of the intermediary levels are string references (azureml/file)
+ param_tokens = param.split(".")
+ test_layer = data
+ for layer in param_tokens:
+ if test_layer is None:
+ continue
+ if isinstance(test_layer, str):
+ msg = f"Cannot use '--set' on properties defined by reference strings: --set {param}"
+ raise MlException(
+ message=msg,
+ no_personal_data_message=msg,
+ )
+ test_layer = test_layer.get(layer, None)
+ objects.set_(data, param, val)
+ return data
+
+ @pre_load
+ # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype
+ def trim_dump_only(self, data, **kwargs):
+ """Marshmallow raises if dump_only fields are present in the schema. This is not desirable for our use case,
+ where read-only properties can be present in the yaml, and should simply be ignored, while we should raise in.
+
+ the case an unknown field is present - to prevent typos.
+ """
+ if isinstance(data, str) or data is None:
+ return data
+ for key, value in self.fields.items():
+ if value.dump_only:
+ schema_key = value.data_key or key
+ if data.get(schema_key, None) is not None:
+ data.pop(schema_key)
+ return data
+
+
+class YamlFileSchema(PathAwareSchema):
+ """Base class that allows derived classes to be built from paths to separate yaml files in place of inline yaml
+ definitions.
+
+ This will be transparent to any parent schema containing a nested schema of the derived class, it will not need a
+ union type for the schema, a YamlFile string will be resolved by the pre_load method into a dictionary. On loading
+ the child yaml, update the base path to use for loading sub-child files.
+ """
+
+ def __init__(self, *args, **kwargs):
+ self._previous_base_path = None
+ super().__init__(*args, **kwargs)
+
+ @classmethod
+ def _resolve_path(cls, data, base_path) -> Optional[Path]:
+ if isinstance(data, str) and data.startswith(FILE_PREFIX):
+ # Use directly if absolute path
+ path = Path(data[len(FILE_PREFIX) :])
+ if not path.is_absolute():
+ path = Path(base_path) / path
+ path.resolve()
+ return path
+ return None
+
+ @pre_load
+ def load_from_file(self, data, **kwargs):
+ path = self._resolve_path(data, Path(self.context[BASE_PATH_CONTEXT_KEY]))
+ if path is not None:
+ self._previous_base_path = Path(self.context[BASE_PATH_CONTEXT_KEY])
+ # Push update
+ # deepcopy self.context[BASE_PATH_CONTEXT_KEY] to update old base path
+ self.old_base_path = copy.deepcopy(self.context[BASE_PATH_CONTEXT_KEY])
+ self.context[BASE_PATH_CONTEXT_KEY] = path.parent
+
+ data = load_yaml(path)
+ return data
+ return data
+
+ # Schemas are read depth-first, so push/pop to update current path
+ @post_load
+ def reset_base_path_post_load(self, data, **kwargs):
+ if self._previous_base_path is not None:
+ # pop state
+ self.context[BASE_PATH_CONTEXT_KEY] = self._previous_base_path
+ return data
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/schema_meta.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/schema_meta.py
new file mode 100644
index 00000000..d352137c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/core/schema_meta.py
@@ -0,0 +1,53 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+from collections import OrderedDict
+
+from marshmallow import RAISE
+from marshmallow.decorators import post_dump
+from marshmallow.schema import Schema, SchemaMeta
+
+module_logger = logging.getLogger(__name__)
+
+
+class PatchedMeta:
+ ordered = True
+ unknown = RAISE
+
+
+class PatchedBaseSchema(Schema):
+ class Meta:
+ unknown = RAISE
+ ordered = True
+
+ @post_dump
+ # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype
+ def remove_none(self, data, **kwargs):
+ """Prevents from dumping attributes that are None, thus making the dump more compact."""
+ return OrderedDict((key, value) for key, value in data.items() if value is not None)
+
+
+class PatchedSchemaMeta(SchemaMeta):
+ """Currently there is an open issue in marshmallow, that the "unknown" property is not inherited.
+
+ We use a metaclass to inject a Meta class into all our Schema classes.
+ """
+
+ def __new__(mcs, name, bases, dct):
+ meta = dct.get("Meta")
+ if meta is None:
+ dct["Meta"] = PatchedMeta
+ else:
+ if not hasattr(meta, "unknown"):
+ dct["Meta"].unknown = RAISE
+ if not hasattr(meta, "ordered"):
+ dct["Meta"].ordered = True
+
+ if PatchedBaseSchema not in bases:
+ bases = bases + (PatchedBaseSchema,)
+ klass = super().__new__(mcs, name, bases, dct)
+ return klass
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/identity.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/identity.py
new file mode 100644
index 00000000..24cc357c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/identity.py
@@ -0,0 +1,63 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import ValidationError, fields, post_load, pre_dump, validates
+
+from azure.ai.ml._schema.core.fields import StringTransformedEnum
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml._vendor.azure_resources.models._resource_management_client_enums import ResourceIdentityType
+from azure.ai.ml.entities._credentials import IdentityConfiguration, ManagedIdentityConfiguration
+
+
+class IdentitySchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(
+ allowed_values=[
+ ResourceIdentityType.SYSTEM_ASSIGNED,
+ ResourceIdentityType.USER_ASSIGNED,
+ ResourceIdentityType.NONE,
+ # ResourceIdentityType.SYSTEM_ASSIGNED_USER_ASSIGNED, # This is for post PuPr
+ ],
+ casing_transform=camel_to_snake,
+ metadata={"description": "resource identity type."},
+ )
+ principal_id = fields.Str()
+ tenant_id = fields.Str()
+ user_assigned_identities = fields.List(fields.Dict(keys=fields.Str(), values=fields.Str()))
+
+ @validates("user_assigned_identities")
+ def validate_user_assigned_identities(self, data, **kwargs):
+ if len(data) > 1:
+ raise ValidationError(f"Only 1 user assigned identity is currently supported, {len(data)} found")
+
+ @post_load
+ def make(self, data, **kwargs):
+ user_assigned_identities_list = []
+ user_assigned_identities = data.pop("user_assigned_identities", None)
+ if user_assigned_identities:
+ for identity in user_assigned_identities:
+ user_assigned_identities_list.append(
+ ManagedIdentityConfiguration(
+ resource_id=identity.get("resource_id", None),
+ client_id=identity.get("client_id", None),
+ object_id=identity.get("object_id", None),
+ )
+ )
+ data["user_assigned_identities"] = user_assigned_identities_list
+ return IdentityConfiguration(**data)
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ if data.user_assigned_identities:
+ ids = []
+ for _id in data.user_assigned_identities:
+ item = {}
+ item["resource_id"] = _id.resource_id
+ item["principal_id"] = _id.principal_id
+ item["client_id"] = _id.client_id
+ ids.append(item)
+ data.user_assigned_identities = ids
+ return data
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/__init__.py
new file mode 100644
index 00000000..11687396
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/__init__.py
@@ -0,0 +1,28 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore
+
+from azure.ai.ml._schema.job.creation_context import CreationContextSchema
+
+from .base_job import BaseJobSchema
+from .command_job import CommandJobSchema
+from .import_job import ImportJobSchema
+from .parallel_job import ParallelJobSchema
+from .parameterized_command import ParameterizedCommandSchema
+from .parameterized_parallel import ParameterizedParallelSchema
+from .parameterized_spark import ParameterizedSparkSchema
+from .spark_job import SparkJobSchema
+
+__all__ = [
+ "BaseJobSchema",
+ "ParameterizedCommandSchema",
+ "ParameterizedParallelSchema",
+ "CommandJobSchema",
+ "ImportJobSchema",
+ "SparkJobSchema",
+ "ParallelJobSchema",
+ "CreationContextSchema",
+ "ParameterizedSparkSchema",
+]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/base_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/base_job.py
new file mode 100644
index 00000000..852d3921
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/base_job.py
@@ -0,0 +1,69 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+import logging
+
+from marshmallow import fields
+
+from azure.ai.ml._schema.core.fields import ArmStr, ComputeField, NestedField, UnionField
+from azure.ai.ml._schema.core.resource import ResourceSchema
+from azure.ai.ml._schema.job.identity import AMLTokenIdentitySchema, ManagedIdentitySchema, UserIdentitySchema
+from azure.ai.ml.constants._common import AzureMLResourceType
+
+from .creation_context import CreationContextSchema
+from .services import (
+ JobServiceSchema,
+ SshJobServiceSchema,
+ VsCodeJobServiceSchema,
+ TensorBoardJobServiceSchema,
+ JupyterLabJobServiceSchema,
+)
+
+module_logger = logging.getLogger(__name__)
+
+
+class BaseJobSchema(ResourceSchema):
+ creation_context = NestedField(CreationContextSchema, dump_only=True)
+ services = fields.Dict(
+ keys=fields.Str(),
+ values=UnionField(
+ [
+ NestedField(SshJobServiceSchema),
+ NestedField(TensorBoardJobServiceSchema),
+ NestedField(VsCodeJobServiceSchema),
+ NestedField(JupyterLabJobServiceSchema),
+ # JobServiceSchema should be the last in the list.
+ # To support types not set by users like Custom, Tracking, Studio.
+ NestedField(JobServiceSchema),
+ ],
+ is_strict=True,
+ ),
+ )
+ name = fields.Str()
+ id = ArmStr(azureml_type=AzureMLResourceType.JOB, dump_only=True, required=False)
+ display_name = fields.Str(required=False)
+ tags = fields.Dict(keys=fields.Str(), values=fields.Str(allow_none=True))
+ status = fields.Str(dump_only=True)
+ experiment_name = fields.Str()
+ properties = fields.Dict(keys=fields.Str(), values=fields.Str(allow_none=True))
+ description = fields.Str()
+ log_files = fields.Dict(
+ keys=fields.Str(),
+ values=fields.Str(),
+ dump_only=True,
+ metadata={
+ "description": (
+ "The list of log files associated with this run. This section is only populated "
+ "by the service and will be ignored if contained in a yaml sent to the service "
+ "(e.g. via `az ml job create` ...)"
+ )
+ },
+ )
+ compute = ComputeField(required=False)
+ identity = UnionField(
+ [
+ NestedField(ManagedIdentitySchema),
+ NestedField(AMLTokenIdentitySchema),
+ NestedField(UserIdentitySchema),
+ ]
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/command_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/command_job.py
new file mode 100644
index 00000000..9cce7de7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/command_job.py
@@ -0,0 +1,23 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields
+
+from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum
+from azure.ai.ml._schema.job.input_output_fields_provider import InputsField, OutputsField
+from azure.ai.ml.constants import JobType
+
+from .base_job import BaseJobSchema
+from .job_limits import CommandJobLimitsSchema
+from .parameterized_command import ParameterizedCommandSchema
+
+
+class CommandJobSchema(ParameterizedCommandSchema, BaseJobSchema):
+ type = StringTransformedEnum(allowed_values=JobType.COMMAND)
+ # do not promote it as CommandComponent has no field named 'limits'
+ limits = NestedField(CommandJobLimitsSchema)
+ parameters = fields.Dict(dump_only=True)
+ inputs = InputsField()
+ outputs = OutputsField()
+ parent_job_name = fields.Str()
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/creation_context.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/creation_context.py
new file mode 100644
index 00000000..79956e1c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/creation_context.py
@@ -0,0 +1,16 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields
+
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+
+
+class CreationContextSchema(metaclass=PatchedSchemaMeta):
+ created_at = fields.DateTime()
+ created_by = fields.Str()
+ created_by_type = fields.Str()
+ last_modified_at = fields.DateTime()
+ last_modified_by = fields.Str()
+ last_modified_by_type = fields.Str()
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/data_transfer_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/data_transfer_job.py
new file mode 100644
index 00000000..6ea54df6
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/data_transfer_job.py
@@ -0,0 +1,60 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import validates, ValidationError, fields
+from azure.ai.ml._schema.core.fields import NestedField
+from azure.ai.ml._schema.job.input_output_fields_provider import InputsField, OutputsField
+from azure.ai.ml._schema.job.input_output_entry import DatabaseSchema, FileSystemSchema, OutputSchema
+from azure.ai.ml.constants import JobType
+from azure.ai.ml.constants._component import DataTransferTaskType, DataCopyMode
+
+from ..core.fields import ComputeField, StringTransformedEnum, UnionField
+from .base_job import BaseJobSchema
+
+
+class DataTransferCopyJobSchema(BaseJobSchema):
+ type = StringTransformedEnum(required=True, allowed_values=JobType.DATA_TRANSFER)
+ task = StringTransformedEnum(allowed_values=[DataTransferTaskType.COPY_DATA], required=True)
+ data_copy_mode = StringTransformedEnum(
+ allowed_values=[DataCopyMode.MERGE_WITH_OVERWRITE, DataCopyMode.FAIL_IF_CONFLICT]
+ )
+ compute = ComputeField()
+ inputs = InputsField()
+ outputs = OutputsField()
+
+
+class DataTransferImportJobSchema(BaseJobSchema):
+ type = StringTransformedEnum(required=True, allowed_values=JobType.DATA_TRANSFER)
+ task = StringTransformedEnum(allowed_values=[DataTransferTaskType.IMPORT_DATA], required=True)
+ compute = ComputeField()
+ outputs = fields.Dict(
+ keys=fields.Str(),
+ values=NestedField(nested=OutputSchema, allow_none=False),
+ metadata={"description": "Outputs of a data transfer job."},
+ )
+ source = UnionField([NestedField(DatabaseSchema), NestedField(FileSystemSchema)], required=True, allow_none=False)
+
+ @validates("outputs")
+ def outputs_key(self, value):
+ if len(value) != 1 or list(value.keys())[0] != "sink":
+ raise ValidationError(
+ f"outputs field only support one output called sink in task type "
+ f"{DataTransferTaskType.IMPORT_DATA}."
+ )
+
+
+class DataTransferExportJobSchema(BaseJobSchema):
+ type = StringTransformedEnum(required=True, allowed_values=JobType.DATA_TRANSFER)
+ task = StringTransformedEnum(allowed_values=[DataTransferTaskType.EXPORT_DATA], required=True)
+ compute = ComputeField()
+ inputs = InputsField(allow_none=False)
+ sink = UnionField([NestedField(DatabaseSchema), NestedField(FileSystemSchema)], required=True, allow_none=False)
+
+ @validates("inputs")
+ def inputs_key(self, value):
+ if len(value) != 1 or list(value.keys())[0] != "source":
+ raise ValidationError(
+ f"inputs field only support one input called source in task type "
+ f"{DataTransferTaskType.EXPORT_DATA}."
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/distribution.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/distribution.py
new file mode 100644
index 00000000..475792a3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/distribution.py
@@ -0,0 +1,104 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+
+from marshmallow import ValidationError, fields, post_load, pre_dump
+
+from azure.ai.ml._schema.core.fields import StringTransformedEnum
+from azure.ai.ml.constants import DistributionType
+from azure.ai.ml._utils._experimental import experimental
+
+from ..core.schema import PatchedSchemaMeta
+
+module_logger = logging.getLogger(__name__)
+
+
+class MPIDistributionSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(required=True, allowed_values=DistributionType.MPI)
+ process_count_per_instance = fields.Int()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml import MpiDistribution
+
+ data.pop("type", None)
+ return MpiDistribution(**data)
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ from azure.ai.ml import MpiDistribution
+
+ if not isinstance(data, MpiDistribution):
+ raise ValidationError("Cannot dump non-MpiDistribution object into MpiDistributionSchema")
+ return data
+
+
+class TensorFlowDistributionSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(required=True, allowed_values=DistributionType.TENSORFLOW)
+ parameter_server_count = fields.Int()
+ worker_count = fields.Int()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml import TensorFlowDistribution
+
+ data.pop("type", None)
+ return TensorFlowDistribution(**data)
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ from azure.ai.ml import TensorFlowDistribution
+
+ if not isinstance(data, TensorFlowDistribution):
+ raise ValidationError("Cannot dump non-TensorFlowDistribution object into TensorFlowDistributionSchema")
+ return data
+
+
+class PyTorchDistributionSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(required=True, allowed_values=DistributionType.PYTORCH)
+ process_count_per_instance = fields.Int()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml import PyTorchDistribution
+
+ data.pop("type", None)
+ return PyTorchDistribution(**data)
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ from azure.ai.ml import PyTorchDistribution
+
+ if not isinstance(data, PyTorchDistribution):
+ raise ValidationError("Cannot dump non-PyTorchDistribution object into PyTorchDistributionSchema")
+ return data
+
+
+@experimental
+class RayDistributionSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(required=True, allowed_values=DistributionType.RAY)
+ port = fields.Int()
+ address = fields.Str()
+ include_dashboard = fields.Bool()
+ dashboard_port = fields.Int()
+ head_node_additional_args = fields.Str()
+ worker_node_additional_args = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml import RayDistribution
+
+ data.pop("type", None)
+ return RayDistribution(**data)
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ from azure.ai.ml import RayDistribution
+
+ if not isinstance(data, RayDistribution):
+ raise ValidationError("Cannot dump non-RayDistribution object into RayDistributionSchema")
+ return data
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/identity.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/identity.py
new file mode 100644
index 00000000..2f2be676
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/identity.py
@@ -0,0 +1,67 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ ConnectionAuthType,
+ IdentityConfigurationType,
+)
+from azure.ai.ml._schema.core.fields import StringTransformedEnum
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.entities._credentials import (
+ AmlTokenConfiguration,
+ ManagedIdentityConfiguration,
+ UserIdentityConfiguration,
+)
+
+from ..core.schema import PatchedSchemaMeta
+
+module_logger = logging.getLogger(__name__)
+
+
+class ManagedIdentitySchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(
+ required=True,
+ allowed_values=[IdentityConfigurationType.MANAGED, ConnectionAuthType.MANAGED_IDENTITY],
+ casing_transform=camel_to_snake,
+ )
+ client_id = fields.Str()
+ object_id = fields.Str()
+ msi_resource_id = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs):
+ data.pop("type")
+ return ManagedIdentityConfiguration(**data)
+
+
+class AMLTokenIdentitySchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(
+ required=True,
+ allowed_values=IdentityConfigurationType.AML_TOKEN,
+ casing_transform=camel_to_snake,
+ )
+
+ @post_load
+ def make(self, data, **kwargs):
+ data.pop("type")
+ return AmlTokenConfiguration(**data)
+
+
+class UserIdentitySchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(
+ required=True,
+ allowed_values=IdentityConfigurationType.USER_IDENTITY,
+ casing_transform=camel_to_snake,
+ )
+
+ @post_load
+ def make(self, data, **kwargs):
+ data.pop("type")
+ return UserIdentityConfiguration(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/import_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/import_job.py
new file mode 100644
index 00000000..8f7c3908
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/import_job.py
@@ -0,0 +1,54 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum, UnionField
+from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
+from azure.ai.ml._schema.job.input_output_entry import OutputSchema
+from azure.ai.ml.constants import ImportSourceType, JobType
+
+from .base_job import BaseJobSchema
+
+
+class DatabaseImportSourceSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(
+ allowed_values=[
+ ImportSourceType.AZURESQLDB,
+ ImportSourceType.AZURESYNAPSEANALYTICS,
+ ImportSourceType.SNOWFLAKE,
+ ],
+ required=True,
+ )
+ connection = fields.Str(required=True)
+ query = fields.Str(required=True)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._job.import_job import DatabaseImportSource
+
+ return DatabaseImportSource(**data)
+
+
+class FileImportSourceSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(allowed_values=[ImportSourceType.S3], required=True)
+ connection = fields.Str(required=True)
+ path = fields.Str(required=True)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._job.import_job import FileImportSource
+
+ return FileImportSource(**data)
+
+
+class ImportJobSchema(BaseJobSchema):
+ class Meta:
+ exclude = ["compute"] # compute property not applicable to import job
+
+ type = StringTransformedEnum(allowed_values=JobType.IMPORT)
+ source = UnionField([NestedField(DatabaseImportSourceSchema), NestedField(FileImportSourceSchema)], required=True)
+ output = NestedField(OutputSchema, required=True)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/input_output_entry.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/input_output_entry.py
new file mode 100644
index 00000000..1300ab07
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/input_output_entry.py
@@ -0,0 +1,256 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+
+from marshmallow import ValidationError, fields, post_load, pre_dump
+
+from azure.ai.ml._schema.core.fields import (
+ ArmVersionedStr,
+ StringTransformedEnum,
+ UnionField,
+ LocalPathField,
+ NestedField,
+ VersionField,
+)
+
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta, PathAwareSchema
+from azure.ai.ml.constants._common import (
+ AssetTypes,
+ AzureMLResourceType,
+ InputOutputModes,
+)
+from azure.ai.ml.constants._component import ExternalDataType
+
+module_logger = logging.getLogger(__name__)
+
+
+class InputSchema(metaclass=PatchedSchemaMeta):
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._inputs_outputs import Input
+
+ return Input(**data)
+
+ @pre_dump
+ def check_dict(self, data, **kwargs):
+ from azure.ai.ml.entities._inputs_outputs import Input
+
+ if isinstance(data, Input):
+ return data
+ raise ValidationError("InputSchema needs type Input to dump")
+
+
+def generate_path_property(azureml_type):
+ return UnionField(
+ [
+ ArmVersionedStr(azureml_type=azureml_type),
+ fields.Str(metadata={"pattern": r"^(http(s)?):.*"}),
+ fields.Str(metadata={"pattern": r"^(wasb(s)?):.*"}),
+ LocalPathField(pattern=r"^file:.*"),
+ LocalPathField(
+ pattern=r"^(?!(azureml|http(s)?|wasb(s)?|file):).*",
+ ),
+ ],
+ is_strict=True,
+ )
+
+
+def generate_path_on_compute_property(azureml_type):
+ return UnionField(
+ [
+ LocalPathField(pattern=r"^file:.*"),
+ ],
+ is_strict=True,
+ )
+
+
+def generate_datastore_property():
+ metadata = {
+ "description": "Name of the datastore to upload local paths to.",
+ "arm_type": AzureMLResourceType.DATASTORE,
+ }
+ return fields.Str(metadata=metadata, required=False)
+
+
+class ModelInputSchema(InputSchema):
+ mode = StringTransformedEnum(
+ allowed_values=[
+ InputOutputModes.DOWNLOAD,
+ InputOutputModes.RO_MOUNT,
+ InputOutputModes.DIRECT,
+ ],
+ required=False,
+ )
+ type = StringTransformedEnum(
+ allowed_values=[
+ AssetTypes.CUSTOM_MODEL,
+ AssetTypes.MLFLOW_MODEL,
+ AssetTypes.TRITON_MODEL,
+ ]
+ )
+ path = generate_path_property(azureml_type=AzureMLResourceType.MODEL)
+ datastore = generate_datastore_property()
+
+
+class DataInputSchema(InputSchema):
+ mode = StringTransformedEnum(
+ allowed_values=[
+ InputOutputModes.DOWNLOAD,
+ InputOutputModes.RO_MOUNT,
+ InputOutputModes.DIRECT,
+ ],
+ required=False,
+ )
+ type = StringTransformedEnum(
+ allowed_values=[
+ AssetTypes.URI_FILE,
+ AssetTypes.URI_FOLDER,
+ ]
+ )
+ path = generate_path_property(azureml_type=AzureMLResourceType.DATA)
+ path_on_compute = generate_path_on_compute_property(azureml_type=AzureMLResourceType.DATA)
+ datastore = generate_datastore_property()
+
+
+class MLTableInputSchema(InputSchema):
+ mode = StringTransformedEnum(
+ allowed_values=[
+ InputOutputModes.DOWNLOAD,
+ InputOutputModes.RO_MOUNT,
+ InputOutputModes.EVAL_MOUNT,
+ InputOutputModes.EVAL_DOWNLOAD,
+ InputOutputModes.DIRECT,
+ ],
+ required=False,
+ )
+ type = StringTransformedEnum(allowed_values=[AssetTypes.MLTABLE])
+ path = generate_path_property(azureml_type=AzureMLResourceType.DATA)
+ path_on_compute = generate_path_on_compute_property(azureml_type=AzureMLResourceType.DATA)
+ datastore = generate_datastore_property()
+
+
+class InputLiteralValueSchema(metaclass=PatchedSchemaMeta):
+ value = UnionField([fields.Str(), fields.Bool(), fields.Int(), fields.Float()])
+
+ @post_load
+ def make(self, data, **kwargs):
+ return data["value"]
+
+ @pre_dump
+ def check_dict(self, data, **kwargs):
+ if hasattr(data, "value"):
+ return data
+ raise ValidationError("InputLiteralValue must have a field value")
+
+
+class OutputSchema(PathAwareSchema):
+ name = fields.Str()
+ version = VersionField()
+ mode = StringTransformedEnum(
+ allowed_values=[
+ InputOutputModes.MOUNT,
+ InputOutputModes.UPLOAD,
+ InputOutputModes.RW_MOUNT,
+ InputOutputModes.DIRECT,
+ ],
+ required=False,
+ )
+ type = StringTransformedEnum(
+ allowed_values=[
+ AssetTypes.URI_FILE,
+ AssetTypes.URI_FOLDER,
+ AssetTypes.CUSTOM_MODEL,
+ AssetTypes.MLFLOW_MODEL,
+ AssetTypes.MLTABLE,
+ AssetTypes.TRITON_MODEL,
+ ]
+ )
+ path = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._inputs_outputs import Output
+
+ return Output(**data)
+
+ @pre_dump
+ def check_dict(self, data, **kwargs):
+ from azure.ai.ml.entities._inputs_outputs import Output
+
+ if isinstance(data, Output):
+ return data
+ # Assists with union schema
+ raise ValidationError("OutputSchema needs type Output to dump")
+
+
+class StoredProcedureParamsSchema(metaclass=PatchedSchemaMeta):
+ name = fields.Str()
+ value = fields.Str()
+ type = fields.Str()
+
+ @pre_dump
+ def check_dict(self, data, **kwargs):
+ for key in self.dump_fields.keys(): # pylint: disable=no-member
+ if data.get(key, None) is None:
+ msg = "StoredProcedureParams must have a {!r} value."
+ raise ValidationError(msg.format(key))
+ return data
+
+
+class DatabaseSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(allowed_values=[ExternalDataType.DATABASE], required=True)
+ table_name = fields.Str()
+ query = fields.Str(
+ metadata={"description": "The sql query command."},
+ )
+ stored_procedure = fields.Str()
+ stored_procedure_params = fields.List(NestedField(StoredProcedureParamsSchema))
+
+ connection = fields.Str(required=True)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.data_transfer import Database
+
+ data.pop("type", None)
+ return Database(**data)
+
+ @pre_dump
+ def check_dict(self, data, **kwargs):
+ from azure.ai.ml.data_transfer import Database
+
+ if isinstance(data, Database):
+ return data
+ msg = "DatabaseSchema needs type Database to dump, but got {!r}."
+ raise ValidationError(msg.format(type(data)))
+
+
+class FileSystemSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(
+ allowed_values=[
+ ExternalDataType.FILE_SYSTEM,
+ ],
+ )
+ path = fields.Str()
+
+ connection = fields.Str(required=True)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.data_transfer import FileSystem
+
+ data.pop("type", None)
+ return FileSystem(**data)
+
+ @pre_dump
+ def check_dict(self, data, **kwargs):
+ from azure.ai.ml.data_transfer import FileSystem
+
+ if isinstance(data, FileSystem):
+ return data
+ msg = "FileSystemSchema needs type FileSystem to dump, but got {!r}."
+ raise ValidationError(msg.format(type(data)))
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/input_output_fields_provider.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/input_output_fields_provider.py
new file mode 100644
index 00000000..7fb2e8e0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/input_output_fields_provider.py
@@ -0,0 +1,50 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields
+
+from azure.ai.ml._schema._utils.data_binding_expression import support_data_binding_expression_for_fields
+from azure.ai.ml._schema.core.fields import NestedField, PrimitiveValueField, UnionField
+from azure.ai.ml._schema.job.input_output_entry import (
+ DataInputSchema,
+ InputLiteralValueSchema,
+ MLTableInputSchema,
+ ModelInputSchema,
+ OutputSchema,
+)
+
+
+def InputsField(*, support_databinding: bool = False, **kwargs):
+ value_fields = [
+ NestedField(DataInputSchema),
+ NestedField(ModelInputSchema),
+ NestedField(MLTableInputSchema),
+ NestedField(InputLiteralValueSchema),
+ PrimitiveValueField(is_strict=False),
+ # This ordering of types for the values keyword is intentional. The ordering of types
+ # determines what order schema values are matched and cast in. Changing the current ordering can
+ # result in values being mis-cast such as 1.0 translating into True.
+ ]
+
+ # As is_strict is set to True, 1 and only 1 value field must be matched.
+ # root level data-binding expression has already been covered by PrimitiveValueField;
+ # If support_databinding is True, we should only add data-binding expression support for nested fields.
+ if support_databinding:
+ for field_obj in value_fields:
+ if isinstance(field_obj, NestedField):
+ support_data_binding_expression_for_fields(field_obj.schema)
+
+ return fields.Dict(
+ keys=fields.Str(),
+ values=UnionField(value_fields, metadata={"description": "Inputs to a job."}, is_strict=True, **kwargs),
+ )
+
+
+def OutputsField(**kwargs):
+ return fields.Dict(
+ keys=fields.Str(),
+ values=NestedField(nested=OutputSchema, allow_none=True),
+ metadata={"description": "Outputs of a job."},
+ **kwargs
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/input_port.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/input_port.py
new file mode 100644
index 00000000..f37b2a16
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/input_port.py
@@ -0,0 +1,29 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+
+from marshmallow import fields, post_load, validate
+
+from azure.ai.ml.entities import InputPort
+
+from ..core.schema import PatchedSchemaMeta
+
+module_logger = logging.getLogger(__name__)
+
+
+class InputPortSchema(metaclass=PatchedSchemaMeta):
+ type_string = fields.Str(
+ data_key="type",
+ validate=validate.OneOf(["path", "number", "null"]),
+ dump_default="null",
+ )
+ default = fields.Str()
+ optional = fields.Bool()
+
+ @post_load
+ def make(self, data, **kwargs):
+ return InputPort(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/job_limits.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/job_limits.py
new file mode 100644
index 00000000..850e9b3d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/job_limits.py
@@ -0,0 +1,45 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load, validate
+
+from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
+
+
+class CommandJobLimitsSchema(metaclass=PatchedSchemaMeta):
+ timeout = fields.Int()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import CommandJobLimits
+
+ return CommandJobLimits(**data)
+
+
+class SweepJobLimitsSchema(metaclass=PatchedSchemaMeta):
+ max_concurrent_trials = fields.Int(metadata={"description": "Sweep Job max concurrent trials."})
+ max_total_trials = fields.Int(
+ metadata={"description": "Sweep Job max total trials."},
+ required=True,
+ )
+ timeout = fields.Int(
+ metadata={"description": "The max run duration in Seconds, after which the job will be cancelled."}
+ )
+ trial_timeout = fields.Int(metadata={"description": "Sweep Job Trial timeout value."})
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.sweep import SweepJobLimits
+
+ return SweepJobLimits(**data)
+
+
+class DoWhileLimitsSchema(metaclass=PatchedSchemaMeta):
+ max_iteration_count = fields.Int(
+ metadata={"description": "The max iteration for do_while loop."},
+ validate=validate.Range(min=1, max=1000),
+ required=True,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/job_output.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/job_output.py
new file mode 100644
index 00000000..80679119
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/job_output.py
@@ -0,0 +1,18 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import logging
+
+from marshmallow import fields
+
+from azure.ai.ml._schema.core.fields import ArmStr
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml.constants._common import AzureMLResourceType
+
+module_logger = logging.getLogger(__name__)
+
+
+class JobOutputSchema(metaclass=PatchedSchemaMeta):
+ datastore_id = ArmStr(azureml_type=AzureMLResourceType.DATASTORE)
+ path = fields.Str()
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/parallel_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/parallel_job.py
new file mode 100644
index 00000000..c539e407
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/parallel_job.py
@@ -0,0 +1,15 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from azure.ai.ml._schema.core.fields import StringTransformedEnum
+from azure.ai.ml._schema.job.input_output_fields_provider import InputsField, OutputsField
+from azure.ai.ml.constants import JobType
+
+from .base_job import BaseJobSchema
+from .parameterized_parallel import ParameterizedParallelSchema
+
+
+class ParallelJobSchema(ParameterizedParallelSchema, BaseJobSchema):
+ type = StringTransformedEnum(allowed_values=JobType.PARALLEL)
+ inputs = InputsField()
+ outputs = OutputsField()
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/parameterized_command.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/parameterized_command.py
new file mode 100644
index 00000000..1c011bc9
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/parameterized_command.py
@@ -0,0 +1,41 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields
+
+from azure.ai.ml._schema.core.fields import (
+ CodeField,
+ DistributionField,
+ EnvironmentField,
+ ExperimentalField,
+ NestedField,
+)
+from azure.ai.ml._schema.core.schema import PathAwareSchema
+from azure.ai.ml._schema.job.input_output_entry import InputLiteralValueSchema
+from azure.ai.ml._schema.job_resource_configuration import JobResourceConfigurationSchema
+from azure.ai.ml._schema.queue_settings import QueueSettingsSchema
+
+from ..core.fields import UnionField
+
+
+class ParameterizedCommandSchema(PathAwareSchema):
+ command = fields.Str(
+ metadata={
+ # pylint: disable=line-too-long
+ "description": "The command run and the parameters passed. This string may contain place holders of inputs in {}. "
+ },
+ required=True,
+ )
+ code = CodeField()
+ environment = EnvironmentField(required=True)
+ environment_variables = UnionField(
+ [
+ fields.Dict(keys=fields.Str(), values=fields.Str()),
+ # Used for binding environment variables
+ NestedField(InputLiteralValueSchema),
+ ]
+ )
+ resources = NestedField(JobResourceConfigurationSchema)
+ distribution = DistributionField()
+ queue_settings = ExperimentalField(NestedField(QueueSettingsSchema))
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/parameterized_parallel.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/parameterized_parallel.py
new file mode 100644
index 00000000..bb5cd063
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/parameterized_parallel.py
@@ -0,0 +1,72 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import INCLUDE, fields
+
+from azure.ai.ml._schema.component.parallel_task import ComponentParallelTaskSchema
+from azure.ai.ml._schema.component.retry_settings import RetrySettingsSchema
+from azure.ai.ml._schema.core.fields import DumpableEnumField, NestedField
+from azure.ai.ml._schema.core.schema import PathAwareSchema
+from azure.ai.ml._schema.job.input_output_entry import InputLiteralValueSchema
+from azure.ai.ml._schema.job_resource_configuration import JobResourceConfigurationSchema
+from azure.ai.ml.constants._common import LoggingLevel
+
+from ..core.fields import UnionField
+
+
+class ParameterizedParallelSchema(PathAwareSchema):
+ logging_level = DumpableEnumField(
+ allowed_values=[LoggingLevel.DEBUG, LoggingLevel.INFO, LoggingLevel.WARN],
+ dump_default=LoggingLevel.INFO,
+ metadata={
+ "description": (
+ "A string of the logging level name, which is defined in 'logging'. "
+ "Possible values are 'WARNING', 'INFO', and 'DEBUG'."
+ )
+ },
+ )
+ task = NestedField(ComponentParallelTaskSchema, unknown=INCLUDE)
+ mini_batch_size = fields.Str(
+ metadata={"description": "The batch size of current job."},
+ )
+ partition_keys = fields.List(
+ fields.Str(), metadata={"description": "The keys used to partition input data into mini-batches"}
+ )
+ input_data = fields.Str()
+ resources = NestedField(JobResourceConfigurationSchema)
+ retry_settings = NestedField(RetrySettingsSchema, unknown=INCLUDE)
+ max_concurrency_per_instance = fields.Integer(
+ dump_default=1,
+ metadata={"description": "The max parallellism that each compute instance has."},
+ )
+ error_threshold = fields.Integer(
+ dump_default=-1,
+ metadata={
+ "description": (
+ "The number of item processing failures should be ignored. "
+ "If the error_threshold is reached, the job terminates. "
+ "For a list of files as inputs, one item means one file reference. "
+ "This setting doesn't apply to command parallelization."
+ )
+ },
+ )
+ mini_batch_error_threshold = fields.Integer(
+ dump_default=-1,
+ metadata={
+ "description": (
+ "The number of mini batch processing failures should be ignored. "
+ "If the mini_batch_error_threshold is reached, the job terminates. "
+ "For a list of files as inputs, one item means one file reference. "
+ "This setting can be used by either command or python function parallelization. "
+ "Only one error_threshold setting can be used in one job."
+ )
+ },
+ )
+ environment_variables = UnionField(
+ [
+ fields.Dict(keys=fields.Str(), values=fields.Str()),
+ # Used for binding environment variables
+ NestedField(InputLiteralValueSchema),
+ ]
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/parameterized_spark.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/parameterized_spark.py
new file mode 100644
index 00000000..49e9560a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/parameterized_spark.py
@@ -0,0 +1,151 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+# pylint: disable=unused-argument
+
+import re
+from typing import Any, Dict, List
+
+from marshmallow import ValidationError, fields, post_dump, post_load, pre_dump, pre_load, validates
+
+from azure.ai.ml._schema.core.fields import CodeField, EnvironmentField, NestedField
+from azure.ai.ml._schema.core.schema import PathAwareSchema
+from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
+
+from ..core.fields import UnionField
+
+re_memory_pattern = re.compile("^\\d+[kKmMgGtTpP]$")
+
+
+class SparkEntryFileSchema(metaclass=PatchedSchemaMeta):
+ file = fields.Str(required=True)
+ # add spark_job_entry_type and make it dump only to align with model definition,
+ # this will make us get expected value when call spark._from_rest_object()
+ spark_job_entry_type = fields.Str(dump_only=True)
+
+ @pre_dump
+ def to_dict(self, data, **kwargs):
+ return {"file": data.entry}
+
+
+class SparkEntryClassSchema(metaclass=PatchedSchemaMeta):
+ class_name = fields.Str(required=True)
+ # add spark_job_entry_type and make it dump only to align with model definition,
+ # this will make us get expected value when call spark._from_rest_object()
+ spark_job_entry_type = fields.Str(dump_only=True)
+
+ @pre_dump
+ def to_dict(self, data, **kwargs):
+ return {"class_name": data.entry}
+
+
+CONF_KEY_MAP = {
+ "driver_cores": "spark.driver.cores",
+ "driver_memory": "spark.driver.memory",
+ "executor_cores": "spark.executor.cores",
+ "executor_memory": "spark.executor.memory",
+ "executor_instances": "spark.executor.instances",
+ "dynamic_allocation_enabled": "spark.dynamicAllocation.enabled",
+ "dynamic_allocation_min_executors": "spark.dynamicAllocation.minExecutors",
+ "dynamic_allocation_max_executors": "spark.dynamicAllocation.maxExecutors",
+}
+
+
+def no_duplicates(name: str, value: List):
+ if len(value) != len(set(value)):
+ raise ValidationError(f"{name} must not contain duplicate entries.")
+
+
+class ParameterizedSparkSchema(PathAwareSchema):
+ code = CodeField()
+ entry = UnionField(
+ [NestedField(SparkEntryFileSchema), NestedField(SparkEntryClassSchema)],
+ required=True,
+ metadata={"description": "Entry."},
+ )
+ py_files = fields.List(fields.Str(required=True))
+ jars = fields.List(fields.Str(required=True))
+ files = fields.List(fields.Str(required=True))
+ archives = fields.List(fields.Str(required=True))
+ conf = fields.Dict(keys=fields.Str(), values=fields.Raw())
+ properties = fields.Dict(keys=fields.Str(), values=fields.Raw())
+ environment = EnvironmentField(allow_none=True)
+ args = fields.Str(metadata={"description": "Command Line arguments."})
+
+ @validates("py_files")
+ def no_duplicate_py_files(self, value):
+ no_duplicates("py_files", value)
+
+ @validates("jars")
+ def no_duplicate_jars(self, value):
+ no_duplicates("jars", value)
+
+ @validates("files")
+ def no_duplicate_files(self, value):
+ no_duplicates("files", value)
+
+ @validates("archives")
+ def no_duplicate_archives(self, value):
+ no_duplicates("archives", value)
+
+ @pre_load
+ # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype
+ def map_conf_field_names(self, data, **kwargs):
+ """Map the field names in the conf dictionary.
+ This function must be called after YamlFileSchema.load_from_file.
+ Given marshmallow executes the pre_load functions in the alphabetical order (marshmallow\\schema.py:L166,
+ functions will be checked in alphabetical order when inject to cls._hooks), we must make sure the function
+ name is alphabetically after "load_from_file".
+ """
+ # TODO: it's dangerous to depend on an alphabetical order, we'd better move related logic out of Schema.
+ conf = data["conf"] if "conf" in data else None
+ if conf is not None:
+ for field_key, dict_key in CONF_KEY_MAP.items():
+ value = conf.get(dict_key, None)
+ if dict_key in conf and value is not None:
+ del conf[dict_key]
+ conf[field_key] = value
+ data["conf"] = conf
+ return data
+
+ @post_dump(pass_original=True)
+ def serialize_field_names(self, data: Dict[str, Any], original_data: Dict[str, Any], **kwargs):
+ conf = data["conf"] if "conf" in data else {}
+ if original_data.conf is not None and conf is not None:
+ for field_name, value in original_data.conf.items():
+ if field_name not in conf:
+ if isinstance(value, str) and value.isdigit():
+ value = int(value)
+ conf[field_name] = value
+ if conf is not None:
+ for field_name, dict_name in CONF_KEY_MAP.items():
+ val = conf.get(field_name, None)
+ if field_name in conf and val is not None:
+ if isinstance(val, str) and val.isdigit():
+ val = int(val)
+ del conf[field_name]
+ conf[dict_name] = val
+ data["conf"] = conf
+ return data
+
+ @post_load
+ def demote_conf_fields(self, data, **kwargs):
+ conf = data["conf"] if "conf" in data else None
+ if conf is not None:
+ for field_name, _ in CONF_KEY_MAP.items():
+ value = conf.get(field_name, None)
+ if field_name in conf and value is not None:
+ del conf[field_name]
+ data[field_name] = value
+ return data
+
+ @pre_dump
+ def promote_conf_fields(self, data: object, **kwargs):
+ # copy fields from root object into the 'conf'
+ conf = data.conf or {}
+ for field_name, _ in CONF_KEY_MAP.items():
+ value = data.__getattribute__(field_name)
+ if value is not None:
+ conf[field_name] = value
+ data.__setattr__("conf", conf)
+ return data
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/services.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/services.py
new file mode 100644
index 00000000..f6fed8c2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/services.py
@@ -0,0 +1,100 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import logging
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml.entities._job.job_service import (
+ JobService,
+ SshJobService,
+ JupyterLabJobService,
+ VsCodeJobService,
+ TensorBoardJobService,
+)
+from azure.ai.ml.constants._job.job import JobServiceTypeNames
+from azure.ai.ml._schema.core.fields import StringTransformedEnum, UnionField
+
+from ..core.schema import PathAwareSchema
+
+module_logger = logging.getLogger(__name__)
+
+
+class JobServiceBaseSchema(PathAwareSchema):
+ port = fields.Int()
+ endpoint = fields.Str(dump_only=True)
+ status = fields.Str(dump_only=True)
+ nodes = fields.Str()
+ error_message = fields.Str(dump_only=True)
+ properties = fields.Dict()
+
+
+class JobServiceSchema(JobServiceBaseSchema):
+ """This is to support tansformation of job services passed as dict type and internal job services like Custom,
+ Tracking, Studio set by the system."""
+
+ type = UnionField(
+ [
+ StringTransformedEnum(
+ allowed_values=JobServiceTypeNames.NAMES_ALLOWED_FOR_PUBLIC,
+ pass_original=True,
+ ),
+ fields.Str(),
+ ]
+ )
+
+ @post_load
+ def make(self, data, **kwargs): # pylint: disable=unused-argument
+ data.pop("type", None)
+ return JobService(**data)
+
+
+class TensorBoardJobServiceSchema(JobServiceBaseSchema):
+ type = StringTransformedEnum(
+ allowed_values=JobServiceTypeNames.EntityNames.TENSOR_BOARD,
+ pass_original=True,
+ )
+ log_dir = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs): # pylint: disable=unused-argument
+ data.pop("type", None)
+ return TensorBoardJobService(**data)
+
+
+class SshJobServiceSchema(JobServiceBaseSchema):
+ type = StringTransformedEnum(
+ allowed_values=JobServiceTypeNames.EntityNames.SSH,
+ pass_original=True,
+ )
+ ssh_public_keys = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs): # pylint: disable=unused-argument
+ data.pop("type", None)
+ return SshJobService(**data)
+
+
+class VsCodeJobServiceSchema(JobServiceBaseSchema):
+ type = StringTransformedEnum(
+ allowed_values=JobServiceTypeNames.EntityNames.VS_CODE,
+ pass_original=True,
+ )
+
+ @post_load
+ def make(self, data, **kwargs): # pylint: disable=unused-argument
+ data.pop("type", None)
+ return VsCodeJobService(**data)
+
+
+class JupyterLabJobServiceSchema(JobServiceBaseSchema):
+ type = StringTransformedEnum(
+ allowed_values=JobServiceTypeNames.EntityNames.JUPYTER_LAB,
+ pass_original=True,
+ )
+
+ @post_load
+ def make(self, data, **kwargs): # pylint: disable=unused-argument
+ data.pop("type", None)
+ return JupyterLabJobService(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/spark_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/spark_job.py
new file mode 100644
index 00000000..f9363175
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job/spark_job.py
@@ -0,0 +1,28 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from azure.ai.ml._schema.core.fields import NestedField
+from azure.ai.ml._schema.job.identity import AMLTokenIdentitySchema, ManagedIdentitySchema, UserIdentitySchema
+from azure.ai.ml._schema.job.input_output_fields_provider import InputsField, OutputsField
+from azure.ai.ml._schema.spark_resource_configuration import SparkResourceConfigurationSchema
+from azure.ai.ml.constants import JobType
+
+from ..core.fields import ComputeField, StringTransformedEnum, UnionField
+from .base_job import BaseJobSchema
+from .parameterized_spark import ParameterizedSparkSchema
+
+
+class SparkJobSchema(ParameterizedSparkSchema, BaseJobSchema):
+ type = StringTransformedEnum(required=True, allowed_values=JobType.SPARK)
+ compute = ComputeField()
+ inputs = InputsField()
+ outputs = OutputsField()
+ resources = NestedField(SparkResourceConfigurationSchema)
+ identity = UnionField(
+ [
+ NestedField(ManagedIdentitySchema),
+ NestedField(AMLTokenIdentitySchema),
+ NestedField(UserIdentitySchema),
+ ]
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job_resource_configuration.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job_resource_configuration.py
new file mode 100644
index 00000000..859eef31
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job_resource_configuration.py
@@ -0,0 +1,38 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.fields import UnionField
+
+from .resource_configuration import ResourceConfigurationSchema
+
+
+class JobResourceConfigurationSchema(ResourceConfigurationSchema):
+ locations = fields.List(fields.Str())
+ shm_size = fields.Str(
+ metadata={
+ "description": (
+ "The size of the docker container's shared memory block. "
+ "This should be in the format of `<number><unit>` where number as "
+ "to be greater than 0 and the unit can be one of "
+ "`b` (bytes), `k` (kilobytes), `m` (megabytes), or `g` (gigabytes)."
+ )
+ }
+ )
+ max_instance_count = fields.Int(
+ metadata={"description": "The maximum number of instances to make available to this job."}
+ )
+ docker_args = UnionField(
+ [
+ fields.Str(metadata={"description": "arguments to pass to the Docker run command."}),
+ fields.List(fields.Str()),
+ ]
+ )
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import JobResourceConfiguration
+
+ return JobResourceConfiguration(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job_resources.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job_resources.py
new file mode 100644
index 00000000..49e6eaa0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/job_resources.py
@@ -0,0 +1,21 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
+
+
+class JobResourcesSchema(metaclass=PatchedSchemaMeta):
+ instance_types = fields.List(
+ fields.Str(), metadata={"description": "The instance type to make available to this job."}
+ )
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import JobResources
+
+ return JobResources(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/__init__.py
new file mode 100644
index 00000000..29a4fcd3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/alert_notification.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/alert_notification.py
new file mode 100644
index 00000000..bd7fd69c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/alert_notification.py
@@ -0,0 +1,19 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+
+
+class AlertNotificationSchema(metaclass=PatchedSchemaMeta):
+ emails = fields.List(fields.Str)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.alert_notification import AlertNotification
+
+ return AlertNotification(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/compute.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/compute.py
new file mode 100644
index 00000000..483b4ac5
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/compute.py
@@ -0,0 +1,23 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+
+
+class ComputeConfigurationSchema(metaclass=PatchedSchemaMeta):
+ compute_type = fields.Str(allowed_values=["ServerlessSpark"])
+
+
+class ServerlessSparkComputeSchema(ComputeConfigurationSchema):
+ runtime_version = fields.Str()
+ instance_type = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.compute import ServerlessSparkCompute
+
+ return ServerlessSparkCompute(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/input_data.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/input_data.py
new file mode 100644
index 00000000..d5a6a4f9
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/input_data.py
@@ -0,0 +1,52 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml.constants._monitoring import MonitorDatasetContext
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum, UnionField
+from azure.ai.ml._schema.job.input_output_entry import MLTableInputSchema, DataInputSchema
+
+
+class MonitorInputDataSchema(metaclass=PatchedSchemaMeta):
+ input_data = UnionField(union_fields=[NestedField(DataInputSchema), NestedField(MLTableInputSchema)])
+ data_context = StringTransformedEnum(allowed_values=[o.value for o in MonitorDatasetContext])
+ target_columns = fields.Dict()
+ job_type = fields.Str()
+ uri = fields.Str()
+
+
+class FixedInputDataSchema(MonitorInputDataSchema):
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.input_data import FixedInputData
+
+ return FixedInputData(**data)
+
+
+class TrailingInputDataSchema(MonitorInputDataSchema):
+ window_size = fields.Str()
+ window_offset = fields.Str()
+ pre_processing_component_id = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.input_data import TrailingInputData
+
+ return TrailingInputData(**data)
+
+
+class StaticInputDataSchema(MonitorInputDataSchema):
+ pre_processing_component_id = fields.Str()
+ window_start = fields.String()
+ window_end = fields.String()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.input_data import StaticInputData
+
+ return StaticInputData(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/monitor_definition.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/monitor_definition.py
new file mode 100644
index 00000000..3fe52c9d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/monitor_definition.py
@@ -0,0 +1,53 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml.constants._monitoring import AZMONITORING
+from azure.ai.ml._schema.monitoring.target import MonitoringTargetSchema
+from azure.ai.ml._schema.monitoring.compute import ServerlessSparkComputeSchema
+from azure.ai.ml._schema.monitoring.signals import (
+ DataDriftSignalSchema,
+ DataQualitySignalSchema,
+ PredictionDriftSignalSchema,
+ FeatureAttributionDriftSignalSchema,
+ CustomMonitoringSignalSchema,
+ GenerationSafetyQualitySchema,
+ ModelPerformanceSignalSchema,
+ GenerationTokenStatisticsSchema,
+)
+from azure.ai.ml._schema.monitoring.alert_notification import AlertNotificationSchema
+from azure.ai.ml._schema.core.fields import NestedField, UnionField, StringTransformedEnum
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+
+
+class MonitorDefinitionSchema(metaclass=PatchedSchemaMeta):
+ compute = NestedField(ServerlessSparkComputeSchema)
+ monitoring_target = NestedField(MonitoringTargetSchema)
+ monitoring_signals = fields.Dict(
+ keys=fields.Str(),
+ values=UnionField(
+ union_fields=[
+ NestedField(DataDriftSignalSchema),
+ NestedField(DataQualitySignalSchema),
+ NestedField(PredictionDriftSignalSchema),
+ NestedField(FeatureAttributionDriftSignalSchema),
+ NestedField(CustomMonitoringSignalSchema),
+ NestedField(GenerationSafetyQualitySchema),
+ NestedField(ModelPerformanceSignalSchema),
+ NestedField(GenerationTokenStatisticsSchema),
+ ]
+ ),
+ )
+ alert_notification = UnionField(
+ union_fields=[StringTransformedEnum(allowed_values=AZMONITORING), NestedField(AlertNotificationSchema)]
+ )
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.definition import MonitorDefinition
+
+ return MonitorDefinition(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/schedule.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/schedule.py
new file mode 100644
index 00000000..a2034d33
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/schedule.py
@@ -0,0 +1,11 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from azure.ai.ml._schema.core.fields import NestedField
+from azure.ai.ml._schema.monitoring.monitor_definition import MonitorDefinitionSchema
+from azure.ai.ml._schema.schedule.schedule import ScheduleSchema
+
+
+class MonitorScheduleSchema(ScheduleSchema):
+ create_monitor = NestedField(MonitorDefinitionSchema)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/signals.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/signals.py
new file mode 100644
index 00000000..4f55393b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/signals.py
@@ -0,0 +1,348 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load, pre_dump, ValidationError
+
+from azure.ai.ml._schema.job.input_output_entry import DataInputSchema, MLTableInputSchema
+from azure.ai.ml.constants._common import AzureMLResourceType
+from azure.ai.ml.constants._monitoring import (
+ MonitorSignalType,
+ ALL_FEATURES,
+ MonitorModelType,
+ MonitorDatasetContext,
+ FADColumnNames,
+)
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml._schema.core.fields import ArmVersionedStr, NestedField, UnionField, StringTransformedEnum
+from azure.ai.ml._schema.monitoring.thresholds import (
+ DataDriftMetricThresholdSchema,
+ DataQualityMetricThresholdSchema,
+ PredictionDriftMetricThresholdSchema,
+ FeatureAttributionDriftMetricThresholdSchema,
+ ModelPerformanceMetricThresholdSchema,
+ CustomMonitoringMetricThresholdSchema,
+ GenerationSafetyQualityMetricThresholdSchema,
+ GenerationTokenStatisticsMonitorMetricThresholdSchema,
+)
+
+
+class DataSegmentSchema(metaclass=PatchedSchemaMeta):
+ feature_name = fields.Str()
+ feature_values = fields.List(fields.Str)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.signals import DataSegment
+
+ return DataSegment(**data)
+
+
+class MonitorFeatureFilterSchema(metaclass=PatchedSchemaMeta):
+ top_n_feature_importance = fields.Int()
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.signals import MonitorFeatureFilter
+
+ if not isinstance(data, MonitorFeatureFilter):
+ raise ValidationError("Cannot dump non-MonitorFeatureFilter object into MonitorFeatureFilter")
+ return data
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.signals import MonitorFeatureFilter
+
+ return MonitorFeatureFilter(**data)
+
+
+class BaselineDataRangeSchema(metaclass=PatchedSchemaMeta):
+ window_start = fields.Str()
+ window_end = fields.Str()
+ lookback_window_size = fields.Str()
+ lookback_window_offset = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.signals import BaselineDataRange
+
+ return BaselineDataRange(**data)
+
+
+class ProductionDataSchema(metaclass=PatchedSchemaMeta):
+ input_data = UnionField(union_fields=[NestedField(DataInputSchema), NestedField(MLTableInputSchema)])
+ data_context = StringTransformedEnum(allowed_values=[o.value for o in MonitorDatasetContext])
+ pre_processing_component = fields.Str()
+ data_window = NestedField(BaselineDataRangeSchema)
+ data_column_names = fields.Dict(keys=fields.Str(), values=fields.Str())
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.signals import ProductionData
+
+ return ProductionData(**data)
+
+
+class ReferenceDataSchema(metaclass=PatchedSchemaMeta):
+ input_data = UnionField(union_fields=[NestedField(DataInputSchema), NestedField(MLTableInputSchema)])
+ data_context = StringTransformedEnum(allowed_values=[o.value for o in MonitorDatasetContext])
+ pre_processing_component = fields.Str()
+ target_column_name = fields.Str()
+ data_window = NestedField(BaselineDataRangeSchema)
+ data_column_names = fields.Dict(keys=fields.Str(), values=fields.Str())
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.signals import ReferenceData
+
+ return ReferenceData(**data)
+
+
+class MonitoringSignalSchema(metaclass=PatchedSchemaMeta):
+ production_data = NestedField(ProductionDataSchema)
+ reference_data = NestedField(ReferenceDataSchema)
+ properties = fields.Dict()
+ alert_enabled = fields.Bool()
+
+
+class DataSignalSchema(MonitoringSignalSchema):
+ features = UnionField(
+ union_fields=[
+ NestedField(MonitorFeatureFilterSchema),
+ StringTransformedEnum(allowed_values=ALL_FEATURES),
+ fields.List(fields.Str),
+ ]
+ )
+ feature_type_override = fields.Dict()
+
+
+class DataDriftSignalSchema(DataSignalSchema):
+ type = StringTransformedEnum(allowed_values=MonitorSignalType.DATA_DRIFT, required=True)
+ metric_thresholds = NestedField(DataDriftMetricThresholdSchema)
+ data_segment = NestedField(DataSegmentSchema)
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.signals import DataDriftSignal
+
+ if not isinstance(data, DataDriftSignal):
+ raise ValidationError("Cannot dump non-DataDriftSignal object into DataDriftSignal")
+ return data
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.signals import DataDriftSignal
+
+ data.pop("type", None)
+ return DataDriftSignal(**data)
+
+
+class DataQualitySignalSchema(DataSignalSchema):
+ type = StringTransformedEnum(allowed_values=MonitorSignalType.DATA_QUALITY, required=True)
+ metric_thresholds = NestedField(DataQualityMetricThresholdSchema)
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.signals import DataQualitySignal
+
+ if not isinstance(data, DataQualitySignal):
+ raise ValidationError("Cannot dump non-DataQualitySignal object into DataQualitySignal")
+ return data
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.signals import DataQualitySignal
+
+ data.pop("type", None)
+ return DataQualitySignal(**data)
+
+
+class PredictionDriftSignalSchema(MonitoringSignalSchema):
+ type = StringTransformedEnum(allowed_values=MonitorSignalType.PREDICTION_DRIFT, required=True)
+ metric_thresholds = NestedField(PredictionDriftMetricThresholdSchema)
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.signals import PredictionDriftSignal
+
+ if not isinstance(data, PredictionDriftSignal):
+ raise ValidationError("Cannot dump non-PredictionDriftSignal object into PredictionDriftSignal")
+ return data
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.signals import PredictionDriftSignal
+
+ data.pop("type", None)
+ return PredictionDriftSignal(**data)
+
+
+class ModelSignalSchema(MonitoringSignalSchema):
+ model_type = StringTransformedEnum(allowed_values=[MonitorModelType.CLASSIFICATION, MonitorModelType.REGRESSION])
+
+
+class FADProductionDataSchema(metaclass=PatchedSchemaMeta):
+ input_data = UnionField(union_fields=[NestedField(DataInputSchema), NestedField(MLTableInputSchema)])
+ data_context = StringTransformedEnum(allowed_values=[o.value for o in MonitorDatasetContext])
+ data_column_names = fields.Dict(
+ keys=StringTransformedEnum(allowed_values=[o.value for o in FADColumnNames]), values=fields.Str()
+ )
+ pre_processing_component = fields.Str()
+ data_window = NestedField(BaselineDataRangeSchema)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.signals import FADProductionData
+
+ return FADProductionData(**data)
+
+
+class FeatureAttributionDriftSignalSchema(metaclass=PatchedSchemaMeta):
+ production_data = fields.List(NestedField(FADProductionDataSchema))
+ reference_data = NestedField(ReferenceDataSchema)
+ alert_enabled = fields.Bool()
+ type = StringTransformedEnum(allowed_values=MonitorSignalType.FEATURE_ATTRIBUTION_DRIFT, required=True)
+ metric_thresholds = NestedField(FeatureAttributionDriftMetricThresholdSchema)
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.signals import FeatureAttributionDriftSignal
+
+ if not isinstance(data, FeatureAttributionDriftSignal):
+ raise ValidationError(
+ "Cannot dump non-FeatureAttributionDriftSignal object into FeatureAttributionDriftSignal"
+ )
+ return data
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.signals import FeatureAttributionDriftSignal
+
+ data.pop("type", None)
+ return FeatureAttributionDriftSignal(**data)
+
+
+class ModelPerformanceSignalSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(allowed_values=MonitorSignalType.MODEL_PERFORMANCE, required=True)
+ production_data = NestedField(ProductionDataSchema)
+ reference_data = NestedField(ReferenceDataSchema)
+ data_segment = NestedField(DataSegmentSchema)
+ alert_enabled = fields.Bool()
+ metric_thresholds = NestedField(ModelPerformanceMetricThresholdSchema)
+ properties = fields.Dict()
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.signals import ModelPerformanceSignal
+
+ if not isinstance(data, ModelPerformanceSignal):
+ raise ValidationError("Cannot dump non-ModelPerformanceSignal object into ModelPerformanceSignal")
+ return data
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.signals import ModelPerformanceSignal
+
+ data.pop("type", None)
+ return ModelPerformanceSignal(**data)
+
+
+class ConnectionSchema(metaclass=PatchedSchemaMeta):
+ environment_variables = fields.Dict(keys=fields.Str(), values=fields.Str())
+ secret_config = fields.Dict(keys=fields.Str(), values=fields.Str())
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.signals import Connection
+
+ return Connection(**data)
+
+
+class CustomMonitoringSignalSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(allowed_values=MonitorSignalType.CUSTOM, required=True)
+ connection = NestedField(ConnectionSchema)
+ component_id = ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT)
+ metric_thresholds = fields.List(NestedField(CustomMonitoringMetricThresholdSchema))
+ input_data = fields.Dict(keys=fields.Str(), values=NestedField(ReferenceDataSchema))
+ alert_enabled = fields.Bool()
+ inputs = fields.Dict(
+ keys=fields.Str, values=UnionField(union_fields=[NestedField(DataInputSchema), NestedField(MLTableInputSchema)])
+ )
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.signals import CustomMonitoringSignal
+
+ if not isinstance(data, CustomMonitoringSignal):
+ raise ValidationError("Cannot dump non-CustomMonitoringSignal object into CustomMonitoringSignal")
+ return data
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.signals import CustomMonitoringSignal
+
+ data.pop("type", None)
+ return CustomMonitoringSignal(**data)
+
+
+class LlmDataSchema(metaclass=PatchedSchemaMeta):
+ input_data = UnionField(union_fields=[NestedField(DataInputSchema), NestedField(MLTableInputSchema)])
+ data_column_names = fields.Dict()
+ data_window = NestedField(BaselineDataRangeSchema)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.signals import LlmData
+
+ return LlmData(**data)
+
+
+class GenerationSafetyQualitySchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(allowed_values=MonitorSignalType.GENERATION_SAFETY_QUALITY, required=True)
+ production_data = fields.List(NestedField(LlmDataSchema))
+ connection_id = fields.Str()
+ metric_thresholds = NestedField(GenerationSafetyQualityMetricThresholdSchema)
+ alert_enabled = fields.Bool()
+ properties = fields.Dict()
+ sampling_rate = fields.Float()
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.signals import GenerationSafetyQualitySignal
+
+ if not isinstance(data, GenerationSafetyQualitySignal):
+ raise ValidationError("Cannot dump non-GenerationSafetyQuality object into GenerationSafetyQuality")
+ return data
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.signals import GenerationSafetyQualitySignal
+
+ data.pop("type", None)
+ return GenerationSafetyQualitySignal(**data)
+
+
+class GenerationTokenStatisticsSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(allowed_values=MonitorSignalType.GENERATION_TOKEN_STATISTICS, required=True)
+ production_data = NestedField(LlmDataSchema)
+ metric_thresholds = NestedField(GenerationTokenStatisticsMonitorMetricThresholdSchema)
+ alert_enabled = fields.Bool()
+ properties = fields.Dict()
+ sampling_rate = fields.Float()
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.signals import GenerationTokenStatisticsSignal
+
+ if not isinstance(data, GenerationTokenStatisticsSignal):
+ raise ValidationError("Cannot dump non-GenerationSafetyQuality object into GenerationSafetyQuality")
+ return data
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.signals import GenerationTokenStatisticsSignal
+
+ data.pop("type", None)
+ return GenerationTokenStatisticsSignal(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/target.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/target.py
new file mode 100644
index 00000000..6d3032ca
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/target.py
@@ -0,0 +1,25 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+
+
+from azure.ai.ml.constants._common import AzureMLResourceType
+from azure.ai.ml.constants._monitoring import MonitorTargetTasks
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml._schema.core.fields import ArmVersionedStr, StringTransformedEnum
+
+
+class MonitoringTargetSchema(metaclass=PatchedSchemaMeta):
+ model_id = ArmVersionedStr(azureml_type=AzureMLResourceType.MODEL)
+ ml_task = StringTransformedEnum(allowed_values=[o.value for o in MonitorTargetTasks])
+ endpoint_deployment_id = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.target import MonitoringTarget
+
+ return MonitoringTarget(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/thresholds.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/thresholds.py
new file mode 100644
index 00000000..b7970fca
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/thresholds.py
@@ -0,0 +1,211 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument, name-too-long
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml.constants._monitoring import MonitorFeatureType
+from azure.ai.ml._schema.core.fields import StringTransformedEnum, NestedField
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+
+
+class MetricThresholdSchema(metaclass=PatchedSchemaMeta):
+ threshold = fields.Number()
+
+
+class NumericalDriftMetricsSchema(metaclass=PatchedSchemaMeta):
+ jensen_shannon_distance = fields.Number()
+ normalized_wasserstein_distance = fields.Number()
+ population_stability_index = fields.Number()
+ two_sample_kolmogorov_smirnov_test = fields.Number()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.thresholds import NumericalDriftMetrics
+
+ return NumericalDriftMetrics(**data)
+
+
+class CategoricalDriftMetricsSchema(metaclass=PatchedSchemaMeta):
+ jensen_shannon_distance = fields.Number()
+ population_stability_index = fields.Number()
+ pearsons_chi_squared_test = fields.Number()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.thresholds import CategoricalDriftMetrics
+
+ return CategoricalDriftMetrics(**data)
+
+
+class DataDriftMetricThresholdSchema(MetricThresholdSchema):
+ data_type = StringTransformedEnum(allowed_values=[MonitorFeatureType.NUMERICAL, MonitorFeatureType.CATEGORICAL])
+
+ numerical = NestedField(NumericalDriftMetricsSchema)
+ categorical = NestedField(CategoricalDriftMetricsSchema)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.thresholds import DataDriftMetricThreshold
+
+ return DataDriftMetricThreshold(**data)
+
+
+class DataQualityMetricsNumericalSchema(metaclass=PatchedSchemaMeta):
+ null_value_rate = fields.Number()
+ data_type_error_rate = fields.Number()
+ out_of_bounds_rate = fields.Number()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.thresholds import DataQualityMetricsNumerical
+
+ return DataQualityMetricsNumerical(**data)
+
+
+class DataQualityMetricsCategoricalSchema(metaclass=PatchedSchemaMeta):
+ null_value_rate = fields.Number()
+ data_type_error_rate = fields.Number()
+ out_of_bounds_rate = fields.Number()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.thresholds import DataQualityMetricsCategorical
+
+ return DataQualityMetricsCategorical(**data)
+
+
+class DataQualityMetricThresholdSchema(MetricThresholdSchema):
+ data_type = StringTransformedEnum(allowed_values=[MonitorFeatureType.NUMERICAL, MonitorFeatureType.CATEGORICAL])
+ numerical = NestedField(DataQualityMetricsNumericalSchema)
+ categorical = NestedField(DataQualityMetricsCategoricalSchema)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.thresholds import DataQualityMetricThreshold
+
+ return DataQualityMetricThreshold(**data)
+
+
+class PredictionDriftMetricThresholdSchema(MetricThresholdSchema):
+ data_type = StringTransformedEnum(allowed_values=[MonitorFeatureType.NUMERICAL, MonitorFeatureType.CATEGORICAL])
+ numerical = NestedField(NumericalDriftMetricsSchema)
+ categorical = NestedField(CategoricalDriftMetricsSchema)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.thresholds import PredictionDriftMetricThreshold
+
+ return PredictionDriftMetricThreshold(**data)
+
+
+# pylint: disable-next=name-too-long
+class FeatureAttributionDriftMetricThresholdSchema(MetricThresholdSchema):
+ normalized_discounted_cumulative_gain = fields.Number()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.thresholds import FeatureAttributionDriftMetricThreshold
+
+ return FeatureAttributionDriftMetricThreshold(**data)
+
+
+class ModelPerformanceClassificationThresholdsSchema(metaclass=PatchedSchemaMeta):
+ accuracy = fields.Number()
+ precision = fields.Number()
+ recall = fields.Number()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.thresholds import ModelPerformanceClassificationThresholds
+
+ return ModelPerformanceClassificationThresholds(**data)
+
+
+class ModelPerformanceRegressionThresholdsSchema(metaclass=PatchedSchemaMeta):
+ mae = fields.Number()
+ mse = fields.Number()
+ rmse = fields.Number()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.thresholds import ModelPerformanceRegressionThresholds
+
+ return ModelPerformanceRegressionThresholds(**data)
+
+
+class ModelPerformanceMetricThresholdSchema(MetricThresholdSchema):
+ classification = NestedField(ModelPerformanceClassificationThresholdsSchema)
+ regression = NestedField(ModelPerformanceRegressionThresholdsSchema)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.thresholds import ModelPerformanceMetricThreshold
+
+ return ModelPerformanceMetricThreshold(**data)
+
+
+class CustomMonitoringMetricThresholdSchema(MetricThresholdSchema):
+ metric_name = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.thresholds import CustomMonitoringMetricThreshold
+
+ return CustomMonitoringMetricThreshold(**data)
+
+
+class GenerationSafetyQualityMetricThresholdSchema(metaclass=PatchedSchemaMeta): # pylint: disable=name-too-long
+ groundedness = fields.Dict(
+ keys=StringTransformedEnum(
+ allowed_values=["aggregated_groundedness_pass_rate", "acceptable_groundedness_score_per_instance"]
+ ),
+ values=fields.Number(),
+ )
+ relevance = fields.Dict(
+ keys=StringTransformedEnum(
+ allowed_values=["aggregated_relevance_pass_rate", "acceptable_relevance_score_per_instance"]
+ ),
+ values=fields.Number(),
+ )
+ coherence = fields.Dict(
+ keys=StringTransformedEnum(
+ allowed_values=["aggregated_coherence_pass_rate", "acceptable_coherence_score_per_instance"]
+ ),
+ values=fields.Number(),
+ )
+ fluency = fields.Dict(
+ keys=StringTransformedEnum(
+ allowed_values=["aggregated_fluency_pass_rate", "acceptable_fluency_score_per_instance"]
+ ),
+ values=fields.Number(),
+ )
+ similarity = fields.Dict(
+ keys=StringTransformedEnum(
+ allowed_values=["aggregated_similarity_pass_rate", "acceptable_similarity_score_per_instance"]
+ ),
+ values=fields.Number(),
+ )
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.thresholds import GenerationSafetyQualityMonitoringMetricThreshold
+
+ return GenerationSafetyQualityMonitoringMetricThreshold(**data)
+
+
+class GenerationTokenStatisticsMonitorMetricThresholdSchema(
+ metaclass=PatchedSchemaMeta
+): # pylint: disable=name-too-long
+ totaltoken = fields.Dict(
+ keys=StringTransformedEnum(allowed_values=["total_token_count", "total_token_count_per_group"]),
+ values=fields.Number(),
+ )
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._monitoring.thresholds import GenerationTokenStatisticsMonitorMetricThreshold
+
+ return GenerationTokenStatisticsMonitorMetricThreshold(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/__init__.py
new file mode 100644
index 00000000..a19931cd
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/__init__.py
@@ -0,0 +1,17 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+# pylint: disable=unused-import
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
+
+from .component_job import (
+ CommandSchema,
+ ImportSchema,
+ ParallelSchema,
+ SparkSchema,
+ DataTransferCopySchema,
+ DataTransferImportSchema,
+ DataTransferExportSchema,
+)
+from .pipeline_job import PipelineJobSchema
+from .settings import PipelineJobSettingsSchema
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/automl_node.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/automl_node.py
new file mode 100644
index 00000000..4b815db7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/automl_node.py
@@ -0,0 +1,148 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument,protected-access
+from typing import List
+
+from marshmallow import fields, post_dump, post_load, pre_dump
+
+from azure.ai.ml._schema._utils.data_binding_expression import support_data_binding_expression_for_fields
+from azure.ai.ml._schema.automl import AutoMLClassificationSchema, AutoMLForecastingSchema, AutoMLRegressionSchema
+from azure.ai.ml._schema.automl.image_vertical.image_classification import (
+ ImageClassificationMultilabelSchema,
+ ImageClassificationSchema,
+)
+from azure.ai.ml._schema.automl.image_vertical.image_object_detection import (
+ ImageInstanceSegmentationSchema,
+ ImageObjectDetectionSchema,
+)
+from azure.ai.ml._schema.automl.nlp_vertical.text_classification import TextClassificationSchema
+from azure.ai.ml._schema.automl.nlp_vertical.text_classification_multilabel import TextClassificationMultilabelSchema
+from azure.ai.ml._schema.automl.nlp_vertical.text_ner import TextNerSchema
+from azure.ai.ml._schema.core.fields import ComputeField, NestedField, UnionField
+from azure.ai.ml._schema.core.schema import PathAwareSchema
+from azure.ai.ml._schema.job.input_output_entry import MLTableInputSchema, OutputSchema
+from azure.ai.ml._schema.pipeline.pipeline_job_io import OutputBindingStr
+
+
+class AutoMLNodeMixin(PathAwareSchema):
+ """Inherit this mixin to change automl job schemas to automl node schema.
+
+ eg: Compute is required for automl job but not required for automl node in pipeline.
+ Note: Inherit this before BaseJobSchema to make sure optional takes affect.
+ """
+
+ def __init__(self, **kwargs):
+ super(AutoMLNodeMixin, self).__init__(**kwargs)
+ # update field objects and add data binding support, won't bind task & type as data binding
+ support_data_binding_expression_for_fields(self, attrs_to_skip=["task_type", "type"])
+
+ compute = ComputeField(required=False)
+ outputs = fields.Dict(
+ keys=fields.Str(),
+ values=UnionField([NestedField(OutputSchema), OutputBindingStr], allow_none=True),
+ )
+
+ @pre_dump
+ def resolve_outputs(self, job: "AutoMLJob", **kwargs):
+ # Try resolve object's inputs & outputs and return a resolved new object
+ import copy
+
+ result = copy.copy(job)
+ result._outputs = job._build_outputs()
+ return result
+
+ @post_dump(pass_original=True)
+ # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype
+ def resolve_nested_data(self, job_dict: dict, job: "AutoMLJob", **kwargs):
+ """Resolve nested data into flatten format."""
+ from azure.ai.ml.entities._job.automl.automl_job import AutoMLJob
+
+ if not isinstance(job, AutoMLJob):
+ return job_dict
+ # change output to rest output dicts
+ job_dict["outputs"] = job._to_rest_outputs()
+ return job_dict
+
+ @post_load
+ def make(self, data, **kwargs):
+ data["task"] = data.pop("task_type")
+ return data
+
+
+class AutoMLClassificationNodeSchema(AutoMLNodeMixin, AutoMLClassificationSchema):
+ training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+ validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+ test_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+
+
+class AutoMLRegressionNodeSchema(AutoMLNodeMixin, AutoMLRegressionSchema):
+ training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+ validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+ test_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+
+
+class AutoMLForecastingNodeSchema(AutoMLNodeMixin, AutoMLForecastingSchema):
+ training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+ validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+ test_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+
+
+class AutoMLTextClassificationNode(AutoMLNodeMixin, TextClassificationSchema):
+ training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+ validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+
+
+class AutoMLTextClassificationMultilabelNode(AutoMLNodeMixin, TextClassificationMultilabelSchema):
+ training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+ validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+
+
+class AutoMLTextNerNode(AutoMLNodeMixin, TextNerSchema):
+ training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+ validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+
+
+class ImageClassificationMulticlassNodeSchema(AutoMLNodeMixin, ImageClassificationSchema):
+ training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+ validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+
+
+class ImageClassificationMultilabelNodeSchema(AutoMLNodeMixin, ImageClassificationMultilabelSchema):
+ training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+ validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+
+
+class ImageObjectDetectionNodeSchema(AutoMLNodeMixin, ImageObjectDetectionSchema):
+ training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+ validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+
+
+class ImageInstanceSegmentationNodeSchema(AutoMLNodeMixin, ImageInstanceSegmentationSchema):
+ training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+ validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+
+
+def AutoMLNodeSchema(**kwargs) -> List[fields.Field]:
+ """Get the list of all nested schema for all AutoML nodes.
+
+ :return: The list of fields
+ :rtype: List[fields.Field]
+ """
+ return [
+ # region: automl node schemas
+ NestedField(AutoMLClassificationNodeSchema, **kwargs),
+ NestedField(AutoMLRegressionNodeSchema, **kwargs),
+ NestedField(AutoMLForecastingNodeSchema, **kwargs),
+ # Vision
+ NestedField(ImageClassificationMulticlassNodeSchema, **kwargs),
+ NestedField(ImageClassificationMultilabelNodeSchema, **kwargs),
+ NestedField(ImageObjectDetectionNodeSchema, **kwargs),
+ NestedField(ImageInstanceSegmentationNodeSchema, **kwargs),
+ # NLP
+ NestedField(AutoMLTextClassificationNode, **kwargs),
+ NestedField(AutoMLTextClassificationMultilabelNode, **kwargs),
+ NestedField(AutoMLTextNerNode, **kwargs),
+ # endregion
+ ]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/component_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/component_job.py
new file mode 100644
index 00000000..8f179479
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/component_job.py
@@ -0,0 +1,554 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import logging
+
+from marshmallow import INCLUDE, ValidationError, fields, post_dump, post_load, pre_dump, validates
+
+from ..._schema.component import (
+ AnonymousCommandComponentSchema,
+ AnonymousDataTransferCopyComponentSchema,
+ AnonymousImportComponentSchema,
+ AnonymousParallelComponentSchema,
+ AnonymousSparkComponentSchema,
+ ComponentFileRefField,
+ ComponentYamlRefField,
+ DataTransferCopyComponentFileRefField,
+ ImportComponentFileRefField,
+ ParallelComponentFileRefField,
+ SparkComponentFileRefField,
+)
+from ..._utils.utils import is_data_binding_expression
+from ...constants._common import AzureMLResourceType
+from ...constants._component import DataTransferTaskType, NodeType
+from ...entities._inputs_outputs import Input
+from ...entities._job.pipeline._attr_dict import _AttrDict
+from ...exceptions import ValidationException
+from .._sweep.parameterized_sweep import ParameterizedSweepSchema
+from .._utils.data_binding_expression import support_data_binding_expression_for_fields
+from ..component.flow import FlowComponentSchema
+from ..core.fields import (
+ ArmVersionedStr,
+ ComputeField,
+ EnvironmentField,
+ NestedField,
+ RegistryStr,
+ StringTransformedEnum,
+ TypeSensitiveUnionField,
+ UnionField,
+)
+from ..core.schema import PathAwareSchema
+from ..job import ParameterizedCommandSchema, ParameterizedParallelSchema, ParameterizedSparkSchema
+from ..job.identity import AMLTokenIdentitySchema, ManagedIdentitySchema, UserIdentitySchema
+from ..job.input_output_entry import DatabaseSchema, FileSystemSchema, OutputSchema
+from ..job.input_output_fields_provider import InputsField
+from ..job.job_limits import CommandJobLimitsSchema
+from ..job.parameterized_spark import SparkEntryClassSchema, SparkEntryFileSchema
+from ..job.services import (
+ JobServiceSchema,
+ JupyterLabJobServiceSchema,
+ SshJobServiceSchema,
+ TensorBoardJobServiceSchema,
+ VsCodeJobServiceSchema,
+)
+from ..pipeline.pipeline_job_io import OutputBindingStr
+from ..spark_resource_configuration import SparkResourceConfigurationForNodeSchema
+
+module_logger = logging.getLogger(__name__)
+
+
+# do inherit PathAwareSchema to support relative path & default partial load (allow None value if not specified)
+class BaseNodeSchema(PathAwareSchema):
+ """Base schema for all node schemas."""
+
+ unknown = INCLUDE
+
+ inputs = InputsField(support_databinding=True)
+ outputs = fields.Dict(
+ keys=fields.Str(),
+ values=UnionField([OutputBindingStr, NestedField(OutputSchema)], allow_none=True),
+ )
+ properties = fields.Dict(keys=fields.Str(), values=fields.Str(allow_none=True))
+ comment = fields.Str()
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ # data binding expression is not supported inside component field, while validation error
+ # message will be very long when component is an object as error message will include
+ # str(component), so just add component to skip list. The same to trial in Sweep.
+ support_data_binding_expression_for_fields(self, ["type", "component", "trial", "inputs"])
+
+ @post_dump(pass_original=True)
+ # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype
+ def add_user_setting_attr_dict(self, data, original_data, **kwargs): # pylint: disable=unused-argument
+ """Support serializing unknown fields for pipeline node."""
+ if isinstance(original_data, _AttrDict):
+ user_setting_attr_dict = original_data._get_attrs()
+ # TODO: dump _AttrDict values to serializable data like dict instead of original object
+ # skip fields that are already serialized
+ for key, value in user_setting_attr_dict.items():
+ if key not in data:
+ data[key] = value
+ return data
+
+ # an alternative would be set schema property to be load_only, but sub-schemas like CommandSchema usually also
+ # inherit from other schema classes which also have schema property. Set post dump here would be more efficient.
+ @post_dump()
+ def remove_meaningless_key_for_node(
+ self,
+ data,
+ **kwargs, # pylint: disable=unused-argument
+ ):
+ data.pop("$schema", None)
+ return data
+
+
+def _delete_type_for_binding(io):
+ for key in io:
+ if isinstance(io[key], Input) and io[key].path and is_data_binding_expression(io[key].path):
+ io[key].type = None
+
+
+def _resolve_inputs(result, original_job):
+ result._inputs = original_job._build_inputs()
+ # delete type for literal binding input
+ _delete_type_for_binding(result._inputs)
+
+
+def _resolve_outputs(result, original_job):
+ result._outputs = original_job._build_outputs()
+ # delete type for literal binding output
+ _delete_type_for_binding(result._outputs)
+
+
+def _resolve_inputs_outputs(job):
+ # Try resolve object's inputs & outputs and return a resolved new object
+ import copy
+
+ result = copy.copy(job)
+ _resolve_inputs(result, job)
+ _resolve_outputs(result, job)
+
+ return result
+
+
+class CommandSchema(BaseNodeSchema, ParameterizedCommandSchema):
+ """Schema for Command."""
+
+ # pylint: disable=unused-argument
+ component = TypeSensitiveUnionField(
+ {
+ NodeType.COMMAND: [
+ # inline component or component file reference starting with FILE prefix
+ NestedField(AnonymousCommandComponentSchema, unknown=INCLUDE),
+ # component file reference
+ ComponentFileRefField(),
+ ],
+ },
+ plain_union_fields=[
+ # for registry type assets
+ RegistryStr(),
+ # existing component
+ ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True),
+ ],
+ required=True,
+ )
+ # code is directly linked to component.code, so no need to validate or dump it
+ code = fields.Str(allow_none=True, load_only=True)
+ type = StringTransformedEnum(allowed_values=[NodeType.COMMAND])
+ compute = ComputeField()
+ # do not promote it as CommandComponent has no field named 'limits'
+ limits = NestedField(CommandJobLimitsSchema)
+ # Change required fields to optional
+ command = fields.Str(
+ metadata={
+ "description": "The command run and the parameters passed. \
+ This string may contain place holders of inputs in {}. "
+ },
+ load_only=True,
+ )
+ environment = EnvironmentField()
+ services = fields.Dict(
+ keys=fields.Str(),
+ values=UnionField(
+ [
+ NestedField(SshJobServiceSchema),
+ NestedField(JupyterLabJobServiceSchema),
+ NestedField(TensorBoardJobServiceSchema),
+ NestedField(VsCodeJobServiceSchema),
+ # JobServiceSchema should be the last in the list.
+ # To support types not set by users like Custom, Tracking, Studio.
+ NestedField(JobServiceSchema),
+ ],
+ is_strict=True,
+ ),
+ )
+ identity = UnionField(
+ [
+ NestedField(ManagedIdentitySchema),
+ NestedField(AMLTokenIdentitySchema),
+ NestedField(UserIdentitySchema),
+ ]
+ )
+
+ @post_load
+ def make(self, data, **kwargs) -> "Command":
+ from azure.ai.ml.entities._builders import parse_inputs_outputs
+ from azure.ai.ml.entities._builders.command_func import command
+
+ # parse inputs/outputs
+ data = parse_inputs_outputs(data)
+ try:
+ command_node = command(**data)
+ except ValidationException as e:
+ # It may raise ValidationError during initialization, command._validate_io e.g. raise ValidationError
+ # instead in marshmallow function, so it won't break SchemaValidatable._schema_validate
+ raise ValidationError(e.message) from e
+ return command_node
+
+ @pre_dump
+ def resolve_inputs_outputs(self, job, **kwargs):
+ return _resolve_inputs_outputs(job)
+
+
+class SweepSchema(BaseNodeSchema, ParameterizedSweepSchema):
+ """Schema for Sweep."""
+
+ # pylint: disable=unused-argument
+ type = StringTransformedEnum(allowed_values=[NodeType.SWEEP])
+ compute = ComputeField()
+ trial = TypeSensitiveUnionField(
+ {
+ NodeType.SWEEP: [
+ # inline component or component file reference starting with FILE prefix
+ NestedField(AnonymousCommandComponentSchema, unknown=INCLUDE),
+ # component file reference
+ ComponentFileRefField(),
+ ],
+ },
+ plain_union_fields=[
+ # existing component
+ ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True),
+ ],
+ required=True,
+ )
+
+ @post_load
+ def make(self, data, **kwargs) -> "Sweep":
+ from azure.ai.ml.entities._builders import Sweep, parse_inputs_outputs
+
+ # parse inputs/outputs
+ data = parse_inputs_outputs(data)
+ return Sweep(**data)
+
+ @pre_dump
+ def resolve_inputs_outputs(self, job, **kwargs):
+ return _resolve_inputs_outputs(job)
+
+
+class ParallelSchema(BaseNodeSchema, ParameterizedParallelSchema):
+ """
+ Schema for Parallel.
+ """
+
+ # pylint: disable=unused-argument
+ compute = ComputeField()
+ component = TypeSensitiveUnionField(
+ {
+ NodeType.PARALLEL: [
+ # inline component or component file reference starting with FILE prefix
+ NestedField(AnonymousParallelComponentSchema, unknown=INCLUDE),
+ # component file reference
+ ParallelComponentFileRefField(),
+ ],
+ NodeType.FLOW_PARALLEL: [
+ NestedField(FlowComponentSchema, unknown=INCLUDE, dump_only=True),
+ ComponentYamlRefField(),
+ ],
+ },
+ plain_union_fields=[
+ # for registry type assets
+ RegistryStr(),
+ # existing component
+ ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True),
+ ],
+ required=True,
+ )
+ identity = UnionField(
+ [
+ NestedField(ManagedIdentitySchema),
+ NestedField(AMLTokenIdentitySchema),
+ NestedField(UserIdentitySchema),
+ ]
+ )
+ type = StringTransformedEnum(allowed_values=[NodeType.PARALLEL])
+
+ @post_load
+ def make(self, data, **kwargs) -> "Parallel":
+ from azure.ai.ml.entities._builders import parse_inputs_outputs
+ from azure.ai.ml.entities._builders.parallel_func import parallel_run_function
+
+ data = parse_inputs_outputs(data)
+ parallel_node = parallel_run_function(**data)
+ return parallel_node
+
+ @pre_dump
+ def resolve_inputs_outputs(self, job, **kwargs):
+ return _resolve_inputs_outputs(job)
+
+
+class ImportSchema(BaseNodeSchema):
+ """
+ Schema for Import.
+ """
+
+ # pylint: disable=unused-argument
+ component = TypeSensitiveUnionField(
+ {
+ NodeType.IMPORT: [
+ # inline component or component file reference starting with FILE prefix
+ NestedField(AnonymousImportComponentSchema, unknown=INCLUDE),
+ # component file reference
+ ImportComponentFileRefField(),
+ ],
+ },
+ plain_union_fields=[
+ # for registry type assets
+ RegistryStr(),
+ # existing component
+ ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True),
+ ],
+ required=True,
+ )
+ type = StringTransformedEnum(allowed_values=[NodeType.IMPORT])
+
+ @post_load
+ def make(self, data, **kwargs) -> "Import":
+ from azure.ai.ml.entities._builders import parse_inputs_outputs
+ from azure.ai.ml.entities._builders.import_func import import_job
+
+ # parse inputs/outputs
+ data = parse_inputs_outputs(data)
+ import_node = import_job(**data)
+ return import_node
+
+ @pre_dump
+ def resolve_inputs_outputs(self, job, **kwargs):
+ return _resolve_inputs_outputs(job)
+
+
+class SparkSchema(BaseNodeSchema, ParameterizedSparkSchema):
+ """
+ Schema for Spark.
+ """
+
+ # pylint: disable=unused-argument
+ component = TypeSensitiveUnionField(
+ {
+ NodeType.SPARK: [
+ # inline component or component file reference starting with FILE prefix
+ NestedField(AnonymousSparkComponentSchema, unknown=INCLUDE),
+ # component file reference
+ SparkComponentFileRefField(),
+ ],
+ },
+ plain_union_fields=[
+ # for registry type assets
+ RegistryStr(),
+ # existing component
+ ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True),
+ ],
+ required=True,
+ )
+ type = StringTransformedEnum(allowed_values=[NodeType.SPARK])
+ compute = ComputeField()
+ resources = NestedField(SparkResourceConfigurationForNodeSchema)
+ entry = UnionField(
+ [NestedField(SparkEntryFileSchema), NestedField(SparkEntryClassSchema)],
+ metadata={"description": "Entry."},
+ )
+ py_files = fields.List(fields.Str())
+ jars = fields.List(fields.Str())
+ files = fields.List(fields.Str())
+ archives = fields.List(fields.Str())
+ identity = UnionField(
+ [
+ NestedField(ManagedIdentitySchema),
+ NestedField(AMLTokenIdentitySchema),
+ NestedField(UserIdentitySchema),
+ ]
+ )
+
+ # code is directly linked to component.code, so no need to validate or dump it
+ code = fields.Str(allow_none=True, load_only=True)
+
+ @post_load
+ def make(self, data, **kwargs) -> "Spark":
+ from azure.ai.ml.entities._builders import parse_inputs_outputs
+ from azure.ai.ml.entities._builders.spark_func import spark
+
+ # parse inputs/outputs
+ data = parse_inputs_outputs(data)
+ try:
+ spark_node = spark(**data)
+ except ValidationException as e:
+ # It may raise ValidationError during initialization, command._validate_io e.g. raise ValidationError
+ # instead in marshmallow function, so it won't break SchemaValidatable._schema_validate
+ raise ValidationError(e.message) from e
+ return spark_node
+
+ @pre_dump
+ def resolve_inputs_outputs(self, job, **kwargs):
+ return _resolve_inputs_outputs(job)
+
+
+class DataTransferCopySchema(BaseNodeSchema):
+ """
+ Schema for DataTransferCopy.
+ """
+
+ # pylint: disable=unused-argument
+ component = TypeSensitiveUnionField(
+ {
+ NodeType.DATA_TRANSFER: [
+ # inline component or component file reference starting with FILE prefix
+ NestedField(AnonymousDataTransferCopyComponentSchema, unknown=INCLUDE),
+ # component file reference
+ DataTransferCopyComponentFileRefField(),
+ ],
+ },
+ plain_union_fields=[
+ # for registry type assets
+ RegistryStr(),
+ # existing component
+ ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True),
+ ],
+ required=True,
+ )
+ task = StringTransformedEnum(allowed_values=[DataTransferTaskType.COPY_DATA], required=True)
+ type = StringTransformedEnum(allowed_values=[NodeType.DATA_TRANSFER], required=True)
+ compute = ComputeField()
+
+ @post_load
+ def make(self, data, **kwargs) -> "DataTransferCopy":
+ from azure.ai.ml.entities._builders import parse_inputs_outputs
+ from azure.ai.ml.entities._builders.data_transfer_func import copy_data
+
+ # parse inputs/outputs
+ data = parse_inputs_outputs(data)
+ try:
+ data_transfer_node = copy_data(**data)
+ except ValidationException as e:
+ # It may raise ValidationError during initialization, data_transfer._validate_io e.g. raise ValidationError
+ # instead in marshmallow function, so it won't break SchemaValidatable._schema_validate
+ raise ValidationError(e.message) from e
+ return data_transfer_node
+
+ @pre_dump
+ def resolve_inputs_outputs(self, job, **kwargs):
+ return _resolve_inputs_outputs(job)
+
+
+class DataTransferImportSchema(BaseNodeSchema):
+ # pylint: disable=unused-argument
+ component = UnionField(
+ [
+ # for registry type assets
+ RegistryStr(),
+ # existing component
+ ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True),
+ ],
+ required=True,
+ )
+ task = StringTransformedEnum(allowed_values=[DataTransferTaskType.IMPORT_DATA], required=True)
+ type = StringTransformedEnum(allowed_values=[NodeType.DATA_TRANSFER], required=True)
+ compute = ComputeField()
+ source = UnionField([NestedField(DatabaseSchema), NestedField(FileSystemSchema)], required=True, allow_none=False)
+ outputs = fields.Dict(
+ keys=fields.Str(), values=UnionField([OutputBindingStr, NestedField(OutputSchema)]), allow_none=False
+ )
+
+ @validates("inputs")
+ def inputs_key(self, value):
+ raise ValidationError(f"inputs field is not a valid filed in task type " f"{DataTransferTaskType.IMPORT_DATA}.")
+
+ @validates("outputs")
+ def outputs_key(self, value):
+ if len(value) != 1 or list(value.keys())[0] != "sink":
+ raise ValidationError(
+ f"outputs field only support one output called sink in task type "
+ f"{DataTransferTaskType.IMPORT_DATA}."
+ )
+
+ @post_load
+ def make(self, data, **kwargs) -> "DataTransferImport":
+ from azure.ai.ml.entities._builders import parse_inputs_outputs
+ from azure.ai.ml.entities._builders.data_transfer_func import import_data
+
+ # parse inputs/outputs
+ data = parse_inputs_outputs(data)
+ try:
+ data_transfer_node = import_data(**data)
+ except ValidationException as e:
+ # It may raise ValidationError during initialization, data_transfer._validate_io e.g. raise ValidationError
+ # instead in marshmallow function, so it won't break SchemaValidatable._schema_validate
+ raise ValidationError(e.message) from e
+ return data_transfer_node
+
+ @pre_dump
+ def resolve_inputs_outputs(self, job, **kwargs):
+ return _resolve_inputs_outputs(job)
+
+
+class DataTransferExportSchema(BaseNodeSchema):
+ # pylint: disable=unused-argument
+ component = UnionField(
+ [
+ # for registry type assets
+ RegistryStr(),
+ # existing component
+ ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True),
+ ],
+ required=True,
+ )
+ task = StringTransformedEnum(allowed_values=[DataTransferTaskType.EXPORT_DATA])
+ type = StringTransformedEnum(allowed_values=[NodeType.DATA_TRANSFER])
+ compute = ComputeField()
+ inputs = InputsField(support_databinding=True, allow_none=False)
+ sink = UnionField([NestedField(DatabaseSchema), NestedField(FileSystemSchema)], required=True, allow_none=False)
+
+ @validates("inputs")
+ def inputs_key(self, value):
+ if len(value) != 1 or list(value.keys())[0] != "source":
+ raise ValidationError(
+ f"inputs field only support one input called source in task type "
+ f"{DataTransferTaskType.EXPORT_DATA}."
+ )
+
+ @validates("outputs")
+ def outputs_key(self, value):
+ raise ValidationError(
+ f"outputs field is not a valid filed in task type " f"{DataTransferTaskType.EXPORT_DATA}."
+ )
+
+ @post_load
+ def make(self, data, **kwargs) -> "DataTransferExport":
+ from azure.ai.ml.entities._builders import parse_inputs_outputs
+ from azure.ai.ml.entities._builders.data_transfer_func import export_data
+
+ # parse inputs/outputs
+ data = parse_inputs_outputs(data)
+ try:
+ data_transfer_node = export_data(**data)
+ except ValidationException as e:
+ # It may raise ValidationError during initialization, data_transfer._validate_io e.g. raise ValidationError
+ # instead in marshmallow function, so it won't break SchemaValidatable._schema_validate
+ raise ValidationError(e.message) from e
+ return data_transfer_node
+
+ @pre_dump
+ def resolve_inputs_outputs(self, job, **kwargs):
+ return _resolve_inputs_outputs(job)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/condition_node.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/condition_node.py
new file mode 100644
index 00000000..a1d2901c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/condition_node.py
@@ -0,0 +1,48 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from marshmallow import fields, post_dump, ValidationError
+
+from azure.ai.ml._schema import StringTransformedEnum
+from azure.ai.ml._schema.core.fields import DataBindingStr, NodeBindingStr, UnionField
+from azure.ai.ml._schema.pipeline.control_flow_job import ControlFlowSchema
+from azure.ai.ml.constants._component import ControlFlowType
+
+
+# ConditionNodeSchema did not inherit from BaseNodeSchema since it doesn't have inputs/outputs like other nodes.
+class ConditionNodeSchema(ControlFlowSchema):
+ type = StringTransformedEnum(allowed_values=[ControlFlowType.IF_ELSE])
+ condition = UnionField([DataBindingStr(), fields.Bool()])
+ true_block = UnionField([NodeBindingStr(), fields.List(NodeBindingStr())])
+ false_block = UnionField([NodeBindingStr(), fields.List(NodeBindingStr())])
+
+ @post_dump
+ def simplify_blocks(self, data, **kwargs): # pylint: disable=unused-argument
+ # simplify true_block and false_block to single node if there is only one node in the list
+ # this is to make sure the request to backend won't change after we support list true/false blocks
+ block_keys = ["true_block", "false_block"]
+ for block in block_keys:
+ if isinstance(data.get(block), list) and len(data.get(block)) == 1:
+ data[block] = data.get(block)[0]
+
+ # validate blocks intersection
+ def _normalize_blocks(key):
+ blocks = data.get(key, [])
+ if blocks:
+ if not isinstance(blocks, list):
+ blocks = [blocks]
+ else:
+ blocks = []
+ return blocks
+
+ true_block = _normalize_blocks("true_block")
+ false_block = _normalize_blocks("false_block")
+
+ if not true_block and not false_block:
+ raise ValidationError("True block and false block cannot be empty at the same time.")
+
+ intersection = set(true_block).intersection(set(false_block))
+ if intersection:
+ raise ValidationError(f"True block and false block cannot contain same nodes: {intersection}")
+
+ return data
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/control_flow_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/control_flow_job.py
new file mode 100644
index 00000000..3d1e3e4a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/control_flow_job.py
@@ -0,0 +1,147 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+import copy
+import json
+
+from marshmallow import INCLUDE, fields, pre_dump, pre_load
+
+from azure.ai.ml._schema.core.fields import DataBindingStr, NestedField, StringTransformedEnum, UnionField
+from azure.ai.ml._schema.core.schema import PathAwareSchema
+from azure.ai.ml.constants._component import ControlFlowType
+
+from ..job.input_output_entry import OutputSchema
+from ..job.input_output_fields_provider import InputsField
+from ..job.job_limits import DoWhileLimitsSchema
+from .component_job import _resolve_outputs
+from .pipeline_job_io import OutputBindingStr
+
+# pylint: disable=protected-access
+
+
+class ControlFlowSchema(PathAwareSchema):
+ unknown = INCLUDE
+
+
+class BaseLoopSchema(ControlFlowSchema):
+ unknown = INCLUDE
+ body = DataBindingStr()
+
+ @pre_dump
+ def convert_control_flow_body_to_binding_str(self, data, **kwargs): # pylint: disable= unused-argument
+ result = copy.copy(data)
+ # Update body object to data_binding_str
+ result._body = data._get_body_binding_str()
+ return result
+
+
+class DoWhileSchema(BaseLoopSchema):
+ # pylint: disable=unused-argument
+ type = StringTransformedEnum(allowed_values=[ControlFlowType.DO_WHILE])
+ condition = UnionField(
+ [
+ DataBindingStr(),
+ fields.Str(),
+ ]
+ )
+ mapping = fields.Dict(
+ keys=fields.Str(),
+ values=UnionField(
+ [
+ fields.List(fields.Str()),
+ fields.Str(),
+ ]
+ ),
+ required=True,
+ )
+ limits = NestedField(DoWhileLimitsSchema, required=True)
+
+ @pre_dump
+ def resolve_inputs_outputs(self, data, **kwargs):
+ # Try resolve object's mapping and condition and return a resolved new object
+ result = copy.copy(data)
+ mapping = {}
+ for k, v in result.mapping.items():
+ v = v if isinstance(v, list) else [v]
+ mapping[k] = [item._port_name for item in v]
+ result._mapping = mapping
+
+ try:
+ result._condition = result._condition._port_name
+ except AttributeError:
+ result._condition = result._condition
+
+ return result
+
+ @pre_dump
+ def convert_control_flow_body_to_binding_str(self, data, **kwargs):
+ return super(DoWhileSchema, self).convert_control_flow_body_to_binding_str(data, **kwargs)
+
+
+class ParallelForSchema(BaseLoopSchema):
+ type = StringTransformedEnum(allowed_values=[ControlFlowType.PARALLEL_FOR])
+ items = UnionField(
+ [
+ fields.Dict(keys=fields.Str(), values=InputsField()),
+ fields.List(InputsField()),
+ # put str in last to make sure other type items won't become string when dumps.
+ # TODO: only support binding here
+ fields.Str(),
+ ],
+ required=True,
+ )
+ max_concurrency = fields.Int()
+ outputs = fields.Dict(
+ keys=fields.Str(),
+ values=UnionField([OutputBindingStr, NestedField(OutputSchema)], allow_none=True),
+ )
+
+ @pre_load
+ def load_items(self, data, **kwargs): # pylint: disable= unused-argument
+ # load items from json to convert the assets in it to rest
+ try:
+ items = data["items"]
+ if isinstance(items, str):
+ items = json.loads(items)
+ data["items"] = items
+ except Exception: # pylint: disable=W0718
+ pass
+ return data
+
+ @pre_dump
+ def convert_control_flow_body_to_binding_str(self, data, **kwargs):
+ return super(ParallelForSchema, self).convert_control_flow_body_to_binding_str(data, **kwargs)
+
+ @pre_dump
+ def resolve_outputs(self, job, **kwargs): # pylint: disable=unused-argument
+ result = copy.copy(job)
+ _resolve_outputs(result, job)
+ return result
+
+ @pre_dump
+ def serialize_items(self, data, **kwargs): # pylint: disable= unused-argument
+ # serialize items to json string to avoid being removed by _dump_for_validation
+ from azure.ai.ml.entities._job.pipeline._io import InputOutputBase
+
+ def _binding_handler(obj):
+ if isinstance(obj, InputOutputBase):
+ return str(obj)
+ return repr(obj)
+
+ result = copy.copy(data)
+ if isinstance(result.items, (dict, list)):
+ # use str to serialize input/output builder
+ result._items = json.dumps(result.items, default=_binding_handler)
+ return result
+
+
+class FLScatterGatherSchema(ControlFlowSchema):
+ # TODO determine serialization, or if this is actually needed
+
+ # @pre_dump
+ def serialize_items(self, data, **kwargs):
+ pass
+
+ # @pre_dump
+ def resolve_outputs(self, job, **kwargs):
+ pass
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_command_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_command_job.py
new file mode 100644
index 00000000..c2b96f85
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_command_job.py
@@ -0,0 +1,31 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+from typing import Any
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.fields import ComputeField, EnvironmentField, NestedField, UnionField
+from azure.ai.ml._schema.job.command_job import CommandJobSchema
+from azure.ai.ml._schema.job.input_output_entry import OutputSchema
+
+module_logger = logging.getLogger(__name__)
+
+
+class PipelineCommandJobSchema(CommandJobSchema):
+ compute = ComputeField()
+ environment = EnvironmentField()
+ outputs = fields.Dict(
+ keys=fields.Str(),
+ values=UnionField([NestedField(OutputSchema), fields.Str()], allow_none=True),
+ )
+
+ @post_load
+ def make(self, data: Any, **kwargs: Any):
+ from azure.ai.ml.entities import CommandJob
+
+ return CommandJob(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_component.py
new file mode 100644
index 00000000..05096e99
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_component.py
@@ -0,0 +1,297 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+from copy import deepcopy
+
+import yaml
+from marshmallow import INCLUDE, fields, post_load, pre_dump
+
+from azure.ai.ml._schema._utils.utils import _resolve_group_inputs_for_component
+from azure.ai.ml._schema.assets.asset import AnonymousAssetSchema
+from azure.ai.ml._schema.component.component import ComponentSchema
+from azure.ai.ml._schema.component.input_output import OutputPortSchema, PrimitiveOutputSchema
+from azure.ai.ml._schema.core.fields import (
+ ArmVersionedStr,
+ FileRefField,
+ NestedField,
+ PipelineNodeNameStr,
+ RegistryStr,
+ StringTransformedEnum,
+ TypeSensitiveUnionField,
+ UnionField,
+)
+from azure.ai.ml._schema.pipeline.automl_node import AutoMLNodeSchema
+from azure.ai.ml._schema.pipeline.component_job import (
+ BaseNodeSchema,
+ CommandSchema,
+ DataTransferCopySchema,
+ DataTransferExportSchema,
+ DataTransferImportSchema,
+ ImportSchema,
+ ParallelSchema,
+ SparkSchema,
+ SweepSchema,
+ _resolve_inputs_outputs,
+)
+from azure.ai.ml._schema.pipeline.condition_node import ConditionNodeSchema
+from azure.ai.ml._schema.pipeline.control_flow_job import DoWhileSchema, ParallelForSchema
+from azure.ai.ml._schema.pipeline.pipeline_command_job import PipelineCommandJobSchema
+from azure.ai.ml._schema.pipeline.pipeline_datatransfer_job import (
+ PipelineDataTransferCopyJobSchema,
+ PipelineDataTransferExportJobSchema,
+ PipelineDataTransferImportJobSchema,
+)
+from azure.ai.ml._schema.pipeline.pipeline_import_job import PipelineImportJobSchema
+from azure.ai.ml._schema.pipeline.pipeline_parallel_job import PipelineParallelJobSchema
+from azure.ai.ml._schema.pipeline.pipeline_spark_job import PipelineSparkJobSchema
+from azure.ai.ml._utils.utils import is_private_preview_enabled
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, AzureMLResourceType
+from azure.ai.ml.constants._component import (
+ CONTROL_FLOW_TYPES,
+ ComponentSource,
+ ControlFlowType,
+ DataTransferTaskType,
+ NodeType,
+)
+
+
+class NodeNameStr(PipelineNodeNameStr):
+ def _get_field_name(self) -> str:
+ return "Pipeline node"
+
+
+def PipelineJobsField():
+ pipeline_enable_job_type = {
+ NodeType.COMMAND: [
+ NestedField(CommandSchema, unknown=INCLUDE),
+ NestedField(PipelineCommandJobSchema),
+ ],
+ NodeType.IMPORT: [
+ NestedField(ImportSchema, unknown=INCLUDE),
+ NestedField(PipelineImportJobSchema),
+ ],
+ NodeType.SWEEP: [NestedField(SweepSchema, unknown=INCLUDE)],
+ NodeType.PARALLEL: [
+ # ParallelSchema support parallel pipeline yml with "component"
+ NestedField(ParallelSchema, unknown=INCLUDE),
+ NestedField(PipelineParallelJobSchema, unknown=INCLUDE),
+ ],
+ NodeType.PIPELINE: [NestedField("PipelineSchema", unknown=INCLUDE)],
+ NodeType.AUTOML: AutoMLNodeSchema(unknown=INCLUDE),
+ NodeType.SPARK: [
+ NestedField(SparkSchema, unknown=INCLUDE),
+ NestedField(PipelineSparkJobSchema),
+ ],
+ }
+
+ # Note: the private node types only available when private preview flag opened before init of pipeline job
+ # schema class.
+ if is_private_preview_enabled():
+ pipeline_enable_job_type[ControlFlowType.DO_WHILE] = [NestedField(DoWhileSchema, unknown=INCLUDE)]
+ pipeline_enable_job_type[ControlFlowType.IF_ELSE] = [NestedField(ConditionNodeSchema, unknown=INCLUDE)]
+ pipeline_enable_job_type[ControlFlowType.PARALLEL_FOR] = [NestedField(ParallelForSchema, unknown=INCLUDE)]
+
+ # Todo: Put data_transfer logic to the last to avoid error message conflict, open a item to track:
+ # https://msdata.visualstudio.com/Vienna/_workitems/edit/2244262/
+ pipeline_enable_job_type[NodeType.DATA_TRANSFER] = [
+ TypeSensitiveUnionField(
+ {
+ DataTransferTaskType.COPY_DATA: [
+ NestedField(DataTransferCopySchema, unknown=INCLUDE),
+ NestedField(PipelineDataTransferCopyJobSchema),
+ ],
+ DataTransferTaskType.IMPORT_DATA: [
+ NestedField(DataTransferImportSchema, unknown=INCLUDE),
+ NestedField(PipelineDataTransferImportJobSchema),
+ ],
+ DataTransferTaskType.EXPORT_DATA: [
+ NestedField(DataTransferExportSchema, unknown=INCLUDE),
+ NestedField(PipelineDataTransferExportJobSchema),
+ ],
+ },
+ type_field_name="task",
+ unknown=INCLUDE,
+ )
+ ]
+
+ pipeline_job_field = fields.Dict(
+ keys=NodeNameStr(),
+ values=TypeSensitiveUnionField(pipeline_enable_job_type),
+ )
+ return pipeline_job_field
+
+
+# pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype
+def _post_load_pipeline_jobs(context, data: dict) -> dict:
+ """Silently convert Job in pipeline jobs to node."""
+ from azure.ai.ml.entities._builders import parse_inputs_outputs
+ from azure.ai.ml.entities._builders.condition_node import ConditionNode
+ from azure.ai.ml.entities._builders.do_while import DoWhile
+ from azure.ai.ml.entities._builders.parallel_for import ParallelFor
+ from azure.ai.ml.entities._job.automl.automl_job import AutoMLJob
+ from azure.ai.ml.entities._job.pipeline._component_translatable import ComponentTranslatableMixin
+
+ # parse inputs/outputs
+ data = parse_inputs_outputs(data)
+ # convert JobNode to Component here
+ jobs = data.get("jobs", {})
+
+ for key, job_instance in jobs.items():
+ if isinstance(job_instance, dict):
+ # convert AutoML job dict to instance
+ if job_instance.get("type") == NodeType.AUTOML:
+ job_instance = AutoMLJob._create_instance_from_schema_dict(
+ loaded_data=job_instance,
+ )
+ elif job_instance.get("type") in CONTROL_FLOW_TYPES:
+ # Set source to yaml job for control flow node.
+ job_instance["_source"] = ComponentSource.YAML_JOB
+
+ job_type = job_instance.get("type")
+ if job_type == ControlFlowType.IF_ELSE:
+ # Convert to if-else node.
+ job_instance = ConditionNode._create_instance_from_schema_dict(loaded_data=job_instance)
+ elif job_instance.get("type") == ControlFlowType.DO_WHILE:
+ # Convert to do-while node.
+ job_instance = DoWhile._create_instance_from_schema_dict(
+ pipeline_jobs=jobs, loaded_data=job_instance
+ )
+ elif job_instance.get("type") == ControlFlowType.PARALLEL_FOR:
+ # Convert to do-while node.
+ job_instance = ParallelFor._create_instance_from_schema_dict(
+ pipeline_jobs=jobs, loaded_data=job_instance
+ )
+ jobs[key] = job_instance
+
+ for key, job_instance in jobs.items():
+ # Translate job to node if translatable and overrides to_node.
+ if isinstance(job_instance, ComponentTranslatableMixin) and "_to_node" in type(job_instance).__dict__:
+ # set source as YAML
+ job_instance = job_instance._to_node(
+ context=context,
+ pipeline_job_dict=data,
+ )
+ if job_instance.type == NodeType.DATA_TRANSFER and job_instance.task != DataTransferTaskType.COPY_DATA:
+ job_instance._source = ComponentSource.BUILTIN
+ else:
+ job_instance.component._source = ComponentSource.YAML_JOB
+ job_instance._source = job_instance.component._source
+ jobs[key] = job_instance
+ # update job instance name to key
+ job_instance.name = key
+ return data
+
+
+class PipelineComponentSchema(ComponentSchema):
+ type = StringTransformedEnum(allowed_values=[NodeType.PIPELINE])
+ jobs = PipelineJobsField()
+
+ # primitive output is only supported for command component & pipeline component
+ outputs = fields.Dict(
+ keys=fields.Str(),
+ values=UnionField(
+ [
+ NestedField(PrimitiveOutputSchema, unknown=INCLUDE),
+ NestedField(OutputPortSchema),
+ ]
+ ),
+ )
+
+ @post_load
+ def make(self, data, **kwargs): # pylint: disable=unused-argument
+ return _post_load_pipeline_jobs(self.context, data)
+
+
+class RestPipelineComponentSchema(PipelineComponentSchema):
+ """When component load from rest, won't validate on name since there might
+ be existing component with invalid name."""
+
+ name = fields.Str(required=True)
+
+
+class _AnonymousPipelineComponentSchema(AnonymousAssetSchema, PipelineComponentSchema):
+ """Anonymous pipeline component schema.
+
+ Note that do not support inline define anonymous pipeline component
+ directly. Inheritance follows order: AnonymousAssetSchema,
+ PipelineComponentSchema because we need name and version to be
+ dump_only(marshmallow collects fields follows method resolution
+ order).
+ """
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._component.pipeline_component import PipelineComponent
+
+ # pipeline jobs post process is required before init of pipeline component: it converts control node dict
+ # to entity.
+ # however @post_load invocation order is not guaranteed, so we need to call it explicitly here.
+ _post_load_pipeline_jobs(self.context, data)
+
+ return PipelineComponent(
+ base_path=self.context[BASE_PATH_CONTEXT_KEY],
+ **data,
+ )
+
+
+class PipelineComponentFileRefField(FileRefField):
+ # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype
+ def _serialize(self, value, attr, obj, **kwargs):
+ """FileRefField does not support serialize.
+
+ Call AnonymousPipelineComponent schema to serialize. This
+ function is overwrite because we need Pipeline can be dumped.
+ """
+ # Update base_path to parent path of component file.
+ component_schema_context = deepcopy(self.context)
+ value = _resolve_group_inputs_for_component(value)
+ return _AnonymousPipelineComponentSchema(context=component_schema_context)._serialize(value, **kwargs)
+
+ def _deserialize(self, value, attr, data, **kwargs):
+ # Get component info from component yaml file.
+ data = super()._deserialize(value, attr, data, **kwargs)
+ component_dict = yaml.safe_load(data)
+ source_path = self.context[BASE_PATH_CONTEXT_KEY] / value
+
+ # Update base_path to parent path of component file.
+ component_schema_context = deepcopy(self.context)
+ component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent
+ component = _AnonymousPipelineComponentSchema(context=component_schema_context).load(
+ component_dict, unknown=INCLUDE
+ )
+ component._source_path = source_path
+ component._source = ComponentSource.YAML_COMPONENT
+ return component
+
+
+# Note: PipelineSchema is defined here instead of component_job.py is to
+# resolve circular import and support recursive schema.
+class PipelineSchema(BaseNodeSchema):
+ # pylint: disable=unused-argument
+ # do not support inline define a pipeline node
+ component = UnionField(
+ [
+ # for registry type assets
+ RegistryStr(azureml_type=AzureMLResourceType.COMPONENT),
+ # existing component
+ ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True),
+ # component file reference
+ PipelineComponentFileRefField(),
+ ],
+ required=True,
+ )
+ type = StringTransformedEnum(allowed_values=[NodeType.PIPELINE])
+
+ @post_load
+ def make(self, data, **kwargs) -> "Pipeline":
+ from azure.ai.ml.entities._builders import parse_inputs_outputs
+ from azure.ai.ml.entities._builders.pipeline import Pipeline
+
+ data = parse_inputs_outputs(data)
+ return Pipeline(**data)
+
+ @pre_dump
+ def resolve_inputs_outputs(self, data, **kwargs):
+ return _resolve_inputs_outputs(data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_datatransfer_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_datatransfer_job.py
new file mode 100644
index 00000000..a63e687d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_datatransfer_job.py
@@ -0,0 +1,55 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+from typing import Any
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.fields import NestedField, UnionField
+from azure.ai.ml._schema.job.input_output_entry import OutputSchema
+from azure.ai.ml._schema.pipeline.pipeline_job_io import OutputBindingStr
+from azure.ai.ml._schema.job.data_transfer_job import (
+ DataTransferCopyJobSchema,
+ DataTransferImportJobSchema,
+ DataTransferExportJobSchema,
+)
+
+module_logger = logging.getLogger(__name__)
+
+
+class PipelineDataTransferCopyJobSchema(DataTransferCopyJobSchema):
+ outputs = fields.Dict(
+ keys=fields.Str(),
+ values=UnionField([NestedField(OutputSchema), OutputBindingStr], allow_none=True),
+ )
+
+ @post_load
+ def make(self, data: Any, **kwargs: Any):
+ from azure.ai.ml.entities._job.data_transfer.data_transfer_job import DataTransferCopyJob
+
+ return DataTransferCopyJob(**data)
+
+
+class PipelineDataTransferImportJobSchema(DataTransferImportJobSchema):
+ outputs = fields.Dict(
+ keys=fields.Str(),
+ values=UnionField([NestedField(OutputSchema), OutputBindingStr], allow_none=True),
+ )
+
+ @post_load
+ def make(self, data: Any, **kwargs: Any):
+ from azure.ai.ml.entities._job.data_transfer.data_transfer_job import DataTransferImportJob
+
+ return DataTransferImportJob(**data)
+
+
+class PipelineDataTransferExportJobSchema(DataTransferExportJobSchema):
+ @post_load
+ def make(self, data: Any, **kwargs: Any):
+ from azure.ai.ml.entities._job.data_transfer.data_transfer_job import DataTransferExportJob
+
+ return DataTransferExportJob(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_import_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_import_job.py
new file mode 100644
index 00000000..ae338597
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_import_job.py
@@ -0,0 +1,25 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+from typing import Any
+
+from marshmallow import post_load
+
+from azure.ai.ml._schema.job.import_job import ImportJobSchema
+
+module_logger = logging.getLogger(__name__)
+
+
+class PipelineImportJobSchema(ImportJobSchema):
+ class Meta:
+ exclude = ["compute"] # compute property not applicable to import job
+
+ @post_load
+ def make(self, data: Any, **kwargs: Any):
+ from azure.ai.ml.entities._job.import_job import ImportJob
+
+ return ImportJob(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_job.py
new file mode 100644
index 00000000..46daeb92
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_job.py
@@ -0,0 +1,76 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+
+from marshmallow import INCLUDE, ValidationError, post_load, pre_dump, pre_load
+
+from azure.ai.ml._schema.core.fields import (
+ ArmVersionedStr,
+ ComputeField,
+ NestedField,
+ RegistryStr,
+ StringTransformedEnum,
+ UnionField,
+)
+from azure.ai.ml._schema.job import BaseJobSchema
+from azure.ai.ml._schema.job.input_output_fields_provider import InputsField, OutputsField
+from azure.ai.ml._schema.pipeline.component_job import _resolve_inputs_outputs
+from azure.ai.ml._schema.pipeline.pipeline_component import (
+ PipelineComponentFileRefField,
+ PipelineJobsField,
+ _post_load_pipeline_jobs,
+)
+from azure.ai.ml._schema.pipeline.settings import PipelineJobSettingsSchema
+from azure.ai.ml.constants import JobType
+from azure.ai.ml.constants._common import AzureMLResourceType
+
+module_logger = logging.getLogger(__name__)
+
+
+class PipelineJobSchema(BaseJobSchema):
+ type = StringTransformedEnum(allowed_values=[JobType.PIPELINE])
+ compute = ComputeField()
+ settings = NestedField(PipelineJobSettingsSchema, unknown=INCLUDE)
+ # Support databinding in inputs as we support macro like ${{name}}
+ inputs = InputsField(support_databinding=True)
+ outputs = OutputsField()
+ jobs = PipelineJobsField()
+ component = UnionField(
+ [
+ # for registry type assets
+ RegistryStr(azureml_type=AzureMLResourceType.COMPONENT),
+ # existing component
+ ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True),
+ # component file reference
+ PipelineComponentFileRefField(),
+ ],
+ )
+
+ @pre_dump()
+ def backup_jobs_and_remove_component(self, job, **kwargs):
+ # pylint: disable=protected-access
+ job_copy = _resolve_inputs_outputs(job)
+ if not isinstance(job_copy.component, str):
+ # If component is pipeline component object,
+ # copy jobs to job and remove component.
+ if not job_copy._jobs:
+ job_copy._jobs = job_copy.component.jobs
+ job_copy.component = None
+ return job_copy
+
+ @pre_load()
+ def check_exclusive_fields(self, data: dict, **kwargs) -> dict:
+ error_msg = "'jobs' and 'component' are mutually exclusive fields in pipeline job."
+ # When loading from yaml, data["component"] must be a local path (str)
+ # Otherwise, data["component"] can be a PipelineComponent so data["jobs"] must exist
+ if isinstance(data.get("component"), str) and data.get("jobs"):
+ raise ValidationError(error_msg)
+ return data
+
+ @post_load
+ def make(self, data: dict, **kwargs) -> dict:
+ return _post_load_pipeline_jobs(self.context, data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_job_io.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_job_io.py
new file mode 100644
index 00000000..3fb6a7b7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_job_io.py
@@ -0,0 +1,51 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import logging
+import re
+
+from marshmallow import ValidationError, fields
+
+from azure.ai.ml.constants._component import ComponentJobConstants
+
+module_logger = logging.getLogger(__name__)
+
+
+class OutputBindingStr(fields.Field):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ def _jsonschema_type_mapping(self):
+ schema = {"type": "string", "pattern": ComponentJobConstants.OUTPUT_PATTERN}
+ if self.name is not None:
+ schema["title"] = self.name
+ if self.dump_only:
+ schema["readonly"] = True
+ return schema
+
+ def _serialize(self, value, attr, obj, **kwargs):
+ if isinstance(value, str) and re.match(ComponentJobConstants.OUTPUT_PATTERN, value):
+ return value
+ # _to_job_output in io.py will return Output
+ # add this branch to judge whether original value is a simple binding or Output
+ if (
+ isinstance(value.path, str)
+ and re.match(ComponentJobConstants.OUTPUT_PATTERN, value.path)
+ and value.mode is None
+ ):
+ return value.path
+ raise ValidationError(f"Invalid output binding string '{value}' passed")
+
+ def _deserialize(self, value, attr, data, **kwargs):
+ if (
+ isinstance(value, dict)
+ and "path" in value
+ and "mode" not in value
+ and "name" not in value
+ and "version" not in value
+ ):
+ value = value["path"]
+ if isinstance(value, str) and re.match(ComponentJobConstants.OUTPUT_PATTERN, value):
+ return value
+ raise ValidationError(f"Invalid output binding string '{value}' passed")
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_parallel_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_parallel_job.py
new file mode 100644
index 00000000..3b30fb66
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_parallel_job.py
@@ -0,0 +1,40 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+from typing import Any
+
+from marshmallow import post_load
+
+from azure.ai.ml._schema.core.fields import ComputeField, EnvironmentField, StringTransformedEnum
+from azure.ai.ml._schema.job import ParameterizedParallelSchema
+from azure.ai.ml._schema.pipeline.component_job import BaseNodeSchema
+
+from ...constants._component import NodeType
+
+module_logger = logging.getLogger(__name__)
+
+
+# parallel job inherits parallel attributes from ParameterizedParallelSchema and node functionality from BaseNodeSchema
+class PipelineParallelJobSchema(BaseNodeSchema, ParameterizedParallelSchema):
+ """Schema for ParallelJob in PipelineJob/PipelineComponent."""
+
+ type = StringTransformedEnum(allowed_values=NodeType.PARALLEL)
+ compute = ComputeField()
+ environment = EnvironmentField()
+
+ @post_load
+ def make(self, data: Any, **kwargs: Any):
+ """Construct a ParallelJob from deserialized data.
+
+ :param data: The deserialized data.
+ :type data: dict[str, Any]
+ :return: A ParallelJob.
+ :rtype: azure.ai.ml.entities._job.parallel.ParallelJob
+ """
+ from azure.ai.ml.entities._job.parallel.parallel_job import ParallelJob
+
+ return ParallelJob(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_spark_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_spark_job.py
new file mode 100644
index 00000000..69d58255
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/pipeline_spark_job.py
@@ -0,0 +1,29 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+import logging
+from typing import Any
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.fields import NestedField, UnionField
+from azure.ai.ml._schema.job.input_output_entry import OutputSchema
+from azure.ai.ml._schema.job.spark_job import SparkJobSchema
+
+module_logger = logging.getLogger(__name__)
+
+
+class PipelineSparkJobSchema(SparkJobSchema):
+ outputs = fields.Dict(
+ keys=fields.Str(),
+ values=UnionField([NestedField(OutputSchema), fields.Str()], allow_none=True),
+ )
+
+ @post_load
+ def make(self, data: Any, **kwargs: Any):
+ from azure.ai.ml.entities._job.spark_job import SparkJob
+
+ return SparkJob(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/settings.py
new file mode 100644
index 00000000..1e5227b0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/settings.py
@@ -0,0 +1,42 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import INCLUDE, Schema, fields, post_dump, post_load
+
+from azure.ai.ml._schema.core.fields import ArmStr, StringTransformedEnum, UnionField
+from azure.ai.ml._schema.pipeline.pipeline_component import NodeNameStr
+from azure.ai.ml._utils.utils import is_private_preview_enabled
+from azure.ai.ml.constants._common import AzureMLResourceType, SERVERLESS_COMPUTE
+
+
+class PipelineJobSettingsSchema(Schema):
+ class Meta:
+ unknown = INCLUDE
+
+ default_datastore = ArmStr(azureml_type=AzureMLResourceType.DATASTORE)
+ default_compute = UnionField(
+ [
+ StringTransformedEnum(allowed_values=[SERVERLESS_COMPUTE]),
+ ArmStr(azureml_type=AzureMLResourceType.COMPUTE),
+ ]
+ )
+ continue_on_step_failure = fields.Bool()
+ force_rerun = fields.Bool()
+
+ # move init/finalize under private preview flag to hide them in spec
+ if is_private_preview_enabled():
+ on_init = NodeNameStr()
+ on_finalize = NodeNameStr()
+
+ @post_load
+ def make(self, data, **kwargs) -> "PipelineJobSettings":
+ from azure.ai.ml.entities import PipelineJobSettings
+
+ return PipelineJobSettings(**data)
+
+ @post_dump
+ def remove_none(self, data, **kwargs):
+ return {key: value for key, value in data.items() if value is not None}
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/queue_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/queue_settings.py
new file mode 100644
index 00000000..3196a00c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/queue_settings.py
@@ -0,0 +1,23 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import post_load
+from azure.ai.ml.constants._job.job import JobPriorityValues, JobTierNames
+from azure.ai.ml._schema.core.fields import StringTransformedEnum
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+
+
+class QueueSettingsSchema(metaclass=PatchedSchemaMeta):
+ job_tier = StringTransformedEnum(
+ allowed_values=JobTierNames.ALLOWED_NAMES,
+ )
+ priority = StringTransformedEnum(
+ allowed_values=JobPriorityValues.ALLOWED_VALUES,
+ )
+
+ @post_load
+ def make(self, data, **kwargs): # pylint: disable=unused-argument
+ from azure.ai.ml.entities import QueueSettings
+
+ return QueueSettings(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/__init__.py
new file mode 100644
index 00000000..9c2fe189
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/__init__.py
@@ -0,0 +1,9 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore
+
+from .registry import RegistrySchema
+
+__all__ = ["RegistrySchema"]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/registry.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/registry.py
new file mode 100644
index 00000000..17233195
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/registry.py
@@ -0,0 +1,53 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields
+
+from azure.ai.ml._schema.core.fields import DumpableStringField, NestedField, StringTransformedEnum, UnionField
+from azure.ai.ml._schema.core.intellectual_property import PublisherSchema
+from azure.ai.ml._schema.core.resource import ResourceSchema
+from azure.ai.ml._schema.workspace.identity import IdentitySchema
+from azure.ai.ml._utils.utils import snake_to_pascal
+from azure.ai.ml.constants._common import PublicNetworkAccess
+from azure.ai.ml.constants._registry import AcrAccountSku
+from azure.ai.ml.entities._registry.registry_support_classes import SystemCreatedAcrAccount
+
+from .registry_region_arm_details import RegistryRegionDetailsSchema
+from .system_created_acr_account import SystemCreatedAcrAccountSchema
+from .util import acr_format_validator
+
+
+# Based on 10-01-preview api
+class RegistrySchema(ResourceSchema):
+ # Inherits name, id, tags, and description fields from ResourceSchema
+
+ # Values from RegistryTrackedResource (Client name: Registry)
+ location = fields.Str(required=True)
+
+ # Values from Registry (Client name: RegistryProperties)
+ public_network_access = StringTransformedEnum(
+ allowed_values=[PublicNetworkAccess.DISABLED, PublicNetworkAccess.ENABLED],
+ casing_transform=snake_to_pascal,
+ )
+ replication_locations = fields.List(NestedField(RegistryRegionDetailsSchema))
+ intellectual_property = NestedField(PublisherSchema)
+ # This is an acr account which will be applied to every registryRegionArmDetail defined
+ # in replication_locations. This is different from the internal swagger
+ # definition, which has a per-region list of acr accounts.
+ # Per-region acr account configuration is NOT possible through yaml configs for now.
+ container_registry = UnionField(
+ [DumpableStringField(validate=acr_format_validator), NestedField(SystemCreatedAcrAccountSchema)],
+ required=False,
+ is_strict=True,
+ load_default=SystemCreatedAcrAccount(acr_account_sku=AcrAccountSku.PREMIUM),
+ )
+
+ # Values that can only be set by return values from the system, never
+ # set by the user.
+ identity = NestedField(IdentitySchema, dump_only=True)
+ kind = fields.Str(dump_only=True)
+ sku = fields.Str(dump_only=True)
+ managed_resource_group = fields.Str(dump_only=True)
+ mlflow_registry_uri = fields.Str(dump_only=True)
+ discovery_url = fields.Str(dump_only=True)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/registry_region_arm_details.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/registry_region_arm_details.py
new file mode 100644
index 00000000..c861b94c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/registry_region_arm_details.py
@@ -0,0 +1,61 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import ValidationError, fields, post_load, pre_dump
+
+from azure.ai.ml._schema.core.fields import DumpableStringField, NestedField, UnionField
+from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants._registry import StorageAccountType
+from azure.ai.ml.entities._registry.registry_support_classes import SystemCreatedStorageAccount
+
+from .system_created_storage_account import SystemCreatedStorageAccountSchema
+from .util import storage_account_validator
+
+
+# Differs from the swagger def in that the acr_details can only be supplied as a
+# single registry-wide instance, rather than a per-region list.
+@experimental
+class RegistryRegionDetailsSchema(metaclass=PatchedSchemaMeta):
+ # Commenting this out for the time being.
+ # We do not want to surface the acr_config as a per-region configurable
+ # field. Instead we want to simplify the UX and surface it as a non-list,
+ # top-level value called 'container_registry'.
+ # We don't even want to show the per-region acr accounts when displaying a
+ # registry to the user, so this isn't even left as a dump-only field.
+ """acr_config = fields.List(
+ UnionField(
+ [DumpableStringField(validate=acr_format_validator), NestedField(SystemCreatedAcrAccountSchema)],
+ dump_only=True,
+ is_strict=True,
+ )
+ )"""
+ location = fields.Str()
+ storage_config = UnionField(
+ [
+ NestedField(SystemCreatedStorageAccountSchema),
+ fields.List(DumpableStringField(validate=storage_account_validator)),
+ ],
+ is_strict=True,
+ load_default=SystemCreatedStorageAccount(
+ storage_account_hns=False, storage_account_type=StorageAccountType.STANDARD_LRS
+ ),
+ )
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import RegistryRegionDetails
+
+ data.pop("type", None)
+ return RegistryRegionDetails(**data)
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ from azure.ai.ml.entities import RegistryRegionDetails
+
+ if not isinstance(data, RegistryRegionDetails):
+ raise ValidationError("Cannot dump non-RegistryRegionDetails object into RegistryRegionDetailsSchema")
+ return data
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/system_created_acr_account.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/system_created_acr_account.py
new file mode 100644
index 00000000..08b78c2e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/system_created_acr_account.py
@@ -0,0 +1,35 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import ValidationError, fields, post_load, pre_dump
+
+from azure.ai.ml._schema import StringTransformedEnum
+from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants._registry import AcrAccountSku
+
+
+@experimental
+class SystemCreatedAcrAccountSchema(metaclass=PatchedSchemaMeta):
+ arm_resource_id = fields.Str(dump_only=True)
+ acr_account_sku = StringTransformedEnum(
+ allowed_values=[sku.value for sku in AcrAccountSku], casing_transform=lambda x: x.lower()
+ )
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import SystemCreatedAcrAccount
+
+ data.pop("type", None)
+ return SystemCreatedAcrAccount(**data)
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ from azure.ai.ml.entities import SystemCreatedAcrAccount
+
+ if not isinstance(data, SystemCreatedAcrAccount):
+ raise ValidationError("Cannot dump non-SystemCreatedAcrAccount object into SystemCreatedAcrAccountSchema")
+ return data
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/system_created_storage_account.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/system_created_storage_account.py
new file mode 100644
index 00000000..cdbbcd67
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/system_created_storage_account.py
@@ -0,0 +1,40 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import ValidationError, fields, post_load, pre_dump
+
+from azure.ai.ml._schema import StringTransformedEnum
+from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
+from azure.ai.ml.constants._registry import StorageAccountType
+
+
+class SystemCreatedStorageAccountSchema(metaclass=PatchedSchemaMeta):
+ arm_resource_id = fields.Str(dump_only=True)
+ storage_account_hns = fields.Bool(load_default=False)
+ storage_account_type = StringTransformedEnum(
+ load_default=StorageAccountType.STANDARD_LRS,
+ allowed_values=[accountType.value for accountType in StorageAccountType],
+ casing_transform=lambda x: x.lower(),
+ )
+ replication_count = fields.Int(load_default=1, validate=lambda count: count > 0)
+ replicated_ids = fields.List(fields.Str(), dump_only=True)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import SystemCreatedStorageAccount
+
+ data.pop("type", None)
+ return SystemCreatedStorageAccount(**data)
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ from azure.ai.ml.entities import SystemCreatedStorageAccount
+
+ if not isinstance(data, SystemCreatedStorageAccount):
+ raise ValidationError(
+ "Cannot dump non-SystemCreatedStorageAccount object into SystemCreatedStorageAccountSchema"
+ )
+ return data
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/util.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/util.py
new file mode 100644
index 00000000..19c01e9a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/registry/util.py
@@ -0,0 +1,15 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# Simple helper methods to avoid re-using lambda's everywhere
+
+from azure.ai.ml.constants._registry import ACR_ACCOUNT_FORMAT, STORAGE_ACCOUNT_FORMAT
+
+
+def storage_account_validator(storage_id: str):
+ return STORAGE_ACCOUNT_FORMAT.match(storage_id) is not None
+
+
+def acr_format_validator(acr_id: str):
+ return ACR_ACCOUNT_FORMAT.match(acr_id) is not None
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/resource_configuration.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/resource_configuration.py
new file mode 100644
index 00000000..fece59a2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/resource_configuration.py
@@ -0,0 +1,21 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
+
+
+class ResourceConfigurationSchema(metaclass=PatchedSchemaMeta):
+ instance_count = fields.Int()
+ instance_type = fields.Str(metadata={"description": "The instance type to make available to this job."})
+ properties = fields.Dict(keys=fields.Str())
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import ResourceConfiguration
+
+ return ResourceConfiguration(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/__init__.py
new file mode 100644
index 00000000..fdf8caba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/create_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/create_job.py
new file mode 100644
index 00000000..084f8a5b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/create_job.py
@@ -0,0 +1,144 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+# pylint: disable=protected-access
+import copy
+from typing import Optional
+
+import yaml
+from marshmallow import INCLUDE, ValidationError, fields, post_load, pre_load
+
+from azure.ai.ml._schema import CommandJobSchema
+from azure.ai.ml._schema.core.fields import (
+ ArmStr,
+ ComputeField,
+ EnvironmentField,
+ FileRefField,
+ NestedField,
+ StringTransformedEnum,
+ UnionField,
+)
+from azure.ai.ml._schema.job import BaseJobSchema
+from azure.ai.ml._schema.job.input_output_fields_provider import InputsField, OutputsField
+from azure.ai.ml._schema.pipeline.settings import PipelineJobSettingsSchema
+from azure.ai.ml._utils.utils import load_file, merge_dict
+from azure.ai.ml.constants import JobType
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, AzureMLResourceType
+
+_SCHEDULED_JOB_UPDATES_KEY = "scheduled_job_updates"
+
+
+class CreateJobFileRefField(FileRefField):
+ # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype
+ def _serialize(self, value, attr, obj, **kwargs):
+ """FileRefField does not support serialize.
+
+ This function is overwrite because we need job can be dumped inside schedule.
+ """
+ from azure.ai.ml.entities._builders import BaseNode
+
+ if isinstance(value, BaseNode):
+ # Dump as Job to avoid missing field.
+ value = value._to_job()
+ return value._to_dict()
+
+ def _deserialize(self, value, attr, data, **kwargs) -> "Job":
+ # Get component info from component yaml file.
+ data = super()._deserialize(value, attr, data, **kwargs)
+ job_dict = yaml.safe_load(data)
+
+ from azure.ai.ml.entities import Job
+
+ return Job._load(
+ data=job_dict,
+ yaml_path=self.context[BASE_PATH_CONTEXT_KEY] / value,
+ **kwargs,
+ )
+
+
+class BaseCreateJobSchema(BaseJobSchema):
+ compute = ComputeField()
+ job = UnionField(
+ [
+ ArmStr(azureml_type=AzureMLResourceType.JOB),
+ CreateJobFileRefField,
+ ],
+ required=True,
+ )
+
+ # pylint: disable-next=docstring-missing-param
+ def _get_job_instance_for_remote_job(self, id: Optional[str], data: Optional[dict], **kwargs) -> "Job":
+ """Get a job instance to store updates for remote job.
+
+ :return: The remote job
+ :rtype: Job
+ """
+ from azure.ai.ml.entities import Job
+
+ data = {} if data is None else data
+ if "type" not in data:
+ raise ValidationError("'type' must be specified when scheduling a remote job with updates.")
+ # Create a job instance if job is arm id
+ job_instance = Job._load(
+ data=data,
+ **kwargs,
+ )
+ # Set back the id and base path to created job
+ job_instance._id = id
+ job_instance._base_path = self.context[BASE_PATH_CONTEXT_KEY]
+ return job_instance
+
+ @pre_load
+ def pre_load(self, data, **kwargs): # pylint: disable=unused-argument
+ if isinstance(data, dict):
+ # Put the raw replicas into context.
+ # dict type indicates there are updates to the scheduled job.
+ copied_data = copy.deepcopy(data)
+ copied_data.pop("job", None)
+ self.context[_SCHEDULED_JOB_UPDATES_KEY] = copied_data
+ return data
+
+ @post_load
+ def make(self, data: dict, **kwargs) -> "Job":
+ from azure.ai.ml.entities import Job
+
+ # Get the loaded job
+ job = data.pop("job")
+ # Get the raw dict data before load
+ raw_data = self.context.get(_SCHEDULED_JOB_UPDATES_KEY, {})
+ if isinstance(job, Job):
+ if job._source_path is None:
+ raise ValidationError("Could not load job for schedule without '_source_path' set.")
+ # Load local job again with updated values
+ job_dict = yaml.safe_load(load_file(job._source_path))
+ return Job._load(
+ data=merge_dict(job_dict, raw_data),
+ yaml_path=job._source_path,
+ **kwargs,
+ )
+ # Create a job instance for remote job
+ return self._get_job_instance_for_remote_job(job, raw_data, **kwargs)
+
+
+class PipelineCreateJobSchema(BaseCreateJobSchema):
+ # Note: Here we do not inherit PipelineJobSchema, as we don't need the post_load, pre_load inside.
+ type = StringTransformedEnum(allowed_values=[JobType.PIPELINE])
+ inputs = InputsField()
+ outputs = OutputsField()
+ settings = NestedField(PipelineJobSettingsSchema, unknown=INCLUDE)
+
+
+class CommandCreateJobSchema(BaseCreateJobSchema, CommandJobSchema):
+ class Meta:
+ # Refer to https://github.com/Azure/azureml_run_specification/blob/master
+ # /specs/job-endpoint.md#properties-in-difference-job-types
+ # code and command can not be set during runtime
+ exclude = ["code", "command"]
+
+ environment = EnvironmentField()
+
+
+class SparkCreateJobSchema(BaseCreateJobSchema):
+ type = StringTransformedEnum(allowed_values=[JobType.SPARK])
+ conf = fields.Dict(keys=fields.Str(), values=fields.Raw())
+ environment = EnvironmentField(allow_none=True)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/schedule.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/schedule.py
new file mode 100644
index 00000000..fbde3e9b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/schedule.py
@@ -0,0 +1,44 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields
+
+from azure.ai.ml._schema.core.fields import ArmStr, NestedField, UnionField
+from azure.ai.ml._schema.core.resource import ResourceSchema
+from azure.ai.ml._schema.job import CreationContextSchema
+from azure.ai.ml._schema.schedule.create_job import (
+ CommandCreateJobSchema,
+ CreateJobFileRefField,
+ PipelineCreateJobSchema,
+ SparkCreateJobSchema,
+)
+from azure.ai.ml._schema.schedule.trigger import CronTriggerSchema, RecurrenceTriggerSchema
+from azure.ai.ml.constants._common import AzureMLResourceType
+
+
+class ScheduleSchema(ResourceSchema):
+ name = fields.Str(attribute="name", required=True)
+ display_name = fields.Str(attribute="display_name")
+ trigger = UnionField(
+ [
+ NestedField(CronTriggerSchema),
+ NestedField(RecurrenceTriggerSchema),
+ ],
+ )
+ creation_context = NestedField(CreationContextSchema, dump_only=True)
+ is_enabled = fields.Boolean(dump_only=True)
+ provisioning_state = fields.Str(dump_only=True)
+ properties = fields.Dict(keys=fields.Str(), values=fields.Str(allow_none=True))
+
+
+class JobScheduleSchema(ScheduleSchema):
+ create_job = UnionField(
+ [
+ ArmStr(azureml_type=AzureMLResourceType.JOB),
+ CreateJobFileRefField,
+ NestedField(PipelineCreateJobSchema),
+ NestedField(CommandCreateJobSchema),
+ NestedField(SparkCreateJobSchema),
+ ]
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/trigger.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/trigger.py
new file mode 100644
index 00000000..37147d48
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/schedule/trigger.py
@@ -0,0 +1,82 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields, post_dump, post_load
+
+from azure.ai.ml._restclient.v2022_10_01_preview.models import RecurrenceFrequency, TriggerType, WeekDay
+from azure.ai.ml._schema.core.fields import (
+ DateTimeStr,
+ DumpableIntegerField,
+ NestedField,
+ StringTransformedEnum,
+ UnionField,
+)
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml.constants import TimeZone
+
+
+class TriggerSchema(metaclass=PatchedSchemaMeta):
+ start_time = UnionField([fields.DateTime(), DateTimeStr()])
+ end_time = UnionField([fields.DateTime(), DateTimeStr()])
+ time_zone = fields.Str()
+
+ @post_dump(pass_original=True)
+ # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype
+ def resolve_time_zone(self, data, original_data, **kwargs): # pylint: disable= unused-argument
+ """
+ Auto-convert will get string like "TimeZone.UTC" for TimeZone enum object,
+ while the valid result should be "UTC"
+ """
+ if isinstance(original_data.time_zone, TimeZone):
+ data["time_zone"] = original_data.time_zone.value
+ return data
+
+
+class CronTriggerSchema(TriggerSchema):
+ type = StringTransformedEnum(allowed_values=TriggerType.CRON, required=True)
+ expression = fields.Str(required=True)
+
+ @post_load
+ def make(self, data, **kwargs) -> "CronTrigger": # pylint: disable= unused-argument
+ from azure.ai.ml.entities import CronTrigger
+
+ data.pop("type")
+ return CronTrigger(**data)
+
+
+class RecurrencePatternSchema(metaclass=PatchedSchemaMeta):
+ hours = UnionField([DumpableIntegerField(), fields.List(fields.Int())], required=True)
+ minutes = UnionField([DumpableIntegerField(), fields.List(fields.Int())], required=True)
+ week_days = UnionField(
+ [
+ StringTransformedEnum(allowed_values=[o.value for o in WeekDay]),
+ fields.List(StringTransformedEnum(allowed_values=[o.value for o in WeekDay])),
+ ]
+ )
+ month_days = UnionField(
+ [
+ fields.Int(),
+ fields.List(fields.Int()),
+ ]
+ )
+
+ @post_load
+ def make(self, data, **kwargs) -> "RecurrencePattern": # pylint: disable= unused-argument
+ from azure.ai.ml.entities import RecurrencePattern
+
+ return RecurrencePattern(**data)
+
+
+class RecurrenceTriggerSchema(TriggerSchema):
+ type = StringTransformedEnum(allowed_values=TriggerType.RECURRENCE, required=True)
+ frequency = StringTransformedEnum(allowed_values=[o.value for o in RecurrenceFrequency], required=True)
+ interval = fields.Int(required=True)
+ schedule = NestedField(RecurrencePatternSchema())
+
+ @post_load
+ def make(self, data, **kwargs) -> "RecurrenceTrigger": # pylint: disable= unused-argument
+ from azure.ai.ml.entities import RecurrenceTrigger
+
+ data.pop("type")
+ return RecurrenceTrigger(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/spark_resource_configuration.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/spark_resource_configuration.py
new file mode 100644
index 00000000..8571adf1
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/spark_resource_configuration.py
@@ -0,0 +1,52 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.fields import NumberVersionField, StringTransformedEnum
+from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
+
+
+class SparkResourceConfigurationSchema(metaclass=PatchedSchemaMeta):
+ """Schema for SparkResourceConfiguration."""
+
+ instance_type = fields.Str(metadata={"description": "Optional type of VM used as supported by the compute target."})
+ runtime_version = NumberVersionField()
+
+ @post_load
+ def make(self, data, **kwargs):
+ """Construct a SparkResourceConfiguration object from the marshalled data.
+
+ :param data: The marshalled data.
+ :type data: dict[str, str]
+ :return: A SparkResourceConfiguration object.
+ :rtype: ~azure.ai.ml.entities.SparkResourceConfiguration
+ """
+ from azure.ai.ml.entities import SparkResourceConfiguration
+
+ return SparkResourceConfiguration(**data)
+
+
+class SparkResourceConfigurationForNodeSchema(SparkResourceConfigurationSchema):
+ """
+ Schema for SparkResourceConfiguration, used for node configuration, where we need to move validation logic to
+ schema.
+ """
+
+ instance_type = StringTransformedEnum(
+ allowed_values=[
+ "standard_e4s_v3",
+ "standard_e8s_v3",
+ "standard_e16s_v3",
+ "standard_e32s_v3",
+ "standard_e64s_v3",
+ ],
+ required=True,
+ metadata={"description": "Optional type of VM used as supported by the compute target."},
+ )
+ runtime_version = NumberVersionField(
+ required=True,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/__init__.py
new file mode 100644
index 00000000..dc8b82e2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/__init__.py
@@ -0,0 +1,11 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore
+
+from .workspace import WorkspaceSchema
+from .ai_workspaces.project import ProjectSchema
+from .ai_workspaces.hub import HubSchema
+
+__all__ = ["WorkspaceSchema", "ProjectSchema", "HubSchema"]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/ai_workspaces/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/ai_workspaces/__init__.py
new file mode 100644
index 00000000..29a4fcd3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/ai_workspaces/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/ai_workspaces/capability_host.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/ai_workspaces/capability_host.py
new file mode 100644
index 00000000..cdccb24c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/ai_workspaces/capability_host.py
@@ -0,0 +1,18 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields
+
+from azure.ai.ml._schema.core.schema import PathAwareSchema
+from azure.ai.ml._utils._experimental import experimental
+
+
+@experimental
+class CapabilityHostSchema(PathAwareSchema):
+ name = fields.Str()
+ description = fields.Str()
+ capability_host_kind = fields.Str()
+ vector_store_connections = fields.List(fields.Str(), required=False)
+ ai_services_connections = fields.List(fields.Str(), required=False)
+ storage_connections = fields.List(fields.Str(), required=False)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/ai_workspaces/hub.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/ai_workspaces/hub.py
new file mode 100644
index 00000000..94a7c380
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/ai_workspaces/hub.py
@@ -0,0 +1,18 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields
+
+from azure.ai.ml._schema import StringTransformedEnum
+from azure.ai.ml._schema.workspace import WorkspaceSchema
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants import WorkspaceKind
+
+
+@experimental
+class HubSchema(WorkspaceSchema):
+ # additional_workspace_storage_accounts This field exists in the API, but is unused, and thus not surfaced yet.
+ kind = StringTransformedEnum(required=True, allowed_values=WorkspaceKind.HUB)
+ default_resource_group = fields.Str(required=False)
+ associated_workspaces = fields.List(fields.Str(), required=False, dump_only=True)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/ai_workspaces/project.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/ai_workspaces/project.py
new file mode 100644
index 00000000..86daa735
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/ai_workspaces/project.py
@@ -0,0 +1,16 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields
+
+from azure.ai.ml._schema import StringTransformedEnum
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml._schema.workspace import WorkspaceSchema
+from azure.ai.ml.constants import WorkspaceKind
+
+
+@experimental
+class ProjectSchema(WorkspaceSchema):
+ kind = StringTransformedEnum(required=True, allowed_values=WorkspaceKind.PROJECT)
+ hub_id = fields.Str(required=True)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/__init__.py
new file mode 100644
index 00000000..fa462cfb
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/__init__.py
@@ -0,0 +1,37 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore
+
+from .workspace_connection import WorkspaceConnectionSchema
+from .connection_subtypes import (
+ AzureBlobStoreConnectionSchema,
+ MicrosoftOneLakeConnectionSchema,
+ AzureOpenAIConnectionSchema,
+ AzureAIServicesConnectionSchema,
+ AzureAISearchConnectionSchema,
+ AzureContentSafetyConnectionSchema,
+ AzureSpeechServicesConnectionSchema,
+ APIKeyConnectionSchema,
+ OpenAIConnectionSchema,
+ SerpConnectionSchema,
+ ServerlessConnectionSchema,
+ OneLakeArtifactSchema,
+)
+
+__all__ = [
+ "WorkspaceConnectionSchema",
+ "AzureBlobStoreConnectionSchema",
+ "MicrosoftOneLakeConnectionSchema",
+ "AzureOpenAIConnectionSchema",
+ "AzureAIServicesConnectionSchema",
+ "AzureAISearchConnectionSchema",
+ "AzureContentSafetyConnectionSchema",
+ "AzureSpeechServicesConnectionSchema",
+ "APIKeyConnectionSchema",
+ "OpenAIConnectionSchema",
+ "SerpConnectionSchema",
+ "ServerlessConnectionSchema",
+ "OneLakeArtifactSchema",
+]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/connection_subtypes.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/connection_subtypes.py
new file mode 100644
index 00000000..d04b3e76
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/connection_subtypes.py
@@ -0,0 +1,225 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+from marshmallow.exceptions import ValidationError
+from marshmallow.decorators import pre_load
+
+from azure.ai.ml._restclient.v2024_04_01_preview.models import ConnectionCategory
+from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum, UnionField
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.constants._common import ConnectionTypes
+from azure.ai.ml._schema.workspace.connections.one_lake_artifacts import OneLakeArtifactSchema
+from azure.ai.ml._schema.workspace.connections.credentials import (
+ SasTokenConfigurationSchema,
+ ServicePrincipalConfigurationSchema,
+ AccountKeyConfigurationSchema,
+ AadCredentialConfigurationSchema,
+)
+from azure.ai.ml.entities import AadCredentialConfiguration
+from .workspace_connection import WorkspaceConnectionSchema
+
+
+class AzureBlobStoreConnectionSchema(WorkspaceConnectionSchema):
+ # type and credentials limited
+ type = StringTransformedEnum(
+ allowed_values=ConnectionCategory.AZURE_BLOB, casing_transform=camel_to_snake, required=True
+ )
+ credentials = UnionField(
+ [
+ NestedField(SasTokenConfigurationSchema),
+ NestedField(AccountKeyConfigurationSchema),
+ NestedField(AadCredentialConfigurationSchema),
+ ],
+ required=False,
+ load_default=AadCredentialConfiguration(),
+ )
+
+ url = fields.Str()
+
+ account_name = fields.Str(required=True, allow_none=False)
+ container_name = fields.Str(required=True, allow_none=False)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import AzureBlobStoreConnection
+
+ return AzureBlobStoreConnection(**data)
+
+
+class MicrosoftOneLakeConnectionSchema(WorkspaceConnectionSchema):
+ type = StringTransformedEnum(
+ allowed_values=ConnectionCategory.AZURE_ONE_LAKE, casing_transform=camel_to_snake, required=True
+ )
+ credentials = UnionField(
+ [NestedField(ServicePrincipalConfigurationSchema), NestedField(AadCredentialConfigurationSchema)],
+ required=False,
+ load_default=AadCredentialConfiguration(),
+ )
+ artifact = NestedField(OneLakeArtifactSchema, required=False, allow_none=True)
+
+ endpoint = fields.Str(required=False)
+ one_lake_workspace_name = fields.Str(required=False)
+
+ @pre_load
+ def check_for_target(self, data, **kwargs):
+ target = data.get("target", None)
+ artifact = data.get("artifact", None)
+ endpoint = data.get("endpoint", None)
+ one_lake_workspace_name = data.get("one_lake_workspace_name", None)
+ # If the user is using a target, then they don't need the artifact and one lake workspace name.
+ # This is distinct from when the user set's the 'endpoint' value, which is also used to construct
+ # the target. If the target is already present, then the loaded connection YAML was probably produced
+ # by dumping an extant connection.
+ if target is None:
+ if artifact is None:
+ raise ValidationError("If target is unset, then artifact must be set")
+ if endpoint is None:
+ raise ValidationError("If target is unset, then endpoint must be set")
+ if one_lake_workspace_name is None:
+ raise ValidationError("If target is unset, then one_lake_workspace_name must be set")
+ return data
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import MicrosoftOneLakeConnection
+
+ return MicrosoftOneLakeConnection(**data)
+
+
+class AzureOpenAIConnectionSchema(WorkspaceConnectionSchema):
+ # type and credentials limited
+ type = StringTransformedEnum(
+ allowed_values=ConnectionCategory.AZURE_OPEN_AI, casing_transform=camel_to_snake, required=True
+ )
+ api_key = fields.Str(required=False, allow_none=True)
+ api_version = fields.Str(required=False, allow_none=True)
+
+ azure_endpoint = fields.Str()
+ open_ai_resource_id = fields.Str(required=False, allow_none=True)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import AzureOpenAIConnection
+
+ return AzureOpenAIConnection(**data)
+
+
+class AzureAIServicesConnectionSchema(WorkspaceConnectionSchema):
+ # type and credentials limited
+ type = StringTransformedEnum(
+ allowed_values=ConnectionTypes.AZURE_AI_SERVICES, casing_transform=camel_to_snake, required=True
+ )
+ api_key = fields.Str(required=False, allow_none=True)
+ endpoint = fields.Str()
+ ai_services_resource_id = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import AzureAIServicesConnection
+
+ return AzureAIServicesConnection(**data)
+
+
+class AzureAISearchConnectionSchema(WorkspaceConnectionSchema):
+ # type and credentials limited
+ type = StringTransformedEnum(
+ allowed_values=ConnectionTypes.AZURE_SEARCH, casing_transform=camel_to_snake, required=True
+ )
+ api_key = fields.Str(required=False, allow_none=True)
+ endpoint = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import AzureAISearchConnection
+
+ return AzureAISearchConnection(**data)
+
+
+class AzureContentSafetyConnectionSchema(WorkspaceConnectionSchema):
+ # type and credentials limited
+ type = StringTransformedEnum(
+ allowed_values=ConnectionTypes.AZURE_CONTENT_SAFETY, casing_transform=camel_to_snake, required=True
+ )
+ api_key = fields.Str(required=False, allow_none=True)
+ endpoint = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import AzureContentSafetyConnection
+
+ return AzureContentSafetyConnection(**data)
+
+
+class AzureSpeechServicesConnectionSchema(WorkspaceConnectionSchema):
+ # type and credentials limited
+ type = StringTransformedEnum(
+ allowed_values=ConnectionTypes.AZURE_SPEECH_SERVICES, casing_transform=camel_to_snake, required=True
+ )
+ api_key = fields.Str(required=False, allow_none=True)
+ endpoint = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import AzureSpeechServicesConnection
+
+ return AzureSpeechServicesConnection(**data)
+
+
+class APIKeyConnectionSchema(WorkspaceConnectionSchema):
+ # type and credentials limited
+ type = StringTransformedEnum(
+ allowed_values=ConnectionCategory.API_KEY, casing_transform=camel_to_snake, required=True
+ )
+ api_key = fields.Str(required=True)
+ api_base = fields.Str(required=True)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import APIKeyConnection
+
+ return APIKeyConnection(**data)
+
+
+class OpenAIConnectionSchema(WorkspaceConnectionSchema):
+ # type and credentials limited
+ type = StringTransformedEnum(
+ allowed_values=ConnectionCategory.OPEN_AI, casing_transform=camel_to_snake, required=True
+ )
+ api_key = fields.Str(required=True)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import OpenAIConnection
+
+ return OpenAIConnection(**data)
+
+
+class SerpConnectionSchema(WorkspaceConnectionSchema):
+ # type and credentials limited
+ type = StringTransformedEnum(allowed_values=ConnectionCategory.SERP, casing_transform=camel_to_snake, required=True)
+ api_key = fields.Str(required=True)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import SerpConnection
+
+ return SerpConnection(**data)
+
+
+class ServerlessConnectionSchema(WorkspaceConnectionSchema):
+ # type and credentials limited
+ type = StringTransformedEnum(
+ allowed_values=ConnectionCategory.SERVERLESS, casing_transform=camel_to_snake, required=True
+ )
+ api_key = fields.Str(required=True)
+ endpoint = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import ServerlessConnection
+
+ return ServerlessConnection(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/credentials.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/credentials.py
new file mode 100644
index 00000000..52213c08
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/credentials.py
@@ -0,0 +1,178 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+##### DEV NOTE: For some reason, these schemas correlate to the classes defined in ~azure.ai.ml.entities._credentials.
+# There used to be a credentials.py file in ~azure.ai.ml.entities.workspace.connections,
+# but it was, as far as I could tell, never used. So I removed it and added this comment.
+
+from typing import Dict
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._restclient.v2024_04_01_preview.models import ConnectionAuthType
+from azure.ai.ml._schema.core.fields import StringTransformedEnum
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.entities._credentials import (
+ ManagedIdentityConfiguration,
+ PatTokenConfiguration,
+ SasTokenConfiguration,
+ ServicePrincipalConfiguration,
+ UsernamePasswordConfiguration,
+ AccessKeyConfiguration,
+ ApiKeyConfiguration,
+ AccountKeyConfiguration,
+ AadCredentialConfiguration,
+ NoneCredentialConfiguration,
+)
+
+
+class WorkspaceCredentialsSchema(metaclass=PatchedSchemaMeta):
+ type = fields.Str()
+
+
+class PatTokenConfigurationSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(
+ allowed_values=ConnectionAuthType.PAT,
+ casing_transform=camel_to_snake,
+ required=True,
+ )
+ pat = fields.Str()
+
+ @post_load
+ def make(self, data: Dict[str, str], **kwargs) -> PatTokenConfiguration:
+ data.pop("type")
+ return PatTokenConfiguration(**data)
+
+
+class SasTokenConfigurationSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(
+ allowed_values=ConnectionAuthType.SAS,
+ casing_transform=camel_to_snake,
+ required=True,
+ )
+ sas_token = fields.Str()
+
+ @post_load
+ def make(self, data: Dict[str, str], **kwargs) -> SasTokenConfiguration:
+ data.pop("type")
+ return SasTokenConfiguration(**data)
+
+
+class UsernamePasswordConfigurationSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(
+ allowed_values=ConnectionAuthType.USERNAME_PASSWORD,
+ casing_transform=camel_to_snake,
+ required=True,
+ )
+ username = fields.Str()
+ password = fields.Str()
+
+ @post_load
+ def make(self, data: Dict[str, str], **kwargs) -> UsernamePasswordConfiguration:
+ data.pop("type")
+ return UsernamePasswordConfiguration(**data)
+
+
+class ManagedIdentityConfigurationSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(
+ allowed_values=ConnectionAuthType.MANAGED_IDENTITY,
+ casing_transform=camel_to_snake,
+ required=True,
+ )
+ client_id = fields.Str()
+ resource_id = fields.Str()
+
+ @post_load
+ def make(self, data: Dict[str, str], **kwargs) -> ManagedIdentityConfiguration:
+ data.pop("type")
+ return ManagedIdentityConfiguration(**data)
+
+
+class ServicePrincipalConfigurationSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(
+ allowed_values=ConnectionAuthType.SERVICE_PRINCIPAL,
+ casing_transform=camel_to_snake,
+ required=True,
+ )
+
+ client_id = fields.Str()
+ client_secret = fields.Str()
+ tenant_id = fields.Str()
+
+ @post_load
+ def make(self, data: Dict[str, str], **kwargs) -> ServicePrincipalConfiguration:
+ data.pop("type")
+ return ServicePrincipalConfiguration(**data)
+
+
+class AccessKeyConfigurationSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(
+ allowed_values=ConnectionAuthType.ACCESS_KEY,
+ casing_transform=camel_to_snake,
+ required=True,
+ )
+ access_key_id = fields.Str()
+ secret_access_key = fields.Str()
+
+ @post_load
+ def make(self, data: Dict[str, str], **kwargs) -> AccessKeyConfiguration:
+ data.pop("type")
+ return AccessKeyConfiguration(**data)
+
+
+class ApiKeyConfigurationSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(
+ allowed_values=ConnectionAuthType.API_KEY,
+ casing_transform=camel_to_snake,
+ required=True,
+ )
+ key = fields.Str()
+
+ @post_load
+ def make(self, data: Dict[str, str], **kwargs) -> ApiKeyConfiguration:
+ data.pop("type")
+ return ApiKeyConfiguration(**data)
+
+
+class AccountKeyConfigurationSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(
+ allowed_values=ConnectionAuthType.ACCOUNT_KEY,
+ casing_transform=camel_to_snake,
+ required=True,
+ )
+ account_key = fields.Str()
+
+ @post_load
+ def make(self, data: Dict[str, str], **kwargs) -> AccountKeyConfiguration:
+ data.pop("type")
+ return AccountKeyConfiguration(**data)
+
+
+class AadCredentialConfigurationSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(
+ allowed_values=ConnectionAuthType.AAD,
+ casing_transform=camel_to_snake,
+ required=True,
+ )
+
+ @post_load
+ def make(self, data: Dict[str, str], **kwargs) -> AadCredentialConfiguration:
+ data.pop("type")
+ return AadCredentialConfiguration(**data)
+
+
+class NoneCredentialConfigurationSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(
+ allowed_values=ConnectionAuthType.NONE,
+ casing_transform=camel_to_snake,
+ required=True,
+ )
+
+ @post_load
+ def make(self, data: Dict[str, str], **kwargs) -> NoneCredentialConfiguration:
+ data.pop("type")
+ return NoneCredentialConfiguration(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/one_lake_artifacts.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/one_lake_artifacts.py
new file mode 100644
index 00000000..563a9359
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/one_lake_artifacts.py
@@ -0,0 +1,26 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.fields import StringTransformedEnum
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.constants._common import OneLakeArtifactTypes
+
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+
+
+class OneLakeArtifactSchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(
+ allowed_values=OneLakeArtifactTypes.ONE_LAKE, casing_transform=camel_to_snake, required=True
+ )
+ name = fields.Str(required=True)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import OneLakeConnectionArtifact
+
+ return OneLakeConnectionArtifact(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/workspace_connection.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/workspace_connection.py
new file mode 100644
index 00000000..20863a5a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/connections/workspace_connection.py
@@ -0,0 +1,86 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._restclient.v2024_04_01_preview.models import ConnectionCategory
+from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum, UnionField
+from azure.ai.ml._schema.core.resource import ResourceSchema
+from azure.ai.ml._schema.job import CreationContextSchema
+from azure.ai.ml._schema.workspace.connections.credentials import (
+ AccountKeyConfigurationSchema,
+ ManagedIdentityConfigurationSchema,
+ PatTokenConfigurationSchema,
+ SasTokenConfigurationSchema,
+ ServicePrincipalConfigurationSchema,
+ UsernamePasswordConfigurationSchema,
+ AccessKeyConfigurationSchema,
+ ApiKeyConfigurationSchema,
+ AadCredentialConfigurationSchema,
+ NoneCredentialConfigurationSchema,
+)
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.constants._common import ConnectionTypes
+from azure.ai.ml.entities import NoneCredentialConfiguration, AadCredentialConfiguration
+
+
+class WorkspaceConnectionSchema(ResourceSchema):
+ # Inherits name, id, tags, and description fields from ResourceSchema
+ creation_context = NestedField(CreationContextSchema, dump_only=True)
+ type = StringTransformedEnum(
+ allowed_values=[
+ ConnectionCategory.GIT,
+ ConnectionCategory.CONTAINER_REGISTRY,
+ ConnectionCategory.PYTHON_FEED,
+ ConnectionCategory.S3,
+ ConnectionCategory.SNOWFLAKE,
+ ConnectionCategory.AZURE_SQL_DB,
+ ConnectionCategory.AZURE_SYNAPSE_ANALYTICS,
+ ConnectionCategory.AZURE_MY_SQL_DB,
+ ConnectionCategory.AZURE_POSTGRES_DB,
+ ConnectionTypes.CUSTOM,
+ ConnectionTypes.AZURE_DATA_LAKE_GEN_2,
+ ],
+ casing_transform=camel_to_snake,
+ required=True,
+ )
+
+ # Sorta false, some connection types require this field, some don't.
+ # And some rename it... for client familiarity reasons.
+ target = fields.Str(required=False)
+
+ credentials = UnionField(
+ [
+ NestedField(PatTokenConfigurationSchema),
+ NestedField(SasTokenConfigurationSchema),
+ NestedField(UsernamePasswordConfigurationSchema),
+ NestedField(ManagedIdentityConfigurationSchema),
+ NestedField(ServicePrincipalConfigurationSchema),
+ NestedField(AccessKeyConfigurationSchema),
+ NestedField(ApiKeyConfigurationSchema),
+ NestedField(AccountKeyConfigurationSchema),
+ NestedField(AadCredentialConfigurationSchema),
+ NestedField(NoneCredentialConfigurationSchema),
+ ],
+ required=False,
+ load_default=NoneCredentialConfiguration(),
+ )
+
+ is_shared = fields.Bool(load_default=True)
+ metadata = fields.Dict(required=False)
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import WorkspaceConnection
+
+ # Most non-subclassed connections default to a none credential if none
+ # is provided. ALDS Gen 2 connections default to AAD with this code.
+ if (
+ data.get("type", None) == ConnectionTypes.AZURE_DATA_LAKE_GEN_2
+ and data.get("credentials", None) == NoneCredentialConfiguration()
+ ):
+ data["credentials"] = AadCredentialConfiguration()
+ return WorkspaceConnection(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/customer_managed_key.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/customer_managed_key.py
new file mode 100644
index 00000000..459507fc
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/customer_managed_key.py
@@ -0,0 +1,23 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+
+
+class CustomerManagedKeySchema(metaclass=PatchedSchemaMeta):
+ key_vault = fields.Str()
+ key_uri = fields.Url()
+ cosmosdb_id = fields.Str()
+ storage_id = fields.Str()
+ search_id = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import CustomerManagedKey
+
+ return CustomerManagedKey(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/endpoint_connection.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/endpoint_connection.py
new file mode 100644
index 00000000..ba926d9e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/endpoint_connection.py
@@ -0,0 +1,23 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+
+
+class EndpointConnectionSchema(metaclass=PatchedSchemaMeta):
+ subscription_id = fields.UUID()
+ resource_group = fields.Str()
+ location = fields.Str()
+ vnet_name = fields.Str()
+ subnet_name = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import EndpointConnection
+
+ return EndpointConnection(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/identity.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/identity.py
new file mode 100644
index 00000000..d0348c3b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/identity.py
@@ -0,0 +1,79 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields
+from marshmallow.decorators import post_load, pre_dump
+
+from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum
+from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
+from azure.ai.ml._utils.utils import camel_to_snake, snake_to_camel
+from azure.ai.ml.constants._workspace import ManagedServiceIdentityType
+from azure.ai.ml.entities._credentials import IdentityConfiguration, ManagedIdentityConfiguration
+
+
+class UserAssignedIdentitySchema(metaclass=PatchedSchemaMeta):
+ principal_id = fields.Str(required=False)
+ client_id = fields.Str(required=False)
+ resource_id = fields.Str(required=False)
+
+ @post_load
+ def make(self, data, **kwargs):
+ return ManagedIdentityConfiguration(**data)
+
+
+class IdentitySchema(metaclass=PatchedSchemaMeta):
+ type = StringTransformedEnum(
+ allowed_values=[
+ ManagedServiceIdentityType.SYSTEM_ASSIGNED,
+ ManagedServiceIdentityType.USER_ASSIGNED,
+ ManagedServiceIdentityType.NONE,
+ ManagedServiceIdentityType.SYSTEM_ASSIGNED_USER_ASSIGNED,
+ ],
+ casing_transform=camel_to_snake,
+ metadata={"description": "resource identity type."},
+ )
+ principal_id = fields.Str(required=False)
+ tenant_id = fields.Str(required=False)
+ user_assigned_identities = fields.Dict(
+ keys=fields.Str(required=True), values=NestedField(UserAssignedIdentitySchema, allow_none=True), allow_none=True
+ )
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ if data and isinstance(data, IdentityConfiguration):
+ data.user_assigned_identities = self.uai_list2dict(data.user_assigned_identities)
+ return data
+
+ @post_load
+ def make(self, data, **kwargs):
+ if data.get("user_assigned_identities", False):
+ data["user_assigned_identities"] = self.uai_dict2list(data.pop("user_assigned_identities"))
+ data["type"] = snake_to_camel(data.pop("type"))
+ return IdentityConfiguration(**data)
+
+ def uai_dict2list(self, uai_dict):
+ res = []
+ for resource_id, meta in uai_dict.items():
+ if not isinstance(meta, ManagedIdentityConfiguration):
+ continue
+ c_id = meta.client_id
+ p_id = meta.principal_id
+ res.append(ManagedIdentityConfiguration(resource_id=resource_id, client_id=c_id, principal_id=p_id))
+ return res
+
+ def uai_list2dict(self, uai_list):
+ res = {}
+ if uai_list and isinstance(uai_list, list):
+ for uai in uai_list:
+ if not isinstance(uai, ManagedIdentityConfiguration):
+ continue
+ meta = {}
+ if uai.client_id:
+ meta["client_id"] = uai.client_id
+ if uai.principal_id:
+ meta["principal_id"] = uai.principal_id
+ res[uai.resource_id] = meta
+ return res if res else None
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/network_acls.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/network_acls.py
new file mode 100644
index 00000000..e9e5e8ec
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/network_acls.py
@@ -0,0 +1,63 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import ValidationError, fields, post_load, validates_schema
+
+from azure.ai.ml._schema.core.schema import PathAwareSchema
+from azure.ai.ml.entities._workspace.network_acls import DefaultActionType, IPRule, NetworkAcls
+
+
+class IPRuleSchema(PathAwareSchema):
+ """Schema for IPRule."""
+
+ value = fields.Str(required=True)
+
+ @post_load
+ def make(self, data, **kwargs): # pylint: disable=unused-argument
+ """Create an IPRule object from the marshmallow schema.
+
+ :param data: The data from which the IPRule is being loaded.
+ :type data: OrderedDict[str, Any]
+ :returns: An IPRule object.
+ :rtype: azure.ai.ml.entities._workspace.network_acls.NetworkAcls.IPRule
+ """
+ return IPRule(**data)
+
+
+class NetworkAclsSchema(PathAwareSchema):
+ """Schema for NetworkAcls.
+
+ :param default_action: Specifies the default action when no IP rules are matched.
+ :type default_action: str
+ :param ip_rules: Rules governing the accessibility of a resource from a specific IP address or IP range.
+ :type ip_rules: Optional[List[IPRule]]
+ """
+
+ default_action = fields.Str(required=True)
+ ip_rules = fields.List(fields.Nested(IPRuleSchema), allow_none=True)
+
+ @post_load
+ def make(self, data, **kwargs): # pylint: disable=unused-argument
+ """Create a NetworkAcls object from the marshmallow schema.
+
+ :param data: The data from which the NetworkAcls is being loaded.
+ :type data: OrderedDict[str, Any]
+ :returns: A NetworkAcls object.
+ :rtype: azure.ai.ml.entities._workspace.network_acls.NetworkAcls
+ """
+ return NetworkAcls(**data)
+
+ @validates_schema
+ def validate_schema(self, data, **kwargs): # pylint: disable=unused-argument
+ """Validate the NetworkAcls schema.
+
+ :param data: The data to validate.
+ :type data: OrderedDict[str, Any]
+ :raises ValidationError: If the schema is invalid.
+ """
+ if data["default_action"] not in set([DefaultActionType.DENY, DefaultActionType.ALLOW]):
+ raise ValidationError("Invalid value for default_action. Must be 'Deny' or 'Allow'.")
+
+ if data["default_action"] == DefaultActionType.DENY and not data.get("ip_rules"):
+ raise ValidationError("ip_rules must be provided when default_action is 'Deny'.")
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/networking.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/networking.py
new file mode 100644
index 00000000..f228ee3e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/networking.py
@@ -0,0 +1,224 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument,no-else-return
+
+from marshmallow import EXCLUDE, fields
+from marshmallow.decorators import post_load, pre_dump
+
+from azure.ai.ml._schema import ExperimentalField
+from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum, UnionField
+from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
+from azure.ai.ml._utils.utils import _snake_to_camel, camel_to_snake
+from azure.ai.ml.constants._workspace import FirewallSku, IsolationMode, OutboundRuleCategory
+from azure.ai.ml.entities._workspace.networking import (
+ FqdnDestination,
+ ManagedNetwork,
+ PrivateEndpointDestination,
+ ServiceTagDestination,
+)
+
+
+class ManagedNetworkStatusSchema(metaclass=PatchedSchemaMeta):
+ spark_ready = fields.Bool(dump_only=True)
+ status = fields.Str(dump_only=True)
+
+
+class FqdnOutboundRuleSchema(metaclass=PatchedSchemaMeta):
+ name = fields.Str(required=True)
+ parent_rule_names = fields.List(fields.Str(), dump_only=True)
+ type = fields.Constant("fqdn")
+ destination = fields.Str(required=True)
+ category = StringTransformedEnum(
+ allowed_values=[
+ OutboundRuleCategory.REQUIRED,
+ OutboundRuleCategory.RECOMMENDED,
+ OutboundRuleCategory.USER_DEFINED,
+ ],
+ casing_transform=camel_to_snake,
+ metadata={"description": "outbound rule category."},
+ dump_only=True,
+ )
+ status = fields.Str(dump_only=True)
+
+ @post_load
+ def createdestobject(self, data, **kwargs):
+ dest = data.get("destination")
+ category = data.get("category", OutboundRuleCategory.USER_DEFINED)
+ name = data.get("name")
+ status = data.get("status", None)
+ return FqdnDestination(
+ name=name,
+ destination=dest,
+ category=_snake_to_camel(category),
+ status=status,
+ )
+
+
+class ServiceTagDestinationSchema(metaclass=PatchedSchemaMeta):
+ service_tag = fields.Str(required=True)
+ protocol = fields.Str(required=True)
+ port_ranges = fields.Str(required=True)
+ address_prefixes = fields.List(fields.Str())
+
+
+class ServiceTagOutboundRuleSchema(metaclass=PatchedSchemaMeta):
+ name = fields.Str(required=True)
+ parent_rule_names = fields.List(fields.Str(), dump_only=True)
+ type = fields.Constant("service_tag")
+ destination = NestedField(ServiceTagDestinationSchema, required=True)
+ category = StringTransformedEnum(
+ allowed_values=[
+ OutboundRuleCategory.REQUIRED,
+ OutboundRuleCategory.RECOMMENDED,
+ OutboundRuleCategory.USER_DEFINED,
+ ],
+ casing_transform=camel_to_snake,
+ metadata={"description": "outbound rule category."},
+ dump_only=True,
+ )
+ status = fields.Str(dump_only=True)
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ data.destination = self.service_tag_dest2dict(
+ data.service_tag, data.protocol, data.port_ranges, data.address_prefixes
+ )
+ return data
+
+ @post_load
+ def createdestobject(self, data, **kwargs):
+ dest = data.get("destination")
+ category = data.get("category", OutboundRuleCategory.USER_DEFINED)
+ name = data.get("name")
+ status = data.get("status", None)
+ return ServiceTagDestination(
+ name=name,
+ service_tag=dest["service_tag"],
+ protocol=dest["protocol"],
+ port_ranges=dest["port_ranges"],
+ address_prefixes=dest.get("address_prefixes", None),
+ category=_snake_to_camel(category),
+ status=status,
+ )
+
+ def service_tag_dest2dict(self, service_tag, protocol, port_ranges, address_prefixes):
+ service_tag_dest = {}
+ service_tag_dest["service_tag"] = service_tag
+ service_tag_dest["protocol"] = protocol
+ service_tag_dest["port_ranges"] = port_ranges
+ service_tag_dest["address_prefixes"] = address_prefixes
+ return service_tag_dest
+
+
+class PrivateEndpointDestinationSchema(metaclass=PatchedSchemaMeta):
+ service_resource_id = fields.Str(required=True)
+ subresource_target = fields.Str(required=True)
+ spark_enabled = fields.Bool(required=True)
+
+
+class PrivateEndpointOutboundRuleSchema(metaclass=PatchedSchemaMeta):
+ name = fields.Str(required=True)
+ parent_rule_names = fields.List(fields.Str(), dump_only=True)
+ type = fields.Constant("private_endpoint")
+ destination = NestedField(PrivateEndpointDestinationSchema, required=True)
+ fqdns = fields.List(fields.Str())
+ category = StringTransformedEnum(
+ allowed_values=[
+ OutboundRuleCategory.REQUIRED,
+ OutboundRuleCategory.RECOMMENDED,
+ OutboundRuleCategory.USER_DEFINED,
+ OutboundRuleCategory.DEPENDENCY,
+ ],
+ casing_transform=camel_to_snake,
+ metadata={"description": "outbound rule category."},
+ dump_only=True,
+ )
+ status = fields.Str(dump_only=True)
+
+ @pre_dump
+ def predump(self, data, **kwargs):
+ data.destination = self.pe_dest2dict(data.service_resource_id, data.subresource_target, data.spark_enabled)
+ return data
+
+ @post_load
+ def createdestobject(self, data, **kwargs):
+ dest = data.get("destination")
+ category = data.get("category", OutboundRuleCategory.USER_DEFINED)
+ name = data.get("name")
+ status = data.get("status", None)
+ fqdns = data.get("fqdns", None)
+ return PrivateEndpointDestination(
+ name=name,
+ service_resource_id=dest["service_resource_id"],
+ subresource_target=dest["subresource_target"],
+ spark_enabled=dest["spark_enabled"],
+ category=_snake_to_camel(category),
+ status=status,
+ fqdns=fqdns,
+ )
+
+ def pe_dest2dict(self, service_resource_id, subresource_target, spark_enabled):
+ pedest = {}
+ pedest["service_resource_id"] = service_resource_id
+ pedest["subresource_target"] = subresource_target
+ pedest["spark_enabled"] = spark_enabled
+ return pedest
+
+
+class ManagedNetworkSchema(metaclass=PatchedSchemaMeta):
+ isolation_mode = StringTransformedEnum(
+ allowed_values=[
+ IsolationMode.DISABLED,
+ IsolationMode.ALLOW_INTERNET_OUTBOUND,
+ IsolationMode.ALLOW_ONLY_APPROVED_OUTBOUND,
+ ],
+ casing_transform=camel_to_snake,
+ metadata={"description": "isolation mode for the workspace managed network."},
+ )
+ outbound_rules = fields.List(
+ UnionField(
+ [
+ NestedField(PrivateEndpointOutboundRuleSchema, allow_none=False, unknown=EXCLUDE),
+ NestedField(ServiceTagOutboundRuleSchema, allow_none=False, unknown=EXCLUDE),
+ NestedField(
+ FqdnOutboundRuleSchema, allow_none=False, unknown=EXCLUDE
+ ), # this needs to be last since otherwise union field with match destination as a string
+ ],
+ allow_none=False,
+ is_strict=True,
+ ),
+ allow_none=True,
+ )
+ firewall_sku = ExperimentalField(
+ StringTransformedEnum(
+ allowed_values=[
+ FirewallSku.STANDARD,
+ FirewallSku.BASIC,
+ ],
+ casing_transform=camel_to_snake,
+ metadata={"description": "Firewall sku for FQDN rules in AllowOnlyApprovedOutbound mode"},
+ )
+ )
+ network_id = fields.Str(required=False, dump_only=True)
+ status = NestedField(ManagedNetworkStatusSchema, allow_none=False, unknown=EXCLUDE)
+
+ @post_load
+ def make(self, data, **kwargs):
+ outbound_rules = data.get("outbound_rules", False)
+
+ firewall_sku = data.get("firewall_sku", False)
+ firewall_sku_value = _snake_to_camel(data["firewall_sku"]) if firewall_sku else FirewallSku.STANDARD
+
+ if outbound_rules:
+ return ManagedNetwork(
+ isolation_mode=_snake_to_camel(data["isolation_mode"]),
+ outbound_rules=outbound_rules,
+ firewall_sku=firewall_sku_value,
+ )
+ else:
+ return ManagedNetwork(
+ isolation_mode=_snake_to_camel(data["isolation_mode"]),
+ firewall_sku=firewall_sku_value,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/private_endpoint.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/private_endpoint.py
new file mode 100644
index 00000000..0235a4a0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/private_endpoint.py
@@ -0,0 +1,23 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.fields import NestedField
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+
+from .endpoint_connection import EndpointConnectionSchema
+
+
+class PrivateEndpointSchema(metaclass=PatchedSchemaMeta):
+ approval_type = fields.Str()
+ connections = fields.Dict(keys=fields.Str(), values=NestedField(EndpointConnectionSchema))
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities import PrivateEndpoint
+
+ return PrivateEndpoint(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/serverless_compute.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/serverless_compute.py
new file mode 100644
index 00000000..5137e57f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/serverless_compute.py
@@ -0,0 +1,52 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from marshmallow import fields
+from marshmallow.decorators import post_load, validates
+
+from azure.ai.ml._schema._utils.utils import ArmId
+from azure.ai.ml._schema.core.schema import PathAwareSchema
+from azure.ai.ml.entities._workspace.serverless_compute import ServerlessComputeSettings
+
+
+class ServerlessComputeSettingsSchema(PathAwareSchema):
+ """Schema for ServerlessComputeSettings.
+
+ :param custom_subnet: The custom subnet to use for serverless computes created in the workspace.
+ :type custom_subnet: Optional[ArmId]
+ :param no_public_ip: Whether to disable public ip for the compute. Only valid if custom_subnet is defined.
+ :type no_public_ip: bool
+ """
+
+ custom_subnet = fields.Str(allow_none=True)
+ no_public_ip = fields.Bool(load_default=False)
+
+ @post_load
+ def make(self, data, **_kwargs) -> ServerlessComputeSettings:
+ """Create a ServerlessComputeSettings object from the marshmallow schema.
+
+ :param data: The data from which the ServerlessComputeSettings are being loaded.
+ :type data: OrderedDict[str, Any]
+ :returns: A ServerlessComputeSettings object.
+ :rtype: azure.ai.ml.entities._workspace.serverless_compute.ServerlessComputeSettings
+ """
+ custom_subnet = data.pop("custom_subnet", None)
+ if custom_subnet == "None":
+ custom_subnet = None # For loading from YAML when the user wants to trigger a removal
+ no_public_ip = data.pop("no_public_ip", False)
+ return ServerlessComputeSettings(custom_subnet=custom_subnet, no_public_ip=no_public_ip)
+
+ @validates("custom_subnet")
+ def validate_custom_subnet(self, data: str, **_kwargs):
+ """Validates the custom_subnet field matches the ARM ID format or is a None-recognizable value.
+
+ :param data: The candidate custom_subnet to validate.
+ :type data: str
+ :raises ValidationError: If the custom_subnet is not formatted as an ARM ID.
+ """
+ if data == "None" or data is None:
+ # If the string is literally "None", then it should be deserialized to None
+ pass
+ else:
+ # Verify that we can transform it to an ArmId if it is not None.
+ ArmId(data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/workspace.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/workspace.py
new file mode 100644
index 00000000..1df06f97
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/workspace/workspace.py
@@ -0,0 +1,49 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import EXCLUDE, fields
+
+from azure.ai.ml._schema._utils.utils import validate_arm_str
+from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum
+from azure.ai.ml._schema.core.schema import PathAwareSchema
+from azure.ai.ml._schema.workspace.customer_managed_key import CustomerManagedKeySchema
+from azure.ai.ml._schema.workspace.identity import IdentitySchema
+from azure.ai.ml._schema.workspace.network_acls import NetworkAclsSchema
+from azure.ai.ml._schema.workspace.networking import ManagedNetworkSchema
+from azure.ai.ml._schema.workspace.serverless_compute import ServerlessComputeSettingsSchema
+from azure.ai.ml._utils.utils import snake_to_pascal
+from azure.ai.ml.constants._common import PublicNetworkAccess
+
+
+class WorkspaceSchema(PathAwareSchema):
+ name = fields.Str(required=True)
+ location = fields.Str()
+ id = fields.Str(dump_only=True)
+ resource_group = fields.Str()
+ description = fields.Str()
+ discovery_url = fields.Str()
+ display_name = fields.Str()
+ hbi_workspace = fields.Bool()
+ storage_account = fields.Str(validate=validate_arm_str)
+ container_registry = fields.Str(validate=validate_arm_str)
+ key_vault = fields.Str(validate=validate_arm_str)
+ application_insights = fields.Str(validate=validate_arm_str)
+ customer_managed_key = NestedField(CustomerManagedKeySchema)
+ tags = fields.Dict(keys=fields.Str(), values=fields.Str())
+ mlflow_tracking_uri = fields.Str(dump_only=True)
+ image_build_compute = fields.Str()
+ public_network_access = StringTransformedEnum(
+ allowed_values=[PublicNetworkAccess.DISABLED, PublicNetworkAccess.ENABLED],
+ casing_transform=snake_to_pascal,
+ )
+ network_acls = NestedField(NetworkAclsSchema)
+ system_datastores_auth_mode = fields.Str()
+ identity = NestedField(IdentitySchema)
+ primary_user_assigned_identity = fields.Str()
+ workspace_hub = fields.Str(validate=validate_arm_str)
+ managed_network = NestedField(ManagedNetworkSchema, unknown=EXCLUDE)
+ provision_network_now = fields.Bool()
+ enable_data_isolation = fields.Bool()
+ allow_roleassignment_on_rg = fields.Bool()
+ serverless_compute = NestedField(ServerlessComputeSettingsSchema)