aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/entities
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/entities
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/entities')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/__init__.py631
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/__init__.py17
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/_package/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/_package/base_environment_source.py48
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/_package/inferencing_server.py216
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/_package/model_configuration.py55
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/_package/model_package.py338
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/artifact.py131
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/code.py142
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/data.py237
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/feature_set.py220
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/index.py137
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/model.py219
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/asset.py145
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/auto_delete_setting.py42
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/environment.py478
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/federated_learning_silo.py123
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/intellectual_property.py49
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/workspace_asset_reference.py87
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/__init__.py20
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/_model_base.py881
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/_patch.py20
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/_serialization.py1998
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/models/__init__.py18
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/models/_models.py214
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/models/_patch.py223
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/__init__.py28
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/base_node.py568
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/command.py1017
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/command_func.py314
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/condition_node.py146
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/control_flow_node.py170
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/data_transfer.py575
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/data_transfer_func.py335
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/do_while.py357
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/fl_scatter_gather.py886
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/import_func.py93
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/import_node.py205
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/parallel.py551
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/parallel_for.py362
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/parallel_func.py285
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/pipeline.py225
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/spark.py663
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/spark_func.py306
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/subcomponents.py59
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/sweep.py454
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/_additional_includes.py541
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/automl_component.py42
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/code.py297
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/command_component.py300
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/component.py641
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/component_factory.py171
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/datatransfer_component.py325
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/flow.py553
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/import_component.py96
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/parallel_component.py305
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/pipeline_component.py529
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/spark_component.py211
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_aml_compute_node_info.py50
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_custom_applications.py221
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_image_metadata.py63
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_schedule.py153
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_setup_scripts.py90
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_usage.py100
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_vm_size.py104
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/aml_compute.py281
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/compute.py261
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/compute_instance.py511
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/kubernetes_compute.py105
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/synapsespark_compute.py234
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/unsupported_compute.py62
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/virtual_machine_compute.py172
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_credentials.py964
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_data/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_data/mltable_metadata.py92
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_data_import/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_data_import/data_import.py130
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_data_import/schedule.py115
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/_constants.py8
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/_on_prem.py121
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/_on_prem_credentials.py128
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/adls_gen1.py106
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/azure_storage.py337
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/datastore.py221
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/one_lake.py153
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/utils.py70
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/batch_deployment.py356
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/batch_job.py38
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/code_configuration.py93
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/container_resource_settings.py74
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/data_asset.py38
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/data_collector.py84
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/deployment.py213
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/deployment_collection.py62
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/deployment_settings.py200
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/event_hub.py32
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/job_definition.py58
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/model_batch_deployment.py207
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/model_batch_deployment_settings.py81
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/online_deployment.py742
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/oversize_data_config.py25
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/payload_response.py26
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/pipeline_component_batch_deployment.py150
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/request_logging.py39
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/resource_requirements_settings.py84
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/run_settings.py50
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/scale_settings.py173
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/_endpoint_helpers.py62
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/batch_endpoint.py134
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/endpoint.py145
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/online_endpoint.py647
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/data_availability_status.py15
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/delay_metadata.py17
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature.py54
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_set_backfill_metadata.py39
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_set_backfill_request.py91
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_set_materialization_metadata.py98
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_set_specification.py46
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_transformation_code_metadata.py13
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_window.py34
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/featureset_spec_metadata.py101
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/materialization_compute_resource.py41
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/materialization_settings.py100
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/materialization_type.py14
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/source_metadata.py69
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/source_process_code_metadata.py13
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/timestamp_column_metadata.py14
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store/_constants.py15
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store/feature_store.py226
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store/materialization_store.py49
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store_entity/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store_entity/data_column.py80
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store_entity/data_column_type.py34
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store_entity/feature_store_entity.py146
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/__init__.py16
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/data_index_func.py748
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/entities/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/entities/data_index.py243
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/_ai_search_config.py31
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/_index_config.py47
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/_index_data_source.py62
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/model_config.py122
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/__init__.py10
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/_models.py25
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/_open_ai_utils.py36
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/_pipeline_decorator.py248
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/__init__.py73
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/base.py34
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/enum_input.py133
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/external_data.py207
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/group_input.py251
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/input.py547
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/output.py180
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/utils.py479
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/_input_output_helpers.py427
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/_studio_url_from_job_id.py26
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/__init__.py16
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/automl_job.py283
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/automl_vertical.py134
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/featurization_settings.py32
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/__init__.py35
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/automl_image.py244
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/automl_image_classification_base.py439
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/automl_image_object_detection_base.py524
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_classification_job.py244
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_classification_multilabel_job.py252
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_classification_search_space.py437
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_instance_segmentation_job.py249
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_limit_settings.py117
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_model_settings.py876
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_object_detection_job.py240
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_object_detection_search_space.py899
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_sweep_settings.py86
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/__init__.py25
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/automl_nlp_job.py467
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/nlp_featurization_settings.py47
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/nlp_fixed_parameters.py117
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/nlp_limit_settings.py79
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/nlp_search_space.py185
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/nlp_sweep_settings.py65
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/text_classification_job.py248
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/text_classification_multilabel_job.py252
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/text_ner_job.py231
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/search_space.py14
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/search_space_utils.py276
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/stack_ensemble_settings.py70
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/__init__.py22
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/automl_tabular.py607
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/classification_job.py352
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/featurization_settings.py170
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/forecasting_job.py686
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/forecasting_settings.py383
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/limit_settings.py101
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/regression_job.py239
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/training_settings.py357
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/utils.py47
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/base_job.py85
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/command_job.py314
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/compute_configuration.py110
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/data_transfer/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/data_transfer/data_transfer_job.py358
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/constants.py20
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/distillation_job.py542
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/endpoint_request_settings.py90
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/prompt_settings.py138
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/teacher_model_settings.py93
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distribution.py229
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/azure_openai_finetuning_job.py242
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/azure_openai_hyperparameters.py125
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/custom_model_finetuning_job.py258
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/finetuning_job.py224
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/finetuning_vertical.py202
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/import_job.py285
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/input_output_entry.py27
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/input_port.py18
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job.py363
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_io_mixin.py37
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_limits.py201
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_name_generator.py487
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_resource_configuration.py239
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_resources.py33
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_service.py424
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/parallel_job.py244
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/parallel_task.py119
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/parameterized_parallel.py96
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/retry_settings.py78
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/run_function.py66
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parameterized_command.py170
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parameterized_spark.py88
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_attr_dict.py161
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_component_translatable.py412
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_io/__init__.py21
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_io/attr_dict.py170
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_io/base.py848
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_io/mixin.py623
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_load_component.py313
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_pipeline_expression.py662
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_pipeline_job_helpers.py182
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/data/expression_component_template.yml16
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/pipeline_job.py711
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/pipeline_job_settings.py75
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/queue_settings.py87
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/resource_configuration.py98
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/service_instance.py59
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/spark_helpers.py210
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/spark_job.py393
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/spark_job_entry.py59
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/spark_job_entry_mixin.py64
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/spark_resource_configuration.py91
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/early_termination_policy.py191
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/objective.py53
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/parameterized_sweep.py341
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/sampling_algorithm.py141
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/search_space.py393
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/sweep_job.py361
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/to_rest_functions.py82
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_load_functions.py1103
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_mixins.py163
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/alert_notification.py54
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/compute.py55
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/definition.py162
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/input_data.py206
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/schedule.py175
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/signals.py1338
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/target.py55
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/thresholds.py954
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_notification/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_notification/notification.py33
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_registry/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_registry/registry.py231
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_registry/registry_support_classes.py273
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_registry/util.py17
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_resource.py194
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_schedule/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_schedule/schedule.py513
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_schedule/trigger.py290
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_system_data.py77
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_util.py645
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validate_funcs.py94
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/__init__.py18
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/core.py531
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/path_aware_schema.py53
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/remote.py162
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/schema.py156
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/_ai_workspaces/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/_ai_workspaces/_constants.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/_ai_workspaces/capability_host.py187
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/_ai_workspaces/hub.py220
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/_ai_workspaces/project.py89
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/compute_runtime.py41
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/connections/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/connections/connection_subtypes.py748
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/connections/one_lake_artifacts.py25
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/connections/workspace_connection.py677
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/customer_managed_key.py48
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/diagnose.py214
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/feature_store_settings.py61
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/network_acls.py90
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/networking.py348
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/private_endpoint.py53
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/serverless_compute.py52
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/workspace.py491
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/workspace_keys.py100
320 files changed, 66328 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/__init__.py
new file mode 100644
index 00000000..508dea7c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/__init__.py
@@ -0,0 +1,631 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+"""Contains entities and SDK objects for Azure Machine Learning SDKv2.
+
+Main areas include managing compute targets, creating/managing workspaces and jobs, and submitting/accessing model, runs
+and run output/logging etc.
+"""
+# pylint: disable=naming-mismatch
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
+
+import logging
+from typing import Any, Optional
+
+from azure.ai.ml._restclient.v2022_10_01.models import CreatedByType
+from azure.ai.ml._restclient.v2022_10_01_preview.models import UsageUnit
+
+from ._assets._artifacts._package.base_environment_source import BaseEnvironment
+from ._assets._artifacts._package.inferencing_server import (
+ AzureMLBatchInferencingServer,
+ AzureMLOnlineInferencingServer,
+ CustomInferencingServer,
+ Route,
+ TritonInferencingServer,
+)
+from ._assets._artifacts._package.model_configuration import ModelConfiguration
+from ._assets._artifacts._package.model_package import (
+ ModelPackage,
+ ModelPackageInput,
+ PackageInputPathId,
+ PackageInputPathUrl,
+ PackageInputPathVersion,
+)
+from ._assets._artifacts.data import Data
+from ._assets._artifacts.feature_set import FeatureSet
+from ._assets._artifacts.index import Index
+from ._assets._artifacts.model import Model
+from ._assets.asset import Asset
+from ._assets.environment import BuildContext, Environment
+from ._assets.intellectual_property import IntellectualProperty
+from ._assets.workspace_asset_reference import (
+ WorkspaceAssetReference as WorkspaceModelReference,
+)
+from ._autogen_entities.models import (
+ AzureOpenAIDeployment,
+ MarketplacePlan,
+ MarketplaceSubscription,
+ ServerlessEndpoint,
+)
+from ._builders import Command, Parallel, Pipeline, Spark, Sweep
+from ._component.command_component import CommandComponent
+from ._component.component import Component
+from ._component.parallel_component import ParallelComponent
+from ._component.pipeline_component import PipelineComponent
+from ._component.spark_component import SparkComponent
+from ._compute._aml_compute_node_info import AmlComputeNodeInfo
+from ._compute._custom_applications import (
+ CustomApplications,
+ EndpointsSettings,
+ ImageSettings,
+ VolumeSettings,
+)
+from ._compute._image_metadata import ImageMetadata
+from ._compute._schedule import (
+ ComputePowerAction,
+ ComputeSchedules,
+ ComputeStartStopSchedule,
+ ScheduleState,
+)
+from ._compute._setup_scripts import ScriptReference, SetupScripts
+from ._compute._usage import Usage, UsageName
+from ._compute._vm_size import VmSize
+from ._compute.aml_compute import AmlCompute, AmlComputeSshSettings
+from ._compute.compute import Compute, NetworkSettings
+from ._compute.compute_instance import (
+ AssignedUserConfiguration,
+ ComputeInstance,
+ ComputeInstanceSshSettings,
+)
+from ._compute.kubernetes_compute import KubernetesCompute
+from ._compute.synapsespark_compute import (
+ AutoPauseSettings,
+ AutoScaleSettings,
+ SynapseSparkCompute,
+)
+from ._compute.unsupported_compute import UnsupportedCompute
+from ._compute.virtual_machine_compute import (
+ VirtualMachineCompute,
+ VirtualMachineSshSettings,
+)
+from ._credentials import (
+ AadCredentialConfiguration,
+ AccessKeyConfiguration,
+ AccountKeyConfiguration,
+ AmlTokenConfiguration,
+ ApiKeyConfiguration,
+ CertificateConfiguration,
+ IdentityConfiguration,
+ ManagedIdentityConfiguration,
+ NoneCredentialConfiguration,
+ PatTokenConfiguration,
+ SasTokenConfiguration,
+ ServicePrincipalConfiguration,
+ UserIdentityConfiguration,
+ UsernamePasswordConfiguration,
+)
+from ._data_import.data_import import DataImport
+from ._data_import.schedule import ImportDataSchedule
+from ._datastore.adls_gen1 import AzureDataLakeGen1Datastore
+from ._datastore.azure_storage import (
+ AzureBlobDatastore,
+ AzureDataLakeGen2Datastore,
+ AzureFileDatastore,
+)
+from ._datastore.datastore import Datastore
+from ._datastore.one_lake import OneLakeArtifact, OneLakeDatastore
+from ._deployment.batch_deployment import BatchDeployment
+from ._deployment.batch_job import BatchJob
+from ._deployment.code_configuration import CodeConfiguration
+from ._deployment.container_resource_settings import ResourceSettings
+from ._deployment.data_asset import DataAsset
+from ._deployment.data_collector import DataCollector
+from ._deployment.deployment_collection import DeploymentCollection
+from ._deployment.deployment_settings import (
+ BatchRetrySettings,
+ OnlineRequestSettings,
+ ProbeSettings,
+)
+from ._deployment.model_batch_deployment import ModelBatchDeployment
+from ._deployment.model_batch_deployment_settings import ModelBatchDeploymentSettings
+from ._deployment.online_deployment import (
+ Deployment,
+ KubernetesOnlineDeployment,
+ ManagedOnlineDeployment,
+ OnlineDeployment,
+)
+from ._deployment.pipeline_component_batch_deployment import (
+ PipelineComponentBatchDeployment,
+)
+from ._deployment.request_logging import RequestLogging
+from ._deployment.resource_requirements_settings import ResourceRequirementsSettings
+from ._deployment.scale_settings import (
+ DefaultScaleSettings,
+ OnlineScaleSettings,
+ TargetUtilizationScaleSettings,
+)
+from ._endpoint.batch_endpoint import BatchEndpoint
+from ._endpoint.endpoint import Endpoint
+from ._endpoint.online_endpoint import (
+ EndpointAadToken,
+ EndpointAuthKeys,
+ EndpointAuthToken,
+ KubernetesOnlineEndpoint,
+ ManagedOnlineEndpoint,
+ OnlineEndpoint,
+)
+from ._feature_set.data_availability_status import DataAvailabilityStatus
+from ._feature_set.feature import Feature
+from ._feature_set.feature_set_backfill_metadata import FeatureSetBackfillMetadata
+from ._feature_set.feature_set_backfill_request import FeatureSetBackfillRequest
+from ._feature_set.feature_set_materialization_metadata import (
+ FeatureSetMaterializationMetadata,
+)
+from ._feature_set.feature_set_specification import FeatureSetSpecification
+from ._feature_set.feature_window import FeatureWindow
+from ._feature_set.materialization_compute_resource import (
+ MaterializationComputeResource,
+)
+from ._feature_set.materialization_settings import MaterializationSettings
+from ._feature_set.materialization_type import MaterializationType
+from ._feature_store.feature_store import FeatureStore
+from ._feature_store.materialization_store import MaterializationStore
+from ._feature_store_entity.data_column import DataColumn
+from ._feature_store_entity.data_column_type import DataColumnType
+from ._feature_store_entity.feature_store_entity import FeatureStoreEntity
+from ._indexes import AzureAISearchConfig, GitSource, IndexDataSource, LocalSource
+from ._indexes import ModelConfiguration as IndexModelConfiguration
+from ._job.command_job import CommandJob
+from ._job.compute_configuration import ComputeConfiguration
+from ._job.finetuning.custom_model_finetuning_job import CustomModelFineTuningJob
+from ._job.input_port import InputPort
+from ._job.job import Job
+from ._job.job_limits import CommandJobLimits
+from ._job.job_resources import JobResources
+from ._job.job_resource_configuration import JobResourceConfiguration
+from ._job.job_service import (
+ JobService,
+ JupyterLabJobService,
+ SshJobService,
+ TensorBoardJobService,
+ VsCodeJobService,
+)
+from ._job.parallel.parallel_task import ParallelTask
+from ._job.parallel.retry_settings import RetrySettings
+from ._job.parameterized_command import ParameterizedCommand
+
+# Pipeline related entities goes behind component since it depends on component
+from ._job.pipeline.pipeline_job import PipelineJob, PipelineJobSettings
+from ._job.queue_settings import QueueSettings
+from ._job.resource_configuration import ResourceConfiguration
+from ._job.service_instance import ServiceInstance
+from ._job.spark_job import SparkJob
+from ._job.spark_job_entry import SparkJobEntry, SparkJobEntryType
+from ._job.spark_resource_configuration import SparkResourceConfiguration
+from ._monitoring.alert_notification import AlertNotification
+from ._monitoring.compute import ServerlessSparkCompute
+from ._monitoring.definition import MonitorDefinition
+from ._monitoring.input_data import (
+ FixedInputData,
+ MonitorInputData,
+ StaticInputData,
+ TrailingInputData,
+)
+from ._monitoring.schedule import MonitorSchedule
+from ._monitoring.signals import (
+ BaselineDataRange,
+ CustomMonitoringSignal,
+ DataDriftSignal,
+ DataQualitySignal,
+ DataSegment,
+ FADProductionData,
+ FeatureAttributionDriftSignal,
+ GenerationSafetyQualitySignal,
+ GenerationTokenStatisticsSignal,
+ LlmData,
+ ModelPerformanceSignal,
+ MonitorFeatureFilter,
+ PredictionDriftSignal,
+ ProductionData,
+ ReferenceData,
+)
+from ._monitoring.target import MonitoringTarget
+from ._monitoring.thresholds import (
+ CategoricalDriftMetrics,
+ CustomMonitoringMetricThreshold,
+ DataDriftMetricThreshold,
+ DataQualityMetricsCategorical,
+ DataQualityMetricsNumerical,
+ DataQualityMetricThreshold,
+ FeatureAttributionDriftMetricThreshold,
+ GenerationSafetyQualityMonitoringMetricThreshold,
+ GenerationTokenStatisticsMonitorMetricThreshold,
+ ModelPerformanceClassificationThresholds,
+ ModelPerformanceMetricThreshold,
+ ModelPerformanceRegressionThresholds,
+ NumericalDriftMetrics,
+ PredictionDriftMetricThreshold,
+)
+from ._notification.notification import Notification
+from ._registry.registry import Registry
+from ._registry.registry_support_classes import (
+ RegistryRegionDetails,
+ SystemCreatedAcrAccount,
+ SystemCreatedStorageAccount,
+)
+from ._resource import Resource
+from ._schedule.schedule import JobSchedule, Schedule, ScheduleTriggerResult
+from ._schedule.trigger import CronTrigger, RecurrencePattern, RecurrenceTrigger
+from ._system_data import SystemData
+from ._validation import ValidationResult
+from ._workspace._ai_workspaces.hub import Hub
+from ._workspace._ai_workspaces.project import Project
+from ._workspace.compute_runtime import ComputeRuntime
+from ._workspace.connections.connection_subtypes import (
+ APIKeyConnection,
+ AzureAISearchConnection,
+ AzureAIServicesConnection,
+ AzureBlobStoreConnection,
+ AzureContentSafetyConnection,
+ AzureOpenAIConnection,
+ AzureSpeechServicesConnection,
+ MicrosoftOneLakeConnection,
+ OpenAIConnection,
+ SerpConnection,
+ ServerlessConnection,
+)
+from ._workspace.connections.one_lake_artifacts import OneLakeConnectionArtifact
+from ._workspace.connections.workspace_connection import WorkspaceConnection
+from ._workspace.customer_managed_key import CustomerManagedKey
+from ._workspace.diagnose import (
+ DiagnoseRequestProperties,
+ DiagnoseResponseResult,
+ DiagnoseResponseResultValue,
+ DiagnoseResult,
+ DiagnoseWorkspaceParameters,
+)
+from ._workspace.feature_store_settings import FeatureStoreSettings
+from ._workspace.network_acls import DefaultActionType, IPRule, NetworkAcls
+from ._workspace.networking import (
+ FqdnDestination,
+ IsolationMode,
+ ManagedNetwork,
+ ManagedNetworkProvisionStatus,
+ OutboundRule,
+ PrivateEndpointDestination,
+ ServiceTagDestination,
+)
+from ._workspace.private_endpoint import EndpointConnection, PrivateEndpoint
+from ._workspace.serverless_compute import ServerlessComputeSettings
+from ._workspace.workspace import Workspace
+from ._workspace._ai_workspaces.capability_host import (
+ CapabilityHost,
+ CapabilityHostKind,
+)
+from ._workspace.workspace_keys import (
+ ContainerRegistryCredential,
+ NotebookAccessKeys,
+ WorkspaceKeys,
+)
+
+__all__ = [
+ "Resource",
+ "Job",
+ "CommandJob",
+ "PipelineJob",
+ "ServiceInstance",
+ "SystemData",
+ "SparkJob",
+ "SparkJobEntry",
+ "SparkJobEntryType",
+ "CommandJobLimits",
+ "ComputeConfiguration",
+ "CustomModelFineTuningJob",
+ "CreatedByType",
+ "ResourceConfiguration",
+ "JobResources",
+ "JobResourceConfiguration",
+ "QueueSettings",
+ "JobService",
+ "SshJobService",
+ "TensorBoardJobService",
+ "VsCodeJobService",
+ "JupyterLabJobService",
+ "SparkResourceConfiguration",
+ "ParameterizedCommand",
+ "InputPort",
+ "BatchEndpoint",
+ "OnlineEndpoint",
+ "Deployment",
+ "BatchDeployment",
+ "BatchJob",
+ "CodeConfiguration",
+ "Endpoint",
+ "OnlineDeployment",
+ "Data",
+ "KubernetesOnlineEndpoint",
+ "ManagedOnlineEndpoint",
+ "KubernetesOnlineDeployment",
+ "ManagedOnlineDeployment",
+ "OnlineRequestSettings",
+ "OnlineScaleSettings",
+ "ProbeSettings",
+ "BatchRetrySettings",
+ "RetrySettings",
+ "ParallelTask",
+ "DefaultScaleSettings",
+ "TargetUtilizationScaleSettings",
+ "Asset",
+ "Environment",
+ "BuildContext",
+ "Model",
+ "ModelBatchDeployment",
+ "ModelBatchDeploymentSettings",
+ "IPRule",
+ "DefaultActionType",
+ "NetworkAcls",
+ "Workspace",
+ "WorkspaceKeys",
+ "WorkspaceConnection",
+ "AzureBlobStoreConnection",
+ "MicrosoftOneLakeConnection",
+ "AzureOpenAIConnection",
+ "AzureAIServicesConnection",
+ "AzureAISearchConnection",
+ "AzureContentSafetyConnection",
+ "AzureSpeechServicesConnection",
+ "APIKeyConnection",
+ "OpenAIConnection",
+ "SerpConnection",
+ "ServerlessConnection",
+ "DiagnoseRequestProperties",
+ "DiagnoseResult",
+ "DiagnoseResponseResult",
+ "DiagnoseResponseResultValue",
+ "DiagnoseWorkspaceParameters",
+ "PrivateEndpoint",
+ "OutboundRule",
+ "ManagedNetwork",
+ "FqdnDestination",
+ "ServiceTagDestination",
+ "PrivateEndpointDestination",
+ "IsolationMode",
+ "ManagedNetworkProvisionStatus",
+ "EndpointConnection",
+ "CustomerManagedKey",
+ "DataImport",
+ "Datastore",
+ "AzureDataLakeGen1Datastore",
+ "AzureBlobDatastore",
+ "AzureDataLakeGen2Datastore",
+ "AzureFileDatastore",
+ "OneLakeDatastore",
+ "OneLakeArtifact",
+ "OneLakeConnectionArtifact",
+ "Compute",
+ "VirtualMachineCompute",
+ "AmlCompute",
+ "ComputeInstance",
+ "UnsupportedCompute",
+ "KubernetesCompute",
+ "NetworkSettings",
+ "Component",
+ "PipelineJobSettings",
+ "PipelineComponentBatchDeployment",
+ "ParallelComponent",
+ "CommandComponent",
+ "SparkComponent",
+ "ResourceRequirementsSettings",
+ "ResourceSettings",
+ "AssignedUserConfiguration",
+ "ComputeInstanceSshSettings",
+ "VmSize",
+ "Usage",
+ "UsageName",
+ "UsageUnit",
+ "CronTrigger",
+ "RecurrenceTrigger",
+ "RecurrencePattern",
+ "JobSchedule",
+ "ImportDataSchedule",
+ "Schedule",
+ "ScheduleTriggerResult",
+ "ComputePowerAction",
+ "ComputeSchedules",
+ "ComputeStartStopSchedule",
+ "ScheduleState",
+ "PipelineComponent",
+ "VirtualMachineSshSettings",
+ "AmlComputeSshSettings",
+ "AmlComputeNodeInfo",
+ "ImageMetadata",
+ "CustomApplications",
+ "ImageSettings",
+ "EndpointsSettings",
+ "VolumeSettings",
+ "SetupScripts",
+ "ScriptReference",
+ "SystemCreatedAcrAccount",
+ "SystemCreatedStorageAccount",
+ "ValidationResult",
+ "RegistryRegionDetails",
+ "Registry",
+ "SynapseSparkCompute",
+ "AutoScaleSettings",
+ "AutoPauseSettings",
+ "WorkspaceModelReference",
+ "Hub",
+ "Project",
+ "CapabilityHost",
+ "CapabilityHostKind",
+ "Feature",
+ "FeatureSet",
+ "FeatureSetBackfillRequest",
+ "ComputeRuntime",
+ "FeatureStoreSettings",
+ "FeatureStoreEntity",
+ "DataColumn",
+ "DataColumnType",
+ "FeatureSetSpecification",
+ "MaterializationComputeResource",
+ "FeatureWindow",
+ "MaterializationSettings",
+ "MaterializationType",
+ "FeatureStore",
+ "MaterializationStore",
+ "Notification",
+ "FeatureSetBackfillMetadata",
+ "DataAvailabilityStatus",
+ "FeatureSetMaterializationMetadata",
+ "ServerlessComputeSettings",
+ # builders
+ "Command",
+ "Parallel",
+ "Sweep",
+ "Spark",
+ "Pipeline",
+ "PatTokenConfiguration",
+ "SasTokenConfiguration",
+ "ManagedIdentityConfiguration",
+ "AccountKeyConfiguration",
+ "ServicePrincipalConfiguration",
+ "CertificateConfiguration",
+ "UsernamePasswordConfiguration",
+ "UserIdentityConfiguration",
+ "AmlTokenConfiguration",
+ "IdentityConfiguration",
+ "NotebookAccessKeys",
+ "ContainerRegistryCredential",
+ "EndpointAuthKeys",
+ "EndpointAuthToken",
+ "EndpointAadToken",
+ "ModelPackage",
+ "ModelPackageInput",
+ "AzureMLOnlineInferencingServer",
+ "AzureMLBatchInferencingServer",
+ "TritonInferencingServer",
+ "CustomInferencingServer",
+ "ModelConfiguration",
+ "BaseEnvironment",
+ "PackageInputPathId",
+ "PackageInputPathUrl",
+ "PackageInputPathVersion",
+ "Route",
+ "AccessKeyConfiguration",
+ "AlertNotification",
+ "ServerlessSparkCompute",
+ "ApiKeyConfiguration",
+ "MonitorDefinition",
+ "MonitorInputData",
+ "MonitorSchedule",
+ "DataDriftSignal",
+ "DataQualitySignal",
+ "PredictionDriftSignal",
+ "FeatureAttributionDriftSignal",
+ "CustomMonitoringSignal",
+ "GenerationSafetyQualitySignal",
+ "GenerationTokenStatisticsSignal",
+ "ModelPerformanceSignal",
+ "MonitorFeatureFilter",
+ "DataSegment",
+ "FADProductionData",
+ "LlmData",
+ "ProductionData",
+ "ReferenceData",
+ "BaselineDataRange",
+ "MonitoringTarget",
+ "FixedInputData",
+ "StaticInputData",
+ "TrailingInputData",
+ "DataDriftMetricThreshold",
+ "DataQualityMetricThreshold",
+ "PredictionDriftMetricThreshold",
+ "FeatureAttributionDriftMetricThreshold",
+ "CustomMonitoringMetricThreshold",
+ "GenerationSafetyQualityMonitoringMetricThreshold",
+ "GenerationTokenStatisticsMonitorMetricThreshold",
+ "CategoricalDriftMetrics",
+ "NumericalDriftMetrics",
+ "DataQualityMetricsNumerical",
+ "DataQualityMetricsCategorical",
+ "ModelPerformanceMetricThreshold",
+ "ModelPerformanceClassificationThresholds",
+ "ModelPerformanceRegressionThresholds",
+ "DataCollector",
+ "IntellectualProperty",
+ "DataAsset",
+ "DeploymentCollection",
+ "RequestLogging",
+ "NoneCredentialConfiguration",
+ "MarketplacePlan",
+ "MarketplaceSubscription",
+ "ServerlessEndpoint",
+ "AccountKeyConfiguration",
+ "AadCredentialConfiguration",
+ "Index",
+ "AzureOpenAIDeployment",
+ "AzureAISearchConfig",
+ "IndexDataSource",
+ "GitSource",
+ "LocalSource",
+ "IndexModelConfiguration",
+]
+
+# Allow importing these types for backwards compatibility
+
+
+def __getattr__(name: str):
+ requested: Optional[Any] = None
+
+ if name == "Choice":
+ from ..sweep import Choice
+
+ requested = Choice
+ if name == "LogNormal":
+ from ..sweep import LogNormal
+
+ requested = LogNormal
+ if name == "LogUniform":
+ from ..sweep import LogUniform
+
+ requested = LogUniform
+ if name == "Normal":
+ from ..sweep import Normal
+
+ requested = Normal
+ if name == "QLogNormal":
+ from ..sweep import QLogNormal
+
+ requested = QLogNormal
+ if name == "QLogUniform":
+ from ..sweep import QLogUniform
+
+ requested = QLogUniform
+ if name == "QNormal":
+ from ..sweep import QNormal
+
+ requested = QNormal
+ if name == "QUniform":
+ from ..sweep import QUniform
+
+ requested = QUniform
+ if name == "Randint":
+ from ..sweep import Randint
+
+ requested = Randint
+ if name == "Uniform":
+ from ..sweep import Uniform
+
+ requested = Uniform
+
+ if requested:
+ if not getattr(__getattr__, "warning_issued", False):
+ logging.warning(
+ " %s will be removed from the azure.ai.ml.entities namespace in a future release."
+ " Please import from the azure.ai.ml.sweep namespace instead.",
+ name,
+ )
+ __getattr__.warning_issued = True # type: ignore[attr-defined]
+ return requested
+
+ raise AttributeError(f"module 'azure.ai.ml.entities' has no attribute {name}")
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/__init__.py
new file mode 100644
index 00000000..5ee0f971
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/__init__.py
@@ -0,0 +1,17 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
+
+
+from ._artifacts.artifact import Artifact
+from ._artifacts.code import Code
+from ._artifacts.data import Data
+from ._artifacts.index import Index
+from ._artifacts.model import Model
+from .environment import Environment
+from ._artifacts._package.model_package import ModelPackage
+from .workspace_asset_reference import WorkspaceAssetReference
+
+__all__ = ["Artifact", "Model", "Code", "Data", "Index", "Environment", "WorkspaceAssetReference", "ModelPackage"]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/__init__.py
new file mode 100644
index 00000000..fdf8caba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/_package/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/_package/__init__.py
new file mode 100644
index 00000000..fdf8caba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/_package/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/_package/base_environment_source.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/_package/base_environment_source.py
new file mode 100644
index 00000000..1be67144
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/_package/base_environment_source.py
@@ -0,0 +1,48 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=redefined-builtin
+
+from typing import Dict, Optional
+
+from azure.ai.ml._restclient.v2023_08_01_preview.models import BaseEnvironmentId as RestBaseEnvironmentId
+from azure.ai.ml._schema.assets.package.base_environment_source import BaseEnvironmentSourceSchema
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+
+
+@experimental
+class BaseEnvironment:
+ """Base environment type.
+
+ All required parameters must be populated in order to send to Azure.
+
+ :param type: The type of the base environment.
+ :type type: str
+ :param resource_id: The resource id of the base environment. e.g. azureml:name:version
+ :type resource_id: str
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START base_env_entity_create]
+ :end-before: [END base_env_entity_create]
+ :language: python
+ :dedent: 8
+ :caption: Create a Base Environment object.
+ """
+
+ def __init__(self, type: str, resource_id: Optional[str] = None):
+ self.type = type
+ self.resource_id = resource_id
+
+ @classmethod
+ def _from_rest_object(cls, rest_obj: RestBaseEnvironmentId) -> "RestBaseEnvironmentId":
+ return BaseEnvironment(type=rest_obj.base_environment_source_type, resource_id=rest_obj.resource_id)
+
+ def _to_dict(self) -> Dict:
+ return dict(BaseEnvironmentSourceSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self))
+
+ def _to_rest_object(self) -> RestBaseEnvironmentId:
+ return RestBaseEnvironmentId(base_environment_source_type=self.type, resource_id=self.resource_id)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/_package/inferencing_server.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/_package/inferencing_server.py
new file mode 100644
index 00000000..6e685244
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/_package/inferencing_server.py
@@ -0,0 +1,216 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access,unused-argument
+
+from typing import Any, Optional
+
+from azure.ai.ml._restclient.v2023_02_01_preview.models import (
+ AzureMLOnlineInferencingServer as RestAzureMLOnlineInferencingServer,
+)
+from azure.ai.ml._restclient.v2023_02_01_preview.models import CustomInferencingServer as RestCustomInferencingServer
+from azure.ai.ml._restclient.v2023_02_01_preview.models import (
+ OnlineInferenceConfiguration as RestOnlineInferenceConfiguration,
+)
+from azure.ai.ml._restclient.v2023_02_01_preview.models import Route as RestRoute
+from azure.ai.ml._restclient.v2023_02_01_preview.models import TritonInferencingServer as RestTritonInferencingServer
+from azure.ai.ml._restclient.v2023_08_01_preview.models import (
+ AzureMLBatchInferencingServer as RestAzureMLBatchInferencingServer,
+)
+from azure.ai.ml._restclient.v2023_08_01_preview.models import (
+ AzureMLOnlineInferencingServer as RestAzureMLOnlineInferencingServer,
+)
+from azure.ai.ml._utils._experimental import experimental
+
+from ...._deployment.code_configuration import CodeConfiguration
+
+
+@experimental
+class AzureMLOnlineInferencingServer:
+ """Azure ML online inferencing configurations.
+
+ :param code_configuration: The code configuration of the inferencing server.
+ :type code_configuration: str
+ :ivar type: The type of the inferencing server.
+ """
+
+ def __init__(self, *, code_configuration: Optional[CodeConfiguration] = None, **kwargs: Any):
+ self.type = "azureml_online"
+ self.code_configuration = code_configuration
+
+ @classmethod
+ def _from_rest_object(cls, rest_obj: RestAzureMLOnlineInferencingServer) -> "RestAzureMLOnlineInferencingServer":
+ return AzureMLOnlineInferencingServer(type=rest_obj.server_type, code_configuration=rest_obj.code_configuration)
+
+ def _to_rest_object(self) -> RestAzureMLOnlineInferencingServer:
+ return RestAzureMLOnlineInferencingServer(server_type=self.type, code_configuration=self.code_configuration)
+
+
+@experimental
+class AzureMLBatchInferencingServer:
+ """Azure ML batch inferencing configurations.
+
+ :param code_configuration: The code configuration of the inferencing server.
+ :type code_configuration: azure.ai.ml.entities.CodeConfiguration
+ :ivar type: The type of the inferencing server.
+ """
+
+ def __init__(self, *, code_configuration: Optional[CodeConfiguration] = None, **kwargs: Any):
+ self.type = "azureml_batch"
+ self.code_configuration = code_configuration
+
+ @classmethod
+ def _from_rest_object(cls, rest_obj: RestAzureMLBatchInferencingServer) -> "RestAzureMLBatchInferencingServer":
+ return AzureMLBatchInferencingServer(code_configuration=rest_obj.code_configuration)
+
+ def _to_rest_object(self) -> RestAzureMLBatchInferencingServer:
+ return RestAzureMLBatchInferencingServer(server_type=self.type, code_configuration=self.code_configuration)
+
+
+@experimental
+class TritonInferencingServer:
+ """Azure ML triton inferencing configurations.
+
+ :param inference_configuration: The inference configuration of the inferencing server.
+ :type inference_configuration: azure.ai.ml.entities.CodeConfiguration
+ :ivar type: The type of the inferencing server.
+ """
+
+ def __init__(self, *, inference_configuration: Optional[CodeConfiguration] = None, **kwargs: Any):
+ self.type = "triton"
+ self.inference_configuration = inference_configuration
+
+ @classmethod
+ def _from_rest_object(cls, rest_obj: RestTritonInferencingServer) -> "RestTritonInferencingServer":
+ return CustomInferencingServer(
+ type=rest_obj.server_type, inference_configuration=rest_obj.inference_configuration
+ )
+
+ def _to_rest_object(self) -> RestTritonInferencingServer:
+ return RestCustomInferencingServer(server_type=self.type, inference_configuration=self.inference_configuration)
+
+
+@experimental
+class Route:
+ """Route.
+
+ :param port: The port of the route.
+ :type port: str
+ :param path: The path of the route.
+ :type path: str
+ """
+
+ def __init__(self, *, port: Optional[str] = None, path: Optional[str] = None):
+ self.port = port
+ self.path = path
+
+ @classmethod
+ def _from_rest_object(cls, rest_obj: RestRoute) -> "RestRoute":
+ return Route(port=rest_obj.port, path=rest_obj.path)
+
+ def _to_rest_object(self) -> Optional[RestRoute]:
+ return RestRoute(port=self.port, path=self.path)
+
+
+@experimental
+class OnlineInferenceConfiguration:
+ """Online inference configurations.
+
+ :param liveness_route: The liveness route of the online inference configuration.
+ :type liveness_route: Route
+ :param readiness_route: The readiness route of the online inference configuration.
+ :type readiness_route: Route
+ :param scoring_route: The scoring route of the online inference configuration.
+ :type scoring_route: Route
+ :param entry_script: The entry script of the online inference configuration.
+ :type entry_script: str
+ :param configuration: The configuration of the online inference configuration.
+ :type configuration: dict
+ """
+
+ def __init__(
+ self,
+ liveness_route: Optional[Route] = None,
+ readiness_route: Optional[Route] = None,
+ scoring_route: Optional[Route] = None,
+ entry_script: Optional[str] = None,
+ configuration: Optional[dict] = None,
+ ):
+ self.liveness_route = liveness_route
+ self.readiness_route = readiness_route
+ self.scoring_route = scoring_route
+ self.entry_script = entry_script
+ self.configuration = configuration
+
+ @classmethod
+ def _from_rest_object(cls, rest_obj: RestOnlineInferenceConfiguration) -> "RestOnlineInferenceConfiguration":
+ return OnlineInferenceConfiguration(
+ liveness_route=Route._from_rest_object(rest_obj.liveness_route),
+ readiness_route=Route._from_rest_object(rest_obj.readiness_route),
+ scoring_route=Route._from_rest_object(rest_obj.scoring_route),
+ entry_script=rest_obj.entry_script,
+ configuration=rest_obj.configuration,
+ )
+
+ def _to_rest_object(self) -> RestOnlineInferenceConfiguration:
+ if self.liveness_route is not None and self.readiness_route is not None and self.scoring_route is not None:
+ return RestOnlineInferenceConfiguration(
+ liveness_route=self.liveness_route._to_rest_object(),
+ readiness_route=self.readiness_route._to_rest_object(),
+ scoring_route=self.scoring_route._to_rest_object(),
+ entry_script=self.entry_script,
+ configuration=self.configuration,
+ )
+
+ if self.liveness_route is None:
+ return RestOnlineInferenceConfiguration(
+ readiness_route=self.readiness_route._to_rest_object() if self.readiness_route is not None else None,
+ scoring_route=self.scoring_route._to_rest_object() if self.scoring_route is not None else None,
+ entry_script=self.entry_script,
+ configuration=self.configuration,
+ )
+
+ if self.readiness_route is None:
+ return RestOnlineInferenceConfiguration(
+ liveness_route=self.liveness_route._to_rest_object(),
+ scoring_route=self.scoring_route._to_rest_object() if self.scoring_route is not None else None,
+ entry_script=self.entry_script,
+ configuration=self.configuration,
+ )
+
+ if self.scoring_route is None:
+ return RestOnlineInferenceConfiguration(
+ liveness_route=self.liveness_route._to_rest_object(),
+ readiness_route=self.readiness_route._to_rest_object(),
+ entry_script=self.entry_script,
+ configuration=self.configuration,
+ )
+
+ return RestOnlineInferenceConfiguration(
+ entry_script=self.entry_script,
+ configuration=self.configuration,
+ )
+
+
+@experimental
+class CustomInferencingServer:
+ """Custom inferencing configurations.
+
+ :param inference_configuration: The inference configuration of the inferencing server.
+ :type inference_configuration: OnlineInferenceConfiguration
+ :ivar type: The type of the inferencing server.
+ """
+
+ def __init__(self, *, inference_configuration: Optional[OnlineInferenceConfiguration] = None, **kwargs: Any):
+ self.type = "custom"
+ self.inference_configuration = inference_configuration
+
+ @classmethod
+ def _from_rest_object(cls, rest_obj: RestCustomInferencingServer) -> "RestCustomInferencingServer":
+ return CustomInferencingServer(
+ type=rest_obj.server_type, inference_configuration=rest_obj.inference_configuration
+ )
+
+ def _to_rest_object(self) -> RestCustomInferencingServer:
+ return RestCustomInferencingServer(server_type=self.type, inference_configuration=self.inference_configuration)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/_package/model_configuration.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/_package/model_configuration.py
new file mode 100644
index 00000000..73c777cf
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/_package/model_configuration.py
@@ -0,0 +1,55 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ----------------------------------------------------------
+
+
+from typing import Optional
+
+from azure.ai.ml._exception_helper import log_and_raise_error
+from azure.ai.ml._restclient.v2023_04_01_preview.models import ModelConfiguration as RestModelConfiguration
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+
+
+@experimental
+class ModelConfiguration:
+ """ModelConfiguration.
+
+ :param mode: The mode of the model. Possible values include: "Copy", "Download".
+ :type mode: str
+ :param mount_path: The mount path of the model.
+ :type mount_path: str
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START model_configuration_entity_create]
+ :end-before: [END model_configuration_entity_create]
+ :language: python
+ :dedent: 8
+ :caption: Creating a Model Configuration object.
+ """
+
+ def __init__(self, *, mode: Optional[str] = None, mount_path: Optional[str] = None):
+ self.mode = mode
+ self.mount_path = mount_path
+
+ @classmethod
+ def _from_rest_object(cls, rest_obj: RestModelConfiguration) -> "ModelConfiguration":
+ return ModelConfiguration(mode=rest_obj.mode, mount_path=rest_obj.mount_path)
+
+ def _to_rest_object(self) -> RestModelConfiguration:
+ self._validate()
+ return RestModelConfiguration(mode=self.mode, mount_path=self.mount_path)
+
+ def _validate(self) -> None:
+ if self.mode is not None and self.mode.lower() not in ["copy", "download"]:
+ msg = "Mode must be either 'Copy' or 'Download'"
+ err = ValidationException(
+ message=msg,
+ target=ErrorTarget.MODEL,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ log_and_raise_error(err)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/_package/model_package.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/_package/model_package.py
new file mode 100644
index 00000000..c4797c20
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/_package/model_package.py
@@ -0,0 +1,338 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access, redefined-builtin
+
+import re
+from os import PathLike
+from pathlib import Path
+from typing import IO, Any, AnyStr, Dict, List, Optional, Union
+
+from azure.ai.ml._restclient.v2023_08_01_preview.models import CodeConfiguration
+from azure.ai.ml._restclient.v2023_08_01_preview.models import ModelPackageInput as RestModelPackageInput
+from azure.ai.ml._restclient.v2023_08_01_preview.models import PackageInputPathId as RestPackageInputPathId
+from azure.ai.ml._restclient.v2023_08_01_preview.models import PackageInputPathUrl as RestPackageInputPathUrl
+from azure.ai.ml._restclient.v2023_08_01_preview.models import PackageInputPathVersion as RestPackageInputPathVersion
+from azure.ai.ml._restclient.v2023_08_01_preview.models import PackageRequest, PackageResponse
+from azure.ai.ml._schema.assets.package.model_package import ModelPackageSchema
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml._utils.utils import dump_yaml_to_file, snake_to_pascal
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY
+from azure.ai.ml.entities._resource import Resource
+from azure.ai.ml.entities._util import load_from_dict
+
+from .base_environment_source import BaseEnvironment
+from .inferencing_server import AzureMLBatchInferencingServer, AzureMLOnlineInferencingServer
+from .model_configuration import ModelConfiguration
+
+
+@experimental
+class PackageInputPathId:
+ """Package input path specified with a resource ID.
+
+ :param input_path_type: The type of the input path. Accepted values are "Url", "PathId", and "PathVersion".
+ :type input_path_type: Optional[str]
+ :param resource_id: The resource ID of the input path. e.g. "azureml://subscriptions/<>/resourceGroups/
+ <>/providers/Microsoft.MachineLearningServices/workspaces/<>/data/<>/versions/<>".
+ :type resource_id: Optional[str]
+ """
+
+ def __init__(
+ self,
+ *,
+ input_path_type: Optional[str] = None,
+ resource_id: Optional[str] = None,
+ ) -> None:
+ self.input_path_type = input_path_type
+ self.resource_id = resource_id
+
+ def _to_rest_object(self) -> RestPackageInputPathId:
+ return RestPackageInputPathId(
+ input_path_type=self.input_path_type,
+ resource_id=self.resource_id,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, package_input_path_id_rest_object: RestPackageInputPathId) -> "PackageInputPathId":
+ return PackageInputPathId(
+ input_path_type=package_input_path_id_rest_object.input_path_type,
+ resource_id=package_input_path_id_rest_object.resource_id,
+ )
+
+
+@experimental
+class PackageInputPathVersion:
+ """Package input path specified with a resource name and version.
+
+ :param input_path_type: The type of the input path. Accepted values are "Url", "PathId", and "PathVersion".
+ :type input_path_type: Optional[str]
+ :param resource_name: The resource name of the input path.
+ :type resource_name: Optional[str]
+ :param resource_version: The resource version of the input path.
+ :type resource_version: Optional[str]
+ """
+
+ def __init__(
+ self,
+ *,
+ input_path_type: Optional[str] = None,
+ resource_name: Optional[str] = None,
+ resource_version: Optional[str] = None,
+ ) -> None:
+ self.input_path_type = input_path_type
+ self.resource_name = resource_name
+ self.resource_version = resource_version
+
+ def _to_rest_object(self) -> RestPackageInputPathVersion:
+ return RestPackageInputPathVersion(
+ input_path_type=self.input_path_type,
+ resource_name=self.resource_name,
+ resource_version=self.resource_version,
+ )
+
+ @classmethod
+ def _from_rest_object(
+ cls, package_input_path_version_rest_object: RestPackageInputPathVersion
+ ) -> "PackageInputPathVersion":
+ return PackageInputPathVersion(
+ input_path_type=package_input_path_version_rest_object.input_path_type,
+ resource_name=package_input_path_version_rest_object.resource_name,
+ resource_version=package_input_path_version_rest_object.resource_version,
+ )
+
+
+@experimental
+class PackageInputPathUrl:
+ """Package input path specified with a url.
+
+ :param input_path_type: The type of the input path. Accepted values are "Url", "PathId", and "PathVersion".
+ :type input_path_type: Optional[str]
+ :param url: The url of the input path. e.g. "azureml://subscriptions/<>/resourceGroups/
+ <>/providers/Microsoft.MachineLearningServices/workspaces/data/<>/versions/<>".
+ :type url: Optional[str]
+ """
+
+ def __init__(self, *, input_path_type: Optional[str] = None, url: Optional[str] = None) -> None:
+ self.input_path_type = input_path_type
+ self.url = url
+
+ def _to_rest_object(self) -> RestPackageInputPathUrl:
+ return RestPackageInputPathUrl(
+ input_path_type=self.input_path_type,
+ url=self.url,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, package_input_path_url_rest_object: RestPackageInputPathUrl) -> "PackageInputPathUrl":
+ return PackageInputPathUrl(
+ input_path_type=package_input_path_url_rest_object.input_path_type,
+ url=package_input_path_url_rest_object.url,
+ )
+
+
+@experimental
+class ModelPackageInput:
+ """Model package input.
+
+ :param type: The type of the input.
+ :type type: Optional[str]
+ :param path: The path of the input.
+ :type path: Optional[Union[~azure.ai.ml.entities.PackageInputPathId, ~azure.ai.ml.entities.PackageInputPathUrl,
+ ~azure.ai.ml.entities.PackageInputPathVersion]]
+ :param mode: The input mode.
+ :type mode: Optional[str]
+ :param mount_path: The mount path for the input.
+ :type mount_path: Optional[str]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START model_package_input_entity_create]
+ :end-before: [END model_package_input_entity_create]
+ :language: python
+ :dedent: 8
+ :caption: Create a Model Package Input object.
+ """
+
+ def __init__(
+ self,
+ *,
+ type: Optional[str] = None,
+ path: Optional[Union[PackageInputPathId, PackageInputPathUrl, PackageInputPathVersion]] = None,
+ mode: Optional[str] = None,
+ mount_path: Optional[str] = None,
+ ) -> None:
+ self.type = type
+ self.path = path
+ self.mode = mode
+ self.mount_path = mount_path
+
+ def _to_rest_object(self) -> RestModelPackageInput:
+ if self.path is None:
+ return RestModelPackageInput(
+ input_type=snake_to_pascal(self.type),
+ path=None,
+ mode=snake_to_pascal(self.mode),
+ mount_path=self.mount_path,
+ )
+ return RestModelPackageInput(
+ input_type=snake_to_pascal(self.type),
+ path=self.path._to_rest_object(),
+ mode=snake_to_pascal(self.mode),
+ mount_path=self.mount_path,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, model_package_input_rest_object: RestModelPackageInput) -> "ModelPackageInput":
+ return ModelPackageInput(
+ type=model_package_input_rest_object.input_type,
+ path=model_package_input_rest_object.path._from_rest_object(),
+ mode=model_package_input_rest_object.mode,
+ mount_path=model_package_input_rest_object.mount_path,
+ )
+
+
+@experimental
+class ModelPackage(Resource, PackageRequest):
+ """Model package.
+
+ :param target_environment_name: The target environment name for the model package.
+ :type target_environment_name: str
+ :param inferencing_server: The inferencing server of the model package.
+ :type inferencing_server: Union[~azure.ai.ml.entities.AzureMLOnlineInferencingServer,
+ ~azure.ai.ml.entities.AzureMLBatchInferencingServer]
+ :param base_environment_source: The base environment source of the model package.
+ :type base_environment_source: Optional[~azure.ai.ml.entities.BaseEnvironment]
+ :param target_environment_version: The version of the model package.
+ :type target_environment_version: Optional[str]
+ :param environment_variables: The environment variables of the model package.
+ :type environment_variables: Optional[dict[str, str]]
+ :param inputs: The inputs of the model package.
+ :type inputs: Optional[list[~azure.ai.ml.entities.ModelPackageInput]]
+ :param model_configuration: The model configuration.
+ :type model_configuration: Optional[~azure.ai.ml.entities.ModelConfiguration]
+ :param tags: The tags of the model package.
+ :type tags: Optional[dict[str, str]]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START model_package_entity_create]
+ :end-before: [END model_package_entity_create]
+ :language: python
+ :dedent: 8
+ :caption: Create a Model Package object.
+ """
+
+ def __init__(
+ self,
+ *,
+ target_environment: Union[str, Dict[str, str]],
+ inferencing_server: Union[AzureMLOnlineInferencingServer, AzureMLBatchInferencingServer],
+ base_environment_source: Optional[BaseEnvironment] = None,
+ environment_variables: Optional[Dict[str, str]] = None,
+ inputs: Optional[List[ModelPackageInput]] = None,
+ model_configuration: Optional[ModelConfiguration] = None,
+ tags: Optional[Dict[str, str]] = None,
+ **kwargs: Any,
+ ):
+ if isinstance(target_environment, dict):
+ target_environment = target_environment["name"]
+ env_version = None
+ else:
+ parse_id = re.match(r"azureml:(\w+):(\d+)$", target_environment)
+
+ if parse_id:
+ target_environment = parse_id.group(1)
+ env_version = parse_id.group(2)
+ else:
+ env_version = None
+
+ super().__init__(
+ name=target_environment,
+ target_environment_id=target_environment,
+ base_environment_source=base_environment_source,
+ inferencing_server=inferencing_server,
+ model_configuration=model_configuration,
+ inputs=inputs,
+ tags=tags,
+ environment_variables=environment_variables,
+ )
+ self.environment_version = env_version
+
+ @classmethod
+ def _load(
+ cls,
+ data: Optional[Dict] = None,
+ yaml_path: Optional[Union[PathLike, str]] = None,
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "ModelPackage":
+ params_override = params_override or []
+ data = data or {}
+ context = {
+ BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"),
+ PARAMS_OVERRIDE_KEY: params_override,
+ }
+ res: ModelPackage = load_from_dict(ModelPackageSchema, data, context, **kwargs)
+ return res
+
+ def dump(
+ self,
+ dest: Union[str, PathLike, IO[AnyStr]],
+ **kwargs: Any,
+ ) -> None:
+ """Dumps the job content into a file in YAML format.
+
+ :param dest: The local path or file stream to write the YAML content to.
+ If dest is a file path, a new file will be created.
+ If dest is an open file, the file will be written to directly.
+ :type dest: Union[PathLike, str, IO[AnyStr]]
+ :raises FileExistsError: Raised if dest is a file path and the file already exists.
+ :raises IOError: Raised if dest is an open file and the file is not writable.
+ """
+ yaml_serialized = self._to_dict()
+ dump_yaml_to_file(dest, yaml_serialized, default_flow_style=False)
+
+ def _to_dict(self) -> Dict:
+ return dict(ModelPackageSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self))
+
+ @classmethod
+ def _from_rest_object(cls, model_package_rest_object: PackageResponse) -> Any:
+ target_environment_id = model_package_rest_object.target_environment_id
+ return target_environment_id
+
+ def _to_rest_object(self) -> PackageRequest:
+ code = None
+
+ if (
+ self.inferencing_server
+ and hasattr(self.inferencing_server, "code_configuration")
+ and self.inferencing_server.code_configuration
+ ):
+ self.inferencing_server.code_configuration._validate()
+ code_id = (
+ self.inferencing_server.code_configuration.code
+ if isinstance(self.inferencing_server.code_configuration.code, str)
+ else self.inferencing_server.code_configuration.code.id
+ )
+ code = CodeConfiguration(
+ code_id=code_id,
+ scoring_script=self.inferencing_server.code_configuration.scoring_script,
+ )
+ self.inferencing_server.code_configuration = code
+
+ package_request = PackageRequest(
+ target_environment_id=self.target_environment_id,
+ base_environment_source=(
+ self.base_environment_source._to_rest_object() if self.base_environment_source else None
+ ),
+ inferencing_server=self.inferencing_server._to_rest_object() if self.inferencing_server else None,
+ model_configuration=self.model_configuration._to_rest_object() if self.model_configuration else None,
+ inputs=[input._to_rest_object() for input in self.inputs] if self.inputs else None,
+ tags=self.tags,
+ environment_variables=self.environment_variables,
+ )
+
+ return package_request
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/artifact.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/artifact.py
new file mode 100644
index 00000000..f82e2aa0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/artifact.py
@@ -0,0 +1,131 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from abc import abstractmethod
+from os import PathLike
+from pathlib import Path, PurePosixPath
+from typing import Any, Dict, Optional, Union
+from urllib.parse import urljoin
+
+from azure.ai.ml._utils.utils import is_mlflow_uri, is_url
+from azure.ai.ml.entities._assets.asset import Asset
+
+
+class ArtifactStorageInfo:
+ def __init__(
+ self,
+ name: str,
+ version: str,
+ relative_path: str,
+ datastore_arm_id: Optional[str],
+ container_name: str,
+ storage_account_url: Optional[str] = None,
+ is_file: Optional[bool] = None,
+ indicator_file: Optional[str] = None,
+ ):
+ self.name = name
+ self.version = version
+ self.relative_path = relative_path
+ self.datastore_arm_id = datastore_arm_id
+ self.container_name = container_name
+ self.storage_account_url = storage_account_url
+ self.is_file = is_file
+ self.indicator_file = indicator_file
+
+ @property
+ def full_storage_path(self) -> Optional[str]:
+ if self.storage_account_url is None:
+ return f"{self.container_name}/{self.relative_path}"
+ return urljoin(self.storage_account_url, f"{self.container_name}/{self.relative_path}")
+
+ @property
+ def subdir_path(self) -> Optional[str]:
+ if self.is_file:
+ path = PurePosixPath(self.relative_path).parent
+ if self.storage_account_url is None:
+ return f"{self.container_name}/{path}"
+ return urljoin(self.storage_account_url, f"{self.container_name}/{path}")
+ return self.full_storage_path
+
+
+class Artifact(Asset):
+ """Base class for artifact, can't be instantiated directly.
+
+ :param name: Name of the resource.
+ :type name: str
+ :param version: Version of the resource.
+ :type version: str
+ :param path: The local or remote path to the asset.
+ :type path: Union[str, os.PathLike]
+ :param description: Description of the resource.
+ :type description: str
+ :param tags: Tag dictionary. Tags can be added, removed, and updated.
+ :type tags: dict[str, str]
+ :param properties: The asset property dictionary.
+ :type properties: dict[str, str]
+ :param datastore: The datastore to upload the local artifact to.
+ :type datastore: str
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: dict
+ """
+
+ def __init__(
+ self,
+ name: Optional[str] = None,
+ version: Optional[str] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ properties: Optional[Dict] = None,
+ path: Optional[Union[str, PathLike]] = None,
+ datastore: Optional[str] = None,
+ **kwargs: Any,
+ ):
+ super().__init__(
+ name=name,
+ version=version,
+ description=description,
+ tags=tags,
+ properties=properties,
+ **kwargs,
+ )
+ self.path = path
+ self.datastore = datastore
+
+ @property
+ def path(self) -> Optional[Union[str, PathLike]]:
+ return self._path
+
+ @path.setter
+ def path(self, value: Optional[Union[str, PathLike]]) -> None:
+ if not value or is_url(value) or Path(value).is_absolute() or is_mlflow_uri(value):
+ self._path = value
+ else:
+ self._path = Path(self.base_path, value).resolve()
+
+ @abstractmethod
+ def _to_dict(self) -> Dict:
+ pass
+
+ def __eq__(self, other: Any) -> bool:
+ return (
+ type(self) == type(other) # pylint: disable = unidiomatic-typecheck
+ and self.name == other.name
+ and self.id == other.id
+ and self.version == other.version
+ and self.description == other.description
+ and self.tags == other.tags
+ and self.properties == other.properties
+ and self.base_path == other.base_path
+ and self._is_anonymous == other._is_anonymous
+ )
+
+ def __ne__(self, other: Any) -> bool:
+ return not self.__eq__(other)
+
+ @abstractmethod
+ def _update_path(self, asset_artifact: ArtifactStorageInfo) -> None:
+ """Updates an an artifact with the remote path of a local upload.
+
+ :param asset_artifact: The asset storage info of the artifact
+ :type asset_artifact: ArtifactStorageInfo
+ """
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/code.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/code.py
new file mode 100644
index 00000000..b08149ab
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/code.py
@@ -0,0 +1,142 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+import os
+from os import PathLike
+from pathlib import Path
+from typing import Any, Dict, Optional, Union
+
+from azure.ai.ml._restclient.v2022_05_01.models import CodeVersionData, CodeVersionDetails
+from azure.ai.ml._schema import CodeAssetSchema
+from azure.ai.ml._utils._arm_id_utils import AMLVersionedArmId
+from azure.ai.ml._utils._asset_utils import IgnoreFile, get_content_hash, get_content_hash_version, get_ignore_file
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY, ArmConstants
+from azure.ai.ml.entities._assets import Artifact
+from azure.ai.ml.entities._system_data import SystemData
+from azure.ai.ml.entities._util import load_from_dict
+
+from .artifact import ArtifactStorageInfo
+
+
+class Code(Artifact):
+ """Code for training and scoring.
+
+ :param name: Name of the resource.
+ :type name: str
+ :param version: Version of the resource.
+ :type version: str
+ :param path: A local path or a remote uri. A datastore remote uri example is like,
+ "azureml://subscriptions/{}/resourcegroups/{}/workspaces/{}/datastores/{}/paths/path_on_datastore/"
+ :type path: str
+ :param description: Description of the resource.
+ :type description: str
+ :param tags: Tag dictionary. Tags can be added, removed, and updated.
+ :type tags: dict[str, str]
+ :param properties: The asset property dictionary.
+ :type properties: dict[str, str]
+ :param ignore_file: Ignore file for the resource.
+ :type ignore_file: IgnoreFile
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: dict
+ """
+
+ def __init__(
+ self,
+ *,
+ name: Optional[str] = None,
+ version: Optional[str] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ properties: Optional[Dict] = None,
+ path: Optional[Union[str, PathLike]] = None,
+ ignore_file: Optional[IgnoreFile] = None,
+ **kwargs: Any,
+ ):
+ super().__init__(
+ name=name,
+ version=version,
+ description=description,
+ tags=tags,
+ properties=properties,
+ path=path,
+ **kwargs,
+ )
+ self._arm_type = ArmConstants.CODE_VERSION_TYPE
+ if self.path and os.path.isabs(self.path):
+ # Only calculate hash for local files
+ self._ignore_file = get_ignore_file(self.path) if ignore_file is None else ignore_file
+ self._hash_sha256 = get_content_hash(self.path, self._ignore_file)
+
+ @classmethod
+ def _load(
+ cls,
+ data: Optional[Dict] = None,
+ yaml_path: Optional[Union[PathLike, str]] = None,
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "Code":
+ data = data or {}
+ params_override = params_override or []
+ context = {
+ BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"),
+ PARAMS_OVERRIDE_KEY: params_override,
+ }
+ res: Code = load_from_dict(CodeAssetSchema, data, context, **kwargs)
+ return res
+
+ def _to_dict(self) -> Dict:
+ res: dict = CodeAssetSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
+
+ @classmethod
+ def _from_rest_object(cls, code_rest_object: CodeVersionData) -> "Code":
+ rest_code_version: CodeVersionDetails = code_rest_object.properties
+ arm_id = AMLVersionedArmId(arm_id=code_rest_object.id)
+ code = Code(
+ id=code_rest_object.id,
+ name=arm_id.asset_name,
+ version=arm_id.asset_version,
+ path=rest_code_version.code_uri,
+ description=rest_code_version.description,
+ tags=rest_code_version.tags,
+ properties=rest_code_version.properties,
+ # pylint: disable=protected-access
+ creation_context=SystemData._from_rest_object(code_rest_object.system_data),
+ is_anonymous=rest_code_version.is_anonymous,
+ )
+ return code
+
+ def _to_rest_object(self) -> CodeVersionData:
+ properties = {}
+ if hasattr(self, "_hash_sha256"):
+ properties["hash_sha256"] = self._hash_sha256
+ properties["hash_version"] = get_content_hash_version()
+ code_version = CodeVersionDetails(code_uri=self.path, is_anonymous=self._is_anonymous, properties=properties)
+ code_version_resource = CodeVersionData(properties=code_version)
+
+ return code_version_resource
+
+ def _update_path(self, asset_artifact: ArtifactStorageInfo) -> None:
+ """Update an artifact with the remote path of a local upload.
+
+ :param asset_artifact: The asset storage info of the artifact
+ :type asset_artifact: ArtifactStorageInfo
+ """
+ if asset_artifact.is_file:
+ # Code paths cannot be pointers to single files. It must be a pointer to a container
+ # Skipping the setter to avoid being resolved as a local path
+ self._path = asset_artifact.subdir_path # pylint: disable=attribute-defined-outside-init
+ else:
+ self._path = asset_artifact.full_storage_path # pylint: disable=attribute-defined-outside-init
+
+ # pylint: disable=unused-argument
+ def _to_arm_resource_param(self, **kwargs: Any) -> Dict:
+ properties = self._to_rest_object().properties
+
+ return {
+ self._arm_type: {
+ ArmConstants.NAME: self.name,
+ ArmConstants.VERSION: self.version,
+ ArmConstants.PROPERTIES_PARAMETER_NAME: self._serialize.body(properties, "CodeVersionDetails"),
+ }
+ }
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/data.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/data.py
new file mode 100644
index 00000000..710e959a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/data.py
@@ -0,0 +1,237 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import os
+import re
+from os import PathLike
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Tuple, Type, Union
+
+from azure.ai.ml._exception_helper import log_and_raise_error
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ DataContainer,
+ DataContainerProperties,
+ DataType,
+ DataVersionBase,
+ DataVersionBaseProperties,
+ MLTableData,
+ UriFileDataVersion,
+ UriFolderDataVersion,
+)
+from azure.ai.ml._schema import DataSchema
+from azure.ai.ml._utils._arm_id_utils import get_arm_id_object_from_id
+from azure.ai.ml._utils.utils import is_url
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY, SHORT_URI_FORMAT, AssetTypes
+from azure.ai.ml.entities._assets import Artifact
+from azure.ai.ml.entities._system_data import SystemData
+from azure.ai.ml.entities._util import load_from_dict
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+
+from .artifact import ArtifactStorageInfo
+
+DataAssetTypeModelMap: Dict[str, Type[DataVersionBaseProperties]] = {
+ AssetTypes.URI_FILE: UriFileDataVersion,
+ AssetTypes.URI_FOLDER: UriFolderDataVersion,
+ AssetTypes.MLTABLE: MLTableData,
+}
+
+
+def getModelForDataAssetType(data_asset_type: str) -> Optional[Type[DataVersionBaseProperties]]:
+ model = DataAssetTypeModelMap.get(data_asset_type)
+ if model is None:
+ msg = "Unknown DataType {}".format(data_asset_type)
+ err = ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ target=ErrorTarget.DATA,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ log_and_raise_error(err)
+ return model
+
+
+DataTypeMap: Dict[DataType, str] = {
+ DataType.URI_FILE: AssetTypes.URI_FILE,
+ DataType.URI_FOLDER: AssetTypes.URI_FOLDER,
+ DataType.MLTABLE: AssetTypes.MLTABLE,
+}
+
+
+def getDataAssetType(data_type: DataType) -> str:
+ return DataTypeMap.get(data_type, data_type) # pass through value if no match found
+
+
+class Data(Artifact):
+ """Data for training and scoring.
+
+ :param name: Name of the resource.
+ :type name: str
+ :param version: Version of the resource.
+ :type version: str
+ :param description: Description of the resource.
+ :type description: str
+ :param tags: Tag dictionary. Tags can be added, removed, and updated.
+ :type tags: dict[str, str]
+ :param properties: The asset property dictionary.
+ :type properties: dict[str, str]
+ :param path: The path to the asset on the datastore. This can be local or remote
+ :type path: str
+ :param type: The type of the asset. Valid values are uri_file, uri_folder, mltable. Defaults to uri_folder.
+ :type type: Literal[AssetTypes.URI_FILE, AssetTypes.URI_FOLDER, AssetTypes.MLTABLE]
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: dict
+ """
+
+ def __init__(
+ self,
+ *,
+ name: Optional[str] = None,
+ version: Optional[str] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ properties: Optional[Dict] = None,
+ path: Optional[str] = None, # if type is mltable, the path has to be a folder.
+ type: str = AssetTypes.URI_FOLDER, # pylint: disable=redefined-builtin
+ **kwargs: Any,
+ ):
+ self._path: Optional[Union[Path, str, PathLike]] = None
+
+ self._skip_validation = kwargs.pop("skip_validation", False)
+ self._mltable_schema_url = kwargs.pop("mltable_schema_url", None)
+ self._referenced_uris = kwargs.pop("referenced_uris", None)
+ self.type = type
+ super().__init__(
+ name=name,
+ version=version,
+ path=path,
+ description=description,
+ tags=tags,
+ properties=properties,
+ **kwargs,
+ )
+ self.path = path
+
+ @property
+ def path(self) -> Optional[Union[Path, str, PathLike]]:
+ return self._path
+
+ @path.setter
+ def path(self, value: str) -> None:
+ # Call the parent setter to resolve the path with base_path if it was a local path
+ # TODO: Bug Item number: 2883424
+ super(Data, type(self)).path.fset(self, value) # type: ignore
+ if self.type == AssetTypes.URI_FOLDER and self._path is not None and not is_url(self._path):
+ self._path = Path(os.path.join(self._path, ""))
+
+ @classmethod
+ def _load(
+ cls,
+ data: Optional[Dict] = None,
+ yaml_path: Optional[Union[PathLike, str]] = None,
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "Data":
+ data = data or {}
+ params_override = params_override or []
+ context = {
+ BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"),
+ PARAMS_OVERRIDE_KEY: params_override,
+ }
+ data_asset = Data._load_from_dict(yaml_data=data, context=context, **kwargs)
+
+ return data_asset
+
+ @classmethod
+ def _load_from_dict(cls, yaml_data: Dict, context: Dict, **kwargs: Any) -> "Data":
+ return Data(**load_from_dict(DataSchema, yaml_data, context, **kwargs))
+
+ def _to_dict(self) -> Dict:
+ res: dict = DataSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
+
+ def _to_container_rest_object(self) -> DataContainer:
+ VersionDetailsClass = getModelForDataAssetType(self.type)
+ return DataContainer(
+ properties=DataContainerProperties(
+ properties=self.properties,
+ tags=self.tags,
+ is_archived=False,
+ data_type=VersionDetailsClass.data_type if VersionDetailsClass is not None else None,
+ )
+ )
+
+ def _to_rest_object(self) -> Optional[DataVersionBase]:
+ VersionDetailsClass = getModelForDataAssetType(self.type)
+ if VersionDetailsClass is not None:
+ data_version_details = VersionDetailsClass(
+ description=self.description,
+ is_anonymous=self._is_anonymous,
+ tags=self.tags,
+ is_archived=False,
+ properties=self.properties,
+ data_uri=self.path,
+ auto_delete_setting=self.auto_delete_setting,
+ )
+ if VersionDetailsClass._attribute_map.get("referenced_uris") is not None:
+ data_version_details.referenced_uris = self._referenced_uris
+ return DataVersionBase(properties=data_version_details)
+
+ return None
+
+ @classmethod
+ def _from_container_rest_object(cls, data_container_rest_object: DataContainer) -> "Data":
+ data_rest_object_details: DataContainerProperties = data_container_rest_object.properties
+ data = Data(
+ name=data_container_rest_object.name,
+ creation_context=SystemData._from_rest_object(data_container_rest_object.system_data),
+ tags=data_rest_object_details.tags,
+ properties=data_rest_object_details.properties,
+ type=getDataAssetType(data_rest_object_details.data_type),
+ )
+ data.latest_version = data_rest_object_details.latest_version
+ return data
+
+ @classmethod
+ def _from_rest_object(cls, data_rest_object: DataVersionBase) -> "Data":
+ data_rest_object_details: DataVersionBaseProperties = data_rest_object.properties
+ arm_id_object = get_arm_id_object_from_id(data_rest_object.id)
+ path = data_rest_object_details.data_uri
+ data = Data(
+ id=data_rest_object.id,
+ name=arm_id_object.asset_name,
+ version=arm_id_object.asset_version,
+ path=path,
+ type=getDataAssetType(data_rest_object_details.data_type),
+ description=data_rest_object_details.description,
+ tags=data_rest_object_details.tags,
+ properties=data_rest_object_details.properties,
+ creation_context=SystemData._from_rest_object(data_rest_object.system_data),
+ is_anonymous=data_rest_object_details.is_anonymous,
+ referenced_uris=getattr(data_rest_object_details, "referenced_uris", None),
+ auto_delete_setting=getattr(data_rest_object_details, "auto_delete_setting", None),
+ )
+ return data
+
+ @classmethod
+ def _resolve_cls_and_type(cls, data: Dict, params_override: Optional[List[Dict]] = None) -> Tuple:
+ from azure.ai.ml.entities._data_import.data_import import DataImport
+
+ if "source" in data:
+ return DataImport, None
+
+ return cls, None
+
+ def _update_path(self, asset_artifact: ArtifactStorageInfo) -> None:
+ regex = r"datastores\/(.+)"
+ # datastore_arm_id is null for registry scenario, so capture the full_storage_path
+ if not asset_artifact.datastore_arm_id and asset_artifact.full_storage_path:
+ self.path = asset_artifact.full_storage_path
+ else:
+ groups = re.search(regex, asset_artifact.datastore_arm_id) # type: ignore
+ if groups:
+ datastore_name = groups.group(1)
+ self.path = SHORT_URI_FORMAT.format(datastore_name, asset_artifact.relative_path)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/feature_set.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/feature_set.py
new file mode 100644
index 00000000..a5bb73fe
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/feature_set.py
@@ -0,0 +1,220 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from os import PathLike
+from pathlib import Path
+from typing import IO, Any, AnyStr, Dict, List, Optional, Union, cast
+
+from azure.ai.ml._restclient.v2023_10_01.models import (
+ FeaturesetContainer,
+ FeaturesetContainerProperties,
+ FeaturesetVersion,
+ FeaturesetVersionProperties,
+)
+from azure.ai.ml._schema._feature_set.feature_set_schema import FeatureSetSchema
+from azure.ai.ml._utils._arm_id_utils import AMLNamedArmId, get_arm_id_object_from_id
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, LONG_URI_FORMAT, PARAMS_OVERRIDE_KEY
+from azure.ai.ml.entities._assets import Artifact
+from azure.ai.ml.entities._feature_set.feature_set_specification import FeatureSetSpecification
+from azure.ai.ml.entities._feature_set.materialization_settings import MaterializationSettings
+from azure.ai.ml.entities._util import load_from_dict
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+
+from .artifact import ArtifactStorageInfo
+
+
+class FeatureSet(Artifact):
+ """Feature Set
+
+ :param name: The name of the Feature Set resource.
+ :type name: str
+ :param version: The version of the Feature Set resource.
+ :type version: str
+ :param entities: Specifies list of entities.
+ :type entities: list[str]
+ :param specification: Specifies the feature set spec details.
+ :type specification: ~azure.ai.ml.entities.FeatureSetSpecification
+ :param stage: Feature set stage. Allowed values: Development, Production, Archived. Defatuls to Development.
+ :type stage: Optional[str]
+ :param description: The description of the Feature Set resource. Defaults to None.
+ :type description: Optional[str]
+ :param tags: Tag dictionary. Tags can be added, removed, and updated. Defaults to None.
+ :type tags: Optional[dict[str, str]]
+ :param materialization_settings: Specifies the materialization settings. Defaults to None.
+ :type materialization_settings: Optional[~azure.ai.ml.entities.MaterializationSettings]
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: dict
+ :raises ValidationException: Raised if stage is specified and is not valid.
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_featurestore.py
+ :start-after: [START configure_feature_set]
+ :end-before: [END configure_feature_set]
+ :language: Python
+ :dedent: 8
+ :caption: Instantiating a Feature Set object
+ """
+
+ def __init__(
+ self,
+ *,
+ name: str,
+ version: str,
+ entities: List[str],
+ specification: Optional[FeatureSetSpecification],
+ stage: Optional[str] = "Development",
+ description: Optional[str] = None,
+ materialization_settings: Optional[MaterializationSettings] = None,
+ tags: Optional[Dict] = None,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ name=name,
+ version=version,
+ description=description,
+ tags=tags,
+ path=specification.path if specification is not None else None,
+ **kwargs,
+ )
+ if stage and stage not in ["Development", "Production", "Archived"]:
+ msg = f"Stage must be Development, Production, or Archived, found {stage}"
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ target=ErrorTarget.FEATURE_SET,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ self.entities = entities
+ self.specification = specification
+ self.stage = stage
+ self.materialization_settings = materialization_settings
+ self.latest_version = None
+
+ def _to_rest_object(self) -> FeaturesetVersion:
+ featureset_version_properties = FeaturesetVersionProperties(
+ description=self.description,
+ properties=self.properties,
+ tags=self.tags,
+ entities=self.entities,
+ materialization_settings=(
+ self.materialization_settings._to_rest_object() if self.materialization_settings else None
+ ),
+ specification=self.specification._to_rest_object() if self.specification is not None else None,
+ stage=self.stage,
+ )
+ return FeaturesetVersion(name=self.name, properties=featureset_version_properties)
+
+ @classmethod
+ def _from_rest_object(cls, featureset_rest_object: FeaturesetVersion) -> Optional["FeatureSet"]:
+ if not featureset_rest_object:
+ return None
+ featureset_rest_object_details: FeaturesetVersionProperties = featureset_rest_object.properties
+ arm_id_object = get_arm_id_object_from_id(featureset_rest_object.id)
+ featureset = FeatureSet(
+ id=featureset_rest_object.id,
+ name=arm_id_object.asset_name,
+ version=arm_id_object.asset_version,
+ description=featureset_rest_object_details.description,
+ tags=featureset_rest_object_details.tags,
+ entities=featureset_rest_object_details.entities,
+ materialization_settings=MaterializationSettings._from_rest_object(
+ featureset_rest_object_details.materialization_settings
+ ),
+ specification=FeatureSetSpecification._from_rest_object(featureset_rest_object_details.specification),
+ stage=featureset_rest_object_details.stage,
+ properties=featureset_rest_object_details.properties,
+ )
+ return featureset
+
+ @classmethod
+ def _from_container_rest_object(cls, rest_obj: FeaturesetContainer) -> "FeatureSet":
+ rest_object_details: FeaturesetContainerProperties = rest_obj.properties
+ arm_id_object = get_arm_id_object_from_id(rest_obj.id)
+ featureset = FeatureSet(
+ name=arm_id_object.asset_name,
+ description=rest_object_details.description,
+ tags=rest_object_details.tags,
+ entities=[],
+ specification=FeatureSetSpecification(),
+ version="",
+ )
+ featureset.latest_version = rest_object_details.latest_version
+ return featureset
+
+ @classmethod
+ def _load(
+ cls,
+ data: Optional[Dict] = None,
+ yaml_path: Optional[Union[PathLike, str]] = None,
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "FeatureSet":
+ data = data or {}
+ params_override = params_override or []
+ base_path = Path(yaml_path).parent if yaml_path else Path("./")
+ context = {
+ BASE_PATH_CONTEXT_KEY: base_path,
+ PARAMS_OVERRIDE_KEY: params_override,
+ }
+ loaded_schema = load_from_dict(FeatureSetSchema, data, context, **kwargs)
+ feature_set = FeatureSet(base_path=base_path, **loaded_schema)
+ return feature_set
+
+ def _to_dict(self) -> Dict:
+ return dict(FeatureSetSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self))
+
+ def _update_path(self, asset_artifact: ArtifactStorageInfo) -> None:
+ # if datastore_arm_id is null, capture the full_storage_path
+ if not asset_artifact.datastore_arm_id and asset_artifact.full_storage_path:
+ self.path = asset_artifact.full_storage_path
+ else:
+ aml_datastore_id = AMLNamedArmId(asset_artifact.datastore_arm_id)
+ self.path = LONG_URI_FORMAT.format(
+ aml_datastore_id.subscription_id,
+ aml_datastore_id.resource_group_name,
+ aml_datastore_id.workspace_name,
+ aml_datastore_id.asset_name,
+ asset_artifact.relative_path,
+ )
+
+ if self.specification is not None:
+ self.specification.path = self.path
+
+ def dump(self, dest: Union[str, PathLike, IO[AnyStr]], **kwargs: Any) -> None:
+ """Dump the asset content into a file in YAML format.
+
+ :param dest: The local path or file stream to write the YAML content to.
+ If dest is a file path, a new file will be created.
+ If dest is an open file, the file will be written to directly.
+ :type dest: Union[PathLike, str, IO[AnyStr]]
+ :raises FileExistsError: Raised if dest is a file path and the file already exists.
+ :raises IOError: Raised if dest is an open file and the file is not writable.
+ """
+
+ import os
+ import shutil
+
+ from azure.ai.ml._utils.utils import is_url
+
+ origin_spec_path = self.specification.path if self.specification is not None else None
+ if isinstance(dest, (PathLike, str)) and self.specification is not None and not is_url(self.specification.path):
+ if os.path.exists(dest):
+ raise FileExistsError(f"File {dest} already exists.")
+ relative_path = os.path.basename(cast(PathLike, self.specification.path))
+ src_spec_path = (
+ str(Path(self._base_path, self.specification.path)) if self.specification.path is not None else ""
+ )
+ dest_spec_path = str(Path(os.path.dirname(dest), relative_path))
+ if os.path.exists(dest_spec_path):
+ shutil.rmtree(dest_spec_path)
+ shutil.copytree(src=src_spec_path, dst=dest_spec_path)
+ self.specification.path = str(Path("./", relative_path))
+ super().dump(dest=dest, **kwargs)
+
+ if self.specification is not None:
+ self.specification.path = origin_spec_path
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/index.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/index.py
new file mode 100644
index 00000000..35f671d3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/index.py
@@ -0,0 +1,137 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from os import PathLike
+from pathlib import Path
+from typing import Any, Dict, Optional, Union, cast
+
+# cspell:disable-next-line
+from azure.ai.ml._restclient.azure_ai_assets_v2024_04_01.azureaiassetsv20240401.models import Index as RestIndex
+from azure.ai.ml._schema import IndexAssetSchema
+from azure.ai.ml._utils._arm_id_utils import AMLAssetId, AMLNamedArmId
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, LONG_URI_FORMAT, PARAMS_OVERRIDE_KEY
+from azure.ai.ml.entities._assets import Artifact
+from azure.ai.ml.entities._assets._artifacts.artifact import ArtifactStorageInfo
+from azure.ai.ml.entities._system_data import RestSystemData, SystemData
+from azure.ai.ml.entities._util import load_from_dict
+
+
+@experimental
+class Index(Artifact):
+ """Index asset.
+
+ :ivar name: Name of the resource.
+ :vartype name: str
+ :ivar version: Version of the resource.
+ :vartype version: str
+ :ivar id: Fully qualified resource Id:
+ azureml://workspace/{workspaceName}/indexes/{name}/versions/{version} of the index. Required.
+ :vartype id: str
+ :ivar stage: Update stage to 'Archive' for soft delete. Default is Development, which means the
+ asset is under development. Required.
+ :vartype stage: str
+ :ivar description: Description information of the asset.
+ :vartype description: Optional[str]
+ :ivar tags: Asset's tags.
+ :vartype tags: Optional[dict[str, str]]
+ :ivar properties: Asset's properties.
+ :vartype properties: Optional[dict[str, str]]
+ :ivar path: The local or remote path to the asset.
+ :vartype path: Optional[Union[str, os.PathLike]]
+ """
+
+ def __init__(
+ self,
+ *,
+ name: str,
+ version: Optional[str] = None,
+ stage: str = "Development",
+ description: Optional[str] = None,
+ tags: Optional[Dict[str, str]] = None,
+ properties: Optional[Dict[str, str]] = None,
+ path: Optional[Union[str, PathLike]] = None,
+ datastore: Optional[str] = None,
+ **kwargs: Any,
+ ):
+ self.stage = stage
+ super().__init__(
+ name=name,
+ version=version,
+ description=description,
+ tags=tags,
+ properties=properties,
+ path=path,
+ datastore=datastore,
+ **kwargs,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, index_rest_object: RestIndex) -> "Index":
+ """Convert the response from the Index API into a Index
+
+ :param RestIndex index_rest_object:
+ :return: An Index Asset
+ :rtype: Index
+ """
+ asset_id = AMLAssetId(asset_id=index_rest_object.id)
+
+ return Index(
+ id=index_rest_object.id,
+ name=asset_id.asset_name,
+ version=asset_id.asset_version,
+ description=index_rest_object.description,
+ tags=index_rest_object.tags,
+ properties=index_rest_object.properties,
+ stage=index_rest_object.stage,
+ path=index_rest_object.storage_uri,
+ # pylint: disable-next=protected-access
+ creation_context=SystemData._from_rest_object(
+ RestSystemData.from_dict(index_rest_object.system_data.as_dict())
+ ),
+ )
+
+ def _to_rest_object(self) -> RestIndex:
+ # Note: Index.name and Index.version get dropped going to RestIndex, since both are encoded in the id
+ # (when present)
+ return RestIndex(
+ stage=self.stage,
+ storage_uri=self.path,
+ description=self.description,
+ tags=self.tags,
+ properties=self.properties,
+ id=self.id,
+ )
+
+ @classmethod
+ def _load(
+ cls,
+ data: Optional[Dict] = None,
+ yaml_path: Optional[Union[PathLike, str]] = None,
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "Index":
+ data = data or {}
+ params_override = params_override or []
+ context = {
+ BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"),
+ PARAMS_OVERRIDE_KEY: params_override,
+ }
+ return cast(Index, load_from_dict(IndexAssetSchema, data, context, **kwargs))
+
+ def _to_dict(self) -> Dict:
+ return cast(dict, IndexAssetSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self))
+
+ def _update_path(self, asset_artifact: ArtifactStorageInfo) -> None:
+ """Updates an an artifact with the remote path of a local upload.
+
+ :param ArtifactStorageInfo asset_artifact: The asset storage info of the artifact
+ """
+ aml_datastore_id = AMLNamedArmId(asset_artifact.datastore_arm_id)
+ self.path = LONG_URI_FORMAT.format(
+ aml_datastore_id.subscription_id,
+ aml_datastore_id.resource_group_name,
+ aml_datastore_id.workspace_name,
+ aml_datastore_id.asset_name,
+ asset_artifact.relative_path,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/model.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/model.py
new file mode 100644
index 00000000..8e65bd3e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/_artifacts/model.py
@@ -0,0 +1,219 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from os import PathLike
+from pathlib import Path
+from typing import Any, Dict, Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ FlavorData,
+ ModelContainer,
+ ModelVersion,
+ ModelVersionProperties,
+)
+from azure.ai.ml._schema import ModelSchema
+from azure.ai.ml._utils._arm_id_utils import AMLNamedArmId, AMLVersionedArmId
+from azure.ai.ml._utils._asset_utils import get_ignore_file, get_object_hash
+from azure.ai.ml.constants._common import (
+ BASE_PATH_CONTEXT_KEY,
+ LONG_URI_FORMAT,
+ PARAMS_OVERRIDE_KEY,
+ ArmConstants,
+ AssetTypes,
+)
+from azure.ai.ml.entities._assets import Artifact
+from azure.ai.ml.entities._assets.intellectual_property import IntellectualProperty
+from azure.ai.ml.entities._system_data import SystemData
+from azure.ai.ml.entities._util import get_md5_string, load_from_dict
+
+from .artifact import ArtifactStorageInfo
+
+
+class Model(Artifact): # pylint: disable=too-many-instance-attributes
+ """Model for training and scoring.
+
+ :param name: The name of the model. Defaults to a random GUID.
+ :type name: Optional[str]
+ :param version: The version of the model. Defaults to "1" if either no name or an unregistered name is provided.
+ Otherwise, defaults to autoincrement from the last registered version of the model with that name.
+ :type version: Optional[str]
+ :param type: The storage format for this entity, used for NCD (Novel Class Discovery). Accepted values are
+ "custom_model", "mlflow_model", or "triton_model". Defaults to "custom_model".
+ :type type: Optional[str]
+ :param utc_time_created: The date and time when the model was created, in
+ UTC ISO 8601 format. (e.g. '2020-10-19 17:44:02.096572').
+ :type utc_time_created: Optional[str]
+ :param flavors: The flavors in which the model can be interpreted. Defaults to None.
+ :type flavors: Optional[dict[str, Any]]
+ :param path: A remote uri or a local path pointing to a model. Defaults to None.
+ :type path: Optional[str]
+ :param description: The description of the resource. Defaults to None
+ :type description: Optional[str]
+ :param tags: Tag dictionary. Tags can be added, removed, and updated. Defaults to None.
+ :type tags: Optional[dict[str, str]]
+ :param properties: The asset property dictionary. Defaults to None.
+ :type properties: Optional[dict[str, str]]
+ :param stage: The stage of the resource. Defaults to None.
+ :type stage: Optional[str]
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: Optional[dict]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START model_entity_create]
+ :end-before: [END model_entity_create]
+ :language: python
+ :dedent: 8
+ :caption: Creating a Model object.
+ """
+
+ def __init__(
+ self,
+ *,
+ name: Optional[str] = None,
+ version: Optional[str] = None,
+ type: Optional[str] = None, # pylint: disable=redefined-builtin
+ path: Optional[Union[str, PathLike]] = None,
+ utc_time_created: Optional[str] = None,
+ flavors: Optional[Dict[str, Dict[str, Any]]] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ properties: Optional[Dict] = None,
+ stage: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ self.job_name = kwargs.pop("job_name", None)
+ self._intellectual_property = kwargs.pop("intellectual_property", None)
+ self._system_metadata = kwargs.pop("system_metadata", None)
+ super().__init__(
+ name=name,
+ version=version,
+ path=path,
+ description=description,
+ tags=tags,
+ properties=properties,
+ **kwargs,
+ )
+ self.utc_time_created = utc_time_created
+ self.flavors = dict(flavors) if flavors else None
+ self._arm_type = ArmConstants.MODEL_VERSION_TYPE
+ self.type = type or AssetTypes.CUSTOM_MODEL
+ self.stage = stage
+ if self._is_anonymous and self.path:
+ _ignore_file = get_ignore_file(self.path)
+ _upload_hash = get_object_hash(self.path, _ignore_file)
+ self.name = get_md5_string(_upload_hash)
+
+ @classmethod
+ def _load(
+ cls,
+ data: Optional[Dict] = None,
+ yaml_path: Optional[Union[PathLike, str]] = None,
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "Model":
+ params_override = params_override or []
+ data = data or {}
+ context = {
+ BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"),
+ PARAMS_OVERRIDE_KEY: params_override,
+ }
+ res: Model = load_from_dict(ModelSchema, data, context, **kwargs)
+ return res
+
+ def _to_dict(self) -> Dict:
+ return dict(ModelSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self))
+
+ @classmethod
+ def _from_rest_object(cls, model_rest_object: ModelVersion) -> "Model":
+ rest_model_version: ModelVersionProperties = model_rest_object.properties
+ arm_id = AMLVersionedArmId(arm_id=model_rest_object.id)
+ model_stage = rest_model_version.stage if hasattr(rest_model_version, "stage") else None
+ model_system_metadata = (
+ rest_model_version.system_metadata if hasattr(rest_model_version, "system_metadata") else None
+ )
+ if hasattr(rest_model_version, "flavors"):
+ flavors = {key: flavor.data for key, flavor in rest_model_version.flavors.items()}
+ model = Model(
+ id=model_rest_object.id,
+ name=arm_id.asset_name,
+ version=arm_id.asset_version,
+ path=rest_model_version.model_uri,
+ description=rest_model_version.description,
+ tags=rest_model_version.tags,
+ flavors=flavors, # pylint: disable=possibly-used-before-assignment
+ properties=rest_model_version.properties,
+ stage=model_stage,
+ # pylint: disable=protected-access
+ creation_context=SystemData._from_rest_object(model_rest_object.system_data),
+ type=rest_model_version.model_type,
+ job_name=rest_model_version.job_name,
+ intellectual_property=(
+ IntellectualProperty._from_rest_object(rest_model_version.intellectual_property)
+ if rest_model_version.intellectual_property
+ else None
+ ),
+ system_metadata=model_system_metadata,
+ )
+ return model
+
+ @classmethod
+ def _from_container_rest_object(cls, model_container_rest_object: ModelContainer) -> "Model":
+ model = Model(
+ name=model_container_rest_object.name,
+ version="1",
+ id=model_container_rest_object.id,
+ # pylint: disable=protected-access
+ creation_context=SystemData._from_rest_object(model_container_rest_object.system_data),
+ )
+ model.latest_version = model_container_rest_object.properties.latest_version
+
+ # Setting version to None since if version is not provided it is defaulted to "1".
+ # This should go away once container concept is finalized.
+ model.version = None
+ return model
+
+ def _to_rest_object(self) -> ModelVersion:
+ model_version = ModelVersionProperties(
+ description=self.description,
+ tags=self.tags,
+ properties=self.properties,
+ flavors=(
+ {key: FlavorData(data=dict(value)) for key, value in self.flavors.items()} if self.flavors else None
+ ), # flatten OrderedDict to dict
+ model_type=self.type,
+ model_uri=self.path,
+ stage=self.stage,
+ is_anonymous=self._is_anonymous,
+ )
+ model_version.system_metadata = self._system_metadata if hasattr(self, "_system_metadata") else None
+
+ model_version_resource = ModelVersion(properties=model_version)
+
+ return model_version_resource
+
+ def _update_path(self, asset_artifact: ArtifactStorageInfo) -> None:
+ # datastore_arm_id is null for registry scenario, so capture the full_storage_path
+ if not asset_artifact.datastore_arm_id and asset_artifact.full_storage_path:
+ self.path = asset_artifact.full_storage_path
+ else:
+ aml_datastore_id = AMLNamedArmId(asset_artifact.datastore_arm_id)
+ self.path = LONG_URI_FORMAT.format(
+ aml_datastore_id.subscription_id,
+ aml_datastore_id.resource_group_name,
+ aml_datastore_id.workspace_name,
+ aml_datastore_id.asset_name,
+ asset_artifact.relative_path,
+ )
+
+ def _to_arm_resource_param(self, **kwargs: Any) -> Dict: # pylint: disable=unused-argument
+ properties = self._to_rest_object().properties
+
+ return {
+ self._arm_type: {
+ ArmConstants.NAME: self.name,
+ ArmConstants.VERSION: self.version,
+ ArmConstants.PROPERTIES_PARAMETER_NAME: self._serialize.body(properties, "ModelVersionProperties"),
+ }
+ }
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/asset.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/asset.py
new file mode 100644
index 00000000..b6ee2b55
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/asset.py
@@ -0,0 +1,145 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import uuid
+from abc import abstractmethod
+from os import PathLike
+from typing import IO, Any, AnyStr, Dict, Optional, Union
+
+from azure.ai.ml._exception_helper import log_and_raise_error
+from azure.ai.ml._utils.utils import dump_yaml_to_file
+from azure.ai.ml.entities._resource import Resource
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+
+
+class Asset(Resource):
+ """Base class for asset.
+
+ This class should not be instantiated directly. Instead, use one of its subclasses.
+
+ :param name: The name of the asset. Defaults to a random GUID.
+ :type name: Optional[str]]
+ :param version: The version of the asset. Defaults to "1" if no name is provided, otherwise defaults to
+ autoincrement from the last registered version of the asset with that name. For a model name that has
+ never been registered, a default version will be assigned.
+ :type version: Optional[str]
+ :param description: The description of the resource. Defaults to None.
+ :type description: Optional[str]
+ :param tags: Tag dictionary. Tags can be added, removed, and updated. Defaults to None.
+ :type tags: Optional[dict[str, str]]
+ :param properties: The asset property dictionary. Defaults to None.
+ :type properties: Optional[dict[str, str]]
+ :keyword kwargs: A dictionary of additional configuration parameters.
+ :paramtype kwargs: Optional[dict]
+ """
+
+ def __init__(
+ self,
+ name: Optional[str] = None,
+ version: Optional[str] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ properties: Optional[Dict] = None,
+ **kwargs: Any,
+ ) -> None:
+ self._is_anonymous = kwargs.pop("is_anonymous", False)
+ self._auto_increment_version = kwargs.pop("auto_increment_version", False)
+ self.auto_delete_setting = kwargs.pop("auto_delete_setting", None)
+
+ if not name and version is None:
+ name = _get_random_name()
+ version = "1"
+ self._is_anonymous = True
+ elif version is not None and not name:
+ msg = "If version is specified, name must be specified also."
+ err = ValidationException(
+ message=msg,
+ target=ErrorTarget.ASSET,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.MISSING_FIELD,
+ )
+ log_and_raise_error(err)
+
+ super().__init__(
+ name=name,
+ description=description,
+ tags=tags,
+ properties=properties,
+ **kwargs,
+ )
+
+ self.version = version
+ self.latest_version = None
+
+ @abstractmethod
+ def _to_dict(self) -> Dict:
+ """Dump the artifact content into a pure dict object."""
+
+ @property
+ def version(self) -> Optional[str]:
+ """The asset version.
+
+ :return: The asset version.
+ :rtype: str
+ """
+ return self._version
+
+ @version.setter
+ def version(self, value: str) -> None:
+ """Sets the asset version.
+
+ :param value: The asset version.
+ :type value: str
+ :raises ValidationException: Raised if value is not a string.
+ """
+ if value:
+ if not isinstance(value, str):
+ msg = f"Asset version must be a string, not type {type(value)}."
+ err = ValidationException(
+ message=msg,
+ target=ErrorTarget.ASSET,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ log_and_raise_error(err)
+
+ self._version = value
+ self._auto_increment_version = self.name and not self._version
+
+ def dump(self, dest: Union[str, PathLike, IO[AnyStr]], **kwargs: Any) -> None:
+ """Dump the asset content into a file in YAML format.
+
+ :param dest: The local path or file stream to write the YAML content to.
+ If dest is a file path, a new file will be created.
+ If dest is an open file, the file will be written to directly.
+ :type dest: Union[PathLike, str, IO[AnyStr]]
+ :raises FileExistsError: Raised if dest is a file path and the file already exists.
+ :raises IOError: Raised if dest is an open file and the file is not writable.
+ """
+ path = kwargs.pop("path", None)
+ yaml_serialized = self._to_dict()
+ dump_yaml_to_file(dest, yaml_serialized, default_flow_style=False, path=path, **kwargs)
+
+ def __eq__(self, other: Any) -> bool:
+ return bool(
+ self.name == other.name
+ and self.id == other.id
+ and self.version == other.version
+ and self.description == other.description
+ and self.tags == other.tags
+ and self.properties == other.properties
+ and self.base_path == other.base_path
+ and self._is_anonymous == other._is_anonymous
+ and self._auto_increment_version == other._auto_increment_version
+ and self.auto_delete_setting == other.auto_delete_setting
+ )
+
+ def __ne__(self, other: Any) -> bool:
+ return not self.__eq__(other)
+
+
+def _get_random_name() -> str:
+ return str(uuid.uuid4())
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/auto_delete_setting.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/auto_delete_setting.py
new file mode 100644
index 00000000..ea6bf9e8
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/auto_delete_setting.py
@@ -0,0 +1,42 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Any, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import AutoDeleteSetting as RestAutoDeleteSetting
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants._common import AutoDeleteCondition
+from azure.ai.ml.entities._mixins import DictMixin
+
+
+@experimental
+class AutoDeleteSetting(DictMixin):
+ """Class which defines the auto delete setting.
+ :param condition: When to check if an asset is expired.
+ Possible values include: "CreatedGreaterThan", "LastAccessedGreaterThan".
+ :type condition: AutoDeleteCondition
+ :param value: Expiration condition value.
+ :type value: str
+ """
+
+ def __init__(
+ self,
+ *,
+ condition: AutoDeleteCondition = AutoDeleteCondition.CREATED_GREATER_THAN,
+ value: Union[str, None] = None
+ ):
+ self.condition = condition
+ self.value = value
+
+ def _to_rest_object(self) -> RestAutoDeleteSetting:
+ return RestAutoDeleteSetting(condition=self.condition, value=self.value)
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestAutoDeleteSetting) -> "AutoDeleteSetting":
+ return cls(condition=obj.condition, value=obj.value)
+
+ def __eq__(self, other: Any) -> bool:
+ if not isinstance(other, AutoDeleteSetting):
+ return NotImplemented
+ return self.condition == other.condition and self.value == other.value
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/environment.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/environment.py
new file mode 100644
index 00000000..865273fb
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/environment.py
@@ -0,0 +1,478 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access, too-many-instance-attributes
+
+import os
+from pathlib import Path
+from typing import Any, Dict, Optional, Union
+
+import yaml # type: ignore[import]
+
+from azure.ai.ml._exception_helper import log_and_raise_error
+from azure.ai.ml._restclient.v2023_04_01_preview.models import BuildContext as RestBuildContext
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ EnvironmentContainer,
+ EnvironmentVersion,
+ EnvironmentVersionProperties,
+)
+from azure.ai.ml._schema import EnvironmentSchema
+from azure.ai.ml._utils._arm_id_utils import AMLVersionedArmId
+from azure.ai.ml._utils._asset_utils import get_ignore_file, get_object_hash
+from azure.ai.ml._utils.utils import dump_yaml, is_url, load_file, load_yaml
+from azure.ai.ml.constants._common import ANONYMOUS_ENV_NAME, BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY, ArmConstants
+from azure.ai.ml.entities._assets.asset import Asset
+from azure.ai.ml.entities._assets.intellectual_property import IntellectualProperty
+from azure.ai.ml.entities._mixins import LocalizableMixin
+from azure.ai.ml.entities._system_data import SystemData
+from azure.ai.ml.entities._util import get_md5_string, load_from_dict
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+
+
+class BuildContext:
+ """Docker build context for Environment.
+
+ :param path: The local or remote path to the the docker build context directory.
+ :type path: Union[str, os.PathLike]
+ :param dockerfile_path: The path to the dockerfile relative to root of docker build context directory.
+ :type dockerfile_path: str
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START build_context_entity_create]
+ :end-before: [END build_context_entity_create]
+ :language: python
+ :dedent: 8
+ :caption: Create a Build Context object.
+ """
+
+ def __init__(
+ self,
+ *,
+ dockerfile_path: Optional[str] = None,
+ path: Optional[Union[str, os.PathLike]] = None,
+ ):
+ self.dockerfile_path = dockerfile_path
+ self.path = path
+
+ def _to_rest_object(self) -> RestBuildContext:
+ return RestBuildContext(context_uri=self.path, dockerfile_path=self.dockerfile_path)
+
+ @classmethod
+ def _from_rest_object(cls, rest_obj: RestBuildContext) -> "BuildContext":
+ return BuildContext(
+ path=rest_obj.context_uri,
+ dockerfile_path=rest_obj.dockerfile_path,
+ )
+
+ def __eq__(self, other: Any) -> bool:
+ res: bool = self.dockerfile_path == other.dockerfile_path and self.path == other.path
+ return res
+
+ def __ne__(self, other: Any) -> bool:
+ return not self.__eq__(other)
+
+
+class Environment(Asset, LocalizableMixin):
+ """Environment for training.
+
+ :param name: Name of the resource.
+ :type name: str
+ :param version: Version of the asset.
+ :type version: str
+ :param description: Description of the resource.
+ :type description: str
+ :param image: URI of a custom base image.
+ :type image: str
+ :param build: Docker build context to create the environment. Mutually exclusive with "image"
+ :type build: ~azure.ai.ml.entities._assets.environment.BuildContext
+ :param conda_file: Path to configuration file listing conda packages to install.
+ :type conda_file: typing.Union[str, os.PathLike]
+ :param tags: Tag dictionary. Tags can be added, removed, and updated.
+ :type tags: dict[str, str]
+ :param properties: The asset property dictionary.
+ :type properties: dict[str, str]
+ :param datastore: The datastore to upload the local artifact to.
+ :type datastore: str
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: dict
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START env_entity_create]
+ :end-before: [END env_entity_create]
+ :language: python
+ :dedent: 8
+ :caption: Create a Environment object.
+ """
+
+ def __init__(
+ self,
+ *,
+ name: Optional[str] = None,
+ version: Optional[str] = None,
+ description: Optional[str] = None,
+ image: Optional[str] = None,
+ build: Optional[BuildContext] = None,
+ conda_file: Optional[Union[str, os.PathLike, Dict]] = None,
+ tags: Optional[Dict] = None,
+ properties: Optional[Dict] = None,
+ datastore: Optional[str] = None,
+ **kwargs: Any,
+ ):
+ self._arm_type: str = ""
+ self.latest_version: str = "" # type: ignore[assignment]
+ self.image: Optional[str] = None
+ inference_config = kwargs.pop("inference_config", None)
+ os_type = kwargs.pop("os_type", None)
+ self._intellectual_property = kwargs.pop("intellectual_property", None)
+
+ super().__init__(
+ name=name,
+ version=version,
+ description=description,
+ tags=tags,
+ properties=properties,
+ **kwargs,
+ )
+
+ self.conda_file = conda_file
+ self.image = image
+ self.build = build
+ self.inference_config = inference_config
+ self.os_type = os_type
+ self._arm_type = ArmConstants.ENVIRONMENT_VERSION_TYPE
+ self._conda_file_path = (
+ _resolve_path(base_path=self.base_path, input=conda_file)
+ if isinstance(conda_file, (os.PathLike, str))
+ else None
+ )
+ self.path = None
+ self.datastore = datastore
+ self._upload_hash = None
+
+ self._translated_conda_file = None
+ if self.conda_file:
+ self._translated_conda_file = dump_yaml(self.conda_file, sort_keys=True) # service needs str representation
+
+ if self.build and self.build.path and not is_url(self.build.path):
+ path = Path(self.build.path)
+ if not path.is_absolute():
+ path = Path(self.base_path, path).resolve()
+ self.path = path
+
+ if self._is_anonymous:
+ if self.path:
+ self._ignore_file = get_ignore_file(path)
+ self._upload_hash = get_object_hash(path, self._ignore_file)
+ self._generate_anonymous_name_version(source="build")
+ elif self.image:
+ self._generate_anonymous_name_version(
+ source="image", conda_file=self._translated_conda_file, inference_config=self.inference_config
+ )
+
+ @property
+ def conda_file(self) -> Optional[Union[str, os.PathLike, Dict]]:
+ """Conda environment specification.
+
+ :return: Conda dependencies loaded from `conda_file` param.
+ :rtype: Optional[Union[str, os.PathLike]]
+ """
+ return self._conda_file
+
+ @conda_file.setter
+ def conda_file(self, value: Optional[Union[str, os.PathLike, Dict]]) -> None:
+ """Set conda environment specification.
+
+ :param value: A path to a local conda dependencies yaml file or a loaded yaml dictionary of dependencies.
+ :type value: Union[str, os.PathLike, Dict]
+ :return: None
+ """
+ if not isinstance(value, Dict):
+ value = _deserialize(self.base_path, value, is_conda=True)
+ self._conda_file = value
+
+ @classmethod
+ def _load(
+ cls,
+ data: Optional[dict] = None,
+ yaml_path: Optional[Union[os.PathLike, str]] = None,
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "Environment":
+ params_override = params_override or []
+ data = data or {}
+ context = {
+ BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"),
+ PARAMS_OVERRIDE_KEY: params_override,
+ }
+ res: Environment = load_from_dict(EnvironmentSchema, data, context, **kwargs)
+ return res
+
+ def _to_rest_object(self) -> EnvironmentVersion:
+ self.validate()
+ environment_version = EnvironmentVersionProperties()
+ if self.conda_file:
+ environment_version.conda_file = self._translated_conda_file
+ if self.image:
+ environment_version.image = self.image
+ if self.build:
+ environment_version.build = self.build._to_rest_object()
+ if self.os_type:
+ environment_version.os_type = self.os_type
+ if self.tags:
+ environment_version.tags = self.tags
+ if self._is_anonymous:
+ environment_version.is_anonymous = self._is_anonymous
+ if self.inference_config:
+ environment_version.inference_config = self.inference_config
+ if self.description:
+ environment_version.description = self.description
+ if self.properties:
+ environment_version.properties = self.properties
+
+ environment_version_resource = EnvironmentVersion(properties=environment_version)
+
+ return environment_version_resource
+
+ @classmethod
+ def _from_rest_object(cls, env_rest_object: EnvironmentVersion) -> "Environment":
+ rest_env_version = env_rest_object.properties
+ arm_id = AMLVersionedArmId(arm_id=env_rest_object.id)
+
+ environment = Environment(
+ id=env_rest_object.id,
+ name=arm_id.asset_name,
+ version=arm_id.asset_version,
+ description=rest_env_version.description,
+ tags=rest_env_version.tags,
+ creation_context=(
+ SystemData._from_rest_object(env_rest_object.system_data) if env_rest_object.system_data else None
+ ),
+ is_anonymous=rest_env_version.is_anonymous,
+ image=rest_env_version.image,
+ os_type=rest_env_version.os_type,
+ inference_config=rest_env_version.inference_config,
+ build=BuildContext._from_rest_object(rest_env_version.build) if rest_env_version.build else None,
+ properties=rest_env_version.properties,
+ intellectual_property=(
+ IntellectualProperty._from_rest_object(rest_env_version.intellectual_property)
+ if rest_env_version.intellectual_property
+ else None
+ ),
+ )
+
+ if rest_env_version.conda_file:
+ translated_conda_file = yaml.safe_load(rest_env_version.conda_file)
+ environment.conda_file = translated_conda_file
+ environment._translated_conda_file = rest_env_version.conda_file
+
+ return environment
+
+ @classmethod
+ def _from_container_rest_object(cls, env_container_rest_object: EnvironmentContainer) -> "Environment":
+ env = Environment(
+ name=env_container_rest_object.name,
+ version="1",
+ id=env_container_rest_object.id,
+ creation_context=SystemData._from_rest_object(env_container_rest_object.system_data),
+ )
+ env.latest_version = env_container_rest_object.properties.latest_version
+
+ # Setting version to None since if version is not provided it is defaulted to "1".
+ # This should go away once container concept is finalized.
+ env.version = None
+ return env
+
+ def _to_arm_resource_param(self, **kwargs: Any) -> Dict: # pylint: disable=unused-argument
+ properties = self._to_rest_object().properties
+
+ return {
+ self._arm_type: {
+ ArmConstants.NAME: self.name,
+ ArmConstants.VERSION: self.version,
+ ArmConstants.PROPERTIES_PARAMETER_NAME: self._serialize.body(properties, "EnvironmentVersion"),
+ }
+ }
+
+ def _to_dict(self) -> Dict:
+ res: dict = EnvironmentSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
+
+ def validate(self) -> None:
+ """Validate the environment by checking its name, image and build
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START env_entities_validate]
+ :end-before: [END env_entities_validate]
+ :language: python
+ :dedent: 8
+ :caption: Validate environment example.
+ """
+
+ if self.name is None:
+ msg = "Environment name is required"
+ err = ValidationException(
+ message=msg,
+ target=ErrorTarget.ENVIRONMENT,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.MISSING_FIELD,
+ )
+ log_and_raise_error(err)
+ if self.image is None and self.build is None:
+ msg = "Docker image or Dockerfile is required for environments"
+ err = ValidationException(
+ message=msg,
+ target=ErrorTarget.ENVIRONMENT,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.MISSING_FIELD,
+ )
+ log_and_raise_error(err)
+ if self.image and self.build:
+ msg = "Docker image or Dockerfile should be provided not both"
+ err = ValidationException(
+ message=msg,
+ target=ErrorTarget.ENVIRONMENT,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ log_and_raise_error(err)
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, Environment):
+ return NotImplemented
+ return (
+ self.name == other.name
+ and self.id == other.id
+ and self.version == other.version
+ and self.description == other.description
+ and self.tags == other.tags
+ and self.properties == other.properties
+ and self.base_path == other.base_path
+ and self.image == other.image
+ and self.build == other.build
+ and self.conda_file == other.conda_file
+ and self.inference_config == other.inference_config
+ and self._is_anonymous == other._is_anonymous
+ and self.os_type == other.os_type
+ and self._intellectual_property == other._intellectual_property
+ )
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
+
+ def _generate_anonymous_name_version(
+ self, source: str, conda_file: Optional[str] = None, inference_config: Optional[Dict] = None
+ ) -> None:
+ hash_str = ""
+ if source == "image":
+ hash_str = hash_str.join(get_md5_string(self.image))
+ if inference_config:
+ hash_str = hash_str.join(get_md5_string(yaml.dump(inference_config, sort_keys=True)))
+ if conda_file:
+ hash_str = hash_str.join(get_md5_string(conda_file))
+ if source == "build":
+ if self.build is not None and not self.build.dockerfile_path:
+ hash_str = hash_str.join(get_md5_string(self._upload_hash))
+ else:
+ if self.build is not None:
+ hash_str = hash_str.join(get_md5_string(self._upload_hash)).join(
+ get_md5_string(self.build.dockerfile_path)
+ )
+ version_hash = get_md5_string(hash_str)
+ self.version = version_hash
+ self.name = ANONYMOUS_ENV_NAME
+
+ def _localize(self, base_path: str) -> None:
+ """Called on an asset got from service to clean up remote attributes like id, creation_context, etc. and update
+ base_path.
+
+ :param base_path: The base path
+ :type base_path: str
+ """
+ if not getattr(self, "id", None):
+ raise ValueError("Only remote asset can be localize but got a {} without id.".format(type(self)))
+ self._id = None
+ self._creation_context = None
+ self._base_path = base_path
+ if self._is_anonymous:
+ self.name, self.version = None, None
+
+
+# TODO: Remove _DockerBuild and _DockerConfiguration classes once local endpoint moves to using updated env
+class _DockerBuild:
+ """Helper class to encapsulate Docker build info for Environment."""
+
+ def __init__(
+ self,
+ base_path: Optional[Union[str, os.PathLike]] = None,
+ dockerfile: Optional[str] = None,
+ ):
+ self.dockerfile = _deserialize(base_path, dockerfile)
+
+ @classmethod
+ def _to_rest_object(cls) -> None:
+ return None
+
+ def _from_rest_object(self, rest_obj: Any) -> None:
+ self.dockerfile = rest_obj.dockerfile
+
+ def __eq__(self, other: Any) -> bool:
+ res: bool = self.dockerfile == other.dockerfile
+ return res
+
+ def __ne__(self, other: Any) -> bool:
+ return not self.__eq__(other)
+
+
+def _deserialize(
+ base_path: Optional[Union[str, os.PathLike]],
+ input: Optional[Union[str, os.PathLike, Dict]], # pylint: disable=redefined-builtin
+ is_conda: bool = False,
+) -> Optional[Union[str, os.PathLike, Dict]]:
+ """Deserialize user input files for conda and docker.
+
+ :param base_path: The base path for all files supplied by user.
+ :type base_path: Union[str, os.PathLike]
+ :param input: Input to be deserialized. Will be either dictionary of file contents or path to file.
+ :type input: Union[str, os.PathLike, Dict[str, str]]
+ :param is_conda: If file is conda file, it will be returned as dictionary
+ :type is_conda: bool
+ :return: The deserialized data
+ :rtype: Union[str, Dict]
+ """
+
+ if input:
+ path = _resolve_path(base_path=base_path, input=input)
+ data: Union[str, Dict] = ""
+ if is_conda:
+ data = load_yaml(path)
+ else:
+ data = load_file(path)
+ return data
+ return input
+
+
+def _resolve_path(base_path: Any, input: Any) -> Path: # pylint: disable=redefined-builtin
+ """Deserialize user input files for conda and docker.
+
+ :param base_path: The base path for all files supplied by user.
+ :type base_path: Union[str, os.PathLike]
+ :param input: Input to be deserialized. Will be either dictionary of file contents or path to file.
+ :type input: Union[str, os.PathLike, Dict[str, str]]
+ :return: The resolved path
+ :rtype: Path
+ """
+
+ path = Path(input)
+ if not path.is_absolute():
+ path = Path(base_path, path).resolve()
+ return path
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/federated_learning_silo.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/federated_learning_silo.py
new file mode 100644
index 00000000..8255f887
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/federated_learning_silo.py
@@ -0,0 +1,123 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+# TODO determine where this file should live.
+
+from os import PathLike
+from typing import IO, Any, AnyStr, Dict, List, Optional, Union
+
+from azure.ai.ml import Input
+from azure.ai.ml._utils.utils import dump_yaml_to_file, load_yaml
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+
+
+# Entity representation of a federated learning silo.
+# Used by Federated Learning DSL nodes as inputs for creating
+# FL subgraphs in pipelines.
+# The functionality of this entity is limited, and it exists mostly
+# To simplify the process of loading and validating these objects from YAML.
+class FederatedLearningSilo:
+ def __init__(
+ self,
+ *,
+ compute: str,
+ datastore: str,
+ inputs: Dict[str, Input],
+ ):
+ """
+ A pseudo-entity that represents a federated learning silo, which is an isolated compute with its own
+ datastore and input targets. This is meant to be used in conjunction with the
+ Federated Learning DSL node to create federated learning pipelines. This does NOT represent any specific
+ AML resource, and is instead merely meant to simply client-side experiences with managing FL data distribution.
+ Standard usage involves the "load_list" classmethod to load a list of these objects from YAML, which serves
+ as a necessary input for FL processes.
+
+
+ :param compute: The resource id of a compute.
+ :type compute: str
+ :param datastore: The resource id of a datastore.
+ :type datastore: str
+ :param inputs: A dictionary of input entities that exist in the previously specified datastore.
+ The keys of this dictionary are the keyword names that these inputs should be entered into.
+ :type inputs: dict[str, Input]
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: dict
+ """
+ self.compute = compute
+ self.datastore = datastore
+ self.inputs = inputs
+
+ def dump(
+ self,
+ dest: Union[str, PathLike, IO[AnyStr]],
+ # pylint: disable=unused-argument
+ **kwargs: Any,
+ ) -> None:
+ """Dump the Federated Learning Silo spec into a file in yaml format.
+
+ :param dest: Either
+ * A path to a local file
+ * A writeable file-like object
+ :type dest: Union[str, PathLike, IO[AnyStr]]
+ """
+ yaml_serialized = self._to_dict()
+ dump_yaml_to_file(dest, yaml_serialized, default_flow_style=False)
+
+ def _to_dict(self) -> Dict:
+ # JIT import to avoid experimental warnings on unrelated calls
+ from azure.ai.ml._schema.assets.federated_learning_silo import FederatedLearningSiloSchema
+
+ schema = FederatedLearningSiloSchema(context={BASE_PATH_CONTEXT_KEY: "./"})
+
+ return Dict(schema.dump(self))
+
+ @classmethod
+ def _load_from_dict(cls, silo_dict: dict) -> "FederatedLearningSilo":
+ data_input = silo_dict.get("inputs", {})
+ return FederatedLearningSilo(compute=silo_dict["compute"], datastore=silo_dict["datastore"], inputs=data_input)
+
+ # simple load based off mltable metadata loading style
+ @classmethod
+ def _load(
+ cls,
+ yaml_path: Optional[Union[PathLike, str]] = None,
+ ) -> "FederatedLearningSilo":
+ yaml_dict = load_yaml(yaml_path)
+ return FederatedLearningSilo._load_from_dict(silo_dict=yaml_dict)
+
+ @classmethod
+ def load_list(
+ cls,
+ *,
+ yaml_path: Optional[Union[PathLike, str]],
+ list_arg: str,
+ ) -> List["FederatedLearningSilo"]:
+ """
+ Loads a list of federated learning silos from YAML. This is the expected entry point
+ for this class; load a list of these, then supply them to the federated learning DSL
+ package node in order to produce an FL pipeline.
+
+ The structure of the supplied YAML file is assumed to be a list of FL silos under the
+ name specified by the list_arg input, as shown below.
+
+ list_arg:
+ - silo 1 ...
+ - silo 2 ...
+
+ :keyword yaml_path: A path leading to a local YAML file which contains a list of
+ FederatedLearningSilo objects.
+ :paramtype yaml_path: Optional[Union[PathLike, str]]
+ :keyword list_arg: A string that names the top-level value which contains the list
+ of FL silos.
+ :paramtype list_arg: str
+ :return: The list of federated learning silos
+ :rtype: List[FederatedLearningSilo]
+ """
+ yaml_dict = load_yaml(yaml_path)
+ return [
+ FederatedLearningSilo._load_from_dict(silo_dict=silo_yaml_dict) for silo_yaml_dict in yaml_dict[list_arg]
+ ]
+
+ # There are no to/from rest object functions because this object has no
+ # rest object equivalent. Any conversions should be done as part of the
+ # to/from rest object functions of OTHER entity objects.
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/intellectual_property.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/intellectual_property.py
new file mode 100644
index 00000000..58b96a1b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/intellectual_property.py
@@ -0,0 +1,49 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Any, Optional
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import IntellectualProperty as RestIntellectualProperty
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants._assets import IPProtectionLevel
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+
+@experimental
+class IntellectualProperty(RestTranslatableMixin):
+ """Intellectual property settings definition.
+
+ :keyword publisher: The publisher's name.
+ :paramtype publisher: Optional[str]
+ :keyword protection_level: Asset Protection Level. Accepted values are IPProtectionLevel.ALL ("all") and
+ IPProtectionLevel.NONE ("none"). Defaults to IPProtectionLevel.ALL ("all").
+ :paramtype protection_level: Optional[Union[str, ~azure.ai.ml.constants.IPProtectionLevel]]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START intellectual_property_configuration]
+ :end-before: [END intellectual_property_configuration]
+ :language: python
+ :dedent: 8
+ :caption: Configuring intellectual property settings on a CommandComponent.
+ """
+
+ def __init__(
+ self, *, publisher: Optional[str] = None, protection_level: IPProtectionLevel = IPProtectionLevel.ALL
+ ) -> None:
+ self.publisher = publisher
+ self.protection_level = protection_level
+
+ def _to_rest_object(self) -> RestIntellectualProperty:
+ return RestIntellectualProperty(publisher=self.publisher, protection_level=self.protection_level)
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestIntellectualProperty) -> "IntellectualProperty":
+ return cls(publisher=obj.publisher, protection_level=obj.protection_level)
+
+ def __eq__(self, other: Any) -> bool:
+ if not isinstance(other, IntellectualProperty):
+ return NotImplemented
+ return self.publisher == other.publisher and self.protection_level == other.protection_level
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/workspace_asset_reference.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/workspace_asset_reference.py
new file mode 100644
index 00000000..1e7d1ba2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_assets/workspace_asset_reference.py
@@ -0,0 +1,87 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import os
+from pathlib import Path
+from typing import Any, Dict, Optional, Union
+
+from azure.ai.ml._restclient.v2021_10_01_dataplanepreview.models import (
+ ResourceManagementAssetReferenceData,
+ ResourceManagementAssetReferenceDetails,
+)
+from azure.ai.ml._schema import WorkspaceAssetReferenceSchema
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY
+from azure.ai.ml.entities._assets.asset import Asset
+from azure.ai.ml.entities._util import load_from_dict
+
+
+class WorkspaceAssetReference(Asset):
+ """Workspace Model Reference.
+
+ This is for SDK internal use only, might be deprecated in the future.
+ :param name: Model name
+ :type name: str
+ :param version: Model version
+ :type version: str
+ :param asset_id: Model asset id
+ :type version: str
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: dict
+ """
+
+ def __init__(
+ self,
+ *,
+ name: Optional[str] = None,
+ version: Optional[str] = None,
+ asset_id: Optional[str] = None,
+ properties: Optional[Dict] = None,
+ **kwargs: Any,
+ ):
+ super().__init__(
+ name=name,
+ version=version,
+ properties=properties,
+ **kwargs,
+ )
+ self.asset_id = asset_id
+
+ @classmethod
+ def _load(
+ cls: Any,
+ data: Optional[dict] = None,
+ yaml_path: Optional[Union[os.PathLike, str]] = None,
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "WorkspaceAssetReference":
+ data = data or {}
+ params_override = params_override or []
+ context = {
+ BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"),
+ PARAMS_OVERRIDE_KEY: params_override,
+ }
+ res: WorkspaceAssetReference = load_from_dict(WorkspaceAssetReferenceSchema, data, context, **kwargs)
+ return res
+
+ def _to_rest_object(self) -> ResourceManagementAssetReferenceData:
+ resource_management_details = ResourceManagementAssetReferenceDetails(
+ destination_name=self.name,
+ destination_version=self.version,
+ source_asset_id=self.asset_id,
+ )
+ resource_management = ResourceManagementAssetReferenceData(properties=resource_management_details)
+ return resource_management
+
+ @classmethod
+ def _from_rest_object(cls, resource_object: ResourceManagementAssetReferenceData) -> "WorkspaceAssetReference":
+ resource_management = WorkspaceAssetReference(
+ name=resource_object.properties.destination_name,
+ version=resource_object.properties.destination_version,
+ asset_id=resource_object.properties.source_asset_id,
+ )
+
+ return resource_management
+
+ def _to_dict(self) -> Dict:
+ return dict(WorkspaceAssetReferenceSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self))
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/__init__.py
new file mode 100644
index 00000000..8dfc61b2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/__init__.py
@@ -0,0 +1,20 @@
+# coding=utf-8
+# --------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for license information.
+# Code generated by Microsoft (R) Python Code Generator.
+# Changes may cause incorrect behavior and will be lost if the code is regenerated.
+# --------------------------------------------------------------------------
+
+
+try:
+ from ._patch import __all__ as _patch_all
+ from ._patch import * # pylint: disable=unused-wildcard-import
+except ImportError:
+ _patch_all = []
+from ._patch import patch_sdk as _patch_sdk
+
+__all__ = []
+__all__.extend([p for p in _patch_all if p not in __all__])
+
+_patch_sdk()
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/_model_base.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/_model_base.py
new file mode 100644
index 00000000..5bf680b4
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/_model_base.py
@@ -0,0 +1,881 @@
+# coding=utf-8
+# --------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+# pylint: disable=protected-access, broad-except
+
+import calendar
+import decimal
+import functools
+import sys
+import logging
+import base64
+import re
+import copy
+import typing
+import enum
+import email.utils
+from datetime import datetime, date, time, timedelta, timezone
+from json import JSONEncoder
+from typing_extensions import Self
+import isodate
+from azure.core.exceptions import DeserializationError
+from azure.core import CaseInsensitiveEnumMeta
+from azure.core.pipeline import PipelineResponse
+from azure.core.serialization import NULL
+
+if sys.version_info >= (3, 9):
+ from collections.abc import MutableMapping
+else:
+ from typing import MutableMapping
+
+_LOGGER = logging.getLogger(__name__)
+
+__all__ = ["SdkJSONEncoder", "Model", "rest_field", "rest_discriminator"]
+
+TZ_UTC = timezone.utc
+_T = typing.TypeVar("_T")
+
+
+def _timedelta_as_isostr(td: timedelta) -> str:
+ """Converts a datetime.timedelta object into an ISO 8601 formatted string, e.g. 'P4DT12H30M05S'
+
+ Function adapted from the Tin Can Python project: https://github.com/RusticiSoftware/TinCanPython
+
+ :param timedelta td: The timedelta to convert
+ :rtype: str
+ :return: ISO8601 version of this timedelta
+ """
+
+ # Split seconds to larger units
+ seconds = td.total_seconds()
+ minutes, seconds = divmod(seconds, 60)
+ hours, minutes = divmod(minutes, 60)
+ days, hours = divmod(hours, 24)
+
+ days, hours, minutes = list(map(int, (days, hours, minutes)))
+ seconds = round(seconds, 6)
+
+ # Build date
+ date_str = ""
+ if days:
+ date_str = "%sD" % days
+
+ if hours or minutes or seconds:
+ # Build time
+ time_str = "T"
+
+ # Hours
+ bigger_exists = date_str or hours
+ if bigger_exists:
+ time_str += "{:02}H".format(hours)
+
+ # Minutes
+ bigger_exists = bigger_exists or minutes
+ if bigger_exists:
+ time_str += "{:02}M".format(minutes)
+
+ # Seconds
+ try:
+ if seconds.is_integer():
+ seconds_string = "{:02}".format(int(seconds))
+ else:
+ # 9 chars long w/ leading 0, 6 digits after decimal
+ seconds_string = "%09.6f" % seconds
+ # Remove trailing zeros
+ seconds_string = seconds_string.rstrip("0")
+ except AttributeError: # int.is_integer() raises
+ seconds_string = "{:02}".format(seconds)
+
+ time_str += "{}S".format(seconds_string)
+ else:
+ time_str = ""
+
+ return "P" + date_str + time_str
+
+
+def _serialize_bytes(o, format: typing.Optional[str] = None) -> str:
+ encoded = base64.b64encode(o).decode()
+ if format == "base64url":
+ return encoded.strip("=").replace("+", "-").replace("/", "_")
+ return encoded
+
+
+def _serialize_datetime(o, format: typing.Optional[str] = None):
+ if hasattr(o, "year") and hasattr(o, "hour"):
+ if format == "rfc7231":
+ return email.utils.format_datetime(o, usegmt=True)
+ if format == "unix-timestamp":
+ return int(calendar.timegm(o.utctimetuple()))
+
+ # astimezone() fails for naive times in Python 2.7, so make make sure o is aware (tzinfo is set)
+ if not o.tzinfo:
+ iso_formatted = o.replace(tzinfo=TZ_UTC).isoformat()
+ else:
+ iso_formatted = o.astimezone(TZ_UTC).isoformat()
+ # Replace the trailing "+00:00" UTC offset with "Z" (RFC 3339: https://www.ietf.org/rfc/rfc3339.txt)
+ return iso_formatted.replace("+00:00", "Z")
+ # Next try datetime.date or datetime.time
+ return o.isoformat()
+
+
+def _is_readonly(p):
+ try:
+ return p._visibility == ["read"] # pylint: disable=protected-access
+ except AttributeError:
+ return False
+
+
+class SdkJSONEncoder(JSONEncoder):
+ """A JSON encoder that's capable of serializing datetime objects and bytes."""
+
+ def __init__(self, *args, exclude_readonly: bool = False, format: typing.Optional[str] = None, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.exclude_readonly = exclude_readonly
+ self.format = format
+
+ def default(self, o): # pylint: disable=too-many-return-statements
+ if _is_model(o):
+ if self.exclude_readonly:
+ readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)]
+ return {k: v for k, v in o.items() if k not in readonly_props}
+ return dict(o.items())
+ try:
+ return super(SdkJSONEncoder, self).default(o)
+ except TypeError:
+ if isinstance(o, type(NULL)):
+ return None
+ if isinstance(o, decimal.Decimal):
+ return float(o)
+ if isinstance(o, (bytes, bytearray)):
+ return _serialize_bytes(o, self.format)
+ try:
+ # First try datetime.datetime
+ return _serialize_datetime(o, self.format)
+ except AttributeError:
+ pass
+ # Last, try datetime.timedelta
+ try:
+ return _timedelta_as_isostr(o)
+ except AttributeError:
+ # This will be raised when it hits value.total_seconds in the method above
+ pass
+ return super(SdkJSONEncoder, self).default(o)
+
+
+_VALID_DATE = re.compile(r"\d{4}[-]\d{2}[-]\d{2}T\d{2}:\d{2}:\d{2}" + r"\.?\d*Z?[-+]?[\d{2}]?:?[\d{2}]?")
+_VALID_RFC7231 = re.compile(
+ r"(Mon|Tue|Wed|Thu|Fri|Sat|Sun),\s\d{2}\s"
+ r"(Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)\s\d{4}\s\d{2}:\d{2}:\d{2}\sGMT"
+)
+
+
+def _deserialize_datetime(attr: typing.Union[str, datetime]) -> datetime:
+ """Deserialize ISO-8601 formatted string into Datetime object.
+
+ :param str attr: response string to be deserialized.
+ :rtype: ~datetime.datetime
+ :returns: The datetime object from that input
+ """
+ if isinstance(attr, datetime):
+ # i'm already deserialized
+ return attr
+ attr = attr.upper()
+ match = _VALID_DATE.match(attr)
+ if not match:
+ raise ValueError("Invalid datetime string: " + attr)
+
+ check_decimal = attr.split(".")
+ if len(check_decimal) > 1:
+ decimal_str = ""
+ for digit in check_decimal[1]:
+ if digit.isdigit():
+ decimal_str += digit
+ else:
+ break
+ if len(decimal_str) > 6:
+ attr = attr.replace(decimal_str, decimal_str[0:6])
+
+ date_obj = isodate.parse_datetime(attr)
+ test_utc = date_obj.utctimetuple()
+ if test_utc.tm_year > 9999 or test_utc.tm_year < 1:
+ raise OverflowError("Hit max or min date")
+ return date_obj
+
+
+def _deserialize_datetime_rfc7231(attr: typing.Union[str, datetime]) -> datetime:
+ """Deserialize RFC7231 formatted string into Datetime object.
+
+ :param str attr: response string to be deserialized.
+ :rtype: ~datetime.datetime
+ :returns: The datetime object from that input
+ """
+ if isinstance(attr, datetime):
+ # i'm already deserialized
+ return attr
+ match = _VALID_RFC7231.match(attr)
+ if not match:
+ raise ValueError("Invalid datetime string: " + attr)
+
+ return email.utils.parsedate_to_datetime(attr)
+
+
+def _deserialize_datetime_unix_timestamp(attr: typing.Union[float, datetime]) -> datetime:
+ """Deserialize unix timestamp into Datetime object.
+
+ :param str attr: response string to be deserialized.
+ :rtype: ~datetime.datetime
+ :returns: The datetime object from that input
+ """
+ if isinstance(attr, datetime):
+ # i'm already deserialized
+ return attr
+ return datetime.fromtimestamp(attr, TZ_UTC)
+
+
+def _deserialize_date(attr: typing.Union[str, date]) -> date:
+ """Deserialize ISO-8601 formatted string into Date object.
+ :param str attr: response string to be deserialized.
+ :rtype: date
+ :returns: The date object from that input
+ """
+ # This must NOT use defaultmonth/defaultday. Using None ensure this raises an exception.
+ if isinstance(attr, date):
+ return attr
+ return isodate.parse_date(attr, defaultmonth=None, defaultday=None) # type: ignore
+
+
+def _deserialize_time(attr: typing.Union[str, time]) -> time:
+ """Deserialize ISO-8601 formatted string into time object.
+
+ :param str attr: response string to be deserialized.
+ :rtype: datetime.time
+ :returns: The time object from that input
+ """
+ if isinstance(attr, time):
+ return attr
+ return isodate.parse_time(attr)
+
+
+def _deserialize_bytes(attr):
+ if isinstance(attr, (bytes, bytearray)):
+ return attr
+ return bytes(base64.b64decode(attr))
+
+
+def _deserialize_bytes_base64(attr):
+ if isinstance(attr, (bytes, bytearray)):
+ return attr
+ padding = "=" * (3 - (len(attr) + 3) % 4) # type: ignore
+ attr = attr + padding # type: ignore
+ encoded = attr.replace("-", "+").replace("_", "/")
+ return bytes(base64.b64decode(encoded))
+
+
+def _deserialize_duration(attr):
+ if isinstance(attr, timedelta):
+ return attr
+ return isodate.parse_duration(attr)
+
+
+def _deserialize_decimal(attr):
+ if isinstance(attr, decimal.Decimal):
+ return attr
+ return decimal.Decimal(str(attr))
+
+
+_DESERIALIZE_MAPPING = {
+ datetime: _deserialize_datetime,
+ date: _deserialize_date,
+ time: _deserialize_time,
+ bytes: _deserialize_bytes,
+ bytearray: _deserialize_bytes,
+ timedelta: _deserialize_duration,
+ typing.Any: lambda x: x,
+ decimal.Decimal: _deserialize_decimal,
+}
+
+_DESERIALIZE_MAPPING_WITHFORMAT = {
+ "rfc3339": _deserialize_datetime,
+ "rfc7231": _deserialize_datetime_rfc7231,
+ "unix-timestamp": _deserialize_datetime_unix_timestamp,
+ "base64": _deserialize_bytes,
+ "base64url": _deserialize_bytes_base64,
+}
+
+
+def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = None):
+ if rf and rf._format:
+ return _DESERIALIZE_MAPPING_WITHFORMAT.get(rf._format)
+ return _DESERIALIZE_MAPPING.get(annotation)
+
+
+def _get_type_alias_type(module_name: str, alias_name: str):
+ types = {
+ k: v
+ for k, v in sys.modules[module_name].__dict__.items()
+ if isinstance(v, typing._GenericAlias) # type: ignore
+ }
+ if alias_name not in types:
+ return alias_name
+ return types[alias_name]
+
+
+def _get_model(module_name: str, model_name: str):
+ models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)}
+ module_end = module_name.rsplit(".", 1)[0]
+ models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)})
+ if isinstance(model_name, str):
+ model_name = model_name.split(".")[-1]
+ if model_name not in models:
+ return model_name
+ return models[model_name]
+
+
+_UNSET = object()
+
+
+class _MyMutableMapping(MutableMapping[str, typing.Any]): # pylint: disable=unsubscriptable-object
+ def __init__(self, data: typing.Dict[str, typing.Any]) -> None:
+ self._data = copy.deepcopy(data)
+
+ def __contains__(self, key: typing.Any) -> bool:
+ return key in self._data
+
+ def __getitem__(self, key: str) -> typing.Any:
+ return self._data.__getitem__(key)
+
+ def __setitem__(self, key: str, value: typing.Any) -> None:
+ self._data.__setitem__(key, value)
+
+ def __delitem__(self, key: str) -> None:
+ self._data.__delitem__(key)
+
+ def __iter__(self) -> typing.Iterator[typing.Any]:
+ return self._data.__iter__()
+
+ def __len__(self) -> int:
+ return self._data.__len__()
+
+ def __ne__(self, other: typing.Any) -> bool:
+ return not self.__eq__(other)
+
+ def keys(self) -> typing.KeysView[str]:
+ return self._data.keys()
+
+ def values(self) -> typing.ValuesView[typing.Any]:
+ return self._data.values()
+
+ def items(self) -> typing.ItemsView[str, typing.Any]:
+ return self._data.items()
+
+ def get(self, key: str, default: typing.Any = None) -> typing.Any:
+ try:
+ return self[key]
+ except KeyError:
+ return default
+
+ @typing.overload
+ def pop(self, key: str) -> typing.Any: ...
+
+ @typing.overload
+ def pop(self, key: str, default: _T) -> _T: ...
+
+ @typing.overload
+ def pop(self, key: str, default: typing.Any) -> typing.Any: ...
+
+ def pop(self, key: str, default: typing.Any = _UNSET) -> typing.Any:
+ if default is _UNSET:
+ return self._data.pop(key)
+ return self._data.pop(key, default)
+
+ def popitem(self) -> typing.Tuple[str, typing.Any]:
+ return self._data.popitem()
+
+ def clear(self) -> None:
+ self._data.clear()
+
+ def update(self, *args: typing.Any, **kwargs: typing.Any) -> None:
+ self._data.update(*args, **kwargs)
+
+ @typing.overload
+ def setdefault(self, key: str, default: None = None) -> None: ...
+
+ @typing.overload
+ def setdefault(self, key: str, default: typing.Any) -> typing.Any: ...
+
+ def setdefault(self, key: str, default: typing.Any = _UNSET) -> typing.Any:
+ if default is _UNSET:
+ return self._data.setdefault(key)
+ return self._data.setdefault(key, default)
+
+ def __eq__(self, other: typing.Any) -> bool:
+ try:
+ other_model = self.__class__(other)
+ except Exception:
+ return False
+ return self._data == other_model._data
+
+ def __repr__(self) -> str:
+ return str(self._data)
+
+
+def _is_model(obj: typing.Any) -> bool:
+ return getattr(obj, "_is_model", False)
+
+
+def _serialize(o, format: typing.Optional[str] = None): # pylint: disable=too-many-return-statements
+ if isinstance(o, list):
+ return [_serialize(x, format) for x in o]
+ if isinstance(o, dict):
+ return {k: _serialize(v, format) for k, v in o.items()}
+ if isinstance(o, set):
+ return {_serialize(x, format) for x in o}
+ if isinstance(o, tuple):
+ return tuple(_serialize(x, format) for x in o)
+ if isinstance(o, (bytes, bytearray)):
+ return _serialize_bytes(o, format)
+ if isinstance(o, decimal.Decimal):
+ return float(o)
+ if isinstance(o, enum.Enum):
+ return o.value
+ try:
+ # First try datetime.datetime
+ return _serialize_datetime(o, format)
+ except AttributeError:
+ pass
+ # Last, try datetime.timedelta
+ try:
+ return _timedelta_as_isostr(o)
+ except AttributeError:
+ # This will be raised when it hits value.total_seconds in the method above
+ pass
+ return o
+
+
+def _get_rest_field(
+ attr_to_rest_field: typing.Dict[str, "_RestField"], rest_name: str
+) -> typing.Optional["_RestField"]:
+ try:
+ return next(rf for rf in attr_to_rest_field.values() if rf._rest_name == rest_name)
+ except StopIteration:
+ return None
+
+
+def _create_value(rf: typing.Optional["_RestField"], value: typing.Any) -> typing.Any:
+ if not rf:
+ return _serialize(value, None)
+ if rf._is_multipart_file_input:
+ return value
+ if rf._is_model:
+ return _deserialize(rf._type, value)
+ return _serialize(value, rf._format)
+
+
+class Model(_MyMutableMapping):
+ _is_model = True
+
+ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None:
+ class_name = self.__class__.__name__
+ if len(args) > 1:
+ raise TypeError(f"{class_name}.__init__() takes 2 positional arguments but {len(args) + 1} were given")
+ dict_to_pass = {
+ rest_field._rest_name: rest_field._default
+ for rest_field in self._attr_to_rest_field.values()
+ if rest_field._default is not _UNSET
+ }
+ if args:
+ dict_to_pass.update(
+ {k: _create_value(_get_rest_field(self._attr_to_rest_field, k), v) for k, v in args[0].items()}
+ )
+ else:
+ non_attr_kwargs = [k for k in kwargs if k not in self._attr_to_rest_field]
+ if non_attr_kwargs:
+ # actual type errors only throw the first wrong keyword arg they see, so following that.
+ raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'")
+ dict_to_pass.update(
+ {
+ self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v)
+ for k, v in kwargs.items()
+ if v is not None
+ }
+ )
+ super().__init__(dict_to_pass)
+
+ def copy(self) -> "Model":
+ return Model(self.__dict__)
+
+ def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> Self:
+ # we know the last three classes in mro are going to be 'Model', 'dict', and 'object'
+ mros = cls.__mro__[:-3][::-1] # ignore model, dict, and object parents, and reverse the mro order
+ attr_to_rest_field: typing.Dict[str, _RestField] = { # map attribute name to rest_field property
+ k: v for mro_class in mros for k, v in mro_class.__dict__.items() if k[0] != "_" and hasattr(v, "_type")
+ }
+ annotations = {
+ k: v
+ for mro_class in mros
+ if hasattr(mro_class, "__annotations__")
+ for k, v in mro_class.__annotations__.items()
+ }
+ for attr, rf in attr_to_rest_field.items():
+ rf._module = cls.__module__
+ if not rf._type:
+ rf._type = rf._get_deserialize_callable_from_annotation(annotations.get(attr, None))
+ if not rf._rest_name_input:
+ rf._rest_name_input = attr
+ cls._attr_to_rest_field: typing.Dict[str, _RestField] = dict(attr_to_rest_field.items())
+
+ return super().__new__(cls) # pylint: disable=no-value-for-parameter
+
+ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None:
+ for base in cls.__bases__:
+ if hasattr(base, "__mapping__"):
+ base.__mapping__[discriminator or cls.__name__] = cls # type: ignore
+
+ @classmethod
+ def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]:
+ for v in cls.__dict__.values():
+ if isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators:
+ return v._rest_name # pylint: disable=protected-access
+ return None
+
+ @classmethod
+ def _deserialize(cls, data, exist_discriminators):
+ if not hasattr(cls, "__mapping__"):
+ return cls(data)
+ discriminator = cls._get_discriminator(exist_discriminators)
+ exist_discriminators.append(discriminator)
+ mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pyright: ignore
+ if mapped_cls == cls:
+ return cls(data)
+ return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access
+
+ def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]:
+ """Return a dict that can be JSONify using json.dump.
+
+ :keyword bool exclude_readonly: Whether to remove the readonly properties.
+ :returns: A dict JSON compatible object
+ :rtype: dict
+ """
+
+ result = {}
+ if exclude_readonly:
+ readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)]
+ for k, v in self.items():
+ if (
+ exclude_readonly
+ and k in readonly_props # pyright: ignore # pylint: disable=possibly-used-before-assignment
+ ):
+ continue
+ is_multipart_file_input = False
+ try:
+ is_multipart_file_input = next(
+ rf for rf in self._attr_to_rest_field.values() if rf._rest_name == k
+ )._is_multipart_file_input
+ except StopIteration:
+ pass
+ result[k] = v if is_multipart_file_input else Model._as_dict_value(v, exclude_readonly=exclude_readonly)
+ return result
+
+ @staticmethod
+ def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any:
+ if v is None or isinstance(v, type(NULL)):
+ return None
+ if isinstance(v, (list, tuple, set)):
+ return type(v)(Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v)
+ if isinstance(v, dict):
+ return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()}
+ return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v
+
+
+def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj):
+ if _is_model(obj):
+ return obj
+ return _deserialize(model_deserializer, obj)
+
+
+def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Callable], obj):
+ if obj is None:
+ return obj
+ return _deserialize_with_callable(if_obj_deserializer, obj)
+
+
+def _deserialize_with_union(deserializers, obj):
+ for deserializer in deserializers:
+ try:
+ return _deserialize(deserializer, obj)
+ except DeserializationError:
+ pass
+ raise DeserializationError()
+
+
+def _deserialize_dict(
+ value_deserializer: typing.Optional[typing.Callable],
+ module: typing.Optional[str],
+ obj: typing.Dict[typing.Any, typing.Any],
+):
+ if obj is None:
+ return obj
+ return {k: _deserialize(value_deserializer, v, module) for k, v in obj.items()}
+
+
+def _deserialize_multiple_sequence(
+ entry_deserializers: typing.List[typing.Optional[typing.Callable]],
+ module: typing.Optional[str],
+ obj,
+):
+ if obj is None:
+ return obj
+ return type(obj)(_deserialize(deserializer, entry, module) for entry, deserializer in zip(obj, entry_deserializers))
+
+
+def _deserialize_sequence(
+ deserializer: typing.Optional[typing.Callable],
+ module: typing.Optional[str],
+ obj,
+):
+ if obj is None:
+ return obj
+ return type(obj)(_deserialize(deserializer, entry, module) for entry in obj)
+
+
+def _get_deserialize_callable_from_annotation( # pylint: disable=R0911
+ annotation: typing.Any,
+ module: typing.Optional[str],
+ rf: typing.Optional["_RestField"] = None,
+) -> typing.Optional[typing.Callable[[typing.Any], typing.Any]]:
+ if not annotation or annotation in [int, float]:
+ return None
+
+ # is it a type alias?
+ if isinstance(annotation, str):
+ if module is not None:
+ annotation = _get_type_alias_type(module, annotation)
+
+ # is it a forward ref / in quotes?
+ if isinstance(annotation, (str, typing.ForwardRef)):
+ try:
+ model_name = annotation.__forward_arg__ # type: ignore
+ except AttributeError:
+ model_name = annotation
+ if module is not None:
+ annotation = _get_model(module, model_name)
+
+ try:
+ if module and _is_model(annotation):
+ if rf:
+ rf._is_model = True
+
+ return functools.partial(_deserialize_model, annotation) # pyright: ignore
+ except Exception:
+ pass
+
+ # is it a literal?
+ try:
+ if annotation.__origin__ is typing.Literal: # pyright: ignore
+ return None
+ except AttributeError:
+ pass
+
+ # is it optional?
+ try:
+ if any(a for a in annotation.__args__ if a == type(None)): # pyright: ignore
+ if_obj_deserializer = _get_deserialize_callable_from_annotation(
+ next(a for a in annotation.__args__ if a != type(None)), module, rf # pyright: ignore
+ )
+
+ return functools.partial(_deserialize_with_optional, if_obj_deserializer)
+ except AttributeError:
+ pass
+
+ if getattr(annotation, "__origin__", None) is typing.Union:
+ # initial ordering is we make `string` the last deserialization option, because it is often them most generic
+ deserializers = [
+ _get_deserialize_callable_from_annotation(arg, module, rf)
+ for arg in sorted(
+ annotation.__args__, key=lambda x: hasattr(x, "__name__") and x.__name__ == "str" # pyright: ignore
+ )
+ ]
+
+ return functools.partial(_deserialize_with_union, deserializers)
+
+ try:
+ if annotation._name == "Dict": # pyright: ignore
+ value_deserializer = _get_deserialize_callable_from_annotation(
+ annotation.__args__[1], module, rf # pyright: ignore
+ )
+
+ return functools.partial(
+ _deserialize_dict,
+ value_deserializer,
+ module,
+ )
+ except (AttributeError, IndexError):
+ pass
+ try:
+ if annotation._name in ["List", "Set", "Tuple", "Sequence"]: # pyright: ignore
+ if len(annotation.__args__) > 1: # pyright: ignore
+
+ entry_deserializers = [
+ _get_deserialize_callable_from_annotation(dt, module, rf)
+ for dt in annotation.__args__ # pyright: ignore
+ ]
+ return functools.partial(_deserialize_multiple_sequence, entry_deserializers, module)
+ deserializer = _get_deserialize_callable_from_annotation(
+ annotation.__args__[0], module, rf # pyright: ignore
+ )
+
+ return functools.partial(_deserialize_sequence, deserializer, module)
+ except (TypeError, IndexError, AttributeError, SyntaxError):
+ pass
+
+ def _deserialize_default(
+ deserializer,
+ obj,
+ ):
+ if obj is None:
+ return obj
+ try:
+ return _deserialize_with_callable(deserializer, obj)
+ except Exception:
+ pass
+ return obj
+
+ if get_deserializer(annotation, rf):
+ return functools.partial(_deserialize_default, get_deserializer(annotation, rf))
+
+ return functools.partial(_deserialize_default, annotation)
+
+
+def _deserialize_with_callable(
+ deserializer: typing.Optional[typing.Callable[[typing.Any], typing.Any]],
+ value: typing.Any,
+):
+ try:
+ if value is None or isinstance(value, type(NULL)):
+ return None
+ if deserializer is None:
+ return value
+ if isinstance(deserializer, CaseInsensitiveEnumMeta):
+ try:
+ return deserializer(value)
+ except ValueError:
+ # for unknown value, return raw value
+ return value
+ if isinstance(deserializer, type) and issubclass(deserializer, Model):
+ return deserializer._deserialize(value, [])
+ return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value)
+ except Exception as e:
+ raise DeserializationError() from e
+
+
+def _deserialize(
+ deserializer: typing.Any,
+ value: typing.Any,
+ module: typing.Optional[str] = None,
+ rf: typing.Optional["_RestField"] = None,
+ format: typing.Optional[str] = None,
+) -> typing.Any:
+ if isinstance(value, PipelineResponse):
+ value = value.http_response.json()
+ if rf is None and format:
+ rf = _RestField(format=format)
+ if not isinstance(deserializer, functools.partial):
+ deserializer = _get_deserialize_callable_from_annotation(deserializer, module, rf)
+ return _deserialize_with_callable(deserializer, value)
+
+
+class _RestField:
+ def __init__(
+ self,
+ *,
+ name: typing.Optional[str] = None,
+ type: typing.Optional[typing.Callable] = None, # pylint: disable=redefined-builtin
+ is_discriminator: bool = False,
+ visibility: typing.Optional[typing.List[str]] = None,
+ default: typing.Any = _UNSET,
+ format: typing.Optional[str] = None,
+ is_multipart_file_input: bool = False,
+ is_required: bool = False,
+ ):
+ self._type = type
+ self._rest_name_input = name
+ self._module: typing.Optional[str] = None
+ self._is_discriminator = is_discriminator
+ self._visibility = visibility
+ self._is_model = False
+ self._default = default
+ self._format = format
+ self._is_multipart_file_input = is_multipart_file_input
+ self._is_required = is_required
+
+ @property
+ def _class_type(self) -> typing.Any:
+ return getattr(self._type, "args", [None])[0]
+
+ @property
+ def _rest_name(self) -> str:
+ if self._rest_name_input is None:
+ raise ValueError("Rest name was never set")
+ return self._rest_name_input
+
+ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin
+ # by this point, type and rest_name will have a value bc we default
+ # them in __new__ of the Model class
+ item = obj.get(self._rest_name)
+ if item is None:
+ return item
+ if self._is_model:
+ return item
+ return _deserialize(self._type, _serialize(item, self._format), rf=self)
+
+ def __set__(self, obj: Model, value) -> None:
+ if value is None:
+ # we want to wipe out entries if users set attr to None
+ try:
+ obj.__delitem__(self._rest_name)
+ except KeyError:
+ pass
+ return
+ if self._is_model:
+ if not _is_model(value):
+ value = _deserialize(self._type, value)
+ obj.__setitem__(self._rest_name, value)
+ return
+ obj.__setitem__(self._rest_name, _serialize(value, self._format))
+
+ def _get_deserialize_callable_from_annotation(
+ self, annotation: typing.Any
+ ) -> typing.Optional[typing.Callable[[typing.Any], typing.Any]]:
+ return _get_deserialize_callable_from_annotation(annotation, self._module, self)
+
+
+def rest_field(
+ *,
+ name: typing.Optional[str] = None,
+ type: typing.Optional[typing.Callable] = None, # pylint: disable=redefined-builtin
+ visibility: typing.Optional[typing.List[str]] = None,
+ default: typing.Any = _UNSET,
+ format: typing.Optional[str] = None,
+ is_multipart_file_input: bool = False,
+ is_required: bool = False,
+) -> typing.Any:
+ return _RestField(
+ name=name,
+ type=type,
+ visibility=visibility,
+ default=default,
+ format=format,
+ is_multipart_file_input=is_multipart_file_input,
+ is_required=is_required,
+ )
+
+
+def rest_discriminator(
+ *,
+ name: typing.Optional[str] = None,
+ type: typing.Optional[typing.Callable] = None, # pylint: disable=redefined-builtin
+) -> typing.Any:
+ return _RestField(name=name, type=type, is_discriminator=True)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/_patch.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/_patch.py
new file mode 100644
index 00000000..f7dd3251
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/_patch.py
@@ -0,0 +1,20 @@
+# ------------------------------------
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+# ------------------------------------
+"""Customize generated code here.
+
+Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize
+"""
+from typing import List
+
+__all__: List[str] = [] # Add all objects you want publicly available to users at this package level
+
+
+def patch_sdk():
+ """Do not remove from this file.
+
+ `patch_sdk` is a last resort escape hatch that allows you to do customizations
+ you can't accomplish using the techniques described in
+ https://aka.ms/azsdk/python/dpcodegen/python/customize
+ """
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/_serialization.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/_serialization.py
new file mode 100644
index 00000000..2f781d74
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/_serialization.py
@@ -0,0 +1,1998 @@
+# --------------------------------------------------------------------------
+#
+# Copyright (c) Microsoft Corporation. All rights reserved.
+#
+# The MIT License (MIT)
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the ""Software""), to
+# deal in the Software without restriction, including without limitation the
+# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+# sell copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+# IN THE SOFTWARE.
+#
+# --------------------------------------------------------------------------
+
+# pylint: skip-file
+# pyright: reportUnnecessaryTypeIgnoreComment=false
+
+from base64 import b64decode, b64encode
+import calendar
+import datetime
+import decimal
+import email
+from enum import Enum
+import json
+import logging
+import re
+import sys
+import codecs
+from typing import (
+ Dict,
+ Any,
+ cast,
+ Optional,
+ Union,
+ AnyStr,
+ IO,
+ Mapping,
+ Callable,
+ TypeVar,
+ MutableMapping,
+ Type,
+ List,
+ Mapping,
+)
+
+try:
+ from urllib import quote # type: ignore
+except ImportError:
+ from urllib.parse import quote
+import xml.etree.ElementTree as ET
+
+import isodate # type: ignore
+
+from azure.core.exceptions import DeserializationError, SerializationError
+from azure.core.serialization import NULL as CoreNull
+
+_BOM = codecs.BOM_UTF8.decode(encoding="utf-8")
+
+ModelType = TypeVar("ModelType", bound="Model")
+JSON = MutableMapping[str, Any]
+
+
+class RawDeserializer:
+
+ # Accept "text" because we're open minded people...
+ JSON_REGEXP = re.compile(r"^(application|text)/([a-z+.]+\+)?json$")
+
+ # Name used in context
+ CONTEXT_NAME = "deserialized_data"
+
+ @classmethod
+ def deserialize_from_text(cls, data: Optional[Union[AnyStr, IO]], content_type: Optional[str] = None) -> Any:
+ """Decode data according to content-type.
+
+ Accept a stream of data as well, but will be load at once in memory for now.
+
+ If no content-type, will return the string version (not bytes, not stream)
+
+ :param data: Input, could be bytes or stream (will be decoded with UTF8) or text
+ :type data: str or bytes or IO
+ :param str content_type: The content type.
+ """
+ if hasattr(data, "read"):
+ # Assume a stream
+ data = cast(IO, data).read()
+
+ if isinstance(data, bytes):
+ data_as_str = data.decode(encoding="utf-8-sig")
+ else:
+ # Explain to mypy the correct type.
+ data_as_str = cast(str, data)
+
+ # Remove Byte Order Mark if present in string
+ data_as_str = data_as_str.lstrip(_BOM)
+
+ if content_type is None:
+ return data
+
+ if cls.JSON_REGEXP.match(content_type):
+ try:
+ return json.loads(data_as_str)
+ except ValueError as err:
+ raise DeserializationError("JSON is invalid: {}".format(err), err)
+ elif "xml" in (content_type or []):
+ try:
+
+ try:
+ if isinstance(data, unicode): # type: ignore
+ # If I'm Python 2.7 and unicode XML will scream if I try a "fromstring" on unicode string
+ data_as_str = data_as_str.encode(encoding="utf-8") # type: ignore
+ except NameError:
+ pass
+
+ return ET.fromstring(data_as_str) # nosec
+ except ET.ParseError as err:
+ # It might be because the server has an issue, and returned JSON with
+ # content-type XML....
+ # So let's try a JSON load, and if it's still broken
+ # let's flow the initial exception
+ def _json_attemp(data):
+ try:
+ return True, json.loads(data)
+ except ValueError:
+ return False, None # Don't care about this one
+
+ success, json_result = _json_attemp(data)
+ if success:
+ return json_result
+ # If i'm here, it's not JSON, it's not XML, let's scream
+ # and raise the last context in this block (the XML exception)
+ # The function hack is because Py2.7 messes up with exception
+ # context otherwise.
+ _LOGGER.critical("Wasn't XML not JSON, failing")
+ raise DeserializationError("XML is invalid") from err
+ raise DeserializationError("Cannot deserialize content-type: {}".format(content_type))
+
+ @classmethod
+ def deserialize_from_http_generics(cls, body_bytes: Optional[Union[AnyStr, IO]], headers: Mapping) -> Any:
+ """Deserialize from HTTP response.
+
+ Use bytes and headers to NOT use any requests/aiohttp or whatever
+ specific implementation.
+ Headers will tested for "content-type"
+ """
+ # Try to use content-type from headers if available
+ content_type = None
+ if "content-type" in headers:
+ content_type = headers["content-type"].split(";")[0].strip().lower()
+ # Ouch, this server did not declare what it sent...
+ # Let's guess it's JSON...
+ # Also, since Autorest was considering that an empty body was a valid JSON,
+ # need that test as well....
+ else:
+ content_type = "application/json"
+
+ if body_bytes:
+ return cls.deserialize_from_text(body_bytes, content_type)
+ return None
+
+
+_LOGGER = logging.getLogger(__name__)
+
+try:
+ _long_type = long # type: ignore
+except NameError:
+ _long_type = int
+
+
+class UTC(datetime.tzinfo):
+ """Time Zone info for handling UTC"""
+
+ def utcoffset(self, dt):
+ """UTF offset for UTC is 0."""
+ return datetime.timedelta(0)
+
+ def tzname(self, dt):
+ """Timestamp representation."""
+ return "Z"
+
+ def dst(self, dt):
+ """No daylight saving for UTC."""
+ return datetime.timedelta(hours=1)
+
+
+try:
+ from datetime import timezone as _FixedOffset # type: ignore
+except ImportError: # Python 2.7
+
+ class _FixedOffset(datetime.tzinfo): # type: ignore
+ """Fixed offset in minutes east from UTC.
+ Copy/pasted from Python doc
+ :param datetime.timedelta offset: offset in timedelta format
+ """
+
+ def __init__(self, offset):
+ self.__offset = offset
+
+ def utcoffset(self, dt):
+ return self.__offset
+
+ def tzname(self, dt):
+ return str(self.__offset.total_seconds() / 3600)
+
+ def __repr__(self):
+ return "<FixedOffset {}>".format(self.tzname(None))
+
+ def dst(self, dt):
+ return datetime.timedelta(0)
+
+ def __getinitargs__(self):
+ return (self.__offset,)
+
+
+try:
+ from datetime import timezone
+
+ TZ_UTC = timezone.utc
+except ImportError:
+ TZ_UTC = UTC() # type: ignore
+
+_FLATTEN = re.compile(r"(?<!\\)\.")
+
+
+def attribute_transformer(key, attr_desc, value):
+ """A key transformer that returns the Python attribute.
+
+ :param str key: The attribute name
+ :param dict attr_desc: The attribute metadata
+ :param object value: The value
+ :returns: A key using attribute name
+ """
+ return (key, value)
+
+
+def full_restapi_key_transformer(key, attr_desc, value):
+ """A key transformer that returns the full RestAPI key path.
+
+ :param str _: The attribute name
+ :param dict attr_desc: The attribute metadata
+ :param object value: The value
+ :returns: A list of keys using RestAPI syntax.
+ """
+ keys = _FLATTEN.split(attr_desc["key"])
+ return ([_decode_attribute_map_key(k) for k in keys], value)
+
+
+def last_restapi_key_transformer(key, attr_desc, value):
+ """A key transformer that returns the last RestAPI key.
+
+ :param str key: The attribute name
+ :param dict attr_desc: The attribute metadata
+ :param object value: The value
+ :returns: The last RestAPI key.
+ """
+ key, value = full_restapi_key_transformer(key, attr_desc, value)
+ return (key[-1], value)
+
+
+def _create_xml_node(tag, prefix=None, ns=None):
+ """Create a XML node."""
+ if prefix and ns:
+ ET.register_namespace(prefix, ns)
+ if ns:
+ return ET.Element("{" + ns + "}" + tag)
+ else:
+ return ET.Element(tag)
+
+
+class Model(object):
+ """Mixin for all client request body/response body models to support
+ serialization and deserialization.
+ """
+
+ _subtype_map: Dict[str, Dict[str, Any]] = {}
+ _attribute_map: Dict[str, Dict[str, Any]] = {}
+ _validation: Dict[str, Dict[str, Any]] = {}
+
+ def __init__(self, **kwargs: Any) -> None:
+ self.additional_properties: Optional[Dict[str, Any]] = {}
+ for k in kwargs:
+ if k not in self._attribute_map:
+ _LOGGER.warning("%s is not a known attribute of class %s and will be ignored", k, self.__class__)
+ elif k in self._validation and self._validation[k].get("readonly", False):
+ _LOGGER.warning("Readonly attribute %s will be ignored in class %s", k, self.__class__)
+ else:
+ setattr(self, k, kwargs[k])
+
+ def __eq__(self, other: Any) -> bool:
+ """Compare objects by comparing all attributes."""
+ if isinstance(other, self.__class__):
+ return self.__dict__ == other.__dict__
+ return False
+
+ def __ne__(self, other: Any) -> bool:
+ """Compare objects by comparing all attributes."""
+ return not self.__eq__(other)
+
+ def __str__(self) -> str:
+ return str(self.__dict__)
+
+ @classmethod
+ def enable_additional_properties_sending(cls) -> None:
+ cls._attribute_map["additional_properties"] = {"key": "", "type": "{object}"}
+
+ @classmethod
+ def is_xml_model(cls) -> bool:
+ try:
+ cls._xml_map # type: ignore
+ except AttributeError:
+ return False
+ return True
+
+ @classmethod
+ def _create_xml_node(cls):
+ """Create XML node."""
+ try:
+ xml_map = cls._xml_map # type: ignore
+ except AttributeError:
+ xml_map = {}
+
+ return _create_xml_node(xml_map.get("name", cls.__name__), xml_map.get("prefix", None), xml_map.get("ns", None))
+
+ def serialize(self, keep_readonly: bool = False, **kwargs: Any) -> JSON:
+ """Return the JSON that would be sent to server from this model.
+
+ This is an alias to `as_dict(full_restapi_key_transformer, keep_readonly=False)`.
+
+ If you want XML serialization, you can pass the kwargs is_xml=True.
+
+ :param bool keep_readonly: If you want to serialize the readonly attributes
+ :returns: A dict JSON compatible object
+ :rtype: dict
+ """
+ serializer = Serializer(self._infer_class_models())
+ return serializer._serialize(self, keep_readonly=keep_readonly, **kwargs) # type: ignore
+
+ def as_dict(
+ self,
+ keep_readonly: bool = True,
+ key_transformer: Callable[[str, Dict[str, Any], Any], Any] = attribute_transformer,
+ **kwargs: Any
+ ) -> JSON:
+ """Return a dict that can be serialized using json.dump.
+
+ Advanced usage might optionally use a callback as parameter:
+
+ .. code::python
+
+ def my_key_transformer(key, attr_desc, value):
+ return key
+
+ Key is the attribute name used in Python. Attr_desc
+ is a dict of metadata. Currently contains 'type' with the
+ msrest type and 'key' with the RestAPI encoded key.
+ Value is the current value in this object.
+
+ The string returned will be used to serialize the key.
+ If the return type is a list, this is considered hierarchical
+ result dict.
+
+ See the three examples in this file:
+
+ - attribute_transformer
+ - full_restapi_key_transformer
+ - last_restapi_key_transformer
+
+ If you want XML serialization, you can pass the kwargs is_xml=True.
+
+ :param function key_transformer: A key transformer function.
+ :returns: A dict JSON compatible object
+ :rtype: dict
+ """
+ serializer = Serializer(self._infer_class_models())
+ return serializer._serialize(self, key_transformer=key_transformer, keep_readonly=keep_readonly, **kwargs) # type: ignore
+
+ @classmethod
+ def _infer_class_models(cls):
+ try:
+ str_models = cls.__module__.rsplit(".", 1)[0]
+ models = sys.modules[str_models]
+ client_models = {k: v for k, v in models.__dict__.items() if isinstance(v, type)}
+ if cls.__name__ not in client_models:
+ raise ValueError("Not Autorest generated code")
+ except Exception:
+ # Assume it's not Autorest generated (tests?). Add ourselves as dependencies.
+ client_models = {cls.__name__: cls}
+ return client_models
+
+ @classmethod
+ def deserialize(cls: Type[ModelType], data: Any, content_type: Optional[str] = None) -> ModelType:
+ """Parse a str using the RestAPI syntax and return a model.
+
+ :param str data: A str using RestAPI structure. JSON by default.
+ :param str content_type: JSON by default, set application/xml if XML.
+ :returns: An instance of this model
+ :raises: DeserializationError if something went wrong
+ """
+ deserializer = Deserializer(cls._infer_class_models())
+ return deserializer(cls.__name__, data, content_type=content_type) # type: ignore
+
+ @classmethod
+ def from_dict(
+ cls: Type[ModelType],
+ data: Any,
+ key_extractors: Optional[Callable[[str, Dict[str, Any], Any], Any]] = None,
+ content_type: Optional[str] = None,
+ ) -> ModelType:
+ """Parse a dict using given key extractor return a model.
+
+ By default consider key
+ extractors (rest_key_case_insensitive_extractor, attribute_key_case_insensitive_extractor
+ and last_rest_key_case_insensitive_extractor)
+
+ :param dict data: A dict using RestAPI structure
+ :param str content_type: JSON by default, set application/xml if XML.
+ :returns: An instance of this model
+ :raises: DeserializationError if something went wrong
+ """
+ deserializer = Deserializer(cls._infer_class_models())
+ deserializer.key_extractors = ( # type: ignore
+ [ # type: ignore
+ attribute_key_case_insensitive_extractor,
+ rest_key_case_insensitive_extractor,
+ last_rest_key_case_insensitive_extractor,
+ ]
+ if key_extractors is None
+ else key_extractors
+ )
+ return deserializer(cls.__name__, data, content_type=content_type) # type: ignore
+
+ @classmethod
+ def _flatten_subtype(cls, key, objects):
+ if "_subtype_map" not in cls.__dict__:
+ return {}
+ result = dict(cls._subtype_map[key])
+ for valuetype in cls._subtype_map[key].values():
+ result.update(objects[valuetype]._flatten_subtype(key, objects))
+ return result
+
+ @classmethod
+ def _classify(cls, response, objects):
+ """Check the class _subtype_map for any child classes.
+ We want to ignore any inherited _subtype_maps.
+ Remove the polymorphic key from the initial data.
+ """
+ for subtype_key in cls.__dict__.get("_subtype_map", {}).keys():
+ subtype_value = None
+
+ if not isinstance(response, ET.Element):
+ rest_api_response_key = cls._get_rest_key_parts(subtype_key)[-1]
+ subtype_value = response.pop(rest_api_response_key, None) or response.pop(subtype_key, None)
+ else:
+ subtype_value = xml_key_extractor(subtype_key, cls._attribute_map[subtype_key], response)
+ if subtype_value:
+ # Try to match base class. Can be class name only
+ # (bug to fix in Autorest to support x-ms-discriminator-name)
+ if cls.__name__ == subtype_value:
+ return cls
+ flatten_mapping_type = cls._flatten_subtype(subtype_key, objects)
+ try:
+ return objects[flatten_mapping_type[subtype_value]] # type: ignore
+ except KeyError:
+ _LOGGER.warning(
+ "Subtype value %s has no mapping, use base class %s.",
+ subtype_value,
+ cls.__name__,
+ )
+ break
+ else:
+ _LOGGER.warning("Discriminator %s is absent or null, use base class %s.", subtype_key, cls.__name__)
+ break
+ return cls
+
+ @classmethod
+ def _get_rest_key_parts(cls, attr_key):
+ """Get the RestAPI key of this attr, split it and decode part
+ :param str attr_key: Attribute key must be in attribute_map.
+ :returns: A list of RestAPI part
+ :rtype: list
+ """
+ rest_split_key = _FLATTEN.split(cls._attribute_map[attr_key]["key"])
+ return [_decode_attribute_map_key(key_part) for key_part in rest_split_key]
+
+
+def _decode_attribute_map_key(key):
+ """This decode a key in an _attribute_map to the actual key we want to look at
+ inside the received data.
+
+ :param str key: A key string from the generated code
+ """
+ return key.replace("\\.", ".")
+
+
+class Serializer(object):
+ """Request object model serializer."""
+
+ basic_types = {str: "str", int: "int", bool: "bool", float: "float"}
+
+ _xml_basic_types_serializers = {"bool": lambda x: str(x).lower()}
+ days = {0: "Mon", 1: "Tue", 2: "Wed", 3: "Thu", 4: "Fri", 5: "Sat", 6: "Sun"}
+ months = {
+ 1: "Jan",
+ 2: "Feb",
+ 3: "Mar",
+ 4: "Apr",
+ 5: "May",
+ 6: "Jun",
+ 7: "Jul",
+ 8: "Aug",
+ 9: "Sep",
+ 10: "Oct",
+ 11: "Nov",
+ 12: "Dec",
+ }
+ validation = {
+ "min_length": lambda x, y: len(x) < y,
+ "max_length": lambda x, y: len(x) > y,
+ "minimum": lambda x, y: x < y,
+ "maximum": lambda x, y: x > y,
+ "minimum_ex": lambda x, y: x <= y,
+ "maximum_ex": lambda x, y: x >= y,
+ "min_items": lambda x, y: len(x) < y,
+ "max_items": lambda x, y: len(x) > y,
+ "pattern": lambda x, y: not re.match(y, x, re.UNICODE),
+ "unique": lambda x, y: len(x) != len(set(x)),
+ "multiple": lambda x, y: x % y != 0,
+ }
+
+ def __init__(self, classes: Optional[Mapping[str, type]] = None):
+ self.serialize_type = {
+ "iso-8601": Serializer.serialize_iso,
+ "rfc-1123": Serializer.serialize_rfc,
+ "unix-time": Serializer.serialize_unix,
+ "duration": Serializer.serialize_duration,
+ "date": Serializer.serialize_date,
+ "time": Serializer.serialize_time,
+ "decimal": Serializer.serialize_decimal,
+ "long": Serializer.serialize_long,
+ "bytearray": Serializer.serialize_bytearray,
+ "base64": Serializer.serialize_base64,
+ "object": self.serialize_object,
+ "[]": self.serialize_iter,
+ "{}": self.serialize_dict,
+ }
+ self.dependencies: Dict[str, type] = dict(classes) if classes else {}
+ self.key_transformer = full_restapi_key_transformer
+ self.client_side_validation = True
+
+ def _serialize(self, target_obj, data_type=None, **kwargs):
+ """Serialize data into a string according to type.
+
+ :param target_obj: The data to be serialized.
+ :param str data_type: The type to be serialized from.
+ :rtype: str, dict
+ :raises: SerializationError if serialization fails.
+ """
+ key_transformer = kwargs.get("key_transformer", self.key_transformer)
+ keep_readonly = kwargs.get("keep_readonly", False)
+ if target_obj is None:
+ return None
+
+ attr_name = None
+ class_name = target_obj.__class__.__name__
+
+ if data_type:
+ return self.serialize_data(target_obj, data_type, **kwargs)
+
+ if not hasattr(target_obj, "_attribute_map"):
+ data_type = type(target_obj).__name__
+ if data_type in self.basic_types.values():
+ return self.serialize_data(target_obj, data_type, **kwargs)
+
+ # Force "is_xml" kwargs if we detect a XML model
+ try:
+ is_xml_model_serialization = kwargs["is_xml"]
+ except KeyError:
+ is_xml_model_serialization = kwargs.setdefault("is_xml", target_obj.is_xml_model())
+
+ serialized = {}
+ if is_xml_model_serialization:
+ serialized = target_obj._create_xml_node()
+ try:
+ attributes = target_obj._attribute_map
+ for attr, attr_desc in attributes.items():
+ attr_name = attr
+ if not keep_readonly and target_obj._validation.get(attr_name, {}).get("readonly", False):
+ continue
+
+ if attr_name == "additional_properties" and attr_desc["key"] == "":
+ if target_obj.additional_properties is not None:
+ serialized.update(target_obj.additional_properties)
+ continue
+ try:
+
+ orig_attr = getattr(target_obj, attr)
+ if is_xml_model_serialization:
+ pass # Don't provide "transformer" for XML for now. Keep "orig_attr"
+ else: # JSON
+ keys, orig_attr = key_transformer(attr, attr_desc.copy(), orig_attr)
+ keys = keys if isinstance(keys, list) else [keys]
+
+ kwargs["serialization_ctxt"] = attr_desc
+ new_attr = self.serialize_data(orig_attr, attr_desc["type"], **kwargs)
+
+ if is_xml_model_serialization:
+ xml_desc = attr_desc.get("xml", {})
+ xml_name = xml_desc.get("name", attr_desc["key"])
+ xml_prefix = xml_desc.get("prefix", None)
+ xml_ns = xml_desc.get("ns", None)
+ if xml_desc.get("attr", False):
+ if xml_ns:
+ ET.register_namespace(xml_prefix, xml_ns)
+ xml_name = "{{{}}}{}".format(xml_ns, xml_name)
+ serialized.set(xml_name, new_attr) # type: ignore
+ continue
+ if xml_desc.get("text", False):
+ serialized.text = new_attr # type: ignore
+ continue
+ if isinstance(new_attr, list):
+ serialized.extend(new_attr) # type: ignore
+ elif isinstance(new_attr, ET.Element):
+ # If the down XML has no XML/Name, we MUST replace the tag with the local tag. But keeping the namespaces.
+ if "name" not in getattr(orig_attr, "_xml_map", {}):
+ splitted_tag = new_attr.tag.split("}")
+ if len(splitted_tag) == 2: # Namespace
+ new_attr.tag = "}".join([splitted_tag[0], xml_name])
+ else:
+ new_attr.tag = xml_name
+ serialized.append(new_attr) # type: ignore
+ else: # That's a basic type
+ # Integrate namespace if necessary
+ local_node = _create_xml_node(xml_name, xml_prefix, xml_ns)
+ local_node.text = str(new_attr)
+ serialized.append(local_node) # type: ignore
+ else: # JSON
+ for k in reversed(keys): # type: ignore
+ new_attr = {k: new_attr}
+
+ _new_attr = new_attr
+ _serialized = serialized
+ for k in keys: # type: ignore
+ if k not in _serialized:
+ _serialized.update(_new_attr) # type: ignore
+ _new_attr = _new_attr[k] # type: ignore
+ _serialized = _serialized[k]
+ except ValueError as err:
+ if isinstance(err, SerializationError):
+ raise
+
+ except (AttributeError, KeyError, TypeError) as err:
+ msg = "Attribute {} in object {} cannot be serialized.\n{}".format(attr_name, class_name, str(target_obj))
+ raise SerializationError(msg) from err
+ else:
+ return serialized
+
+ def body(self, data, data_type, **kwargs):
+ """Serialize data intended for a request body.
+
+ :param data: The data to be serialized.
+ :param str data_type: The type to be serialized from.
+ :rtype: dict
+ :raises: SerializationError if serialization fails.
+ :raises: ValueError if data is None
+ """
+
+ # Just in case this is a dict
+ internal_data_type_str = data_type.strip("[]{}")
+ internal_data_type = self.dependencies.get(internal_data_type_str, None)
+ try:
+ is_xml_model_serialization = kwargs["is_xml"]
+ except KeyError:
+ if internal_data_type and issubclass(internal_data_type, Model):
+ is_xml_model_serialization = kwargs.setdefault("is_xml", internal_data_type.is_xml_model())
+ else:
+ is_xml_model_serialization = False
+ if internal_data_type and not isinstance(internal_data_type, Enum):
+ try:
+ deserializer = Deserializer(self.dependencies)
+ # Since it's on serialization, it's almost sure that format is not JSON REST
+ # We're not able to deal with additional properties for now.
+ deserializer.additional_properties_detection = False
+ if is_xml_model_serialization:
+ deserializer.key_extractors = [ # type: ignore
+ attribute_key_case_insensitive_extractor,
+ ]
+ else:
+ deserializer.key_extractors = [
+ rest_key_case_insensitive_extractor,
+ attribute_key_case_insensitive_extractor,
+ last_rest_key_case_insensitive_extractor,
+ ]
+ data = deserializer._deserialize(data_type, data)
+ except DeserializationError as err:
+ raise SerializationError("Unable to build a model: " + str(err)) from err
+
+ return self._serialize(data, data_type, **kwargs)
+
+ def url(self, name, data, data_type, **kwargs):
+ """Serialize data intended for a URL path.
+
+ :param data: The data to be serialized.
+ :param str data_type: The type to be serialized from.
+ :rtype: str
+ :raises: TypeError if serialization fails.
+ :raises: ValueError if data is None
+ """
+ try:
+ output = self.serialize_data(data, data_type, **kwargs)
+ if data_type == "bool":
+ output = json.dumps(output)
+
+ if kwargs.get("skip_quote") is True:
+ output = str(output)
+ output = output.replace("{", quote("{")).replace("}", quote("}"))
+ else:
+ output = quote(str(output), safe="")
+ except SerializationError:
+ raise TypeError("{} must be type {}.".format(name, data_type))
+ else:
+ return output
+
+ def query(self, name, data, data_type, **kwargs):
+ """Serialize data intended for a URL query.
+
+ :param data: The data to be serialized.
+ :param str data_type: The type to be serialized from.
+ :keyword bool skip_quote: Whether to skip quote the serialized result.
+ Defaults to False.
+ :rtype: str, list
+ :raises: TypeError if serialization fails.
+ :raises: ValueError if data is None
+ """
+ try:
+ # Treat the list aside, since we don't want to encode the div separator
+ if data_type.startswith("["):
+ internal_data_type = data_type[1:-1]
+ do_quote = not kwargs.get("skip_quote", False)
+ return self.serialize_iter(data, internal_data_type, do_quote=do_quote, **kwargs)
+
+ # Not a list, regular serialization
+ output = self.serialize_data(data, data_type, **kwargs)
+ if data_type == "bool":
+ output = json.dumps(output)
+ if kwargs.get("skip_quote") is True:
+ output = str(output)
+ else:
+ output = quote(str(output), safe="")
+ except SerializationError:
+ raise TypeError("{} must be type {}.".format(name, data_type))
+ else:
+ return str(output)
+
+ def header(self, name, data, data_type, **kwargs):
+ """Serialize data intended for a request header.
+
+ :param data: The data to be serialized.
+ :param str data_type: The type to be serialized from.
+ :rtype: str
+ :raises: TypeError if serialization fails.
+ :raises: ValueError if data is None
+ """
+ try:
+ if data_type in ["[str]"]:
+ data = ["" if d is None else d for d in data]
+
+ output = self.serialize_data(data, data_type, **kwargs)
+ if data_type == "bool":
+ output = json.dumps(output)
+ except SerializationError:
+ raise TypeError("{} must be type {}.".format(name, data_type))
+ else:
+ return str(output)
+
+ def serialize_data(self, data, data_type, **kwargs):
+ """Serialize generic data according to supplied data type.
+
+ :param data: The data to be serialized.
+ :param str data_type: The type to be serialized from.
+ :param bool required: Whether it's essential that the data not be
+ empty or None
+ :raises: AttributeError if required data is None.
+ :raises: ValueError if data is None
+ :raises: SerializationError if serialization fails.
+ """
+ if data is None:
+ raise ValueError("No value for given attribute")
+
+ try:
+ if data is CoreNull:
+ return None
+ if data_type in self.basic_types.values():
+ return self.serialize_basic(data, data_type, **kwargs)
+
+ elif data_type in self.serialize_type:
+ return self.serialize_type[data_type](data, **kwargs)
+
+ # If dependencies is empty, try with current data class
+ # It has to be a subclass of Enum anyway
+ enum_type = self.dependencies.get(data_type, data.__class__)
+ if issubclass(enum_type, Enum):
+ return Serializer.serialize_enum(data, enum_obj=enum_type)
+
+ iter_type = data_type[0] + data_type[-1]
+ if iter_type in self.serialize_type:
+ return self.serialize_type[iter_type](data, data_type[1:-1], **kwargs)
+
+ except (ValueError, TypeError) as err:
+ msg = "Unable to serialize value: {!r} as type: {!r}."
+ raise SerializationError(msg.format(data, data_type)) from err
+ else:
+ return self._serialize(data, **kwargs)
+
+ @classmethod
+ def _get_custom_serializers(cls, data_type, **kwargs):
+ custom_serializer = kwargs.get("basic_types_serializers", {}).get(data_type)
+ if custom_serializer:
+ return custom_serializer
+ if kwargs.get("is_xml", False):
+ return cls._xml_basic_types_serializers.get(data_type)
+
+ @classmethod
+ def serialize_basic(cls, data, data_type, **kwargs):
+ """Serialize basic builting data type.
+ Serializes objects to str, int, float or bool.
+
+ Possible kwargs:
+ - basic_types_serializers dict[str, callable] : If set, use the callable as serializer
+ - is_xml bool : If set, use xml_basic_types_serializers
+
+ :param data: Object to be serialized.
+ :param str data_type: Type of object in the iterable.
+ """
+ custom_serializer = cls._get_custom_serializers(data_type, **kwargs)
+ if custom_serializer:
+ return custom_serializer(data)
+ if data_type == "str":
+ return cls.serialize_unicode(data)
+ return eval(data_type)(data) # nosec
+
+ @classmethod
+ def serialize_unicode(cls, data):
+ """Special handling for serializing unicode strings in Py2.
+ Encode to UTF-8 if unicode, otherwise handle as a str.
+
+ :param data: Object to be serialized.
+ :rtype: str
+ """
+ try: # If I received an enum, return its value
+ return data.value
+ except AttributeError:
+ pass
+
+ try:
+ if isinstance(data, unicode): # type: ignore
+ # Don't change it, JSON and XML ElementTree are totally able
+ # to serialize correctly u'' strings
+ return data
+ except NameError:
+ return str(data)
+ else:
+ return str(data)
+
+ def serialize_iter(self, data, iter_type, div=None, **kwargs):
+ """Serialize iterable.
+
+ Supported kwargs:
+ - serialization_ctxt dict : The current entry of _attribute_map, or same format.
+ serialization_ctxt['type'] should be same as data_type.
+ - is_xml bool : If set, serialize as XML
+
+ :param list attr: Object to be serialized.
+ :param str iter_type: Type of object in the iterable.
+ :param bool required: Whether the objects in the iterable must
+ not be None or empty.
+ :param str div: If set, this str will be used to combine the elements
+ in the iterable into a combined string. Default is 'None'.
+ :keyword bool do_quote: Whether to quote the serialized result of each iterable element.
+ Defaults to False.
+ :rtype: list, str
+ """
+ if isinstance(data, str):
+ raise SerializationError("Refuse str type as a valid iter type.")
+
+ serialization_ctxt = kwargs.get("serialization_ctxt", {})
+ is_xml = kwargs.get("is_xml", False)
+
+ serialized = []
+ for d in data:
+ try:
+ serialized.append(self.serialize_data(d, iter_type, **kwargs))
+ except ValueError as err:
+ if isinstance(err, SerializationError):
+ raise
+ serialized.append(None)
+
+ if kwargs.get("do_quote", False):
+ serialized = ["" if s is None else quote(str(s), safe="") for s in serialized]
+
+ if div:
+ serialized = ["" if s is None else str(s) for s in serialized]
+ serialized = div.join(serialized)
+
+ if "xml" in serialization_ctxt or is_xml:
+ # XML serialization is more complicated
+ xml_desc = serialization_ctxt.get("xml", {})
+ xml_name = xml_desc.get("name")
+ if not xml_name:
+ xml_name = serialization_ctxt["key"]
+
+ # Create a wrap node if necessary (use the fact that Element and list have "append")
+ is_wrapped = xml_desc.get("wrapped", False)
+ node_name = xml_desc.get("itemsName", xml_name)
+ if is_wrapped:
+ final_result = _create_xml_node(xml_name, xml_desc.get("prefix", None), xml_desc.get("ns", None))
+ else:
+ final_result = []
+ # All list elements to "local_node"
+ for el in serialized:
+ if isinstance(el, ET.Element):
+ el_node = el
+ else:
+ el_node = _create_xml_node(node_name, xml_desc.get("prefix", None), xml_desc.get("ns", None))
+ if el is not None: # Otherwise it writes "None" :-p
+ el_node.text = str(el)
+ final_result.append(el_node)
+ return final_result
+ return serialized
+
+ def serialize_dict(self, attr, dict_type, **kwargs):
+ """Serialize a dictionary of objects.
+
+ :param dict attr: Object to be serialized.
+ :param str dict_type: Type of object in the dictionary.
+ :param bool required: Whether the objects in the dictionary must
+ not be None or empty.
+ :rtype: dict
+ """
+ serialization_ctxt = kwargs.get("serialization_ctxt", {})
+ serialized = {}
+ for key, value in attr.items():
+ try:
+ serialized[self.serialize_unicode(key)] = self.serialize_data(value, dict_type, **kwargs)
+ except ValueError as err:
+ if isinstance(err, SerializationError):
+ raise
+ serialized[self.serialize_unicode(key)] = None
+
+ if "xml" in serialization_ctxt:
+ # XML serialization is more complicated
+ xml_desc = serialization_ctxt["xml"]
+ xml_name = xml_desc["name"]
+
+ final_result = _create_xml_node(xml_name, xml_desc.get("prefix", None), xml_desc.get("ns", None))
+ for key, value in serialized.items():
+ ET.SubElement(final_result, key).text = value
+ return final_result
+
+ return serialized
+
+ def serialize_object(self, attr, **kwargs):
+ """Serialize a generic object.
+ This will be handled as a dictionary. If object passed in is not
+ a basic type (str, int, float, dict, list) it will simply be
+ cast to str.
+
+ :param dict attr: Object to be serialized.
+ :rtype: dict or str
+ """
+ if attr is None:
+ return None
+ if isinstance(attr, ET.Element):
+ return attr
+ obj_type = type(attr)
+ if obj_type in self.basic_types:
+ return self.serialize_basic(attr, self.basic_types[obj_type], **kwargs)
+ if obj_type is _long_type:
+ return self.serialize_long(attr)
+ if obj_type is str:
+ return self.serialize_unicode(attr)
+ if obj_type is datetime.datetime:
+ return self.serialize_iso(attr)
+ if obj_type is datetime.date:
+ return self.serialize_date(attr)
+ if obj_type is datetime.time:
+ return self.serialize_time(attr)
+ if obj_type is datetime.timedelta:
+ return self.serialize_duration(attr)
+ if obj_type is decimal.Decimal:
+ return self.serialize_decimal(attr)
+
+ # If it's a model or I know this dependency, serialize as a Model
+ elif obj_type in self.dependencies.values() or isinstance(attr, Model):
+ return self._serialize(attr)
+
+ if obj_type == dict:
+ serialized = {}
+ for key, value in attr.items():
+ try:
+ serialized[self.serialize_unicode(key)] = self.serialize_object(value, **kwargs)
+ except ValueError:
+ serialized[self.serialize_unicode(key)] = None
+ return serialized
+
+ if obj_type == list:
+ serialized = []
+ for obj in attr:
+ try:
+ serialized.append(self.serialize_object(obj, **kwargs))
+ except ValueError:
+ pass
+ return serialized
+ return str(attr)
+
+ @staticmethod
+ def serialize_enum(attr, enum_obj=None):
+ try:
+ result = attr.value
+ except AttributeError:
+ result = attr
+ try:
+ enum_obj(result) # type: ignore
+ return result
+ except ValueError:
+ for enum_value in enum_obj: # type: ignore
+ if enum_value.value.lower() == str(attr).lower():
+ return enum_value.value
+ error = "{!r} is not valid value for enum {!r}"
+ raise SerializationError(error.format(attr, enum_obj))
+
+ @staticmethod
+ def serialize_bytearray(attr, **kwargs):
+ """Serialize bytearray into base-64 string.
+
+ :param attr: Object to be serialized.
+ :rtype: str
+ """
+ return b64encode(attr).decode()
+
+ @staticmethod
+ def serialize_base64(attr, **kwargs):
+ """Serialize str into base-64 string.
+
+ :param attr: Object to be serialized.
+ :rtype: str
+ """
+ encoded = b64encode(attr).decode("ascii")
+ return encoded.strip("=").replace("+", "-").replace("/", "_")
+
+ @staticmethod
+ def serialize_decimal(attr, **kwargs):
+ """Serialize Decimal object to float.
+
+ :param attr: Object to be serialized.
+ :rtype: float
+ """
+ return float(attr)
+
+ @staticmethod
+ def serialize_long(attr, **kwargs):
+ """Serialize long (Py2) or int (Py3).
+
+ :param attr: Object to be serialized.
+ :rtype: int/long
+ """
+ return _long_type(attr)
+
+ @staticmethod
+ def serialize_date(attr, **kwargs):
+ """Serialize Date object into ISO-8601 formatted string.
+
+ :param Date attr: Object to be serialized.
+ :rtype: str
+ """
+ if isinstance(attr, str):
+ attr = isodate.parse_date(attr)
+ t = "{:04}-{:02}-{:02}".format(attr.year, attr.month, attr.day)
+ return t
+
+ @staticmethod
+ def serialize_time(attr, **kwargs):
+ """Serialize Time object into ISO-8601 formatted string.
+
+ :param datetime.time attr: Object to be serialized.
+ :rtype: str
+ """
+ if isinstance(attr, str):
+ attr = isodate.parse_time(attr)
+ t = "{:02}:{:02}:{:02}".format(attr.hour, attr.minute, attr.second)
+ if attr.microsecond:
+ t += ".{:02}".format(attr.microsecond)
+ return t
+
+ @staticmethod
+ def serialize_duration(attr, **kwargs):
+ """Serialize TimeDelta object into ISO-8601 formatted string.
+
+ :param TimeDelta attr: Object to be serialized.
+ :rtype: str
+ """
+ if isinstance(attr, str):
+ attr = isodate.parse_duration(attr)
+ return isodate.duration_isoformat(attr)
+
+ @staticmethod
+ def serialize_rfc(attr, **kwargs):
+ """Serialize Datetime object into RFC-1123 formatted string.
+
+ :param Datetime attr: Object to be serialized.
+ :rtype: str
+ :raises: TypeError if format invalid.
+ """
+ try:
+ if not attr.tzinfo:
+ _LOGGER.warning("Datetime with no tzinfo will be considered UTC.")
+ utc = attr.utctimetuple()
+ except AttributeError:
+ raise TypeError("RFC1123 object must be valid Datetime object.")
+
+ return "{}, {:02} {} {:04} {:02}:{:02}:{:02} GMT".format(
+ Serializer.days[utc.tm_wday],
+ utc.tm_mday,
+ Serializer.months[utc.tm_mon],
+ utc.tm_year,
+ utc.tm_hour,
+ utc.tm_min,
+ utc.tm_sec,
+ )
+
+ @staticmethod
+ def serialize_iso(attr, **kwargs):
+ """Serialize Datetime object into ISO-8601 formatted string.
+
+ :param Datetime attr: Object to be serialized.
+ :rtype: str
+ :raises: SerializationError if format invalid.
+ """
+ if isinstance(attr, str):
+ attr = isodate.parse_datetime(attr)
+ try:
+ if not attr.tzinfo:
+ _LOGGER.warning("Datetime with no tzinfo will be considered UTC.")
+ utc = attr.utctimetuple()
+ if utc.tm_year > 9999 or utc.tm_year < 1:
+ raise OverflowError("Hit max or min date")
+
+ microseconds = str(attr.microsecond).rjust(6, "0").rstrip("0").ljust(3, "0")
+ if microseconds:
+ microseconds = "." + microseconds
+ date = "{:04}-{:02}-{:02}T{:02}:{:02}:{:02}".format(
+ utc.tm_year, utc.tm_mon, utc.tm_mday, utc.tm_hour, utc.tm_min, utc.tm_sec
+ )
+ return date + microseconds + "Z"
+ except (ValueError, OverflowError) as err:
+ msg = "Unable to serialize datetime object."
+ raise SerializationError(msg) from err
+ except AttributeError as err:
+ msg = "ISO-8601 object must be valid Datetime object."
+ raise TypeError(msg) from err
+
+ @staticmethod
+ def serialize_unix(attr, **kwargs):
+ """Serialize Datetime object into IntTime format.
+ This is represented as seconds.
+
+ :param Datetime attr: Object to be serialized.
+ :rtype: int
+ :raises: SerializationError if format invalid
+ """
+ if isinstance(attr, int):
+ return attr
+ try:
+ if not attr.tzinfo:
+ _LOGGER.warning("Datetime with no tzinfo will be considered UTC.")
+ return int(calendar.timegm(attr.utctimetuple()))
+ except AttributeError:
+ raise TypeError("Unix time object must be valid Datetime object.")
+
+
+def rest_key_extractor(attr, attr_desc, data):
+ key = attr_desc["key"]
+ working_data = data
+
+ while "." in key:
+ # Need the cast, as for some reasons "split" is typed as list[str | Any]
+ dict_keys = cast(List[str], _FLATTEN.split(key))
+ if len(dict_keys) == 1:
+ key = _decode_attribute_map_key(dict_keys[0])
+ break
+ working_key = _decode_attribute_map_key(dict_keys[0])
+ working_data = working_data.get(working_key, data)
+ if working_data is None:
+ # If at any point while following flatten JSON path see None, it means
+ # that all properties under are None as well
+ return None
+ key = ".".join(dict_keys[1:])
+
+ return working_data.get(key)
+
+
+def rest_key_case_insensitive_extractor(attr, attr_desc, data):
+ key = attr_desc["key"]
+ working_data = data
+
+ while "." in key:
+ dict_keys = _FLATTEN.split(key)
+ if len(dict_keys) == 1:
+ key = _decode_attribute_map_key(dict_keys[0])
+ break
+ working_key = _decode_attribute_map_key(dict_keys[0])
+ working_data = attribute_key_case_insensitive_extractor(working_key, None, working_data)
+ if working_data is None:
+ # If at any point while following flatten JSON path see None, it means
+ # that all properties under are None as well
+ return None
+ key = ".".join(dict_keys[1:])
+
+ if working_data:
+ return attribute_key_case_insensitive_extractor(key, None, working_data)
+
+
+def last_rest_key_extractor(attr, attr_desc, data):
+ """Extract the attribute in "data" based on the last part of the JSON path key."""
+ key = attr_desc["key"]
+ dict_keys = _FLATTEN.split(key)
+ return attribute_key_extractor(dict_keys[-1], None, data)
+
+
+def last_rest_key_case_insensitive_extractor(attr, attr_desc, data):
+ """Extract the attribute in "data" based on the last part of the JSON path key.
+
+ This is the case insensitive version of "last_rest_key_extractor"
+ """
+ key = attr_desc["key"]
+ dict_keys = _FLATTEN.split(key)
+ return attribute_key_case_insensitive_extractor(dict_keys[-1], None, data)
+
+
+def attribute_key_extractor(attr, _, data):
+ return data.get(attr)
+
+
+def attribute_key_case_insensitive_extractor(attr, _, data):
+ found_key = None
+ lower_attr = attr.lower()
+ for key in data:
+ if lower_attr == key.lower():
+ found_key = key
+ break
+
+ return data.get(found_key)
+
+
+def _extract_name_from_internal_type(internal_type):
+ """Given an internal type XML description, extract correct XML name with namespace.
+
+ :param dict internal_type: An model type
+ :rtype: tuple
+ :returns: A tuple XML name + namespace dict
+ """
+ internal_type_xml_map = getattr(internal_type, "_xml_map", {})
+ xml_name = internal_type_xml_map.get("name", internal_type.__name__)
+ xml_ns = internal_type_xml_map.get("ns", None)
+ if xml_ns:
+ xml_name = "{{{}}}{}".format(xml_ns, xml_name)
+ return xml_name
+
+
+def xml_key_extractor(attr, attr_desc, data):
+ if isinstance(data, dict):
+ return None
+
+ # Test if this model is XML ready first
+ if not isinstance(data, ET.Element):
+ return None
+
+ xml_desc = attr_desc.get("xml", {})
+ xml_name = xml_desc.get("name", attr_desc["key"])
+
+ # Look for a children
+ is_iter_type = attr_desc["type"].startswith("[")
+ is_wrapped = xml_desc.get("wrapped", False)
+ internal_type = attr_desc.get("internalType", None)
+ internal_type_xml_map = getattr(internal_type, "_xml_map", {})
+
+ # Integrate namespace if necessary
+ xml_ns = xml_desc.get("ns", internal_type_xml_map.get("ns", None))
+ if xml_ns:
+ xml_name = "{{{}}}{}".format(xml_ns, xml_name)
+
+ # If it's an attribute, that's simple
+ if xml_desc.get("attr", False):
+ return data.get(xml_name)
+
+ # If it's x-ms-text, that's simple too
+ if xml_desc.get("text", False):
+ return data.text
+
+ # Scenario where I take the local name:
+ # - Wrapped node
+ # - Internal type is an enum (considered basic types)
+ # - Internal type has no XML/Name node
+ if is_wrapped or (internal_type and (issubclass(internal_type, Enum) or "name" not in internal_type_xml_map)):
+ children = data.findall(xml_name)
+ # If internal type has a local name and it's not a list, I use that name
+ elif not is_iter_type and internal_type and "name" in internal_type_xml_map:
+ xml_name = _extract_name_from_internal_type(internal_type)
+ children = data.findall(xml_name)
+ # That's an array
+ else:
+ if internal_type: # Complex type, ignore itemsName and use the complex type name
+ items_name = _extract_name_from_internal_type(internal_type)
+ else:
+ items_name = xml_desc.get("itemsName", xml_name)
+ children = data.findall(items_name)
+
+ if len(children) == 0:
+ if is_iter_type:
+ if is_wrapped:
+ return None # is_wrapped no node, we want None
+ else:
+ return [] # not wrapped, assume empty list
+ return None # Assume it's not there, maybe an optional node.
+
+ # If is_iter_type and not wrapped, return all found children
+ if is_iter_type:
+ if not is_wrapped:
+ return children
+ else: # Iter and wrapped, should have found one node only (the wrap one)
+ if len(children) != 1:
+ raise DeserializationError(
+ "Tried to deserialize an array not wrapped, and found several nodes '{}'. Maybe you should declare this array as wrapped?".format(
+ xml_name
+ )
+ )
+ return list(children[0]) # Might be empty list and that's ok.
+
+ # Here it's not a itertype, we should have found one element only or empty
+ if len(children) > 1:
+ raise DeserializationError("Find several XML '{}' where it was not expected".format(xml_name))
+ return children[0]
+
+
+class Deserializer(object):
+ """Response object model deserializer.
+
+ :param dict classes: Class type dictionary for deserializing complex types.
+ :ivar list key_extractors: Ordered list of extractors to be used by this deserializer.
+ """
+
+ basic_types = {str: "str", int: "int", bool: "bool", float: "float"}
+
+ valid_date = re.compile(r"\d{4}[-]\d{2}[-]\d{2}T\d{2}:\d{2}:\d{2}" r"\.?\d*Z?[-+]?[\d{2}]?:?[\d{2}]?")
+
+ def __init__(self, classes: Optional[Mapping[str, type]] = None):
+ self.deserialize_type = {
+ "iso-8601": Deserializer.deserialize_iso,
+ "rfc-1123": Deserializer.deserialize_rfc,
+ "unix-time": Deserializer.deserialize_unix,
+ "duration": Deserializer.deserialize_duration,
+ "date": Deserializer.deserialize_date,
+ "time": Deserializer.deserialize_time,
+ "decimal": Deserializer.deserialize_decimal,
+ "long": Deserializer.deserialize_long,
+ "bytearray": Deserializer.deserialize_bytearray,
+ "base64": Deserializer.deserialize_base64,
+ "object": self.deserialize_object,
+ "[]": self.deserialize_iter,
+ "{}": self.deserialize_dict,
+ }
+ self.deserialize_expected_types = {
+ "duration": (isodate.Duration, datetime.timedelta),
+ "iso-8601": (datetime.datetime),
+ }
+ self.dependencies: Dict[str, type] = dict(classes) if classes else {}
+ self.key_extractors = [rest_key_extractor, xml_key_extractor]
+ # Additional properties only works if the "rest_key_extractor" is used to
+ # extract the keys. Making it to work whatever the key extractor is too much
+ # complicated, with no real scenario for now.
+ # So adding a flag to disable additional properties detection. This flag should be
+ # used if your expect the deserialization to NOT come from a JSON REST syntax.
+ # Otherwise, result are unexpected
+ self.additional_properties_detection = True
+
+ def __call__(self, target_obj, response_data, content_type=None):
+ """Call the deserializer to process a REST response.
+
+ :param str target_obj: Target data type to deserialize to.
+ :param requests.Response response_data: REST response object.
+ :param str content_type: Swagger "produces" if available.
+ :raises: DeserializationError if deserialization fails.
+ :return: Deserialized object.
+ """
+ data = self._unpack_content(response_data, content_type)
+ return self._deserialize(target_obj, data)
+
+ def _deserialize(self, target_obj, data):
+ """Call the deserializer on a model.
+
+ Data needs to be already deserialized as JSON or XML ElementTree
+
+ :param str target_obj: Target data type to deserialize to.
+ :param object data: Object to deserialize.
+ :raises: DeserializationError if deserialization fails.
+ :return: Deserialized object.
+ """
+ # This is already a model, go recursive just in case
+ if hasattr(data, "_attribute_map"):
+ constants = [name for name, config in getattr(data, "_validation", {}).items() if config.get("constant")]
+ try:
+ for attr, mapconfig in data._attribute_map.items():
+ if attr in constants:
+ continue
+ value = getattr(data, attr)
+ if value is None:
+ continue
+ local_type = mapconfig["type"]
+ internal_data_type = local_type.strip("[]{}")
+ if internal_data_type not in self.dependencies or isinstance(internal_data_type, Enum):
+ continue
+ setattr(data, attr, self._deserialize(local_type, value))
+ return data
+ except AttributeError:
+ return
+
+ response, class_name = self._classify_target(target_obj, data)
+
+ if isinstance(response, str):
+ return self.deserialize_data(data, response)
+ elif isinstance(response, type) and issubclass(response, Enum):
+ return self.deserialize_enum(data, response)
+
+ if data is None:
+ return data
+ try:
+ attributes = response._attribute_map # type: ignore
+ d_attrs = {}
+ for attr, attr_desc in attributes.items():
+ # Check empty string. If it's not empty, someone has a real "additionalProperties"...
+ if attr == "additional_properties" and attr_desc["key"] == "":
+ continue
+ raw_value = None
+ # Enhance attr_desc with some dynamic data
+ attr_desc = attr_desc.copy() # Do a copy, do not change the real one
+ internal_data_type = attr_desc["type"].strip("[]{}")
+ if internal_data_type in self.dependencies:
+ attr_desc["internalType"] = self.dependencies[internal_data_type]
+
+ for key_extractor in self.key_extractors:
+ found_value = key_extractor(attr, attr_desc, data)
+ if found_value is not None:
+ if raw_value is not None and raw_value != found_value:
+ msg = (
+ "Ignoring extracted value '%s' from %s for key '%s'"
+ " (duplicate extraction, follow extractors order)"
+ )
+ _LOGGER.warning(msg, found_value, key_extractor, attr)
+ continue
+ raw_value = found_value
+
+ value = self.deserialize_data(raw_value, attr_desc["type"])
+ d_attrs[attr] = value
+ except (AttributeError, TypeError, KeyError) as err:
+ msg = "Unable to deserialize to object: " + class_name # type: ignore
+ raise DeserializationError(msg) from err
+ else:
+ additional_properties = self._build_additional_properties(attributes, data)
+ return self._instantiate_model(response, d_attrs, additional_properties)
+
+ def _build_additional_properties(self, attribute_map, data):
+ if not self.additional_properties_detection:
+ return None
+ if "additional_properties" in attribute_map and attribute_map.get("additional_properties", {}).get("key") != "":
+ # Check empty string. If it's not empty, someone has a real "additionalProperties"
+ return None
+ if isinstance(data, ET.Element):
+ data = {el.tag: el.text for el in data}
+
+ known_keys = {
+ _decode_attribute_map_key(_FLATTEN.split(desc["key"])[0])
+ for desc in attribute_map.values()
+ if desc["key"] != ""
+ }
+ present_keys = set(data.keys())
+ missing_keys = present_keys - known_keys
+ return {key: data[key] for key in missing_keys}
+
+ def _classify_target(self, target, data):
+ """Check to see whether the deserialization target object can
+ be classified into a subclass.
+ Once classification has been determined, initialize object.
+
+ :param str target: The target object type to deserialize to.
+ :param str/dict data: The response data to deserialize.
+ """
+ if target is None:
+ return None, None
+
+ if isinstance(target, str):
+ try:
+ target = self.dependencies[target]
+ except KeyError:
+ return target, target
+
+ try:
+ target = target._classify(data, self.dependencies) # type: ignore
+ except AttributeError:
+ pass # Target is not a Model, no classify
+ return target, target.__class__.__name__ # type: ignore
+
+ def failsafe_deserialize(self, target_obj, data, content_type=None):
+ """Ignores any errors encountered in deserialization,
+ and falls back to not deserializing the object. Recommended
+ for use in error deserialization, as we want to return the
+ HttpResponseError to users, and not have them deal with
+ a deserialization error.
+
+ :param str target_obj: The target object type to deserialize to.
+ :param str/dict data: The response data to deserialize.
+ :param str content_type: Swagger "produces" if available.
+ """
+ try:
+ return self(target_obj, data, content_type=content_type)
+ except:
+ _LOGGER.debug(
+ "Ran into a deserialization error. Ignoring since this is failsafe deserialization", exc_info=True
+ )
+ return None
+
+ @staticmethod
+ def _unpack_content(raw_data, content_type=None):
+ """Extract the correct structure for deserialization.
+
+ If raw_data is a PipelineResponse, try to extract the result of RawDeserializer.
+ if we can't, raise. Your Pipeline should have a RawDeserializer.
+
+ If not a pipeline response and raw_data is bytes or string, use content-type
+ to decode it. If no content-type, try JSON.
+
+ If raw_data is something else, bypass all logic and return it directly.
+
+ :param raw_data: Data to be processed.
+ :param content_type: How to parse if raw_data is a string/bytes.
+ :raises JSONDecodeError: If JSON is requested and parsing is impossible.
+ :raises UnicodeDecodeError: If bytes is not UTF8
+ """
+ # Assume this is enough to detect a Pipeline Response without importing it
+ context = getattr(raw_data, "context", {})
+ if context:
+ if RawDeserializer.CONTEXT_NAME in context:
+ return context[RawDeserializer.CONTEXT_NAME]
+ raise ValueError("This pipeline didn't have the RawDeserializer policy; can't deserialize")
+
+ # Assume this is enough to recognize universal_http.ClientResponse without importing it
+ if hasattr(raw_data, "body"):
+ return RawDeserializer.deserialize_from_http_generics(raw_data.text(), raw_data.headers)
+
+ # Assume this enough to recognize requests.Response without importing it.
+ if hasattr(raw_data, "_content_consumed"):
+ return RawDeserializer.deserialize_from_http_generics(raw_data.text, raw_data.headers)
+
+ if isinstance(raw_data, (str, bytes)) or hasattr(raw_data, "read"):
+ return RawDeserializer.deserialize_from_text(raw_data, content_type) # type: ignore
+ return raw_data
+
+ def _instantiate_model(self, response, attrs, additional_properties=None):
+ """Instantiate a response model passing in deserialized args.
+
+ :param response: The response model class.
+ :param d_attrs: The deserialized response attributes.
+ """
+ if callable(response):
+ subtype = getattr(response, "_subtype_map", {})
+ try:
+ readonly = [k for k, v in response._validation.items() if v.get("readonly")]
+ const = [k for k, v in response._validation.items() if v.get("constant")]
+ kwargs = {k: v for k, v in attrs.items() if k not in subtype and k not in readonly + const}
+ response_obj = response(**kwargs)
+ for attr in readonly:
+ setattr(response_obj, attr, attrs.get(attr))
+ if additional_properties:
+ response_obj.additional_properties = additional_properties
+ return response_obj
+ except TypeError as err:
+ msg = "Unable to deserialize {} into model {}. ".format(kwargs, response) # type: ignore
+ raise DeserializationError(msg + str(err))
+ else:
+ try:
+ for attr, value in attrs.items():
+ setattr(response, attr, value)
+ return response
+ except Exception as exp:
+ msg = "Unable to populate response model. "
+ msg += "Type: {}, Error: {}".format(type(response), exp)
+ raise DeserializationError(msg)
+
+ def deserialize_data(self, data, data_type):
+ """Process data for deserialization according to data type.
+
+ :param str data: The response string to be deserialized.
+ :param str data_type: The type to deserialize to.
+ :raises: DeserializationError if deserialization fails.
+ :return: Deserialized object.
+ """
+ if data is None:
+ return data
+
+ try:
+ if not data_type:
+ return data
+ if data_type in self.basic_types.values():
+ return self.deserialize_basic(data, data_type)
+ if data_type in self.deserialize_type:
+ if isinstance(data, self.deserialize_expected_types.get(data_type, tuple())):
+ return data
+
+ is_a_text_parsing_type = lambda x: x not in ["object", "[]", r"{}"]
+ if isinstance(data, ET.Element) and is_a_text_parsing_type(data_type) and not data.text:
+ return None
+ data_val = self.deserialize_type[data_type](data)
+ return data_val
+
+ iter_type = data_type[0] + data_type[-1]
+ if iter_type in self.deserialize_type:
+ return self.deserialize_type[iter_type](data, data_type[1:-1])
+
+ obj_type = self.dependencies[data_type]
+ if issubclass(obj_type, Enum):
+ if isinstance(data, ET.Element):
+ data = data.text
+ return self.deserialize_enum(data, obj_type)
+
+ except (ValueError, TypeError, AttributeError) as err:
+ msg = "Unable to deserialize response data."
+ msg += " Data: {}, {}".format(data, data_type)
+ raise DeserializationError(msg) from err
+ else:
+ return self._deserialize(obj_type, data)
+
+ def deserialize_iter(self, attr, iter_type):
+ """Deserialize an iterable.
+
+ :param list attr: Iterable to be deserialized.
+ :param str iter_type: The type of object in the iterable.
+ :rtype: list
+ """
+ if attr is None:
+ return None
+ if isinstance(attr, ET.Element): # If I receive an element here, get the children
+ attr = list(attr)
+ if not isinstance(attr, (list, set)):
+ raise DeserializationError("Cannot deserialize as [{}] an object of type {}".format(iter_type, type(attr)))
+ return [self.deserialize_data(a, iter_type) for a in attr]
+
+ def deserialize_dict(self, attr, dict_type):
+ """Deserialize a dictionary.
+
+ :param dict/list attr: Dictionary to be deserialized. Also accepts
+ a list of key, value pairs.
+ :param str dict_type: The object type of the items in the dictionary.
+ :rtype: dict
+ """
+ if isinstance(attr, list):
+ return {x["key"]: self.deserialize_data(x["value"], dict_type) for x in attr}
+
+ if isinstance(attr, ET.Element):
+ # Transform <Key>value</Key> into {"Key": "value"}
+ attr = {el.tag: el.text for el in attr}
+ return {k: self.deserialize_data(v, dict_type) for k, v in attr.items()}
+
+ def deserialize_object(self, attr, **kwargs):
+ """Deserialize a generic object.
+ This will be handled as a dictionary.
+
+ :param dict attr: Dictionary to be deserialized.
+ :rtype: dict
+ :raises: TypeError if non-builtin datatype encountered.
+ """
+ if attr is None:
+ return None
+ if isinstance(attr, ET.Element):
+ # Do no recurse on XML, just return the tree as-is
+ return attr
+ if isinstance(attr, str):
+ return self.deserialize_basic(attr, "str")
+ obj_type = type(attr)
+ if obj_type in self.basic_types:
+ return self.deserialize_basic(attr, self.basic_types[obj_type])
+ if obj_type is _long_type:
+ return self.deserialize_long(attr)
+
+ if obj_type == dict:
+ deserialized = {}
+ for key, value in attr.items():
+ try:
+ deserialized[key] = self.deserialize_object(value, **kwargs)
+ except ValueError:
+ deserialized[key] = None
+ return deserialized
+
+ if obj_type == list:
+ deserialized = []
+ for obj in attr:
+ try:
+ deserialized.append(self.deserialize_object(obj, **kwargs))
+ except ValueError:
+ pass
+ return deserialized
+
+ else:
+ error = "Cannot deserialize generic object with type: "
+ raise TypeError(error + str(obj_type))
+
+ def deserialize_basic(self, attr, data_type):
+ """Deserialize basic builtin data type from string.
+ Will attempt to convert to str, int, float and bool.
+ This function will also accept '1', '0', 'true' and 'false' as
+ valid bool values.
+
+ :param str attr: response string to be deserialized.
+ :param str data_type: deserialization data type.
+ :rtype: str, int, float or bool
+ :raises: TypeError if string format is not valid.
+ """
+ # If we're here, data is supposed to be a basic type.
+ # If it's still an XML node, take the text
+ if isinstance(attr, ET.Element):
+ attr = attr.text
+ if not attr:
+ if data_type == "str":
+ # None or '', node <a/> is empty string.
+ return ""
+ else:
+ # None or '', node <a/> with a strong type is None.
+ # Don't try to model "empty bool" or "empty int"
+ return None
+
+ if data_type == "bool":
+ if attr in [True, False, 1, 0]:
+ return bool(attr)
+ elif isinstance(attr, str):
+ if attr.lower() in ["true", "1"]:
+ return True
+ elif attr.lower() in ["false", "0"]:
+ return False
+ raise TypeError("Invalid boolean value: {}".format(attr))
+
+ if data_type == "str":
+ return self.deserialize_unicode(attr)
+ return eval(data_type)(attr) # nosec
+
+ @staticmethod
+ def deserialize_unicode(data):
+ """Preserve unicode objects in Python 2, otherwise return data
+ as a string.
+
+ :param str data: response string to be deserialized.
+ :rtype: str or unicode
+ """
+ # We might be here because we have an enum modeled as string,
+ # and we try to deserialize a partial dict with enum inside
+ if isinstance(data, Enum):
+ return data
+
+ # Consider this is real string
+ try:
+ if isinstance(data, unicode): # type: ignore
+ return data
+ except NameError:
+ return str(data)
+ else:
+ return str(data)
+
+ @staticmethod
+ def deserialize_enum(data, enum_obj):
+ """Deserialize string into enum object.
+
+ If the string is not a valid enum value it will be returned as-is
+ and a warning will be logged.
+
+ :param str data: Response string to be deserialized. If this value is
+ None or invalid it will be returned as-is.
+ :param Enum enum_obj: Enum object to deserialize to.
+ :rtype: Enum
+ """
+ if isinstance(data, enum_obj) or data is None:
+ return data
+ if isinstance(data, Enum):
+ data = data.value
+ if isinstance(data, int):
+ # Workaround. We might consider remove it in the future.
+ try:
+ return list(enum_obj.__members__.values())[data]
+ except IndexError:
+ error = "{!r} is not a valid index for enum {!r}"
+ raise DeserializationError(error.format(data, enum_obj))
+ try:
+ return enum_obj(str(data))
+ except ValueError:
+ for enum_value in enum_obj:
+ if enum_value.value.lower() == str(data).lower():
+ return enum_value
+ # We don't fail anymore for unknown value, we deserialize as a string
+ _LOGGER.warning("Deserializer is not able to find %s as valid enum in %s", data, enum_obj)
+ return Deserializer.deserialize_unicode(data)
+
+ @staticmethod
+ def deserialize_bytearray(attr):
+ """Deserialize string into bytearray.
+
+ :param str attr: response string to be deserialized.
+ :rtype: bytearray
+ :raises: TypeError if string format invalid.
+ """
+ if isinstance(attr, ET.Element):
+ attr = attr.text
+ return bytearray(b64decode(attr)) # type: ignore
+
+ @staticmethod
+ def deserialize_base64(attr):
+ """Deserialize base64 encoded string into string.
+
+ :param str attr: response string to be deserialized.
+ :rtype: bytearray
+ :raises: TypeError if string format invalid.
+ """
+ if isinstance(attr, ET.Element):
+ attr = attr.text
+ padding = "=" * (3 - (len(attr) + 3) % 4) # type: ignore
+ attr = attr + padding # type: ignore
+ encoded = attr.replace("-", "+").replace("_", "/")
+ return b64decode(encoded)
+
+ @staticmethod
+ def deserialize_decimal(attr):
+ """Deserialize string into Decimal object.
+
+ :param str attr: response string to be deserialized.
+ :rtype: Decimal
+ :raises: DeserializationError if string format invalid.
+ """
+ if isinstance(attr, ET.Element):
+ attr = attr.text
+ try:
+ return decimal.Decimal(str(attr)) # type: ignore
+ except decimal.DecimalException as err:
+ msg = "Invalid decimal {}".format(attr)
+ raise DeserializationError(msg) from err
+
+ @staticmethod
+ def deserialize_long(attr):
+ """Deserialize string into long (Py2) or int (Py3).
+
+ :param str attr: response string to be deserialized.
+ :rtype: long or int
+ :raises: ValueError if string format invalid.
+ """
+ if isinstance(attr, ET.Element):
+ attr = attr.text
+ return _long_type(attr) # type: ignore
+
+ @staticmethod
+ def deserialize_duration(attr):
+ """Deserialize ISO-8601 formatted string into TimeDelta object.
+
+ :param str attr: response string to be deserialized.
+ :rtype: TimeDelta
+ :raises: DeserializationError if string format invalid.
+ """
+ if isinstance(attr, ET.Element):
+ attr = attr.text
+ try:
+ duration = isodate.parse_duration(attr)
+ except (ValueError, OverflowError, AttributeError) as err:
+ msg = "Cannot deserialize duration object."
+ raise DeserializationError(msg) from err
+ else:
+ return duration
+
+ @staticmethod
+ def deserialize_date(attr):
+ """Deserialize ISO-8601 formatted string into Date object.
+
+ :param str attr: response string to be deserialized.
+ :rtype: Date
+ :raises: DeserializationError if string format invalid.
+ """
+ if isinstance(attr, ET.Element):
+ attr = attr.text
+ if re.search(r"[^\W\d_]", attr, re.I + re.U): # type: ignore
+ raise DeserializationError("Date must have only digits and -. Received: %s" % attr)
+ # This must NOT use defaultmonth/defaultday. Using None ensure this raises an exception.
+ return isodate.parse_date(attr, defaultmonth=0, defaultday=0)
+
+ @staticmethod
+ def deserialize_time(attr):
+ """Deserialize ISO-8601 formatted string into time object.
+
+ :param str attr: response string to be deserialized.
+ :rtype: datetime.time
+ :raises: DeserializationError if string format invalid.
+ """
+ if isinstance(attr, ET.Element):
+ attr = attr.text
+ if re.search(r"[^\W\d_]", attr, re.I + re.U): # type: ignore
+ raise DeserializationError("Date must have only digits and -. Received: %s" % attr)
+ return isodate.parse_time(attr)
+
+ @staticmethod
+ def deserialize_rfc(attr):
+ """Deserialize RFC-1123 formatted string into Datetime object.
+
+ :param str attr: response string to be deserialized.
+ :rtype: Datetime
+ :raises: DeserializationError if string format invalid.
+ """
+ if isinstance(attr, ET.Element):
+ attr = attr.text
+ try:
+ parsed_date = email.utils.parsedate_tz(attr) # type: ignore
+ date_obj = datetime.datetime(
+ *parsed_date[:6], tzinfo=_FixedOffset(datetime.timedelta(minutes=(parsed_date[9] or 0) / 60))
+ )
+ if not date_obj.tzinfo:
+ date_obj = date_obj.astimezone(tz=TZ_UTC)
+ except ValueError as err:
+ msg = "Cannot deserialize to rfc datetime object."
+ raise DeserializationError(msg) from err
+ else:
+ return date_obj
+
+ @staticmethod
+ def deserialize_iso(attr):
+ """Deserialize ISO-8601 formatted string into Datetime object.
+
+ :param str attr: response string to be deserialized.
+ :rtype: Datetime
+ :raises: DeserializationError if string format invalid.
+ """
+ if isinstance(attr, ET.Element):
+ attr = attr.text
+ try:
+ attr = attr.upper() # type: ignore
+ match = Deserializer.valid_date.match(attr)
+ if not match:
+ raise ValueError("Invalid datetime string: " + attr)
+
+ check_decimal = attr.split(".")
+ if len(check_decimal) > 1:
+ decimal_str = ""
+ for digit in check_decimal[1]:
+ if digit.isdigit():
+ decimal_str += digit
+ else:
+ break
+ if len(decimal_str) > 6:
+ attr = attr.replace(decimal_str, decimal_str[0:6])
+
+ date_obj = isodate.parse_datetime(attr)
+ test_utc = date_obj.utctimetuple()
+ if test_utc.tm_year > 9999 or test_utc.tm_year < 1:
+ raise OverflowError("Hit max or min date")
+ except (ValueError, OverflowError, AttributeError) as err:
+ msg = "Cannot deserialize datetime object."
+ raise DeserializationError(msg) from err
+ else:
+ return date_obj
+
+ @staticmethod
+ def deserialize_unix(attr):
+ """Serialize Datetime object into IntTime format.
+ This is represented as seconds.
+
+ :param int attr: Object to be serialized.
+ :rtype: Datetime
+ :raises: DeserializationError if format invalid
+ """
+ if isinstance(attr, ET.Element):
+ attr = int(attr.text) # type: ignore
+ try:
+ attr = int(attr)
+ date_obj = datetime.datetime.fromtimestamp(attr, TZ_UTC)
+ except ValueError as err:
+ msg = "Cannot deserialize to unix datetime object."
+ raise DeserializationError(msg) from err
+ else:
+ return date_obj
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/models/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/models/__init__.py
new file mode 100644
index 00000000..cda7689a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/models/__init__.py
@@ -0,0 +1,18 @@
+# coding=utf-8
+# --------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for license information.
+# Code generated by Microsoft (R) Python Code Generator.
+# Changes may cause incorrect behavior and will be lost if the code is regenerated.
+# --------------------------------------------------------------------------
+
+from ._models import AzureOpenAIDeployment
+from ._models import ServerlessEndpoint
+from ._models import MarketplaceSubscription
+from ._patch import __all__ as _patch_all
+from ._patch import * # pylint: disable=unused-wildcard-import
+from ._patch import patch_sdk as _patch_sdk
+
+__all__ = ["AzureOpenAIDeployment", "ServerlessEndpoint", "MarketplaceSubscription"]
+__all__.extend([p for p in _patch_all if p not in __all__])
+_patch_sdk()
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/models/_models.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/models/_models.py
new file mode 100644
index 00000000..3b12203d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/models/_models.py
@@ -0,0 +1,214 @@
+# coding=utf-8
+# --------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for license information.
+# Code generated by Microsoft (R) Python Code Generator.
+# Changes may cause incorrect behavior and will be lost if the code is regenerated.
+# --------------------------------------------------------------------------
+
+from typing import Any, Dict, Mapping, Optional, TYPE_CHECKING, overload
+
+from .. import _model_base
+from .._model_base import rest_field
+
+if TYPE_CHECKING:
+ from .. import models as _models
+
+
+class AzureOpenAIDeployment(_model_base.Model):
+ """Azure OpenAI Deployment Information.
+
+ Readonly variables are only populated by the server, and will be ignored when sending a request.
+
+ :ivar name: The deployment name.
+ :vartype name: str
+ :ivar model_name: The name of the model to deploy.
+ :vartype model_name: str
+ :ivar model_version: The model version to deploy.
+ :vartype model_version: str
+ :ivar connection_name: The name of the connection to deploy to.
+ :vartype connection_name: str
+ :ivar target_url: The target URL of the AOAI resource for the deployment.
+ :vartype target_url: str
+ :ivar id: The ARM resource id of the deployment.
+ :vartype id: str
+ :ivar properties: Properties of the deployment.
+ :vartype properties: dict[str, str]
+ :ivar tags: Tags of the deployment.
+ :vartype tags: dict[str, str]
+ """
+
+ name: Optional[str] = rest_field(visibility=["read"])
+ """The deployment name."""
+ model_name: Optional[str] = rest_field(visibility=["read"])
+ """The name of the model to deploy."""
+ model_version: Optional[str] = rest_field(visibility=["read"])
+ """The model version to deploy."""
+ connection_name: Optional[str] = rest_field(visibility=["read"])
+ """The name of the connection to deploy to."""
+ target_url: Optional[str] = rest_field(visibility=["read"])
+ """The target URL of the AOAI resource for the deployment."""
+ id: Optional[str] = rest_field(visibility=["read"])
+ """The ARM resource id of the deployment."""
+
+
+class MarketplacePlan(_model_base.Model):
+ """Marketplace Subscription Definition.
+
+ Readonly variables are only populated by the server, and will be ignored when sending a request.
+
+ :ivar publisher_id: The id of the publisher.
+ :vartype publisher_id: str
+ :ivar offer_id: The id of the offering associated with the plan.
+ :vartype offer_id: str
+ :ivar plan_id: The id of the plan.
+ :vartype plan_id: str
+ :ivar term_id: The term id.
+ :vartype term_id: str
+ """
+
+ publisher_id: Optional[str] = rest_field(visibility=["read"])
+ """The id of the publisher."""
+ offer_id: Optional[str] = rest_field(visibility=["read"])
+ """The id of the offering associated with the plan."""
+ plan_id: Optional[str] = rest_field(visibility=["read"])
+ """The id of the plan."""
+ term_id: Optional[str] = rest_field(visibility=["read"])
+ """The term id."""
+
+
+class MarketplaceSubscription(_model_base.Model):
+ """Marketplace Subscription Definition.
+
+ Readonly variables are only populated by the server, and will be ignored when sending a request.
+
+ All required parameters must be populated in order to send to server.
+
+ :ivar name: The marketplace subscription name. Required.
+ :vartype name: str
+ :ivar model_id: Model id for which to create marketplace subscription. Required.
+ :vartype model_id: str
+ :ivar marketplace_plan: The plan associated with the marketplace subscription.
+ :vartype marketplace_plan: ~azure.ai.ml.entities.models.MarketplacePlan
+ :ivar status: Status of the marketplace subscription. Possible values are:
+ "pending_fulfillment_start", "subscribed", "unsubscribed", "suspended".
+ :vartype status: str
+ :ivar provisioning_state: Provisioning state of the marketplace subscription. Possible values
+ are: "creating", "deleting", "succeeded", "failed", "updating", and "canceled".
+ :vartype provisioning_state: str
+ :ivar id: ARM resource id of the marketplace subscription.
+ :vartype id: str
+ """
+
+ name: str = rest_field()
+ """The marketplace subscription name. Required."""
+ model_id: str = rest_field()
+ """Model id for which to create marketplace subscription. Required."""
+ marketplace_plan: Optional["_models.MarketplacePlan"] = rest_field(visibility=["read"])
+ """The plan associated with the marketplace subscription."""
+ status: Optional[str] = rest_field(visibility=["read"])
+ """Status of the marketplace subscription. Possible values are: \"pending_fulfillment_start\",
+ \"subscribed\", \"unsubscribed\", \"suspended\"."""
+ provisioning_state: Optional[str] = rest_field(visibility=["read"])
+ """Provisioning state of the marketplace subscription. Possible values are: \"creating\",
+ \"deleting\", \"succeeded\", \"failed\", \"updating\", and \"canceled\"."""
+ id: Optional[str] = rest_field(visibility=["read"])
+ """ARM resource id of the marketplace subscription."""
+
+ @overload
+ def __init__(
+ self,
+ *,
+ name: str,
+ model_id: str,
+ ): ...
+
+ @overload
+ def __init__(self, mapping: Mapping[str, Any]):
+ """
+ :param mapping: raw JSON to initialize the model.
+ :type mapping: Mapping[str, Any]
+ """
+
+ def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation
+ super().__init__(*args, **kwargs)
+
+
+class ServerlessEndpoint(_model_base.Model):
+ """Serverless Endpoint Definition.
+
+ Readonly variables are only populated by the server, and will be ignored when sending a request.
+
+ All required parameters must be populated in order to send to server.
+
+ :ivar name: The deployment name. Required.
+ :vartype name: str
+ :ivar auth_mode: Authentication mode of the endpoint.
+ :vartype auth_mode: str
+ :ivar model_id: The id of the model to deploy. Required.
+ :vartype model_id: str
+ :ivar location: Location in which to create endpoint.
+ :vartype location: str
+ :ivar provisioning_state: Provisioning state of the endpoint. Possible values are: "creating",
+ "deleting", "succeeded", "failed", "updating", and "canceled".
+ :vartype provisioning_state: str
+ :ivar tags: Tags for the endpoint.
+ :vartype tags: dict[str, str]
+ :ivar properties: Properties of the endpoint.
+ :vartype properties: dict[str, str]
+ :ivar description: Descripton of the endpoint.
+ :vartype description: str
+ :ivar scoring_uri: Scoring uri of the endpoint.
+ :vartype scoring_uri: str
+ :ivar id: ARM resource id of the endpoint.
+ :vartype id: str
+ :ivar headers: Headers required to hit the endpoint.
+ :vartype id: dict[str, str]
+ """
+
+ name: str = rest_field()
+ """The deployment name. Required."""
+ auth_mode: Optional[str] = rest_field()
+ """Authentication mode of the endpoint. Possible values are: \"key\", \"aad\".
+ Defaults to \"key\" if not given."""
+ model_id: str = rest_field()
+ """The id of the model to deploy. Required."""
+ location: Optional[str] = rest_field(visibility=["read"])
+ """Location in which to create endpoint."""
+ provisioning_state: Optional[str] = rest_field(visibility=["read"])
+ """Provisioning state of the endpoint. Possible values are: \"creating\", \"deleting\",
+ \"succeeded\", \"failed\", \"updating\", and \"canceled\"."""
+ tags: Optional[Dict[str, str]] = rest_field()
+ """Tags for the endpoint."""
+ properties: Optional[Dict[str, str]] = rest_field()
+ """Properties of the endpoint."""
+ description: Optional[str] = rest_field()
+ """Descripton of the endpoint."""
+ scoring_uri: Optional[str] = rest_field(visibility=["read"])
+ """Scoring uri of the endpoint."""
+ id: Optional[str] = rest_field(visibility=["read"])
+ """ARM resource id of the endpoint."""
+ headers: Optional[Dict[str, str]] = rest_field(visibility=["read"])
+ """Headers required to hit the endpoint."""
+
+ @overload
+ def __init__(
+ self,
+ *,
+ name: str,
+ model_id: str,
+ auth_mode: Optional[str] = None,
+ tags: Optional[Dict[str, str]] = None,
+ properties: Optional[Dict[str, str]] = None,
+ description: Optional[str] = None,
+ ): ...
+
+ @overload
+ def __init__(self, mapping: Mapping[str, Any]):
+ """
+ :param mapping: raw JSON to initialize the model.
+ :type mapping: Mapping[str, Any]
+ """
+
+ def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation
+ super().__init__(*args, **kwargs)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/models/_patch.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/models/_patch.py
new file mode 100644
index 00000000..da29aeb3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_autogen_entities/models/_patch.py
@@ -0,0 +1,223 @@
+# ------------------------------------
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+# ------------------------------------
+
+# pylint: disable=protected-access
+
+"""Customize generated code here.
+
+Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize
+"""
+import json
+from typing import Any, Dict, List, Optional
+
+from azure.ai.ml._restclient.v2024_01_01_preview.models import MarketplaceSubscription as RestMarketplaceSubscription
+from azure.ai.ml._restclient.v2024_01_01_preview.models import (
+ MarketplaceSubscriptionProperties as RestMarketplaceSubscriptionProperties,
+)
+from azure.ai.ml._restclient.v2024_01_01_preview.models import ModelSettings as RestModelSettings
+from azure.ai.ml._restclient.v2024_01_01_preview.models import ServerlessEndpoint as RestServerlessEndpoint
+from azure.ai.ml._restclient.v2024_01_01_preview.models import (
+ ServerlessEndpointProperties as RestServerlessEndpointProperties,
+)
+from azure.ai.ml._restclient.v2024_01_01_preview.models import Sku as RestSku
+from azure.ai.ml._restclient.v2024_04_01_preview.models import (
+ EndpointDeploymentResourcePropertiesBasicResource,
+ OpenAIEndpointDeploymentResourceProperties,
+)
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.entities._system_data import SystemData
+
+from .._model_base import rest_field
+from ._models import AzureOpenAIDeployment as _AzureOpenAIDeployment
+from ._models import MarketplacePlan as _MarketplacePlan
+from ._models import MarketplaceSubscription as _MarketplaceSubscription
+from ._models import ServerlessEndpoint as _ServerlessEndpoint
+
+__all__: List[str] = [
+ "AzureOpenAIDeployment",
+ "ServerlessEndpoint",
+ "MarketplaceSubscription",
+ "MarketplacePlan",
+] # Add all objects you want publicly available to users at this package level
+
+_NULL = object()
+
+
+func_to_attr_type = {
+ "_deserialize_dict": dict,
+ "_deserialize_sequence": list,
+}
+
+
+def _get_rest_field_type(field):
+ if hasattr(field, "_type"):
+ if field._type.func.__name__ == "_deserialize_default":
+ return field._type.args[0]
+ if func_to_attr_type.get(field._type.func.__name__):
+ return func_to_attr_type[field._type.func.__name__]
+ return _get_rest_field_type(field._type.args[0])
+ if hasattr(field, "func") and func_to_attr_type.get(field.func.__name__):
+ return func_to_attr_type[field.func.__name__]
+ if hasattr(field, "args"):
+ return _get_rest_field_type(field.args[0])
+ return field
+
+
+class ValidationMixin:
+ def _validate(self) -> None:
+ # verify types
+ for attr, field in self._attr_to_rest_field.items(): # type: ignore
+ try:
+ attr_value = self.__getitem__(attr) # type: ignore
+ attr_type = type(attr_value)
+ except KeyError as exc:
+ if field._visibility and "read" in field._visibility:
+ # read-only field, no need to validate
+ continue
+ if field._type.func.__name__ != "_deserialize_with_optional":
+ # i'm required
+ raise ValueError(f"attr {attr} is a required property for {self.__class__.__name__}") from exc
+ else:
+ if getattr(attr_value, "_is_model", False):
+ attr_value._validate()
+ rest_field_type = _get_rest_field_type(field)
+ if attr_type != rest_field_type:
+ raise ValueError(f"Type of attr {attr} is of type {attr_type}, not {rest_field_type}")
+
+
+@experimental
+class AzureOpenAIDeployment(_AzureOpenAIDeployment):
+
+ system_data: Optional[SystemData] = rest_field(visibility=["read"])
+ """System data of the deployment."""
+
+ @classmethod
+ def _from_rest_object(cls, obj: EndpointDeploymentResourcePropertiesBasicResource) -> "AzureOpenAIDeployment":
+ properties: OpenAIEndpointDeploymentResourceProperties = obj.properties
+ return cls(
+ name=obj.name,
+ model_name=properties.model.name,
+ model_version=properties.model.version,
+ id=obj.id,
+ system_data=SystemData._from_rest_object(obj.system_data),
+ )
+
+ def as_dict(self, *, exclude_readonly: bool = False) -> Dict[str, Any]:
+ d = super().as_dict(exclude_readonly=exclude_readonly)
+ d["system_data"] = json.loads(json.dumps(self.system_data._to_dict())) # type: ignore
+ return d
+
+
+AzureOpenAIDeployment.__doc__ += (
+ _AzureOpenAIDeployment.__doc__.strip() # type: ignore
+ + """
+ :ivar system_data: System data of the deployment.
+ :vartype system_data: ~azure.ai.ml.entities.SystemData
+"""
+)
+
+
+@experimental
+class MarketplacePlan(_MarketplacePlan):
+ pass
+
+
+@experimental
+class ServerlessEndpoint(_ServerlessEndpoint, ValidationMixin):
+
+ system_data: Optional[SystemData] = rest_field(visibility=["read"])
+ """System data of the endpoint."""
+
+ def _to_rest_object(self) -> RestServerlessEndpoint:
+ return RestServerlessEndpoint(
+ properties=RestServerlessEndpointProperties(
+ model_settings=RestModelSettings(model_id=self.model_id),
+ ),
+ auth_mode="key", # only key is supported for now
+ tags=self.tags,
+ sku=RestSku(name="Consumption"),
+ location=self.location,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestServerlessEndpoint) -> "ServerlessEndpoint":
+ return cls( # type: ignore
+ name=obj.name,
+ id=obj.id,
+ tags=obj.tags,
+ location=obj.location,
+ auth_mode=obj.properties.auth_mode,
+ provisioning_state=camel_to_snake(obj.properties.provisioning_state),
+ model_id=obj.properties.model_settings.model_id if obj.properties.model_settings else None,
+ scoring_uri=obj.properties.inference_endpoint.uri if obj.properties.inference_endpoint else None,
+ system_data=SystemData._from_rest_object(obj.system_data) if obj.system_data else None,
+ headers=obj.properties.inference_endpoint.headers if obj.properties.inference_endpoint else None,
+ )
+
+ def as_dict(self, *, exclude_readonly: bool = False) -> Dict[str, Any]:
+ d = super().as_dict(exclude_readonly=exclude_readonly)
+ d["system_data"] = json.loads(json.dumps(self.system_data._to_dict())) # type: ignore
+ return d
+
+
+ServerlessEndpoint.__doc__ += (
+ _ServerlessEndpoint.__doc__.strip() # type: ignore
+ + """
+ :ivar system_data: System data of the endpoint.
+ :vartype system_data: ~azure.ai.ml.entities.SystemData
+"""
+)
+
+
+@experimental
+class MarketplaceSubscription(_MarketplaceSubscription, ValidationMixin):
+
+ system_data: Optional[SystemData] = rest_field(visibility=["read"])
+ """System data of the endpoint."""
+
+ def _to_rest_object(self) -> RestMarketplaceSubscription:
+ return RestMarketplaceSubscription(properties=RestMarketplaceSubscriptionProperties(model_id=self.model_id))
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestMarketplaceSubscription) -> "MarketplaceSubscription":
+ properties = obj.properties
+ return cls( # type: ignore
+ name=obj.name,
+ id=obj.id,
+ model_id=properties.model_id,
+ marketplace_plan=MarketplacePlan(
+ publisher_id=properties.marketplace_plan.publisher_id,
+ offer_id=properties.marketplace_plan.offer_id,
+ plan_id=properties.marketplace_plan.plan_id,
+ ),
+ status=camel_to_snake(properties.marketplace_subscription_status),
+ provisioning_state=camel_to_snake(properties.provisioning_state),
+ system_data=SystemData._from_rest_object(obj.system_data) if obj.system_data else None,
+ )
+
+ def as_dict(self, *, exclude_readonly: bool = False) -> Dict[str, Any]:
+ d = super().as_dict(exclude_readonly=exclude_readonly)
+ if self.system_data:
+ d["system_data"] = json.loads(json.dumps(self.system_data._to_dict()))
+ return d
+
+
+MarketplaceSubscription.__doc__ = (
+ _MarketplaceSubscription.__doc__.strip() # type: ignore
+ + """
+ :ivar system_data: System data of the marketplace subscription.
+ :vartype system_data: ~azure.ai.ml.entities.SystemData
+"""
+)
+
+
+def patch_sdk():
+ """Do not remove from this file.
+
+ `patch_sdk` is a last resort escape hatch that allows you to do customizations
+ you can't accomplish using the techniques described in
+ https://aka.ms/azsdk/python/dpcodegen/python/customize
+ """
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/__init__.py
new file mode 100644
index 00000000..95dfca0a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/__init__.py
@@ -0,0 +1,28 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from .base_node import BaseNode, parse_inputs_outputs
+from .command import Command
+from .do_while import DoWhile
+from .import_node import Import
+from .parallel import Parallel
+from .pipeline import Pipeline
+from .spark import Spark
+from .sweep import Sweep
+from .data_transfer import DataTransfer, DataTransferCopy, DataTransferImport, DataTransferExport
+
+__all__ = [
+ "BaseNode",
+ "Sweep",
+ "Parallel",
+ "Command",
+ "Import",
+ "Spark",
+ "Pipeline",
+ "parse_inputs_outputs",
+ "DoWhile",
+ "DataTransfer",
+ "DataTransferCopy",
+ "DataTransferImport",
+ "DataTransferExport",
+]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/base_node.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/base_node.py
new file mode 100644
index 00000000..98eba6a5
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/base_node.py
@@ -0,0 +1,568 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+# pylint: disable=protected-access
+
+import logging
+import os
+import uuid
+from abc import abstractmethod
+from enum import Enum
+from functools import wraps
+from typing import Any, Dict, List, Optional, Union
+
+from azure.ai.ml._utils._arm_id_utils import get_resource_name_from_arm_id_safe
+from azure.ai.ml.constants import JobType
+from azure.ai.ml.constants._common import CommonYamlFields
+from azure.ai.ml.constants._component import NodeType
+from azure.ai.ml.entities import Data, Model
+from azure.ai.ml.entities._component.component import Component
+from azure.ai.ml.entities._inputs_outputs import Input, Output
+from azure.ai.ml.entities._job._input_output_helpers import build_input_output
+from azure.ai.ml.entities._job.job import Job
+from azure.ai.ml.entities._job.pipeline._attr_dict import _AttrDict
+from azure.ai.ml.entities._job.pipeline._io import NodeOutput, PipelineInput
+from azure.ai.ml.entities._job.pipeline._io.mixin import NodeWithGroupInputMixin
+from azure.ai.ml.entities._job.pipeline._pipeline_expression import PipelineExpression
+from azure.ai.ml.entities._job.sweep.search_space import SweepDistribution
+from azure.ai.ml.entities._mixins import YamlTranslatableMixin
+from azure.ai.ml.entities._util import convert_ordered_dict_to_dict, resolve_pipeline_parameters
+from azure.ai.ml.entities._validation import MutableValidationResult, PathAwareSchemaValidatableMixin
+from azure.ai.ml.exceptions import ErrorTarget, ValidationException
+
+module_logger = logging.getLogger(__name__)
+
+
+def parse_inputs_outputs(data: dict) -> dict:
+ """Parse inputs and outputs from data. If data is a list, parse each item in the list.
+
+ :param data: A dict that may contain "inputs" or "outputs" keys
+ :type data: dict
+ :return: Dict with parsed "inputs" and "outputs" keys
+ :rtype: Dict
+ """
+
+ if "inputs" in data:
+ data["inputs"] = {key: build_input_output(val) for key, val in data["inputs"].items()}
+ if "outputs" in data:
+ data["outputs"] = {key: build_input_output(val, inputs=False) for key, val in data["outputs"].items()}
+ return data
+
+
+def pipeline_node_decorator(func: Any) -> Any:
+ """Wrap a function and add its return value to the current DSL pipeline.
+
+ :param func: The function to be wrapped.
+ :type func: callable
+ :return: The wrapped function.
+ :rtype: callable
+ """
+
+ @wraps(func)
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
+ automl_job = func(*args, **kwargs)
+ from azure.ai.ml.dsl._pipeline_component_builder import (
+ _add_component_to_current_definition_builder,
+ _is_inside_dsl_pipeline_func,
+ )
+
+ if _is_inside_dsl_pipeline_func():
+ # Build automl job to automl node if it's defined inside DSL pipeline func.
+ automl_job._instance_id = str(uuid.uuid4())
+ _add_component_to_current_definition_builder(automl_job)
+ return automl_job
+
+ return wrapper
+
+
+# pylint: disable=too-many-instance-attributes
+class BaseNode(Job, YamlTranslatableMixin, _AttrDict, PathAwareSchemaValidatableMixin, NodeWithGroupInputMixin):
+ """Base class for node in pipeline, used for component version consumption. Can't be instantiated directly.
+
+ You should not instantiate this class directly. Instead, you should
+ create from a builder function.
+
+ :param type: Type of pipeline node. Defaults to JobType.COMPONENT.
+ :type type: str
+ :param component: Id or instance of the component version to be run for the step
+ :type component: Component
+ :param inputs: The inputs for the node.
+ :type inputs: Optional[Dict[str, Union[
+ ~azure.ai.ml.entities._job.pipeline._io.PipelineInput,
+ ~azure.ai.ml.entities._job.pipeline._io.NodeOutput,
+ ~azure.ai.ml.entities.Input,
+ str,
+ bool,
+ int,
+ float,
+ Enum,
+ 'Input']]]
+ :param outputs: Mapping of output data bindings used in the job.
+ :type outputs: Optional[Dict[str, Union[str, ~azure.ai.ml.entities.Output, 'Output']]]
+ :param name: The name of the node.
+ :type name: Optional[str]
+ :param display_name: The display name of the node.
+ :type display_name: Optional[str]
+ :param description: The description of the node.
+ :type description: Optional[str]
+ :param tags: Tag dictionary. Tags can be added, removed, and updated.
+ :type tags: Optional[Dict]
+ :param properties: The properties of the job.
+ :type properties: Optional[Dict]
+ :param comment: Comment of the pipeline node, which will be shown in designer canvas.
+ :type comment: Optional[str]
+ :param compute: Compute definition containing the compute information for the step.
+ :type compute: Optional[str]
+ :param experiment_name: Name of the experiment the job will be created under,
+ if None is provided, default will be set to current directory name.
+ Will be ignored as a pipeline step.
+ :type experiment_name: Optional[str]
+ :param kwargs: Additional keyword arguments for future compatibility.
+ """
+
+ def __init__(
+ self,
+ *,
+ type: str = JobType.COMPONENT, # pylint: disable=redefined-builtin
+ component: Any,
+ inputs: Optional[Dict] = None,
+ outputs: Optional[Dict] = None,
+ name: Optional[str] = None,
+ display_name: Optional[str] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ properties: Optional[Dict] = None,
+ comment: Optional[str] = None,
+ compute: Optional[str] = None,
+ experiment_name: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ self._init = True
+ # property _source can't be set
+ source = kwargs.pop("_source", None)
+ _from_component_func = kwargs.pop("_from_component_func", False)
+ self._name: Optional[str] = None
+ super(BaseNode, self).__init__(
+ type=type,
+ name=name,
+ display_name=display_name,
+ description=description,
+ tags=tags,
+ properties=properties,
+ compute=compute,
+ experiment_name=experiment_name,
+ **kwargs,
+ )
+ self.comment = comment
+
+ # initialize io
+ inputs = resolve_pipeline_parameters(inputs)
+ inputs, outputs = inputs or {}, outputs or {}
+ # parse empty dict to None so we won't pass default mode, type to backend
+ # add `isinstance` to avoid converting to expression
+ for k, v in inputs.items():
+ if isinstance(v, dict) and v == {}:
+ inputs[k] = None
+
+ # TODO: get rid of self._job_inputs, self._job_outputs once we have unified Input
+ self._job_inputs, self._job_outputs = inputs, outputs
+ if isinstance(component, Component):
+ # Build the inputs from component input definition and given inputs, unfilled inputs will be None
+ self._inputs = self._build_inputs_dict(inputs or {}, input_definition_dict=component.inputs)
+ # Build the outputs from component output definition and given outputs, unfilled outputs will be None
+ self._outputs = self._build_outputs_dict(outputs or {}, output_definition_dict=component.outputs)
+ else:
+ # Build inputs/outputs dict without meta when definition not available
+ self._inputs = self._build_inputs_dict(inputs or {})
+ self._outputs = self._build_outputs_dict(outputs or {})
+
+ self._component = component
+ self._referenced_control_flow_node_instance_id: Optional[str] = None
+ self.kwargs = kwargs
+
+ # Generate an id for every instance
+ self._instance_id = str(uuid.uuid4())
+ if _from_component_func:
+ # add current component in pipeline stack for dsl scenario
+ self._register_in_current_pipeline_component_builder()
+
+ if source is None:
+ if isinstance(component, Component):
+ source = self._component._source
+ else:
+ source = Component._resolve_component_source_from_id(id=self._component)
+ self._source = source
+ self._validate_required_input_not_provided = True
+ self._init = False
+
+ @property
+ def name(self) -> Optional[str]:
+ """Get the name of the node.
+
+ :return: The name of the node.
+ :rtype: str
+ """
+ return self._name
+
+ @name.setter
+ def name(self, value: str) -> None:
+ """Set the name of the node.
+
+ :param value: The name to set for the node.
+ :type value: str
+ :return: None
+ """
+ # when name is not lower case, lower it to make sure it's a valid node name
+ if value and value != value.lower():
+ module_logger.warning(
+ "Changing node name %s to lower case: %s since upper case is not allowed node name.",
+ value,
+ value.lower(),
+ )
+ value = value.lower()
+ self._name = value
+
+ @classmethod
+ def _get_supported_inputs_types(cls) -> Any:
+ """Get the supported input types for node input.
+
+ :param cls: The class (or instance) to retrieve supported input types for.
+ :type cls: object
+
+ :return: A tuple of supported input types.
+ :rtype: tuple
+ """
+ # supported input types for node input
+ return (
+ PipelineInput,
+ NodeOutput,
+ Input,
+ Data,
+ Model,
+ str,
+ bool,
+ int,
+ float,
+ Enum,
+ PipelineExpression,
+ )
+
+ @property
+ def _skip_required_compute_missing_validation(self) -> bool:
+ return False
+
+ def _initializing(self) -> bool:
+ # use this to indicate ongoing init process so all attributes set during init process won't be set as
+ # arbitrary attribute in _AttrDict
+ # TODO: replace this hack
+ return self._init
+
+ def _set_base_path(self, base_path: Optional[Union[str, os.PathLike]]) -> None:
+ """Set the base path for the node.
+
+ Will be used for schema validation. If not set, will use Path.cwd() as the base path
+ (default logic defined in SchemaValidatableMixin._base_path_for_validation).
+
+ :param base_path: The new base path
+ :type base_path: Union[str, os.PathLike]
+ """
+ self._base_path = base_path
+
+ def _set_referenced_control_flow_node_instance_id(self, instance_id: str) -> None:
+ """Set the referenced control flow node instance id.
+
+ If this node is referenced to a control flow node, the instance_id will not be modified.
+
+ :param instance_id: The new instance id
+ :type instance_id: str
+ """
+ if not self._referenced_control_flow_node_instance_id:
+ self._referenced_control_flow_node_instance_id = instance_id
+
+ def _get_component_id(self) -> Union[str, Component]:
+ """Return component id if possible.
+
+ :return: The component id
+ :rtype: Union[str, Component]
+ """
+ if isinstance(self._component, Component) and self._component.id:
+ # If component is remote, return it's asset id
+ return self._component.id
+ # Otherwise, return the component version or arm id.
+ res: Union[str, Component] = self._component
+ return res
+
+ def _get_component_name(self) -> Optional[str]:
+ # first use component version/job's display name or name as component name
+ # make it unique when pipeline build finished.
+ if self._component is None:
+ return None
+ if isinstance(self._component, str):
+ return self._component
+ return str(self._component.name)
+
+ def _to_dict(self) -> Dict:
+ return dict(convert_ordered_dict_to_dict(self._dump_for_validation()))
+
+ @classmethod
+ def _create_validation_error(cls, message: str, no_personal_data_message: str) -> ValidationException:
+ return ValidationException(
+ message=message,
+ no_personal_data_message=no_personal_data_message,
+ target=ErrorTarget.PIPELINE,
+ )
+
+ def _validate_inputs(self) -> MutableValidationResult:
+ validation_result = self._create_empty_validation_result()
+ if self._validate_required_input_not_provided:
+ # validate required inputs not provided
+ if isinstance(self._component, Component):
+ for key, meta in self._component.inputs.items():
+ # raise error when required input with no default value not set
+ if (
+ not self._is_input_set(input_name=key) # input not provided
+ and meta.optional is not True # and it's required
+ and meta.default is None # and it does not have default
+ ):
+ validation_result.append_error(
+ yaml_path=f"inputs.{key}",
+ message=f"Required input {key!r} for component {self.name!r} not provided.",
+ )
+
+ inputs = self._build_inputs()
+ for input_name, input_obj in inputs.items():
+ if isinstance(input_obj, SweepDistribution):
+ validation_result.append_error(
+ yaml_path=f"inputs.{input_name}",
+ message=f"Input of command {self.name} is a SweepDistribution, "
+ f"please use command.sweep to transform the command into a sweep node.",
+ )
+ return validation_result
+
+ def _customized_validate(self) -> MutableValidationResult:
+ """Validate the resource with customized logic.
+
+ Override this method to add customized validation logic.
+
+ :return: The validation result
+ :rtype: MutableValidationResult
+ """
+ validate_result = self._validate_inputs()
+ return validate_result
+
+ @classmethod
+ def _get_skip_fields_in_schema_validation(cls) -> List[str]:
+ return [
+ "inputs", # processed separately
+ "outputs", # processed separately
+ "name",
+ "display_name",
+ "experiment_name", # name is not part of schema but may be set in dsl/yml file
+ "kwargs",
+ ]
+
+ @classmethod
+ def _get_component_attr_name(cls) -> str:
+ return "component"
+
+ @abstractmethod
+ def _to_job(self) -> Job:
+ """This private function is used by the CLI to get a plain job object
+ so that the CLI can properly serialize the object.
+
+ It is needed as BaseNode._to_dict() dumps objects using pipeline child job schema instead of standalone job
+ schema, for example Command objects dump have a nested component property, which doesn't apply to stand alone
+ command jobs. BaseNode._to_dict() needs to be able to dump to both pipeline child job dict as well as stand
+ alone job dict base on context.
+ """
+
+ @classmethod
+ def _from_rest_object(cls, obj: dict) -> "BaseNode":
+ if CommonYamlFields.TYPE not in obj:
+ obj[CommonYamlFields.TYPE] = NodeType.COMMAND
+
+ from azure.ai.ml.entities._job.pipeline._load_component import pipeline_node_factory
+
+ # todo: refine Hard code for now to support different task type for DataTransfer node
+ _type = obj[CommonYamlFields.TYPE]
+ if _type == NodeType.DATA_TRANSFER:
+ _type = "_".join([NodeType.DATA_TRANSFER, obj.get("task", "")])
+ instance: BaseNode = pipeline_node_factory.get_create_instance_func(_type)()
+ init_kwargs = instance._from_rest_object_to_init_params(obj)
+ # TODO: Bug Item number: 2883415
+ instance.__init__(**init_kwargs) # type: ignore
+ return instance
+
+ @classmethod
+ def _from_rest_object_to_init_params(cls, obj: dict) -> Dict:
+ """Convert the rest object to a dict containing items to init the node.
+
+ Will be used in _from_rest_object. Please override this method instead of _from_rest_object to make the logic
+ reusable.
+
+ :param obj: The REST object
+ :type obj: dict
+ :return: The init params
+ :rtype: Dict
+ """
+ inputs = obj.get("inputs", {})
+ outputs = obj.get("outputs", {})
+
+ obj["inputs"] = BaseNode._from_rest_inputs(inputs)
+ obj["outputs"] = BaseNode._from_rest_outputs(outputs)
+
+ # Change computeId -> compute
+ compute_id = obj.pop("computeId", None)
+ obj["compute"] = get_resource_name_from_arm_id_safe(compute_id)
+
+ # Change componentId -> component. Note that sweep node has no componentId.
+ if "componentId" in obj:
+ obj["component"] = obj.pop("componentId")
+
+ # distribution, sweep won't have distribution
+ if "distribution" in obj and obj["distribution"]:
+ from azure.ai.ml.entities._job.distribution import DistributionConfiguration
+
+ obj["distribution"] = DistributionConfiguration._from_rest_object(obj["distribution"])
+
+ return obj
+
+ @classmethod
+ def _picked_fields_from_dict_to_rest_object(cls) -> List[str]:
+ """List of fields to be picked from self._to_dict() in self._to_rest_object().
+
+ By default, returns an empty list.
+
+ Override this method to add custom fields.
+
+ :return: List of fields to pick
+ :rtype: List[str]
+ """
+
+ return []
+
+ def _to_rest_object(self, **kwargs: Any) -> dict: # pylint: disable=unused-argument
+ """Convert self to a rest object for remote call.
+
+ :return: The rest object
+ :rtype: dict
+ """
+ base_dict, rest_obj = self._to_dict(), {}
+ for key in self._picked_fields_from_dict_to_rest_object():
+ if key in base_dict:
+ rest_obj[key] = base_dict.get(key)
+
+ rest_obj.update(
+ dict( # pylint: disable=use-dict-literal
+ name=self.name,
+ type=self.type,
+ display_name=self.display_name,
+ tags=self.tags,
+ computeId=self.compute,
+ inputs=self._to_rest_inputs(),
+ outputs=self._to_rest_outputs(),
+ properties=self.properties,
+ _source=self._source,
+ # add all arbitrary attributes to support setting unknown attributes
+ **self._get_attrs(),
+ )
+ )
+ # only add comment in REST object when it is set
+ if self.comment is not None:
+ rest_obj.update({"comment": self.comment})
+
+ return dict(convert_ordered_dict_to_dict(rest_obj))
+
+ @property
+ def inputs(self) -> Dict:
+ """Get the inputs for the object.
+
+ :return: A dictionary containing the inputs for the object.
+ :rtype: Dict[str, Union[Input, str, bool, int, float]]
+ """
+ return self._inputs # type: ignore
+
+ @property
+ def outputs(self) -> Dict:
+ """Get the outputs of the object.
+
+ :return: A dictionary containing the outputs for the object.
+ :rtype: Dict[str, Union[str, Output]]
+ """
+ return self._outputs # type: ignore
+
+ def __str__(self) -> str:
+ try:
+ return str(self._to_yaml())
+ except BaseException: # pylint: disable=W0718
+ # add try catch in case component job failed in schema parse
+ _obj: _AttrDict = _AttrDict()
+ return _obj.__str__()
+
+ def __hash__(self) -> int: # type: ignore
+ return hash(self.__str__())
+
+ def __help__(self) -> Any:
+ # only show help when component has definition
+ if isinstance(self._component, Component):
+ # TODO: Bug Item number: 2883422
+ return self._component.__help__() # type: ignore
+ return None
+
+ def __bool__(self) -> bool:
+ # _attr_dict will return False if no extra attributes are set
+ return True
+
+ def _get_origin_job_outputs(self) -> Dict[str, Union[str, Output]]:
+ """Restore outputs to JobOutput/BindingString and return them.
+
+ :return: The origin job outputs
+ :rtype: Dict[str, Union[str, Output]]
+ """
+ outputs: Dict = {}
+ if self.outputs is not None:
+ for output_name, output_obj in self.outputs.items():
+ if isinstance(output_obj, NodeOutput):
+ outputs[output_name] = output_obj._data
+ else:
+ raise TypeError("unsupported built output type: {}: {}".format(output_name, type(output_obj)))
+ return outputs
+
+ def _get_telemetry_values(self) -> Dict:
+ telemetry_values = {"type": self.type, "source": self._source}
+ return telemetry_values
+
+ def _register_in_current_pipeline_component_builder(self) -> None:
+ """Register this node in current pipeline component builder by adding self to a global stack."""
+ from azure.ai.ml.dsl._pipeline_component_builder import _add_component_to_current_definition_builder
+
+ # TODO: would it be better if we make _add_component_to_current_definition_builder a public function of
+ # _PipelineComponentBuilderStack and make _PipelineComponentBuilderStack a singleton?
+ _add_component_to_current_definition_builder(self)
+
+ def _is_input_set(self, input_name: str) -> bool:
+ built_inputs = self._build_inputs()
+ return input_name in built_inputs and built_inputs[input_name] is not None
+
+ @classmethod
+ def _refine_optional_inputs_with_no_value(cls, node: "BaseNode", kwargs: Any) -> None:
+ """Refine optional inputs that have no default value and no value is provided when calling command/parallel
+ function.
+
+ This is to align with behavior of calling component to generate a pipeline node.
+
+ :param node: The node
+ :type node: BaseNode
+ :param kwargs: The kwargs
+ :type kwargs: dict
+ """
+ for key, value in node.inputs.items():
+ meta = value._data
+ if (
+ isinstance(meta, Input)
+ and meta._is_primitive_type is False
+ and meta.optional is True
+ and not meta.path
+ and key not in kwargs
+ ):
+ value._data = None
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/command.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/command.py
new file mode 100644
index 00000000..0073307c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/command.py
@@ -0,0 +1,1017 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+# pylint: disable=protected-access,too-many-lines
+import copy
+import logging
+import os
+from enum import Enum
+from os import PathLike
+from typing import Any, Dict, List, Optional, Tuple, Union, cast, overload
+
+from marshmallow import INCLUDE, Schema
+
+from azure.ai.ml._restclient.v2025_01_01_preview.models import CommandJob as RestCommandJob
+from azure.ai.ml._restclient.v2025_01_01_preview.models import JobBase
+from azure.ai.ml._schema.core.fields import ExperimentalField, NestedField, UnionField
+from azure.ai.ml._schema.job.command_job import CommandJobSchema
+from azure.ai.ml._schema.job.identity import AMLTokenIdentitySchema, ManagedIdentitySchema, UserIdentitySchema
+from azure.ai.ml._schema.job.services import JobServiceSchema
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, LOCAL_COMPUTE_PROPERTY, LOCAL_COMPUTE_TARGET
+from azure.ai.ml.constants._component import ComponentSource, NodeType
+from azure.ai.ml.entities._assets import Environment
+from azure.ai.ml.entities._component.command_component import CommandComponent
+from azure.ai.ml.entities._credentials import (
+ AmlTokenConfiguration,
+ ManagedIdentityConfiguration,
+ UserIdentityConfiguration,
+ _BaseJobIdentityConfiguration,
+)
+from azure.ai.ml.entities._inputs_outputs import Input, Output
+from azure.ai.ml.entities._job._input_output_helpers import from_rest_data_outputs, from_rest_inputs_to_dataset_literal
+from azure.ai.ml.entities._job.command_job import CommandJob
+from azure.ai.ml.entities._job.distribution import (
+ DistributionConfiguration,
+ MpiDistribution,
+ PyTorchDistribution,
+ RayDistribution,
+ TensorFlowDistribution,
+)
+from azure.ai.ml.entities._job.job_limits import CommandJobLimits
+from azure.ai.ml.entities._job.job_resource_configuration import JobResourceConfiguration
+from azure.ai.ml.entities._job.job_service import (
+ JobService,
+ JobServiceBase,
+ JupyterLabJobService,
+ SshJobService,
+ TensorBoardJobService,
+ VsCodeJobService,
+)
+from azure.ai.ml.entities._job.queue_settings import QueueSettings
+from azure.ai.ml.entities._job.sweep.early_termination_policy import EarlyTerminationPolicy
+from azure.ai.ml.entities._job.sweep.objective import Objective
+from azure.ai.ml.entities._job.sweep.search_space import (
+ Choice,
+ LogNormal,
+ LogUniform,
+ Normal,
+ QLogNormal,
+ QLogUniform,
+ QNormal,
+ QUniform,
+ Randint,
+ SweepDistribution,
+ Uniform,
+)
+from azure.ai.ml.entities._system_data import SystemData
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+
+from ..._schema import PathAwareSchema
+from ..._schema.job.distribution import (
+ MPIDistributionSchema,
+ PyTorchDistributionSchema,
+ RayDistributionSchema,
+ TensorFlowDistributionSchema,
+)
+from .._job.pipeline._io import NodeWithGroupInputMixin
+from .._util import (
+ convert_ordered_dict_to_dict,
+ from_rest_dict_to_dummy_rest_object,
+ get_rest_dict_for_node_attrs,
+ load_from_dict,
+ validate_attribute_type,
+)
+from .base_node import BaseNode
+from .sweep import Sweep
+
+module_logger = logging.getLogger(__name__)
+
+
+class Command(BaseNode, NodeWithGroupInputMixin):
+ """Base class for command node, used for command component version consumption.
+
+ You should not instantiate this class directly. Instead, you should create it using the builder function: command().
+
+ :keyword component: The ID or instance of the command component or job to be run for the step.
+ :paramtype component: Union[str, ~azure.ai.ml.entities.CommandComponent]
+ :keyword compute: The compute target the job will run on.
+ :paramtype compute: Optional[str]
+ :keyword inputs: A mapping of input names to input data sources used in the job.
+ :paramtype inputs: Optional[dict[str, Union[
+ ~azure.ai.ml.Input, str, bool, int, float, Enum]]]
+ :keyword outputs: A mapping of output names to output data sources used in the job.
+ :paramtype outputs: Optional[dict[str, Union[str, ~azure.ai.ml.Output]]]
+ :keyword limits: The limits for the command component or job.
+ :paramtype limits: ~azure.ai.ml.entities.CommandJobLimits
+ :keyword identity: The identity that the command job will use while running on compute.
+ :paramtype identity: Optional[Union[
+ dict[str, str],
+ ~azure.ai.ml.entities.ManagedIdentityConfiguration,
+ ~azure.ai.ml.entities.AmlTokenConfiguration,
+ ~azure.ai.ml.entities.UserIdentityConfiguration]]
+ :keyword distribution: The configuration for distributed jobs.
+ :paramtype distribution: Optional[Union[dict, ~azure.ai.ml.PyTorchDistribution, ~azure.ai.ml.MpiDistribution,
+ ~azure.ai.ml.TensorFlowDistribution, ~azure.ai.ml.RayDistribution]]
+ :keyword environment: The environment that the job will run in.
+ :paramtype environment: Optional[Union[str, ~azure.ai.ml.entities.Environment]]
+ :keyword environment_variables: A dictionary of environment variable names and values.
+ These environment variables are set on the process where the user script is being executed.
+ :paramtype environment_variables: Optional[dict[str, str]]
+ :keyword resources: The compute resource configuration for the command.
+ :paramtype resources: Optional[~azure.ai.ml.entities.JobResourceConfiguration]
+ :keyword services: The interactive services for the node. This is an experimental parameter, and may change at any
+ time. Please see https://aka.ms/azuremlexperimental for more information.
+ :paramtype services: Optional[dict[str, Union[~azure.ai.ml.entities.JobService,
+ ~azure.ai.ml.entities.JupyterLabJobService,
+ ~azure.ai.ml.entities.SshJobService, ~azure.ai.ml.entities.TensorBoardJobService,
+ ~azure.ai.ml.entities.VsCodeJobService]]]
+ :keyword queue_settings: Queue settings for the job.
+ :paramtype queue_settings: Optional[~azure.ai.ml.entities.QueueSettings]
+ :keyword parent_job_name: parent job id for command job
+ :paramtype parent_job_name: Optional[str]
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Command cannot be successfully validated.
+ Details will be provided in the error message.
+ """
+
+ # pylint: disable=too-many-instance-attributes
+ def __init__(
+ self,
+ *,
+ component: Union[str, CommandComponent],
+ compute: Optional[str] = None,
+ inputs: Optional[
+ Dict[
+ str,
+ Union[
+ Input,
+ str,
+ bool,
+ int,
+ float,
+ Enum,
+ ],
+ ]
+ ] = None,
+ outputs: Optional[Dict[str, Union[str, Output]]] = None,
+ limits: Optional[CommandJobLimits] = None,
+ identity: Optional[
+ Union[Dict, ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration]
+ ] = None,
+ distribution: Optional[
+ Union[
+ Dict,
+ MpiDistribution,
+ TensorFlowDistribution,
+ PyTorchDistribution,
+ RayDistribution,
+ DistributionConfiguration,
+ ]
+ ] = None,
+ environment: Optional[Union[Environment, str]] = None,
+ environment_variables: Optional[Dict] = None,
+ resources: Optional[JobResourceConfiguration] = None,
+ services: Optional[
+ Dict[str, Union[JobService, JupyterLabJobService, SshJobService, TensorBoardJobService, VsCodeJobService]]
+ ] = None,
+ queue_settings: Optional[QueueSettings] = None,
+ parent_job_name: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ # validate init params are valid type
+ validate_attribute_type(attrs_to_check=locals(), attr_type_map=self._attr_type_map())
+
+ # resolve normal dict to dict[str, JobService]
+ services = _resolve_job_services(services)
+ kwargs.pop("type", None)
+ self._parameters: dict = kwargs.pop("parameters", {})
+ BaseNode.__init__(
+ self,
+ type=NodeType.COMMAND,
+ inputs=inputs,
+ outputs=outputs,
+ component=component,
+ compute=compute,
+ services=services,
+ **kwargs,
+ )
+
+ # init mark for _AttrDict
+ self._init = True
+ # initialize command job properties
+ self.limits = limits
+ self.identity = identity
+ self._distribution = distribution
+ self.environment_variables = {} if environment_variables is None else environment_variables
+ self.environment: Any = environment
+ self._resources = resources
+ self._services = services
+ self.queue_settings = queue_settings
+ self.parent_job_name = parent_job_name
+
+ if isinstance(self.component, CommandComponent):
+ self.resources = self.resources or self.component.resources # type: ignore[assignment]
+ self.distribution = self.distribution or self.component.distribution
+
+ self._swept: bool = False
+ self._init = False
+
+ @classmethod
+ def _get_supported_inputs_types(cls) -> Tuple:
+ supported_types = super()._get_supported_inputs_types() or ()
+ return (
+ SweepDistribution,
+ *supported_types,
+ )
+
+ @classmethod
+ def _get_supported_outputs_types(cls) -> Tuple:
+ return str, Output
+
+ @property
+ def parameters(self) -> Dict[str, str]:
+ """MLFlow parameters to be logged during the job.
+
+ :return: The MLFlow parameters to be logged during the job.
+ :rtype: dict[str, str]
+ """
+ return self._parameters
+
+ @property
+ def distribution(
+ self,
+ ) -> Optional[
+ Union[
+ Dict,
+ MpiDistribution,
+ TensorFlowDistribution,
+ PyTorchDistribution,
+ RayDistribution,
+ DistributionConfiguration,
+ ]
+ ]:
+ """The configuration for the distributed command component or job.
+
+ :return: The configuration for distributed jobs.
+ :rtype: Union[~azure.ai.ml.PyTorchDistribution, ~azure.ai.ml.MpiDistribution,
+ ~azure.ai.ml.TensorFlowDistribution, ~azure.ai.ml.RayDistribution]
+ """
+ return self._distribution
+
+ @distribution.setter
+ def distribution(
+ self,
+ value: Union[Dict, PyTorchDistribution, TensorFlowDistribution, MpiDistribution, RayDistribution],
+ ) -> None:
+ """Sets the configuration for the distributed command component or job.
+
+ :param value: The configuration for distributed jobs.
+ :type value: Union[dict, ~azure.ai.ml.PyTorchDistribution, ~azure.ai.ml.MpiDistribution,
+ ~azure.ai.ml.TensorFlowDistribution, ~azure.ai.ml.RayDistribution]
+ """
+ if isinstance(value, dict):
+ dist_schema = UnionField(
+ [
+ NestedField(PyTorchDistributionSchema, unknown=INCLUDE),
+ NestedField(TensorFlowDistributionSchema, unknown=INCLUDE),
+ NestedField(MPIDistributionSchema, unknown=INCLUDE),
+ ExperimentalField(NestedField(RayDistributionSchema, unknown=INCLUDE)),
+ ]
+ )
+ value = dist_schema._deserialize(value=value, attr=None, data=None)
+ self._distribution = value
+
+ @property
+ def resources(self) -> JobResourceConfiguration:
+ """The compute resource configuration for the command component or job.
+
+ :rtype: ~azure.ai.ml.entities.JobResourceConfiguration
+ """
+ return cast(JobResourceConfiguration, self._resources)
+
+ @resources.setter
+ def resources(self, value: Union[Dict, JobResourceConfiguration]) -> None:
+ """Sets the compute resource configuration for the command component or job.
+
+ :param value: The compute resource configuration for the command component or job.
+ :type value: Union[dict, ~azure.ai.ml.entities.JobResourceConfiguration]
+ """
+ if isinstance(value, dict):
+ value = JobResourceConfiguration(**value)
+ self._resources = value
+
+ @property
+ def queue_settings(self) -> Optional[QueueSettings]:
+ """The queue settings for the command component or job.
+
+ :return: The queue settings for the command component or job.
+ :rtype: ~azure.ai.ml.entities.QueueSettings
+ """
+ return self._queue_settings
+
+ @queue_settings.setter
+ def queue_settings(self, value: Union[Dict, QueueSettings]) -> None:
+ """Sets the queue settings for the command component or job.
+
+ :param value: The queue settings for the command component or job.
+ :type value: Union[dict, ~azure.ai.ml.entities.QueueSettings]
+ """
+ if isinstance(value, dict):
+ value = QueueSettings(**value)
+ self._queue_settings = value
+
+ @property
+ def identity(
+ self,
+ ) -> Optional[Union[Dict, ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration]]:
+ """The identity that the job will use while running on compute.
+
+ :return: The identity that the job will use while running on compute.
+ :rtype: Optional[Union[~azure.ai.ml.ManagedIdentityConfiguration, ~azure.ai.ml.AmlTokenConfiguration,
+ ~azure.ai.ml.UserIdentityConfiguration]]
+ """
+ return self._identity
+
+ @identity.setter
+ def identity(
+ self,
+ value: Optional[Union[Dict, ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration]],
+ ) -> None:
+ """Sets the identity that the job will use while running on compute.
+
+ :param value: The identity that the job will use while running on compute.
+ :type value: Union[dict[str, str], ~azure.ai.ml.ManagedIdentityConfiguration,
+ ~azure.ai.ml.AmlTokenConfiguration, ~azure.ai.ml.UserIdentityConfiguration]
+ """
+ if isinstance(value, dict):
+ identity_schema = UnionField(
+ [
+ NestedField(ManagedIdentitySchema, unknown=INCLUDE),
+ NestedField(AMLTokenIdentitySchema, unknown=INCLUDE),
+ NestedField(UserIdentitySchema, unknown=INCLUDE),
+ ]
+ )
+ value = identity_schema._deserialize(value=value, attr=None, data=None)
+ self._identity = value
+
+ @property
+ def services(
+ self,
+ ) -> Optional[
+ Dict[str, Union[JobService, JupyterLabJobService, SshJobService, TensorBoardJobService, VsCodeJobService]]
+ ]:
+ """The interactive services for the node.
+
+ This is an experimental parameter, and may change at any time.
+ Please see https://aka.ms/azuremlexperimental for more information.
+
+ :rtype: dict[str, Union[~azure.ai.ml.entities.JobService, ~azure.ai.ml.entities.JupyterLabJobService,
+ ~azure.ai.ml.entities.SshJobService, ~azure.ai.ml.entities.TensorBoardJobService,
+ ~azure.ai.ml.entities.VsCodeJobService]]
+ """
+ return self._services
+
+ @services.setter
+ def services(
+ self,
+ value: Dict,
+ ) -> None:
+ """Sets the interactive services for the node.
+
+ This is an experimental parameter, and may change at any time.
+ Please see https://aka.ms/azuremlexperimental for more information.
+
+ :param value: The interactive services for the node.
+ :type value: dict[str, Union[~azure.ai.ml.entities.JobService, ~azure.ai.ml.entities.JupyterLabJobService,
+ ~azure.ai.ml.entities.SshJobService, ~azure.ai.ml.entities.TensorBoardJobService,
+ ~azure.ai.ml.entities.VsCodeJobService]]
+ """
+ self._services = _resolve_job_services(value) # type: ignore[assignment]
+
+ @property
+ def component(self) -> Union[str, CommandComponent]:
+ """The ID or instance of the command component or job to be run for the step.
+
+ :return: The ID or instance of the command component or job to be run for the step.
+ :rtype: Union[str, ~azure.ai.ml.entities.CommandComponent]
+ """
+ return self._component
+
+ @property
+ def command(self) -> Optional[str]:
+ """The command to be executed.
+
+ :rtype: Optional[str]
+ """
+ # the same as code
+ if not isinstance(self.component, CommandComponent):
+ return None
+
+ if self.component.command is None:
+ return None
+ return str(self.component.command)
+
+ @command.setter
+ def command(self, value: str) -> None:
+ """Sets the command to be executed.
+
+ :param value: The command to be executed.
+ :type value: str
+ """
+ if isinstance(self.component, CommandComponent):
+ self.component.command = value
+ else:
+ msg = "Can't set command property for a registered component {}. Tried to set it to {}."
+ raise ValidationException(
+ message=msg.format(self.component, value),
+ no_personal_data_message=msg,
+ target=ErrorTarget.COMMAND_JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+ @property
+ def code(self) -> Optional[Union[str, PathLike]]:
+ """The source code to run the job.
+
+ :rtype: Optional[Union[str, os.PathLike]]
+ """
+ # BaseNode is an _AttrDict to allow dynamic attributes, so that lower version of SDK can work with attributes
+ # added in higher version of SDK.
+ # self.code will be treated as an Arbitrary attribute if it raises AttributeError in getting
+ # (when self.component doesn't have attribute code, self.component = 'azureml:xxx:1' e.g.
+ # you may check _AttrDict._is_arbitrary_attr for detailed logic for Arbitrary judgement),
+ # then its value will be set to _AttrDict and be deserialized as {"shape": {}} instead of None,
+ # which is invalid in schema validation.
+ if not isinstance(self.component, CommandComponent):
+ return None
+
+ if self.component.code is None:
+ return None
+
+ return str(self.component.code)
+
+ @code.setter
+ def code(self, value: str) -> None:
+ """Sets the source code to run the job.
+
+ :param value: The source code to run the job. Can be a local path or "http:", "https:", or "azureml:" url
+ pointing to a remote location.
+ :type value: str
+ """
+ if isinstance(self.component, CommandComponent):
+ self.component.code = value
+ else:
+ msg = "Can't set code property for a registered component {}"
+ raise ValidationException(
+ message=msg.format(self.component),
+ no_personal_data_message=msg,
+ target=ErrorTarget.COMMAND_JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+ def set_resources(
+ self,
+ *,
+ instance_type: Optional[Union[str, List[str]]] = None,
+ instance_count: Optional[int] = None,
+ locations: Optional[List[str]] = None,
+ properties: Optional[Dict] = None,
+ docker_args: Optional[Union[str, List[str]]] = None,
+ shm_size: Optional[str] = None,
+ # pylint: disable=unused-argument
+ **kwargs: Any,
+ ) -> None:
+ """Set resources for Command.
+
+ :keyword instance_type: The type of compute instance to run the job on. If not specified, the job will run on
+ the default compute target.
+ :paramtype instance_type: Optional[Union[str, List[str]]]
+ :keyword instance_count: The number of instances to run the job on. If not specified, the job will run on a
+ single instance.
+ :paramtype instance_count: Optional[int]
+ :keyword locations: The list of locations where the job will run. If not specified, the job will run on the
+ default compute target.
+ :paramtype locations: Optional[List[str]]
+ :keyword properties: The properties of the job.
+ :paramtype properties: Optional[dict]
+ :keyword docker_args: The Docker arguments for the job.
+ :paramtype docker_args: Optional[Union[str,List[str]]]
+ :keyword shm_size: The size of the docker container's shared memory block. This should be in the
+ format of (number)(unit) where the number has to be greater than 0 and the unit can be one of
+ b(bytes), k(kilobytes), m(megabytes), or g(gigabytes).
+ :paramtype shm_size: Optional[str]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_command_configurations.py
+ :start-after: [START command_set_resources]
+ :end-before: [END command_set_resources]
+ :language: python
+ :dedent: 8
+ :caption: Setting resources on a Command.
+ """
+ if self.resources is None:
+ self.resources = JobResourceConfiguration()
+
+ if locations is not None:
+ self.resources.locations = locations
+ if instance_type is not None:
+ self.resources.instance_type = instance_type
+ if instance_count is not None:
+ self.resources.instance_count = instance_count
+ if properties is not None:
+ self.resources.properties = properties
+ if docker_args is not None:
+ self.resources.docker_args = docker_args
+ if shm_size is not None:
+ self.resources.shm_size = shm_size
+
+ # Save the resources to internal component as well, otherwise calling sweep() will loose the settings
+ if isinstance(self.component, CommandComponent):
+ self.component.resources = self.resources
+
+ def set_limits(self, *, timeout: int, **kwargs: Any) -> None: # pylint: disable=unused-argument
+ """Set limits for Command.
+
+ :keyword timeout: The timeout for the job in seconds.
+ :paramtype timeout: int
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_command_configurations.py
+ :start-after: [START command_set_limits]
+ :end-before: [END command_set_limits]
+ :language: python
+ :dedent: 8
+ :caption: Setting a timeout limit of 10 seconds on a Command.
+ """
+ if isinstance(self.limits, CommandJobLimits):
+ self.limits.timeout = timeout
+ else:
+ self.limits = CommandJobLimits(timeout=timeout)
+
+ def set_queue_settings(self, *, job_tier: Optional[str] = None, priority: Optional[str] = None) -> None:
+ """Set QueueSettings for the job.
+
+ :keyword job_tier: The job tier. Accepted values are "Spot", "Basic", "Standard", or "Premium".
+ :paramtype job_tier: Optional[str]
+ :keyword priority: The priority of the job on the compute. Defaults to "Medium".
+ :paramtype priority: Optional[str]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_command_configurations.py
+ :start-after: [START command_set_queue_settings]
+ :end-before: [END command_set_queue_settings]
+ :language: python
+ :dedent: 8
+ :caption: Configuring queue settings on a Command.
+ """
+ if isinstance(self.queue_settings, QueueSettings):
+ self.queue_settings.job_tier = job_tier
+ self.queue_settings.priority = priority
+ else:
+ self.queue_settings = QueueSettings(job_tier=job_tier, priority=priority)
+
+ def sweep(
+ self,
+ *,
+ primary_metric: str,
+ goal: str,
+ sampling_algorithm: str = "random",
+ compute: Optional[str] = None,
+ max_concurrent_trials: Optional[int] = None,
+ max_total_trials: Optional[int] = None,
+ timeout: Optional[int] = None,
+ trial_timeout: Optional[int] = None,
+ early_termination_policy: Optional[Union[EarlyTerminationPolicy, str]] = None,
+ search_space: Optional[
+ Dict[
+ str,
+ Union[
+ Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform
+ ],
+ ]
+ ] = None,
+ identity: Optional[
+ Union[ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration]
+ ] = None,
+ queue_settings: Optional[QueueSettings] = None,
+ job_tier: Optional[str] = None,
+ priority: Optional[str] = None,
+ ) -> Sweep:
+ """Turns the command into a sweep node with extra sweep run setting. The command component
+ in the current command node will be used as its trial component. A command node can sweep
+ multiple times, and the generated sweep node will share the same trial component.
+
+ :keyword primary_metric: The primary metric of the sweep objective - e.g. AUC (Area Under the Curve).
+ The metric must be logged while running the trial component.
+ :paramtype primary_metric: str
+ :keyword goal: The goal of the Sweep objective. Accepted values are "minimize" or "maximize".
+ :paramtype goal: str
+ :keyword sampling_algorithm: The sampling algorithm to use inside the search space.
+ Acceptable values are "random", "grid", or "bayesian". Defaults to "random".
+ :paramtype sampling_algorithm: str
+ :keyword compute: The target compute to run the node on. If not specified, the current node's compute
+ will be used.
+ :paramtype compute: Optional[str]
+ :keyword max_total_trials: The maximum number of total trials to run. This value will overwrite the value in
+ CommandJob.limits if specified.
+ :paramtype max_total_trials: Optional[int]
+ :keyword max_concurrent_trials: The maximum number of concurrent trials for the Sweep job.
+ :paramtype max_concurrent_trials: Optional[int]
+ :keyword timeout: The maximum run duration in seconds, after which the job will be cancelled.
+ :paramtype timeout: Optional[int]
+ :keyword trial_timeout: The Sweep Job trial timeout value, in seconds.
+ :paramtype trial_timeout: Optional[int]
+ :keyword early_termination_policy: The early termination policy of the sweep node. Acceptable
+ values are "bandit", "median_stopping", or "truncation_selection". Defaults to None.
+ :paramtype early_termination_policy: Optional[Union[~azure.ai.ml.sweep.BanditPolicy,
+ ~azure.ai.ml.sweep.TruncationSelectionPolicy, ~azure.ai.ml.sweep.MedianStoppingPolicy, str]]
+ :keyword identity: The identity that the job will use while running on compute.
+ :paramtype identity: Optional[Union[
+ ~azure.ai.ml.ManagedIdentityConfiguration,
+ ~azure.ai.ml.AmlTokenConfiguration,
+ ~azure.ai.ml.UserIdentityConfiguration]]
+ :keyword search_space: The search space to use for the sweep job.
+ :paramtype search_space: Optional[Dict[str, Union[
+ Choice,
+ LogNormal,
+ LogUniform,
+ Normal,
+ QLogNormal,
+ QLogUniform,
+ QNormal,
+ QUniform,
+ Randint,
+ Uniform
+
+ ]]]
+
+ :keyword queue_settings: The queue settings for the job.
+ :paramtype queue_settings: Optional[~azure.ai.ml.entities.QueueSettings]
+ :keyword job_tier: **Experimental** The job tier. Accepted values are "Spot", "Basic",
+ "Standard", or "Premium".
+ :paramtype job_tier: Optional[str]
+ :keyword priority: **Experimental** The compute priority. Accepted values are "low",
+ "medium", and "high".
+ :paramtype priority: Optional[str]
+ :return: A Sweep node with the component from current Command node as its trial component.
+ :rtype: ~azure.ai.ml.entities.Sweep
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_sweep_configurations.py
+ :start-after: [START configure_sweep_job_bandit_policy]
+ :end-before: [END configure_sweep_job_bandit_policy]
+ :language: python
+ :dedent: 8
+ :caption: Creating a Sweep node from a Command job.
+ """
+ self._swept = True
+ # inputs & outputs are already built in source Command obj
+ inputs, inputs_search_space = Sweep._get_origin_inputs_and_search_space(self.inputs)
+ if search_space:
+ inputs_search_space.update(search_space)
+
+ if not queue_settings:
+ queue_settings = self.queue_settings
+ if queue_settings is not None:
+ if job_tier is not None:
+ queue_settings.job_tier = job_tier
+ if priority is not None:
+ queue_settings.priority = priority
+
+ sweep_node = Sweep(
+ trial=copy.deepcopy(
+ self.component
+ ), # Make a copy of the underneath Component so that the original node can still be used.
+ compute=self.compute if compute is None else compute,
+ objective=Objective(goal=goal, primary_metric=primary_metric),
+ sampling_algorithm=sampling_algorithm,
+ inputs=inputs,
+ outputs=self._get_origin_job_outputs(),
+ search_space=inputs_search_space,
+ early_termination=early_termination_policy,
+ name=self.name,
+ description=self.description,
+ display_name=self.display_name,
+ tags=self.tags,
+ properties=self.properties,
+ experiment_name=self.experiment_name,
+ identity=self.identity if not identity else identity,
+ _from_component_func=True,
+ queue_settings=queue_settings,
+ )
+ sweep_node.set_limits(
+ max_total_trials=max_total_trials,
+ max_concurrent_trials=max_concurrent_trials,
+ timeout=timeout,
+ trial_timeout=trial_timeout,
+ )
+ return sweep_node
+
+ @classmethod
+ def _attr_type_map(cls) -> dict:
+ return {
+ "component": (str, CommandComponent),
+ "environment": (str, Environment),
+ "environment_variables": dict,
+ "resources": (dict, JobResourceConfiguration),
+ "limits": (dict, CommandJobLimits),
+ "code": (str, os.PathLike),
+ }
+
+ def _to_job(self) -> CommandJob:
+ if isinstance(self.component, CommandComponent):
+ return CommandJob(
+ id=self.id,
+ name=self.name,
+ display_name=self.display_name,
+ description=self.description,
+ tags=self.tags,
+ properties=self.properties,
+ command=self.component.command,
+ experiment_name=self.experiment_name,
+ code=self.component.code,
+ compute=self.compute,
+ status=self.status,
+ environment=self.environment,
+ distribution=self.distribution,
+ identity=self.identity,
+ environment_variables=self.environment_variables,
+ resources=self.resources,
+ limits=self.limits,
+ inputs=self._job_inputs,
+ outputs=self._job_outputs,
+ services=self.services,
+ creation_context=self.creation_context,
+ parameters=self.parameters,
+ queue_settings=self.queue_settings,
+ parent_job_name=self.parent_job_name,
+ )
+
+ return CommandJob(
+ id=self.id,
+ name=self.name,
+ display_name=self.display_name,
+ description=self.description,
+ tags=self.tags,
+ properties=self.properties,
+ command=None,
+ experiment_name=self.experiment_name,
+ code=None,
+ compute=self.compute,
+ status=self.status,
+ environment=self.environment,
+ distribution=self.distribution,
+ identity=self.identity,
+ environment_variables=self.environment_variables,
+ resources=self.resources,
+ limits=self.limits,
+ inputs=self._job_inputs,
+ outputs=self._job_outputs,
+ services=self.services,
+ creation_context=self.creation_context,
+ parameters=self.parameters,
+ queue_settings=self.queue_settings,
+ parent_job_name=self.parent_job_name,
+ )
+
+ @classmethod
+ def _picked_fields_from_dict_to_rest_object(cls) -> List[str]:
+ return ["resources", "distribution", "limits", "environment_variables", "queue_settings"]
+
+ def _to_rest_object(self, **kwargs: Any) -> dict:
+ rest_obj = super()._to_rest_object(**kwargs)
+ for key, value in {
+ "componentId": self._get_component_id(),
+ "distribution": get_rest_dict_for_node_attrs(self.distribution, clear_empty_value=True),
+ "limits": get_rest_dict_for_node_attrs(self.limits, clear_empty_value=True),
+ "resources": get_rest_dict_for_node_attrs(self.resources, clear_empty_value=True),
+ "services": get_rest_dict_for_node_attrs(self.services),
+ "identity": get_rest_dict_for_node_attrs(self.identity),
+ "queue_settings": get_rest_dict_for_node_attrs(self.queue_settings, clear_empty_value=True),
+ }.items():
+ if value is not None:
+ rest_obj[key] = value
+ return cast(dict, convert_ordered_dict_to_dict(rest_obj))
+
+ @classmethod
+ def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "Command":
+ from .command_func import command
+
+ loaded_data = load_from_dict(CommandJobSchema, data, context, additional_message, **kwargs)
+
+ # resources a limits properties are flatten in command() function, exact them and set separately
+ resources = loaded_data.pop("resources", None)
+ limits = loaded_data.pop("limits", None)
+
+ command_job: Command = command(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_data)
+
+ command_job.resources = resources
+ command_job.limits = limits
+ return command_job
+
+ @classmethod
+ def _from_rest_object_to_init_params(cls, obj: dict) -> Dict:
+ obj = BaseNode._from_rest_object_to_init_params(obj)
+
+ if "resources" in obj and obj["resources"]:
+ obj["resources"] = JobResourceConfiguration._from_rest_object(obj["resources"])
+
+ # services, sweep won't have services
+ if "services" in obj and obj["services"]:
+ # pipeline node rest object are dicts while _from_rest_job_services expect RestJobService
+ services = {}
+ for service_name, service in obj["services"].items():
+ # in rest object of a pipeline job, service will be transferred to a dict as
+ # it's attributes of a node, but JobService._from_rest_object expect a
+ # RestJobService, so we need to convert it back. Here we convert the dict to a
+ # dummy rest object which may work as a RestJobService instead.
+ services[service_name] = from_rest_dict_to_dummy_rest_object(service)
+ obj["services"] = JobServiceBase._from_rest_job_services(services)
+
+ # handle limits
+ if "limits" in obj and obj["limits"]:
+ obj["limits"] = CommandJobLimits._from_rest_object(obj["limits"])
+
+ if "identity" in obj and obj["identity"]:
+ obj["identity"] = _BaseJobIdentityConfiguration._from_rest_object(obj["identity"])
+
+ if "queue_settings" in obj and obj["queue_settings"]:
+ obj["queue_settings"] = QueueSettings._from_rest_object(obj["queue_settings"])
+
+ return obj
+
+ @classmethod
+ def _load_from_rest_job(cls, obj: JobBase) -> "Command":
+ from .command_func import command
+
+ rest_command_job: RestCommandJob = obj.properties
+
+ command_job: Command = command(
+ name=obj.name,
+ display_name=rest_command_job.display_name,
+ description=rest_command_job.description,
+ tags=rest_command_job.tags,
+ properties=rest_command_job.properties,
+ command=rest_command_job.command,
+ experiment_name=rest_command_job.experiment_name,
+ services=JobServiceBase._from_rest_job_services(rest_command_job.services),
+ status=rest_command_job.status,
+ creation_context=SystemData._from_rest_object(obj.system_data) if obj.system_data else None,
+ code=rest_command_job.code_id,
+ compute=rest_command_job.compute_id,
+ environment=rest_command_job.environment_id,
+ distribution=DistributionConfiguration._from_rest_object(rest_command_job.distribution),
+ parameters=rest_command_job.parameters,
+ identity=(
+ _BaseJobIdentityConfiguration._from_rest_object(rest_command_job.identity)
+ if rest_command_job.identity
+ else None
+ ),
+ environment_variables=rest_command_job.environment_variables,
+ inputs=from_rest_inputs_to_dataset_literal(rest_command_job.inputs),
+ outputs=from_rest_data_outputs(rest_command_job.outputs),
+ )
+ command_job._id = obj.id
+ command_job.resources = cast(
+ JobResourceConfiguration, JobResourceConfiguration._from_rest_object(rest_command_job.resources)
+ )
+ command_job.limits = CommandJobLimits._from_rest_object(rest_command_job.limits)
+ command_job.queue_settings = QueueSettings._from_rest_object(rest_command_job.queue_settings)
+ if isinstance(command_job.component, CommandComponent):
+ command_job.component._source = (
+ ComponentSource.REMOTE_WORKSPACE_JOB
+ ) # This is used by pipeline job telemetries.
+
+ # Handle special case of local job
+ if (
+ command_job.resources is not None
+ and command_job.resources.properties is not None
+ and command_job.resources.properties.get(LOCAL_COMPUTE_PROPERTY, None)
+ ):
+ command_job.compute = LOCAL_COMPUTE_TARGET
+ command_job.resources.properties.pop(LOCAL_COMPUTE_PROPERTY)
+ return command_job
+
+ def _build_inputs(self) -> Dict:
+ inputs = super(Command, self)._build_inputs()
+ built_inputs = {}
+ # Validate and remove non-specified inputs
+ for key, value in inputs.items():
+ if value is not None:
+ built_inputs[key] = value
+
+ return built_inputs
+
+ @classmethod
+ def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]:
+ from azure.ai.ml._schema.pipeline import CommandSchema
+
+ return CommandSchema(context=context)
+
+ # pylint: disable-next=docstring-missing-param
+ def __call__(self, *args: Any, **kwargs: Any) -> "Command":
+ """Call Command as a function will return a new instance each time.
+
+ :return: A Command node
+ :rtype: Command
+ """
+ if isinstance(self._component, CommandComponent):
+ # call this to validate inputs
+ node: Command = self._component(*args, **kwargs)
+ # merge inputs
+ for name, original_input in self.inputs.items():
+ if name not in kwargs:
+ # use setattr here to make sure owner of input won't change
+ setattr(node.inputs, name, original_input._data)
+ node._job_inputs[name] = original_input._data
+ # get outputs
+ for name, original_output in self.outputs.items():
+ # use setattr here to make sure owner of input won't change
+ if not isinstance(original_output, str):
+ setattr(node.outputs, name, original_output._data)
+ node._job_outputs[name] = original_output._data
+ self._refine_optional_inputs_with_no_value(node, kwargs)
+ # set default values: compute, environment_variables, outputs
+ # won't copy name to be able to distinguish if a node's name is assigned by user
+ # e.g. node_1 = command_func()
+ # In above example, node_1.name will be None so we can apply node_1 as it's name
+ node.compute = self.compute
+ node.tags = self.tags
+ # Pass through the display name only if the display name is not system generated.
+ node.display_name = self.display_name if self.display_name != self.name else None
+ node.environment = copy.deepcopy(self.environment)
+ # deep copy for complex object
+ node.environment_variables = copy.deepcopy(self.environment_variables)
+ node.limits = copy.deepcopy(self.limits)
+ node.distribution = copy.deepcopy(self.distribution)
+ node.resources = copy.deepcopy(self.resources)
+ node.queue_settings = copy.deepcopy(self.queue_settings)
+ node.services = copy.deepcopy(self.services)
+ node.identity = copy.deepcopy(self.identity)
+ return node
+ msg = "Command can be called as a function only when referenced component is {}, currently got {}."
+ raise ValidationException(
+ message=msg.format(type(CommandComponent), self._component),
+ no_personal_data_message=msg.format(type(CommandComponent), "self._component"),
+ target=ErrorTarget.COMMAND_JOB,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+
+@overload
+def _resolve_job_services(services: Optional[Dict]): ...
+
+
+@overload
+def _resolve_job_services(
+ services: Dict[str, Union[JobServiceBase, Dict]],
+) -> Dict[str, Union[JobService, JupyterLabJobService, SshJobService, TensorBoardJobService, VsCodeJobService]]: ...
+
+
+def _resolve_job_services(
+ services: Optional[Dict[str, Union[JobServiceBase, Dict]]],
+) -> Optional[Dict]:
+ """Resolve normal dict to dict[str, JobService]
+
+ :param services: A dict that maps service names to either a JobServiceBase object, or a Dict used to build one
+ :type services: Optional[Dict[str, Union[JobServiceBase, Dict]]]
+ :return:
+ * None if `services` is None
+ * A map of job service names to job services
+ :rtype: Optional[
+ Dict[str, Union[JobService, JupyterLabJobService, SshJobService, TensorBoardJobService, VsCodeJobService]]
+ ]
+ """
+ if services is None:
+ return None
+
+ if not isinstance(services, dict):
+ msg = f"Services must be a dict, got {type(services)} instead."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.COMMAND_JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ result = {}
+ for name, service in services.items():
+ if isinstance(service, dict):
+ service = load_from_dict(JobServiceSchema, service, context={BASE_PATH_CONTEXT_KEY: "."})
+ elif not isinstance(
+ service, (JobService, JupyterLabJobService, SshJobService, TensorBoardJobService, VsCodeJobService)
+ ):
+ msg = f"Service value for key {name!r} must be a dict or JobService object, got {type(service)} instead."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.COMMAND_JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ result[name] = service
+ return result
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/command_func.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/command_func.py
new file mode 100644
index 00000000..c542f880
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/command_func.py
@@ -0,0 +1,314 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import os
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+from azure.ai.ml.constants._common import AssetTypes, LegacyAssetTypes
+from azure.ai.ml.constants._component import ComponentSource
+from azure.ai.ml.entities._assets.environment import Environment
+from azure.ai.ml.entities._component.command_component import CommandComponent
+from azure.ai.ml.entities._credentials import (
+ AmlTokenConfiguration,
+ ManagedIdentityConfiguration,
+ UserIdentityConfiguration,
+)
+from azure.ai.ml.entities._inputs_outputs import Input, Output
+from azure.ai.ml.entities._job.distribution import (
+ DistributionConfiguration,
+ MpiDistribution,
+ PyTorchDistribution,
+ RayDistribution,
+ TensorFlowDistribution,
+)
+from azure.ai.ml.entities._job.job_service import (
+ JobService,
+ JupyterLabJobService,
+ SshJobService,
+ TensorBoardJobService,
+ VsCodeJobService,
+)
+from azure.ai.ml.entities._job.pipeline._component_translatable import ComponentTranslatableMixin
+from azure.ai.ml.entities._job.sweep.search_space import SweepDistribution
+from azure.ai.ml.exceptions import ErrorTarget, ValidationErrorType, ValidationException
+
+from .command import Command
+
+SUPPORTED_INPUTS = [
+ LegacyAssetTypes.PATH,
+ AssetTypes.URI_FILE,
+ AssetTypes.URI_FOLDER,
+ AssetTypes.CUSTOM_MODEL,
+ AssetTypes.MLFLOW_MODEL,
+ AssetTypes.MLTABLE,
+ AssetTypes.TRITON_MODEL,
+]
+
+
+def _parse_input(input_value: Union[Input, Dict, SweepDistribution, str, bool, int, float]) -> Tuple:
+ component_input = None
+ job_input: Optional[Union[Input, Dict, SweepDistribution, str, bool, int, float]] = None
+
+ if isinstance(input_value, Input):
+ component_input = Input(**input_value._to_dict())
+ input_type = input_value.type
+ if input_type in SUPPORTED_INPUTS:
+ job_input = Input(**input_value._to_dict())
+ elif isinstance(input_value, dict):
+ # if user provided dict, we try to parse it to Input.
+ # for job input, only parse for path type
+ input_type = input_value.get("type", None)
+ if input_type in SUPPORTED_INPUTS:
+ job_input = Input(**input_value)
+ component_input = Input(**input_value)
+ elif isinstance(input_value, (SweepDistribution, str, bool, int, float)):
+ # Input bindings are not supported
+ component_input = ComponentTranslatableMixin._to_input_builder_function(input_value)
+ job_input = input_value
+ else:
+ msg = f"Unsupported input type: {type(input_value)}"
+ msg += ", only Input, dict, str, bool, int and float are supported."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.JOB,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ return component_input, job_input
+
+
+def _parse_output(output_value: Optional[Union[Output, Dict, str]]) -> Tuple:
+ component_output = None
+ job_output: Optional[Union[Output, Dict, str]] = None
+
+ if isinstance(output_value, Output):
+ component_output = Output(**output_value._to_dict())
+ job_output = Output(**output_value._to_dict())
+ elif not output_value:
+ # output value can be None or empty dictionary
+ # None output value will be packed into a JobOutput object with mode = ReadWriteMount & type = UriFolder
+ component_output = ComponentTranslatableMixin._to_output(output_value)
+ job_output = output_value
+ elif isinstance(output_value, dict): # When output value is a non-empty dictionary
+ job_output = Output(**output_value)
+ component_output = Output(**output_value)
+ elif isinstance(output_value, str): # When output is passed in from pipeline job yaml
+ job_output = output_value
+ else:
+ msg = f"Unsupported output type: {type(output_value)}, only Output and dict are supported."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.JOB,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ return component_output, job_output
+
+
+def _parse_inputs_outputs(io_dict: Dict, parse_func: Callable) -> Tuple[Dict, Dict]:
+ component_io_dict, job_io_dict = {}, {}
+ if io_dict:
+ for key, val in io_dict.items():
+ component_io, job_io = parse_func(val)
+ component_io_dict[key] = component_io
+ job_io_dict[key] = job_io
+ return component_io_dict, job_io_dict
+
+
+def command(
+ *,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ properties: Optional[Dict] = None,
+ display_name: Optional[str] = None,
+ command: Optional[str] = None, # pylint: disable=redefined-outer-name
+ experiment_name: Optional[str] = None,
+ environment: Optional[Union[str, Environment]] = None,
+ environment_variables: Optional[Dict] = None,
+ distribution: Optional[
+ Union[
+ Dict,
+ MpiDistribution,
+ TensorFlowDistribution,
+ PyTorchDistribution,
+ RayDistribution,
+ DistributionConfiguration,
+ ]
+ ] = None,
+ compute: Optional[str] = None,
+ inputs: Optional[Dict] = None,
+ outputs: Optional[Dict] = None,
+ instance_count: Optional[int] = None,
+ instance_type: Optional[str] = None,
+ locations: Optional[List[str]] = None,
+ docker_args: Optional[Union[str, List[str]]] = None,
+ shm_size: Optional[str] = None,
+ timeout: Optional[int] = None,
+ code: Optional[Union[str, os.PathLike]] = None,
+ identity: Optional[Union[ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration]] = None,
+ is_deterministic: bool = True,
+ services: Optional[
+ Dict[str, Union[JobService, JupyterLabJobService, SshJobService, TensorBoardJobService, VsCodeJobService]]
+ ] = None,
+ job_tier: Optional[str] = None,
+ priority: Optional[str] = None,
+ parent_job_name: Optional[str] = None,
+ **kwargs: Any,
+) -> Command:
+ """Creates a Command object which can be used inside a dsl.pipeline function or used as a standalone Command job.
+
+ :keyword name: The name of the Command job or component.
+ :paramtype name: Optional[str]
+ :keyword description: The description of the Command. Defaults to None.
+ :paramtype description: Optional[str]
+ :keyword tags: Tag dictionary. Tags can be added, removed, and updated. Defaults to None.
+ :paramtype tags: Optional[dict[str, str]]
+ :keyword properties: The job property dictionary. Defaults to None.
+ :paramtype properties: Optional[dict[str, str]]
+ :keyword display_name: The display name of the job. Defaults to a randomly generated name.
+ :paramtype display_name: Optional[str]
+ :keyword command: The command to be executed. Defaults to None.
+ :paramtype command: Optional[str]
+ :keyword experiment_name: The name of the experiment that the job will be created under. Defaults to current
+ directory name.
+ :paramtype experiment_name: Optional[str]
+ :keyword environment: The environment that the job will run in.
+ :paramtype environment: Optional[Union[str, ~azure.ai.ml.entities.Environment]]
+ :keyword environment_variables: A dictionary of environment variable names and values.
+ These environment variables are set on the process where user script is being executed.
+ Defaults to None.
+ :paramtype environment_variables: Optional[dict[str, str]]
+ :keyword distribution: The configuration for distributed jobs. Defaults to None.
+ :paramtype distribution: Optional[Union[dict, ~azure.ai.ml.PyTorchDistribution, ~azure.ai.ml.MpiDistribution,
+ ~azure.ai.ml.TensorFlowDistribution, ~azure.ai.ml.RayDistribution]]
+ :keyword compute: The compute target the job will run on. Defaults to default compute.
+ :paramtype compute: Optional[str]
+ :keyword inputs: A mapping of input names to input data sources used in the job. Defaults to None.
+ :paramtype inputs: Optional[dict[str, Union[~azure.ai.ml.Input, str, bool, int, float, Enum]]]
+ :keyword outputs: A mapping of output names to output data sources used in the job. Defaults to None.
+ :paramtype outputs: Optional[dict[str, Union[str, ~azure.ai.ml.Output]]]
+ :keyword instance_count: The number of instances or nodes to be used by the compute target. Defaults to 1.
+ :paramtype instance_count: Optional[int]
+ :keyword instance_type: The type of VM to be used by the compute target.
+ :paramtype instance_type: Optional[str]
+ :keyword locations: The list of locations where the job will run.
+ :paramtype locations: Optional[List[str]]
+ :keyword docker_args: Extra arguments to pass to the Docker run command. This would override any
+ parameters that have already been set by the system, or in this section. This parameter is only
+ supported for Azure ML compute types. Defaults to None.
+ :paramtype docker_args: Optional[Union[str,List[str]]]
+ :keyword shm_size: The size of the Docker container's shared memory block. This should be in the
+ format of (number)(unit) where the number has to be greater than 0 and the unit can be one of
+ b(bytes), k(kilobytes), m(megabytes), or g(gigabytes).
+ :paramtype shm_size: Optional[str]
+ :keyword timeout: The number, in seconds, after which the job will be cancelled.
+ :paramtype timeout: Optional[int]
+ :keyword code: The source code to run the job. Can be a local path or "http:", "https:", or "azureml:" url
+ pointing to a remote location.
+ :paramtype code: Optional[Union[str, os.PathLike]]
+ :keyword identity: The identity that the command job will use while running on compute.
+ :paramtype identity: Optional[Union[
+ ~azure.ai.ml.entities.ManagedIdentityConfiguration,
+ ~azure.ai.ml.entities.AmlTokenConfiguration,
+ ~azure.ai.ml.entities.UserIdentityConfiguration]]
+ :keyword is_deterministic: Specifies whether the Command will return the same output given the same input.
+ Defaults to True. When True, if a Command Component is deterministic and has been run before in the
+ current workspace with the same input and settings, it will reuse results from a previously submitted
+ job when used as a node or step in a pipeline. In that scenario, no compute resources will be used.
+ :paramtype is_deterministic: bool
+ :keyword services: The interactive services for the node. Defaults to None. This is an experimental parameter,
+ and may change at any time. Please see https://aka.ms/azuremlexperimental for more information.
+ :paramtype services: Optional[dict[str, Union[~azure.ai.ml.entities.JobService,
+ ~azure.ai.ml.entities.JupyterLabJobService, ~azure.ai.ml.entities.SshJobService,
+ ~azure.ai.ml.entities.TensorBoardJobService, ~azure.ai.ml.entities.VsCodeJobService]]]
+ :keyword job_tier: The job tier. Accepted values are "Spot", "Basic", "Standard", or "Premium".
+ :paramtype job_tier: Optional[str]
+ :keyword priority: The priority of the job on the compute. Accepted values are "low", "medium", and "high".
+ Defaults to "medium".
+ :paramtype priority: Optional[str]
+ :keyword parent_job_name: parent job id for command job
+ :paramtype parent_job_name: Optional[str]
+ :return: A Command object.
+ :rtype: ~azure.ai.ml.entities.Command
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_command_configurations.py
+ :start-after: [START command_function]
+ :end-before: [END command_function]
+ :language: python
+ :dedent: 8
+ :caption: Creating a Command Job using the command() builder method.
+ """
+ # pylint: disable=too-many-locals
+ inputs = inputs or {}
+ outputs = outputs or {}
+ component_inputs, job_inputs = _parse_inputs_outputs(inputs, parse_func=_parse_input)
+ # job inputs can not be None
+ job_inputs = {k: v for k, v in job_inputs.items() if v is not None}
+ component_outputs, job_outputs = _parse_inputs_outputs(outputs, parse_func=_parse_output)
+
+ component = kwargs.pop("component", None)
+ if component is None:
+ component = CommandComponent(
+ name=name,
+ tags=tags,
+ code=code,
+ command=command,
+ environment=environment,
+ display_name=display_name,
+ description=description,
+ inputs=component_inputs,
+ outputs=component_outputs,
+ distribution=distribution,
+ environment_variables=environment_variables,
+ _source=ComponentSource.BUILDER,
+ is_deterministic=is_deterministic,
+ **kwargs,
+ )
+ command_obj = Command(
+ component=component,
+ name=name,
+ description=description,
+ tags=tags,
+ properties=properties,
+ display_name=display_name,
+ experiment_name=experiment_name,
+ compute=compute,
+ inputs=job_inputs,
+ outputs=job_outputs,
+ identity=identity,
+ distribution=distribution,
+ environment=environment,
+ environment_variables=environment_variables,
+ services=services,
+ parent_job_name=parent_job_name,
+ **kwargs,
+ )
+
+ if (
+ locations is not None
+ or instance_count is not None
+ or instance_type is not None
+ or docker_args is not None
+ or shm_size is not None
+ ):
+ command_obj.set_resources(
+ locations=locations,
+ instance_count=instance_count,
+ instance_type=instance_type,
+ docker_args=docker_args,
+ shm_size=shm_size,
+ )
+
+ if timeout is not None:
+ command_obj.set_limits(timeout=timeout)
+
+ if job_tier is not None or priority is not None:
+ command_obj.set_queue_settings(job_tier=job_tier, priority=priority)
+
+ return command_obj
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/condition_node.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/condition_node.py
new file mode 100644
index 00000000..5a5ad58b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/condition_node.py
@@ -0,0 +1,146 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Any, Dict, List, Optional
+
+from azure.ai.ml._schema import PathAwareSchema
+from azure.ai.ml._utils.utils import is_data_binding_expression
+from azure.ai.ml.constants._component import ControlFlowType
+from azure.ai.ml.entities._builders import BaseNode
+from azure.ai.ml.entities._builders.control_flow_node import ControlFlowNode
+from azure.ai.ml.entities._job.automl.automl_job import AutoMLJob
+from azure.ai.ml.entities._job.pipeline._io import InputOutputBase
+from azure.ai.ml.entities._validation import MutableValidationResult
+
+
+class ConditionNode(ControlFlowNode):
+ """Conditional node in the pipeline.
+
+ Please do not directly use this class.
+
+ :param condition: The condition for the conditional node.
+ :type condition: Any
+ :param true_block: The list of nodes to execute when the condition is true.
+ :type true_block: List[~azure.ai.ml.entities._builders.BaseNode]
+ :param false_block: The list of nodes to execute when the condition is false.
+ :type false_block: List[~azure.ai.ml.entities._builders.BaseNode]
+ """
+
+ def __init__(
+ self, condition: Any, *, true_block: Optional[List] = None, false_block: Optional[List] = None, **kwargs: Any
+ ) -> None:
+ kwargs.pop("type", None)
+ super(ConditionNode, self).__init__(type=ControlFlowType.IF_ELSE, **kwargs)
+ self.condition = condition
+ if true_block and not isinstance(true_block, list):
+ true_block = [true_block]
+ self._true_block = true_block
+ if false_block and not isinstance(false_block, list):
+ false_block = [false_block]
+ self._false_block = false_block
+
+ @classmethod
+ def _create_schema_for_validation(cls, context: Any) -> PathAwareSchema:
+ from azure.ai.ml._schema.pipeline.condition_node import ConditionNodeSchema
+
+ return ConditionNodeSchema(context=context)
+
+ @classmethod
+ def _from_rest_object(cls, obj: dict) -> "ConditionNode":
+ return cls(**obj)
+
+ @classmethod
+ def _create_instance_from_schema_dict(cls, loaded_data: Dict) -> "ConditionNode":
+ """Create a condition node instance from schema parsed dict.
+
+ :param loaded_data: The loaded data
+ :type loaded_data: Dict
+ :return: The ConditionNode node
+ :rtype: ConditionNode
+ """
+ return cls(**loaded_data)
+
+ @property
+ def true_block(self) -> Optional[List]:
+ """Get the list of nodes to execute when the condition is true.
+
+ :return: The list of nodes to execute when the condition is true.
+ :rtype: List[~azure.ai.ml.entities._builders.BaseNode]
+ """
+ return self._true_block
+
+ @property
+ def false_block(self) -> Optional[List]:
+ """Get the list of nodes to execute when the condition is false.
+
+ :return: The list of nodes to execute when the condition is false.
+ :rtype: List[~azure.ai.ml.entities._builders.BaseNode]
+ """
+ return self._false_block
+
+ def _customized_validate(self) -> MutableValidationResult:
+ return self._validate_params()
+
+ def _validate_params(self) -> MutableValidationResult:
+ # pylint disable=protected-access
+ validation_result = self._create_empty_validation_result()
+ if not isinstance(self.condition, (str, bool, InputOutputBase)):
+ validation_result.append_error(
+ yaml_path="condition",
+ message=f"'condition' of dsl.condition node must be an instance of "
+ f"{str}, {bool} or {InputOutputBase}, got {type(self.condition)}.",
+ )
+
+ # Check if output is a control output.
+ # pylint: disable=protected-access
+ if isinstance(self.condition, InputOutputBase) and self.condition._meta is not None:
+ # pylint: disable=protected-access
+ output_definition = self.condition._meta
+ if output_definition is not None and not output_definition._is_primitive_type:
+ validation_result.append_error(
+ yaml_path="condition",
+ message=f"'condition' of dsl.condition node must be primitive type "
+ f"with value 'True', got {output_definition._is_primitive_type}",
+ )
+
+ # check if condition is valid binding
+ if isinstance(self.condition, str) and not is_data_binding_expression(
+ self.condition, ["parent"], is_singular=False
+ ):
+ error_tail = "for example, ${{parent.jobs.xxx.outputs.output}}"
+ validation_result.append_error(
+ yaml_path="condition",
+ message=f"'condition' of dsl.condition has invalid binding expression: {self.condition}, {error_tail}",
+ )
+
+ error_msg = (
+ "{!r} of dsl.condition node must be an instance of " f"{BaseNode}, {AutoMLJob} or {str}," "got {!r}."
+ )
+ blocks = self.true_block if self.true_block else []
+ for block in blocks:
+ if block is not None and not isinstance(block, (BaseNode, AutoMLJob, str)):
+ validation_result.append_error(
+ yaml_path="true_block", message=error_msg.format("true_block", type(block))
+ )
+ blocks = self.false_block if self.false_block else []
+ for block in blocks:
+ if block is not None and not isinstance(block, (BaseNode, AutoMLJob, str)):
+ validation_result.append_error(
+ yaml_path="false_block", message=error_msg.format("false_block", type(block))
+ )
+
+ # check if true/false block is valid binding
+ for name, blocks in {"true_block": self.true_block, "false_block": self.false_block}.items(): # type: ignore
+ blocks = blocks if blocks else []
+ for block in blocks:
+ if block is None or not isinstance(block, str):
+ continue
+ error_tail = "for example, ${{parent.jobs.xxx}}"
+ if not is_data_binding_expression(block, ["parent", "jobs"], is_singular=False):
+ validation_result.append_error(
+ yaml_path=name,
+ message=f"'{name}' of dsl.condition has invalid binding expression: {block}, {error_tail}",
+ )
+
+ return validation_result
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/control_flow_node.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/control_flow_node.py
new file mode 100644
index 00000000..c757a1e4
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/control_flow_node.py
@@ -0,0 +1,170 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+import logging
+import re
+import uuid
+from abc import ABC
+from typing import Any, Dict, Union, cast # pylint: disable=unused-import
+
+from marshmallow import ValidationError
+
+from azure.ai.ml._utils.utils import is_data_binding_expression
+from azure.ai.ml.constants._common import CommonYamlFields
+from azure.ai.ml.constants._component import ComponentSource, ControlFlowType
+from azure.ai.ml.entities._mixins import YamlTranslatableMixin
+from azure.ai.ml.entities._validation import MutableValidationResult, PathAwareSchemaValidatableMixin
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+
+from .._util import convert_ordered_dict_to_dict
+from .base_node import BaseNode
+
+module_logger = logging.getLogger(__name__)
+
+
+# ControlFlowNode did not inherit from BaseNode since it doesn't have inputs/outputs like other nodes.
+class ControlFlowNode(YamlTranslatableMixin, PathAwareSchemaValidatableMixin, ABC):
+ """Base class for control flow node in the pipeline.
+
+ Please do not directly use this class.
+
+ :param kwargs: Additional keyword arguments.
+ :type kwargs: Dict[str, Union[Any]]
+ """
+
+ def __init__(self, **kwargs: Any) -> None:
+ # TODO(1979547): refactor this
+ _source = kwargs.pop("_source", None)
+ self._source = _source if _source else ComponentSource.DSL
+ _from_component_func = kwargs.pop("_from_component_func", False)
+ self._type = kwargs.get("type", None)
+ self._instance_id = str(uuid.uuid4())
+ self.name = None
+ if _from_component_func:
+ # add current control flow node in pipeline stack for dsl scenario and remove the body from the pipeline
+ # stack.
+ self._register_in_current_pipeline_component_builder()
+
+ @property
+ def type(self) -> Any:
+ """Get the type of the control flow node.
+
+ :return: The type of the control flow node.
+ :rtype: self._type
+ """
+ return self._type
+
+ def _to_dict(self) -> Dict:
+ return dict(self._dump_for_validation())
+
+ def _to_rest_object(self, **kwargs: Any) -> dict: # pylint: disable=unused-argument
+ """Convert self to a rest object for remote call.
+
+ :return: The rest object
+ :rtype: dict
+ """
+ rest_obj = self._to_dict()
+ rest_obj["_source"] = self._source
+ return cast(dict, convert_ordered_dict_to_dict(rest_obj))
+
+ def _register_in_current_pipeline_component_builder(self) -> None:
+ """Register this node in current pipeline component builder by adding self to a global stack."""
+ from azure.ai.ml.dsl._pipeline_component_builder import _add_component_to_current_definition_builder
+
+ _add_component_to_current_definition_builder(self) # type: ignore[arg-type]
+
+ @classmethod
+ def _create_validation_error(cls, message: str, no_personal_data_message: str) -> ValidationException:
+ return ValidationException(
+ message=message,
+ no_personal_data_message=no_personal_data_message,
+ target=ErrorTarget.PIPELINE,
+ )
+
+
+class LoopNode(ControlFlowNode, ABC):
+ """Base class for loop node in the pipeline.
+
+ Please do not directly use this class.
+
+ :param body: The body of the loop node.
+ :type body: ~azure.ai.ml.entities._builders.BaseNode
+ :param kwargs: Additional keyword arguments.
+ :type kwargs: Dict[str, Union[Any]]
+ """
+
+ def __init__(self, *, body: BaseNode, **kwargs: Any) -> None:
+ self._body = body
+ super(LoopNode, self).__init__(**kwargs)
+ # always set the referenced control flow node instance id to the body.
+ self.body._set_referenced_control_flow_node_instance_id(self._instance_id)
+
+ @property
+ def body(self) -> BaseNode:
+ """Get the body of the loop node.
+
+ :return: The body of the loop node.
+ :rtype: ~azure.ai.ml.entities._builders.BaseNode
+ """
+ return self._body
+
+ _extra_body_types = None
+
+ @classmethod
+ def _attr_type_map(cls) -> dict:
+ from .command import Command
+ from .pipeline import Pipeline
+
+ enable_body_type = (Command, Pipeline)
+ if cls._extra_body_types is not None:
+ enable_body_type = enable_body_type + cls._extra_body_types
+ return {
+ "body": enable_body_type,
+ }
+
+ @classmethod
+ def _get_body_from_pipeline_jobs(cls, pipeline_jobs: Dict[str, BaseNode], body_name: str) -> BaseNode:
+ # Get body object from pipeline job list.
+ if body_name not in pipeline_jobs:
+ raise ValidationError(
+ message=f'Cannot find the do-while loop body "{body_name}" in the pipeline.',
+ target=ErrorTarget.PIPELINE,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ return pipeline_jobs[body_name]
+
+ def _validate_body(self) -> MutableValidationResult:
+ # pylint: disable=protected-access
+ validation_result = self._create_empty_validation_result()
+
+ if self._instance_id != self.body._referenced_control_flow_node_instance_id:
+ # When the body is used in another loop node record the error message in validation result.
+ validation_result.append_error("body", "The body of loop node cannot be promoted as another loop again.")
+ return validation_result
+
+ def _get_body_binding_str(self) -> str:
+ return "${{parent.jobs.%s}}" % self.body.name
+
+ @staticmethod
+ def _get_data_binding_expression_value(expression: str, regex: str) -> str:
+ try:
+ if is_data_binding_expression(expression):
+ return str(re.findall(regex, expression)[0])
+
+ return expression
+ except Exception: # pylint: disable=W0718
+ module_logger.warning("Cannot get the value from data binding expression %s.", expression)
+ return expression
+
+ @staticmethod
+ def _is_loop_node_dict(obj: Any) -> bool:
+ return obj.get(CommonYamlFields.TYPE, None) in [ControlFlowType.DO_WHILE, ControlFlowType.PARALLEL_FOR]
+
+ @classmethod
+ def _from_rest_object(cls, obj: dict, pipeline_jobs: dict) -> "LoopNode":
+ from azure.ai.ml.entities._job.pipeline._load_component import pipeline_node_factory
+
+ node_type = obj.get(CommonYamlFields.TYPE, None)
+ load_from_rest_obj_func = pipeline_node_factory.get_load_from_rest_object_func(_type=node_type)
+ return load_from_rest_obj_func(obj, pipeline_jobs) # type: ignore
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/data_transfer.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/data_transfer.py
new file mode 100644
index 00000000..83e88a48
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/data_transfer.py
@@ -0,0 +1,575 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import logging
+from typing import Any, Dict, List, Optional, Tuple, Union, cast
+
+from marshmallow import Schema
+
+from azure.ai.ml._restclient.v2022_10_01_preview.models import JobBase
+from azure.ai.ml._schema.job.data_transfer_job import (
+ DataTransferCopyJobSchema,
+ DataTransferExportJobSchema,
+ DataTransferImportJobSchema,
+)
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, AssetTypes
+from azure.ai.ml.constants._component import DataTransferTaskType, ExternalDataType, NodeType
+from azure.ai.ml.entities._component.component import Component
+from azure.ai.ml.entities._component.datatransfer_component import (
+ DataTransferComponent,
+ DataTransferCopyComponent,
+ DataTransferExportComponent,
+ DataTransferImportComponent,
+)
+from azure.ai.ml.entities._inputs_outputs import Input, Output
+from azure.ai.ml.entities._inputs_outputs.external_data import Database, FileSystem
+from azure.ai.ml.entities._job.data_transfer.data_transfer_job import (
+ DataTransferCopyJob,
+ DataTransferExportJob,
+ DataTransferImportJob,
+)
+from azure.ai.ml.entities._validation.core import MutableValidationResult
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+
+from ..._schema import PathAwareSchema
+from .._job.pipeline._io import NodeOutput
+from .._util import convert_ordered_dict_to_dict, load_from_dict, validate_attribute_type
+from .base_node import BaseNode
+
+module_logger = logging.getLogger(__name__)
+
+
+def _build_source_sink(io_dict: Optional[Union[Dict, Database, FileSystem]]) -> Optional[Union[Database, FileSystem]]:
+ if io_dict is None:
+ return io_dict
+ if isinstance(io_dict, (Database, FileSystem)):
+ component_io = io_dict
+ else:
+ if isinstance(io_dict, dict):
+ data_type = io_dict.pop("type", None)
+ if data_type == ExternalDataType.DATABASE:
+ component_io = Database(**io_dict)
+ elif data_type == ExternalDataType.FILE_SYSTEM:
+ component_io = FileSystem(**io_dict)
+ else:
+ msg = "Type in source or sink only support {} and {}, currently got {}."
+ raise ValidationException(
+ message=msg.format(
+ ExternalDataType.DATABASE,
+ ExternalDataType.FILE_SYSTEM,
+ data_type,
+ ),
+ no_personal_data_message=msg.format(
+ ExternalDataType.DATABASE,
+ ExternalDataType.FILE_SYSTEM,
+ "data_type",
+ ),
+ target=ErrorTarget.DATA_TRANSFER_JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ else:
+ msg = "Source or sink only support dict, Database and FileSystem"
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.DATA_TRANSFER_JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+ return component_io
+
+
+class DataTransfer(BaseNode):
+ """Base class for data transfer node, used for data transfer component version consumption.
+
+ You should not instantiate this class directly.
+ """
+
+ def __init__(
+ self,
+ *,
+ component: Union[str, DataTransferCopyComponent, DataTransferImportComponent],
+ compute: Optional[str] = None,
+ inputs: Optional[Dict[str, Union[NodeOutput, Input, str]]] = None,
+ outputs: Optional[Dict[str, Union[str, Output]]] = None,
+ **kwargs: Any,
+ ):
+ # resolve normal dict to dict[str, JobService]
+ kwargs.pop("type", None)
+ super().__init__(
+ type=NodeType.DATA_TRANSFER,
+ inputs=inputs,
+ outputs=outputs,
+ component=component,
+ compute=compute,
+ **kwargs,
+ )
+
+ @property
+ def component(self) -> Union[str, DataTransferComponent]:
+ res: Union[str, DataTransferComponent] = self._component
+ return res
+
+ @classmethod
+ def _load_from_rest_job(cls, obj: JobBase) -> "DataTransfer":
+ # Todo: need update rest api
+ raise NotImplementedError("Not support submit standalone job for now")
+
+ @classmethod
+ def _get_supported_outputs_types(cls) -> Tuple:
+ return str, Output
+
+ def _build_inputs(self) -> Dict:
+ inputs = super(DataTransfer, self)._build_inputs()
+ built_inputs = {}
+ # Validate and remove non-specified inputs
+ for key, value in inputs.items():
+ if value is not None:
+ built_inputs[key] = value
+
+ return built_inputs
+
+
+@experimental
+class DataTransferCopy(DataTransfer):
+ """Base class for data transfer copy node.
+
+ You should not instantiate this class directly. Instead, you should
+ create from builder function: copy_data.
+
+ :param component: Id or instance of the data transfer component/job to be run for the step
+ :type component: DataTransferCopyComponent
+ :param inputs: Inputs to the data transfer.
+ :type inputs: Dict[str, Union[NodeOutput, Input, str]]
+ :param outputs: Mapping of output data bindings used in the job.
+ :type outputs: Dict[str, Union[str, Output, dict]]
+ :param name: Name of the data transfer.
+ :type name: str
+ :param description: Description of the data transfer.
+ :type description: str
+ :param tags: Tag dictionary. Tags can be added, removed, and updated.
+ :type tags: dict[str, str]
+ :param display_name: Display name of the job.
+ :type display_name: str
+ :param experiment_name: Name of the experiment the job will be created under,
+ if None is provided, default will be set to current directory name.
+ :type experiment_name: str
+ :param compute: The compute target the job runs on.
+ :type compute: str
+ :param data_copy_mode: data copy mode in copy task, possible value is "merge_with_overwrite", "fail_if_conflict".
+ :type data_copy_mode: str
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if DataTransferCopy cannot be successfully validated.
+ Details will be provided in the error message.
+ """
+
+ def __init__(
+ self,
+ *,
+ component: Union[str, DataTransferCopyComponent],
+ compute: Optional[str] = None,
+ inputs: Optional[Dict[str, Union[NodeOutput, Input, str]]] = None,
+ outputs: Optional[Dict[str, Union[str, Output]]] = None,
+ data_copy_mode: Optional[str] = None,
+ **kwargs: Any,
+ ):
+ # validate init params are valid type
+ validate_attribute_type(attrs_to_check=locals(), attr_type_map=self._attr_type_map())
+ super().__init__(
+ inputs=inputs,
+ outputs=outputs,
+ component=component,
+ compute=compute,
+ **kwargs,
+ )
+ # init mark for _AttrDict
+ self._init = True
+ self.task = DataTransferTaskType.COPY_DATA
+ self.data_copy_mode = data_copy_mode
+ is_component = isinstance(component, DataTransferCopyComponent)
+ if is_component:
+ _component: DataTransferCopyComponent = cast(DataTransferCopyComponent, component)
+ self.task = _component.task or self.task
+ self.data_copy_mode = _component.data_copy_mode or self.data_copy_mode
+ self._init = False
+
+ @classmethod
+ def _attr_type_map(cls) -> dict:
+ return {
+ "component": (str, DataTransferCopyComponent),
+ }
+
+ @classmethod
+ def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]:
+ from azure.ai.ml._schema.pipeline import DataTransferCopySchema
+
+ return DataTransferCopySchema(context=context)
+
+ @classmethod
+ def _picked_fields_from_dict_to_rest_object(cls) -> List[str]:
+ return ["type", "task", "data_copy_mode"]
+
+ def _to_rest_object(self, **kwargs: Any) -> dict:
+ rest_obj = super()._to_rest_object(**kwargs)
+ for key, value in {
+ "componentId": self._get_component_id(),
+ "data_copy_mode": self.data_copy_mode,
+ }.items():
+ if value is not None:
+ rest_obj[key] = value
+ return cast(dict, convert_ordered_dict_to_dict(rest_obj))
+
+ @classmethod
+ def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> Any:
+ from .data_transfer_func import copy_data
+
+ loaded_data = load_from_dict(DataTransferCopyJobSchema, data, context, additional_message, **kwargs)
+ data_transfer_job = copy_data(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_data)
+
+ return data_transfer_job
+
+ def _to_job(self) -> DataTransferCopyJob:
+ return DataTransferCopyJob(
+ experiment_name=self.experiment_name,
+ name=self.name,
+ display_name=self.display_name,
+ description=self.description,
+ tags=self.tags,
+ status=self.status,
+ inputs=self._job_inputs,
+ outputs=self._job_outputs,
+ services=self.services,
+ compute=self.compute,
+ data_copy_mode=self.data_copy_mode,
+ )
+
+ # pylint: disable-next=docstring-missing-param
+ def __call__(self, *args: Any, **kwargs: Any) -> "DataTransferCopy":
+ """Call DataTransferCopy as a function will return a new instance each time.
+
+ :return: A DataTransferCopy node
+ :rtype: DataTransferCopy
+ """
+ if isinstance(self._component, Component):
+ # call this to validate inputs
+ node: DataTransferCopy = self._component(*args, **kwargs)
+ # merge inputs
+ for name, original_input in self.inputs.items():
+ if name not in kwargs:
+ # use setattr here to make sure owner of input won't change
+ setattr(node.inputs, name, original_input._data)
+ node._job_inputs[name] = original_input._data
+ # get outputs
+ for name, original_output in self.outputs.items():
+ # use setattr here to make sure owner of input won't change
+ if not isinstance(original_output, str):
+ setattr(node.outputs, name, original_output._data)
+ self._refine_optional_inputs_with_no_value(node, kwargs)
+ # set default values: compute, environment_variables, outputs
+ node._name = self.name
+ node.compute = self.compute
+ node.tags = self.tags
+ # Pass through the display name only if the display name is not system generated.
+ node.display_name = self.display_name if self.display_name != self.name else None
+ return node
+ msg = "copy_data can be called as a function only when referenced component is {}, currently got {}."
+ raise ValidationException(
+ message=msg.format(type(Component), self._component),
+ no_personal_data_message=msg.format(type(Component), "self._component"),
+ target=ErrorTarget.DATA_TRANSFER_JOB,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+
+@experimental
+class DataTransferImport(DataTransfer):
+ """Base class for data transfer import node.
+
+ You should not instantiate this class directly. Instead, you should
+ create from builder function: import_data.
+
+ :param component: Id of the data transfer built in component to be run for the step
+ :type component: str
+ :param source: The data source of file system or database
+ :type source: Union[Dict, Database, FileSystem]
+ :param outputs: Mapping of output data bindings used in the job.
+ :type outputs: Dict[str, Union[str, Output, dict]]
+ :param name: Name of the data transfer.
+ :type name: str
+ :param description: Description of the data transfer.
+ :type description: str
+ :param tags: Tag dictionary. Tags can be added, removed, and updated.
+ :type tags: dict[str, str]
+ :param display_name: Display name of the job.
+ :type display_name: str
+ :param experiment_name: Name of the experiment the job will be created under,
+ if None is provided, default will be set to current directory name.
+ :type experiment_name: str
+ :param compute: The compute target the job runs on.
+ :type compute: str
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if DataTransferImport cannot be successfully validated.
+ Details will be provided in the error message.
+ """
+
+ def __init__(
+ self,
+ *,
+ component: Union[str, DataTransferImportComponent],
+ compute: Optional[str] = None,
+ source: Optional[Union[Dict, Database, FileSystem]] = None,
+ outputs: Optional[Dict[str, Union[str, Output]]] = None,
+ **kwargs: Any,
+ ):
+ # validate init params are valid type
+ validate_attribute_type(attrs_to_check=locals(), attr_type_map=self._attr_type_map())
+ super(DataTransferImport, self).__init__(
+ component=component,
+ outputs=outputs,
+ compute=compute,
+ **kwargs,
+ )
+ # init mark for _AttrDict
+ self._init = True
+ self.task = DataTransferTaskType.IMPORT_DATA
+ is_component = isinstance(component, DataTransferImportComponent)
+ if is_component:
+ _component: DataTransferImportComponent = cast(DataTransferImportComponent, component)
+ self.task = _component.task or self.task
+ self.source = _build_source_sink(source)
+ self._init = False
+
+ @classmethod
+ def _attr_type_map(cls) -> dict:
+ return {
+ "component": (str, DataTransferImportComponent),
+ }
+
+ @classmethod
+ def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]:
+ from azure.ai.ml._schema.pipeline import DataTransferImportSchema
+
+ return DataTransferImportSchema(context=context)
+
+ @classmethod
+ def _picked_fields_from_dict_to_rest_object(cls) -> List[str]:
+ return ["type", "task", "source"]
+
+ def _customized_validate(self) -> MutableValidationResult:
+ result = super()._customized_validate()
+ if self.source is None:
+ result.append_error(
+ yaml_path="source",
+ message="Source is a required field for import data task in DataTransfer job",
+ )
+ if len(self.outputs) != 1 or list(self.outputs.keys())[0] != "sink":
+ result.append_error(
+ yaml_path="outputs.sink",
+ message="Outputs field only support one output called sink in import task",
+ )
+ if (
+ "sink" in self.outputs
+ and not isinstance(self.outputs["sink"], str)
+ and isinstance(self.outputs["sink"]._data, Output)
+ ):
+ sink_output = self.outputs["sink"]._data
+ if self.source is not None:
+
+ if (self.source.type == ExternalDataType.DATABASE and sink_output.type != AssetTypes.MLTABLE) or (
+ self.source.type == ExternalDataType.FILE_SYSTEM and sink_output.type != AssetTypes.URI_FOLDER
+ ):
+ result.append_error(
+ yaml_path="outputs.sink.type",
+ message="Outputs field only support type {} for {} and {} for {}".format(
+ AssetTypes.MLTABLE,
+ ExternalDataType.DATABASE,
+ AssetTypes.URI_FOLDER,
+ ExternalDataType.FILE_SYSTEM,
+ ),
+ )
+ return result
+
+ def _to_rest_object(self, **kwargs: Any) -> dict:
+ rest_obj = super()._to_rest_object(**kwargs)
+ for key, value in {
+ "componentId": self._get_component_id(),
+ }.items():
+ if value is not None:
+ rest_obj[key] = value
+ return cast(dict, convert_ordered_dict_to_dict(rest_obj))
+
+ @classmethod
+ def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "DataTransferImport":
+ from .data_transfer_func import import_data
+
+ loaded_data = load_from_dict(DataTransferImportJobSchema, data, context, additional_message, **kwargs)
+ data_transfer_job: DataTransferImport = import_data(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_data)
+
+ return data_transfer_job
+
+ def _to_job(self) -> DataTransferImportJob:
+ return DataTransferImportJob(
+ experiment_name=self.experiment_name,
+ name=self.name,
+ display_name=self.display_name,
+ description=self.description,
+ tags=self.tags,
+ status=self.status,
+ source=self.source,
+ outputs=self._job_outputs,
+ services=self.services,
+ compute=self.compute,
+ )
+
+
+@experimental
+class DataTransferExport(DataTransfer):
+ """Base class for data transfer export node.
+
+ You should not instantiate this class directly. Instead, you should
+ create from builder function: export_data.
+
+ :param component: Id of the data transfer built in component to be run for the step
+ :type component: str
+ :param sink: The sink of external data and databases.
+ :type sink: Union[Dict, Database, FileSystem]
+ :param inputs: Mapping of input data bindings used in the job.
+ :type inputs: Dict[str, Union[NodeOutput, Input, str, Input]]
+ :param name: Name of the data transfer.
+ :type name: str
+ :param description: Description of the data transfer.
+ :type description: str
+ :param tags: Tag dictionary. Tags can be added, removed, and updated.
+ :type tags: dict[str, str]
+ :param display_name: Display name of the job.
+ :type display_name: str
+ :param experiment_name: Name of the experiment the job will be created under,
+ if None is provided, default will be set to current directory name.
+ :type experiment_name: str
+ :param compute: The compute target the job runs on.
+ :type compute: str
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if DataTransferExport cannot be successfully validated.
+ Details will be provided in the error message.
+ """
+
+ def __init__(
+ self,
+ *,
+ component: Union[str, DataTransferCopyComponent, DataTransferImportComponent],
+ compute: Optional[str] = None,
+ sink: Optional[Union[Dict, Database, FileSystem]] = None,
+ inputs: Optional[Dict[str, Union[NodeOutput, Input, str]]] = None,
+ **kwargs: Any,
+ ):
+ # validate init params are valid type
+ validate_attribute_type(attrs_to_check=locals(), attr_type_map=self._attr_type_map())
+ super(DataTransferExport, self).__init__(
+ component=component,
+ inputs=inputs,
+ compute=compute,
+ **kwargs,
+ )
+ # init mark for _AttrDict
+ self._init = True
+ self.task = DataTransferTaskType.EXPORT_DATA
+ is_component = isinstance(component, DataTransferExportComponent)
+ if is_component:
+ _component: DataTransferExportComponent = cast(DataTransferExportComponent, component)
+ self.task = _component.task or self.task
+ self.sink = sink
+ self._init = False
+
+ @property
+ def sink(self) -> Optional[Union[Dict, Database, FileSystem]]:
+ """The sink of external data and databases.
+
+ :return: The sink of external data and databases.
+ :rtype: Union[None, Database, FileSystem]
+ """
+ return self._sink
+
+ @sink.setter
+ def sink(self, value: Union[Dict, Database, FileSystem]) -> None:
+ self._sink = _build_source_sink(value)
+
+ @classmethod
+ def _attr_type_map(cls) -> dict:
+ return {
+ "component": (str, DataTransferExportComponent),
+ }
+
+ @classmethod
+ def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]:
+ from azure.ai.ml._schema.pipeline import DataTransferExportSchema
+
+ return DataTransferExportSchema(context=context)
+
+ @classmethod
+ def _picked_fields_from_dict_to_rest_object(cls) -> List[str]:
+ return ["type", "task", "sink"]
+
+ def _customized_validate(self) -> MutableValidationResult:
+ result = super()._customized_validate()
+ if self.sink is None:
+ result.append_error(
+ yaml_path="sink",
+ message="Sink is a required field for export data task in DataTransfer job",
+ )
+ if len(self.inputs) != 1 or list(self.inputs.keys())[0] != "source":
+ result.append_error(
+ yaml_path="inputs.source",
+ message="Inputs field only support one input called source in export task",
+ )
+ if "source" in self.inputs and isinstance(self.inputs["source"]._data, Input):
+ source_input = self.inputs["source"]._data
+ if self.sink is not None and not isinstance(self.sink, Dict):
+ if (self.sink.type == ExternalDataType.DATABASE and source_input.type != AssetTypes.URI_FILE) or (
+ self.sink.type == ExternalDataType.FILE_SYSTEM and source_input.type != AssetTypes.URI_FOLDER
+ ):
+ result.append_error(
+ yaml_path="inputs.source.type",
+ message="Inputs field only support type {} for {} and {} for {}".format(
+ AssetTypes.URI_FILE,
+ ExternalDataType.DATABASE,
+ AssetTypes.URI_FOLDER,
+ ExternalDataType.FILE_SYSTEM,
+ ),
+ )
+
+ return result
+
+ def _to_rest_object(self, **kwargs: Any) -> dict:
+ rest_obj = super()._to_rest_object(**kwargs)
+ for key, value in {
+ "componentId": self._get_component_id(),
+ }.items():
+ if value is not None:
+ rest_obj[key] = value
+ return cast(dict, convert_ordered_dict_to_dict(rest_obj))
+
+ @classmethod
+ def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "DataTransferExport":
+ from .data_transfer_func import export_data
+
+ loaded_data = load_from_dict(DataTransferExportJobSchema, data, context, additional_message, **kwargs)
+ data_transfer_job: DataTransferExport = export_data(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_data)
+
+ return data_transfer_job
+
+ def _to_job(self) -> DataTransferExportJob:
+ return DataTransferExportJob(
+ experiment_name=self.experiment_name,
+ name=self.name,
+ display_name=self.display_name,
+ description=self.description,
+ tags=self.tags,
+ status=self.status,
+ sink=self.sink,
+ inputs=self._job_inputs,
+ services=self.services,
+ compute=self.compute,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/data_transfer_func.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/data_transfer_func.py
new file mode 100644
index 00000000..423c125b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/data_transfer_func.py
@@ -0,0 +1,335 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+# pylint: disable=protected-access
+
+from typing import Any, Callable, Dict, Optional, Tuple, Union
+
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants._common import AssetTypes, LegacyAssetTypes
+from azure.ai.ml.constants._component import ComponentSource, DataTransferBuiltinComponentUri, ExternalDataType
+from azure.ai.ml.entities._builders.base_node import pipeline_node_decorator
+from azure.ai.ml.entities._component.datatransfer_component import DataTransferCopyComponent
+from azure.ai.ml.entities._inputs_outputs import Input, Output
+from azure.ai.ml.entities._inputs_outputs.external_data import Database, FileSystem
+from azure.ai.ml.entities._job.pipeline._component_translatable import ComponentTranslatableMixin
+from azure.ai.ml.entities._job.pipeline._io import NodeOutput, PipelineInput
+from azure.ai.ml.exceptions import ErrorTarget, ValidationErrorType, ValidationException
+
+from .data_transfer import DataTransferCopy, DataTransferExport, DataTransferImport, _build_source_sink
+
+SUPPORTED_INPUTS = [
+ LegacyAssetTypes.PATH,
+ AssetTypes.URI_FILE,
+ AssetTypes.URI_FOLDER,
+ AssetTypes.CUSTOM_MODEL,
+ AssetTypes.MLFLOW_MODEL,
+ AssetTypes.MLTABLE,
+ AssetTypes.TRITON_MODEL,
+]
+
+
+def _parse_input(input_value: Union[Input, dict, str, PipelineInput, NodeOutput]) -> Tuple:
+ component_input = None
+ job_input: Union[Input, dict, str, PipelineInput, NodeOutput] = ""
+
+ if isinstance(input_value, Input):
+ component_input = Input(**input_value._to_dict())
+ input_type = input_value.type
+ if input_type in SUPPORTED_INPUTS:
+ job_input = Input(**input_value._to_dict())
+ elif isinstance(input_value, dict):
+ # if user provided dict, we try to parse it to Input.
+ # for job input, only parse for path type
+ input_type = input_value.get("type", None)
+ if input_type in SUPPORTED_INPUTS:
+ job_input = Input(**input_value)
+ component_input = Input(**input_value)
+ elif isinstance(input_value, str):
+ # Input bindings
+ component_input = ComponentTranslatableMixin._to_input_builder_function(input_value)
+ job_input = input_value
+ elif isinstance(input_value, (PipelineInput, NodeOutput)):
+ data: Any = None
+ # datatransfer node can accept PipelineInput/NodeOutput for export task.
+ if input_value._data is None or isinstance(input_value._data, Output):
+ data = Input(type=input_value.type, mode=input_value.mode)
+ else:
+ data = input_value._data
+ component_input, _ = _parse_input(data)
+ job_input = input_value
+ else:
+ msg = (
+ f"Unsupported input type: {type(input_value)}, only Input, dict, str, PipelineInput and NodeOutput are "
+ f"supported."
+ )
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.JOB,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ return component_input, job_input
+
+
+def _parse_output(output_value: Union[Output, Dict]) -> Tuple:
+ component_output = None
+ job_output: Union[Output, Dict] = {}
+
+ if isinstance(output_value, Output):
+ component_output = Output(**output_value._to_dict())
+ job_output = Output(**output_value._to_dict())
+ elif not output_value:
+ # output value can be None or empty dictionary
+ # None output value will be packed into a JobOutput object with mode = ReadWriteMount & type = UriFolder
+ component_output = ComponentTranslatableMixin._to_output(output_value)
+ job_output = output_value
+ elif isinstance(output_value, dict): # When output value is a non-empty dictionary
+ job_output = Output(**output_value)
+ component_output = Output(**output_value)
+ elif isinstance(output_value, str): # When output is passed in from pipeline job yaml
+ job_output = output_value
+ else:
+ msg = f"Unsupported output type: {type(output_value)}, only Output and dict are supported."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.JOB,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ return component_output, job_output
+
+
+def _parse_inputs_outputs(io_dict: Optional[Dict], parse_func: Callable) -> Tuple[Dict, Dict]:
+ component_io_dict, job_io_dict = {}, {}
+ if io_dict:
+ for key, val in io_dict.items():
+ component_io, job_io = parse_func(val)
+ component_io_dict[key] = component_io
+ job_io_dict[key] = job_io
+ return component_io_dict, job_io_dict
+
+
+@experimental
+def copy_data(
+ *,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ display_name: Optional[str] = None,
+ experiment_name: Optional[str] = None,
+ compute: Optional[str] = None,
+ inputs: Optional[Dict] = None,
+ outputs: Optional[Dict] = None,
+ is_deterministic: bool = True,
+ data_copy_mode: Optional[str] = None,
+ **kwargs: Any,
+) -> DataTransferCopy:
+ """Create a DataTransferCopy object which can be used inside dsl.pipeline as a function.
+
+ :keyword name: The name of the job.
+ :paramtype name: str
+ :keyword description: Description of the job.
+ :paramtype description: str
+ :keyword tags: Tag dictionary. Tags can be added, removed, and updated.
+ :paramtype tags: dict[str, str]
+ :keyword display_name: Display name of the job.
+ :paramtype display_name: str
+ :keyword experiment_name: Name of the experiment the job will be created under.
+ :paramtype experiment_name: str
+ :keyword compute: The compute resource the job runs on.
+ :paramtype compute: str
+ :keyword inputs: Mapping of inputs data bindings used in the job.
+ :paramtype inputs: dict
+ :keyword outputs: Mapping of outputs data bindings used in the job.
+ :paramtype outputs: dict
+ :keyword is_deterministic: Specify whether the command will return same output given same input.
+ If a command (component) is deterministic, when use it as a node/step in a pipeline,
+ it will reuse results from a previous submitted job in current workspace which has same inputs and settings.
+ In this case, this step will not use any compute resource.
+ Default to be True, specify is_deterministic=False if you would like to avoid such reuse behavior.
+ :paramtype is_deterministic: bool
+ :keyword data_copy_mode: data copy mode in copy task, possible value is "merge_with_overwrite", "fail_if_conflict".
+ :paramtype data_copy_mode: str
+ :return: A DataTransferCopy object.
+ :rtype: ~azure.ai.ml.entities._component.datatransfer_component.DataTransferCopyComponent
+ """
+ inputs = inputs or {}
+ outputs = outputs or {}
+ component_inputs, job_inputs = _parse_inputs_outputs(inputs, parse_func=_parse_input)
+ # job inputs can not be None
+ job_inputs = {k: v for k, v in job_inputs.items() if v is not None}
+ component_outputs, job_outputs = _parse_inputs_outputs(outputs, parse_func=_parse_output)
+ component = kwargs.pop("component", None)
+ if component is None:
+ component = DataTransferCopyComponent(
+ name=name,
+ tags=tags,
+ display_name=display_name,
+ description=description,
+ inputs=component_inputs,
+ outputs=component_outputs,
+ data_copy_mode=data_copy_mode,
+ _source=ComponentSource.BUILDER,
+ is_deterministic=is_deterministic,
+ **kwargs,
+ )
+ data_transfer_copy_obj = DataTransferCopy(
+ component=component,
+ name=name,
+ description=description,
+ tags=tags,
+ display_name=display_name,
+ experiment_name=experiment_name,
+ compute=compute,
+ inputs=job_inputs,
+ outputs=job_outputs,
+ data_copy_mode=data_copy_mode,
+ **kwargs,
+ )
+ return data_transfer_copy_obj
+
+
+@experimental
+@pipeline_node_decorator
+def import_data(
+ *,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ display_name: Optional[str] = None,
+ experiment_name: Optional[str] = None,
+ compute: Optional[str] = None,
+ source: Optional[Union[Dict, Database, FileSystem]] = None,
+ outputs: Optional[Dict] = None,
+ **kwargs: Any,
+) -> DataTransferImport:
+ """Create a DataTransferImport object which can be used inside dsl.pipeline.
+
+ :keyword name: The name of the job.
+ :paramtype name: str
+ :keyword description: Description of the job.
+ :paramtype description: str
+ :keyword tags: Tag dictionary. Tags can be added, removed, and updated.
+ :paramtype tags: dict[str, str]
+ :keyword display_name: Display name of the job.
+ :paramtype display_name: str
+ :keyword experiment_name: Name of the experiment the job will be created under.
+ :paramtype experiment_name: str
+ :keyword compute: The compute resource the job runs on.
+ :paramtype compute: str
+ :keyword source: The data source of file system or database.
+ :paramtype source: Union[Dict, ~azure.ai.ml.entities._inputs_outputs.external_data.Database,
+ ~azure.ai.ml.entities._inputs_outputs.external_data.FileSystem]
+ :keyword outputs: Mapping of outputs data bindings used in the job.
+ The default will be an output port with the key "sink" and type "mltable".
+ :paramtype outputs: dict
+ :return: A DataTransferImport object.
+ :rtype: ~azure.ai.ml.entities._job.pipeline._component_translatable.DataTransferImport
+ """
+ source = _build_source_sink(source)
+ outputs = outputs or {"sink": Output(type=AssetTypes.MLTABLE)}
+ # # job inputs can not be None
+ # job_inputs = {k: v for k, v in job_inputs.items() if v is not None}
+ _, job_outputs = _parse_inputs_outputs(outputs, parse_func=_parse_output)
+ component = kwargs.pop("component", None)
+ update_source = False
+ if component is None:
+ if source and source.type == ExternalDataType.DATABASE:
+ component = DataTransferBuiltinComponentUri.IMPORT_DATABASE
+ else:
+ component = DataTransferBuiltinComponentUri.IMPORT_FILE_SYSTEM
+ update_source = True
+
+ data_transfer_import_obj = DataTransferImport(
+ component=component,
+ name=name,
+ description=description,
+ tags=tags,
+ display_name=display_name,
+ experiment_name=experiment_name,
+ compute=compute,
+ source=source,
+ outputs=job_outputs,
+ **kwargs,
+ )
+ if update_source:
+ data_transfer_import_obj._source = ComponentSource.BUILTIN
+
+ return data_transfer_import_obj
+
+
+@experimental
+@pipeline_node_decorator
+def export_data(
+ *,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ display_name: Optional[str] = None,
+ experiment_name: Optional[str] = None,
+ compute: Optional[str] = None,
+ sink: Optional[Union[Dict, Database, FileSystem]] = None,
+ inputs: Optional[Dict] = None,
+ **kwargs: Any,
+) -> DataTransferExport:
+ """Create a DataTransferExport object which can be used inside dsl.pipeline.
+
+ :keyword name: The name of the job.
+ :paramtype name: str
+ :keyword description: Description of the job.
+ :paramtype description: str
+ :keyword tags: Tag dictionary. Tags can be added, removed, and updated.
+ :paramtype tags: dict[str, str]
+ :keyword display_name: Display name of the job.
+ :paramtype display_name: str
+ :keyword experiment_name: Name of the experiment the job will be created under.
+ :paramtype experiment_name: str
+ :keyword compute: The compute resource the job runs on.
+ :paramtype compute: str
+ :keyword sink: The sink of external data and databases.
+ :paramtype sink: Union[
+ Dict,
+ ~azure.ai.ml.entities._inputs_outputs.external_data.Database,
+ ~azure.ai.ml.entities._inputs_outputs.external_data.FileSystem]
+ :keyword inputs: Mapping of inputs data bindings used in the job.
+ :paramtype inputs: dict
+ :return: A DataTransferExport object.
+ :rtype: ~azure.ai.ml.entities._job.pipeline._component_translatable.DataTransferExport
+ :raises ValidationException: If sink is not provided or exporting file system is not supported.
+ """
+ sink = _build_source_sink(sink)
+ _, job_inputs = _parse_inputs_outputs(inputs, parse_func=_parse_input)
+ # job inputs can not be None
+ job_inputs = {k: v for k, v in job_inputs.items() if v is not None}
+ component = kwargs.pop("component", None)
+ update_source = False
+ if component is None:
+ if sink and sink.type == ExternalDataType.DATABASE:
+ component = DataTransferBuiltinComponentUri.EXPORT_DATABASE
+ else:
+ msg = "Sink is a required field for export data task and we don't support exporting file system for now."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.JOB,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ update_source = True
+
+ data_transfer_export_obj = DataTransferExport(
+ component=component,
+ name=name,
+ description=description,
+ tags=tags,
+ display_name=display_name,
+ experiment_name=experiment_name,
+ compute=compute,
+ sink=sink,
+ inputs=job_inputs,
+ **kwargs,
+ )
+ if update_source:
+ data_transfer_export_obj._source = ComponentSource.BUILTIN
+
+ return data_transfer_export_obj
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/do_while.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/do_while.py
new file mode 100644
index 00000000..ecfd51ca
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/do_while.py
@@ -0,0 +1,357 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+import logging
+from typing import Any, Dict, Optional, Union
+
+from typing_extensions import Literal
+
+from azure.ai.ml._schema.pipeline.control_flow_job import DoWhileSchema
+from azure.ai.ml.constants._component import DO_WHILE_MAX_ITERATION, ControlFlowType
+from azure.ai.ml.entities._job.job_limits import DoWhileJobLimits
+from azure.ai.ml.entities._job.pipeline._io import InputOutputBase, NodeInput, NodeOutput
+from azure.ai.ml.entities._job.pipeline.pipeline_job import PipelineJob
+from azure.ai.ml.entities._validation import MutableValidationResult
+
+from .._util import load_from_dict, validate_attribute_type
+from .base_node import BaseNode
+from .control_flow_node import LoopNode
+from .pipeline import Pipeline
+
+module_logger = logging.getLogger(__name__)
+
+
+class DoWhile(LoopNode):
+ """Do-while loop node in the pipeline job. By specifying the loop body and loop termination condition in this class,
+ a job-level do while loop can be implemented. It will be initialized when calling dsl.do_while or when loading the
+ pipeline yml containing do_while node. Please do not manually initialize this class.
+
+ :param body: Pipeline job for the do-while loop body.
+ :type body: ~azure.ai.ml.entities._builders.pipeline.Pipeline
+ :param condition: Boolean type control output of body as do-while loop condition.
+ :type condition: ~azure.ai.ml.entities.Output
+ :param mapping: Output-Input mapping for each round of the do-while loop.
+ Key is the last round output of the body. Value is the input port for the current body.
+ :type mapping: dict[Union[str, ~azure.ai.ml.entities.Output],
+ Union[str, ~azure.ai.ml.entities.Input, list]]
+ :param limits: Limits in running the do-while node.
+ :type limits: Union[dict, ~azure.ai.ml.entities._job.job_limits.DoWhileJobLimits]
+ :raises ValidationError: If the initialization parameters are not of valid types.
+ """
+
+ def __init__(
+ self,
+ *,
+ body: Union[Pipeline, BaseNode],
+ condition: Optional[Union[str, NodeInput, NodeOutput]],
+ mapping: Dict,
+ limits: Optional[Union[dict, DoWhileJobLimits]] = None,
+ **kwargs: Any,
+ ) -> None:
+ # validate init params are valid type
+ validate_attribute_type(attrs_to_check=locals(), attr_type_map=self._attr_type_map())
+
+ kwargs.pop("type", None)
+ super(DoWhile, self).__init__(
+ type=ControlFlowType.DO_WHILE,
+ body=body,
+ **kwargs,
+ )
+
+ # init mark for _AttrDict
+ self._init = True
+ self._mapping = mapping or {}
+ self._condition = condition
+ self._limits = limits
+ self._init = False
+
+ @property
+ def mapping(self) -> Dict:
+ """Get the output-input mapping for each round of the do-while loop.
+
+ :return: Output-Input mapping for each round of the do-while loop.
+ :rtype: dict[Union[str, ~azure.ai.ml.entities.Output],
+ Union[str, ~azure.ai.ml.entities.Input, list]]
+ """
+ return self._mapping
+
+ @property
+ def condition(self) -> Optional[Union[str, NodeInput, NodeOutput]]:
+ """Get the boolean type control output of the body as the do-while loop condition.
+
+ :return: Control output of the body as the do-while loop condition.
+ :rtype: ~azure.ai.ml.entities.Output
+ """
+ return self._condition
+
+ @property
+ def limits(self) -> Union[Dict, DoWhileJobLimits, None]:
+ """Get the limits in running the do-while node.
+
+ :return: Limits in running the do-while node.
+ :rtype: Union[dict, ~azure.ai.ml.entities._job.job_limits.DoWhileJobLimits]
+ """
+ return self._limits
+
+ @classmethod
+ def _attr_type_map(cls) -> dict:
+ return {
+ **super(DoWhile, cls)._attr_type_map(),
+ "mapping": dict,
+ "limits": (dict, DoWhileJobLimits),
+ }
+
+ @classmethod
+ def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "DoWhile":
+ loaded_data = load_from_dict(DoWhileSchema, data, context, additional_message, **kwargs)
+
+ return cls(**loaded_data)
+
+ @classmethod
+ def _get_port_obj(
+ cls, body: BaseNode, port_name: str, is_input: bool = True, validate_port: bool = True
+ ) -> Union[str, NodeInput, NodeOutput]:
+ if is_input:
+ port = body.inputs.get(port_name, None)
+ else:
+ port = body.outputs.get(port_name, None)
+ if port is None:
+ if validate_port:
+ raise cls._create_validation_error(
+ message=f"Cannot find {port_name} in do_while loop body {'inputs' if is_input else 'outputs'}.",
+ no_personal_data_message=f"Miss port in do_while loop body {'inputs' if is_input else 'outputs'}.",
+ )
+ return port_name
+
+ res: Union[str, NodeInput, NodeOutput] = port
+ return res
+
+ @classmethod
+ def _create_instance_from_schema_dict(
+ cls, pipeline_jobs: Dict[str, BaseNode], loaded_data: Dict, validate_port: bool = True
+ ) -> "DoWhile":
+ """Create a do_while instance from schema parsed dict.
+
+ :param pipeline_jobs: The pipeline jobs
+ :type pipeline_jobs: Dict[str, BaseNode]
+ :param loaded_data: The loaded data
+ :type loaded_data: Dict
+ :param validate_port: Whether to raise if inputs/outputs are not present. Defaults to True
+ :type validate_port: bool
+ :return: The DoWhile node
+ :rtype: DoWhile
+ """
+
+ # Get body object from pipeline job list.
+ body_name = cls._get_data_binding_expression_value(loaded_data.pop("body"), regex=r"\{\{.*\.jobs\.(.*)\}\}")
+ body = cls._get_body_from_pipeline_jobs(pipeline_jobs, body_name)
+
+ # Convert mapping key-vault to input/output object
+ mapping = {}
+ for k, v in loaded_data.pop("mapping", {}).items():
+ output_name = cls._get_data_binding_expression_value(k, regex=r"\{\{.*\.%s\.outputs\.(.*)\}\}" % body_name)
+ input_names = v if isinstance(v, list) else [v]
+ input_names = [
+ cls._get_data_binding_expression_value(item, regex=r"\{\{.*\.%s\.inputs\.(.*)\}\}" % body_name)
+ for item in input_names
+ ]
+ mapping[output_name] = [cls._get_port_obj(body, item, validate_port=validate_port) for item in input_names]
+
+ limits = loaded_data.pop("limits", None)
+
+ if "condition" in loaded_data:
+ # Convert condition to output object
+ condition_name = cls._get_data_binding_expression_value(
+ loaded_data.pop("condition"), regex=r"\{\{.*\.%s\.outputs\.(.*)\}\}" % body_name
+ )
+ condition_value = cls._get_port_obj(body, condition_name, is_input=False, validate_port=validate_port)
+ else:
+ condition_value = None
+
+ do_while_instance = DoWhile(
+ body=body,
+ mapping=mapping,
+ condition=condition_value,
+ **loaded_data,
+ )
+ do_while_instance.set_limits(**limits)
+
+ return do_while_instance
+
+ @classmethod
+ def _create_schema_for_validation(cls, context: Any) -> DoWhileSchema:
+ return DoWhileSchema(context=context)
+
+ @classmethod
+ def _from_rest_object(cls, obj: dict, pipeline_jobs: dict) -> "DoWhile":
+ # pylint: disable=protected-access
+
+ obj = BaseNode._from_rest_object_to_init_params(obj)
+ return cls._create_instance_from_schema_dict(pipeline_jobs, obj, validate_port=False)
+
+ def set_limits(
+ self,
+ *,
+ max_iteration_count: int,
+ # pylint: disable=unused-argument
+ **kwargs: Any,
+ ) -> None:
+ """
+ Set the maximum iteration count for the do-while job.
+
+ The range of the iteration count is (0, 1000].
+
+ :keyword max_iteration_count: The maximum iteration count for the do-while job.
+ :paramtype max_iteration_count: int
+ """
+ if isinstance(self.limits, DoWhileJobLimits):
+ self.limits._max_iteration_count = max_iteration_count # pylint: disable=protected-access
+ else:
+ self._limits = DoWhileJobLimits(max_iteration_count=max_iteration_count)
+
+ def _customized_validate(self) -> MutableValidationResult:
+ validation_result = self._validate_loop_condition()
+ validation_result.merge_with(self._validate_body())
+ validation_result.merge_with(self._validate_do_while_limit())
+ validation_result.merge_with(self._validate_body_output_mapping())
+ return validation_result
+
+ def _validate_port(
+ self,
+ port: Union[str, NodeInput, NodeOutput],
+ node_ports: Dict[str, Union[NodeInput, NodeOutput]],
+ port_type: Literal["input", "output"],
+ yaml_path: str,
+ ) -> MutableValidationResult:
+ """Validate input/output port is exist in the dowhile body.
+
+ :param port: Either:
+ * The name of an input or output
+ * An input object
+ * An output object
+ :type port: Union[str, NodeInput, NodeOutput],
+ :param node_ports: The node input/outputs
+ :type node_ports: Union[Dict[str, Union[NodeInput, NodeOutput]]]
+ :param port_type: The port type
+ :type port_type: Literal["input", "output"],
+ :param yaml_path: The yaml path
+ :type yaml_path: str,
+ :return: The validation result
+ :rtype: MutableValidationResult
+ """
+ validation_result = self._create_empty_validation_result()
+ if isinstance(port, str):
+ port_obj = node_ports.get(port, None)
+ else:
+ port_obj = port
+ if (
+ port_obj is not None
+ and port_obj._owner is not None # pylint: disable=protected-access
+ and not isinstance(port_obj._owner, PipelineJob) # pylint: disable=protected-access
+ and port_obj._owner._instance_id != self.body._instance_id # pylint: disable=protected-access
+ ):
+ # Check the port owner is dowhile body.
+ validation_result.append_error(
+ yaml_path=yaml_path,
+ message=(
+ f"{port_obj._port_name} is the {port_type} of {port_obj._owner.name}, " # pylint: disable=protected-access
+ f"dowhile only accept {port_type} of the body: {self.body.name}."
+ ),
+ )
+ elif port_obj is None or port_obj._port_name not in node_ports: # pylint: disable=protected-access
+ # Check port is exist in dowhile body.
+ validation_result.append_error(
+ yaml_path=yaml_path,
+ message=(
+ f"The {port_type} of mapping {port_obj._port_name if port_obj else port} does not " # pylint: disable=protected-access
+ f"exist in {self.body.name} {port_type}, existing {port_type}: {node_ports.keys()}"
+ ),
+ )
+ return validation_result
+
+ def _validate_loop_condition(self) -> MutableValidationResult:
+ # pylint: disable=protected-access
+ validation_result = self._create_empty_validation_result()
+ if self.condition is not None:
+ # Check condition exists in dowhile body.
+ validation_result.merge_with(
+ self._validate_port(self.condition, self.body.outputs, port_type="output", yaml_path="condition")
+ )
+ if validation_result.passed:
+ # Check condition is a control output.
+ condition_name = self.condition if isinstance(self.condition, str) else self.condition._port_name
+ if not self.body._outputs[condition_name]._is_primitive_type:
+ validation_result.append_error(
+ yaml_path="condition",
+ message=(
+ f"{condition_name} is not a control output and is not primitive type. "
+ "The condition of dowhile must be the control output or primitive type of the body."
+ ),
+ )
+ return validation_result
+
+ def _validate_do_while_limit(self) -> MutableValidationResult:
+ validation_result = self._create_empty_validation_result()
+ if isinstance(self.limits, DoWhileJobLimits):
+ if not self.limits or self.limits.max_iteration_count is None:
+ return validation_result
+ if isinstance(self.limits.max_iteration_count, InputOutputBase):
+ validation_result.append_error(
+ yaml_path="limit.max_iteration_count",
+ message="The max iteration count cannot be linked with an primitive type input.",
+ )
+ elif self.limits.max_iteration_count > DO_WHILE_MAX_ITERATION or self.limits.max_iteration_count < 0:
+ validation_result.append_error(
+ yaml_path="limit.max_iteration_count",
+ message=f"The max iteration count cannot be less than 0 or larger than {DO_WHILE_MAX_ITERATION}.",
+ )
+ return validation_result
+
+ def _validate_body_output_mapping(self) -> MutableValidationResult:
+ # pylint disable=protected-access
+ validation_result = self._create_empty_validation_result()
+ if not isinstance(self.mapping, dict):
+ validation_result.append_error(
+ yaml_path="mapping", message=f"Mapping expects a dict type but passes in a {type(self.mapping)} type."
+ )
+ else:
+ # Record the mapping relationship between input and output
+ input_output_mapping: Dict = {}
+ # Validate mapping input&output should come from while body
+ for output, inputs in self.mapping.items():
+ # pylint: disable=protected-access
+ output_name = output if isinstance(output, str) else output._port_name
+ validate_results = self._validate_port(
+ output, self.body.outputs, port_type="output", yaml_path="mapping"
+ )
+ if validate_results.passed:
+ is_primitive_output = self.body._outputs[output_name]._is_primitive_type
+ inputs = inputs if isinstance(inputs, list) else [inputs]
+ for item in inputs:
+ input_validate_results = self._validate_port(
+ item, self.body.inputs, port_type="input", yaml_path="mapping"
+ )
+ validation_result.merge_with(input_validate_results)
+ # pylint: disable=protected-access
+ input_name = item if isinstance(item, str) else item._port_name
+ input_output_mapping[input_name] = input_output_mapping.get(input_name, []) + [output_name]
+ is_primitive_type = self.body._inputs[input_name]._meta._is_primitive_type
+
+ if input_validate_results.passed and not is_primitive_output and is_primitive_type:
+ validate_results.append_error(
+ yaml_path="mapping",
+ message=(
+ f"{output_name} is a non-primitive type output and {input_name} "
+ "is a primitive input. Non-primitive type output cannot be connected "
+ "to an a primitive type input."
+ ),
+ )
+
+ validation_result.merge_with(validate_results)
+ # Validate whether input is linked to multiple outputs
+ for _input, outputs in input_output_mapping.items():
+ if len(outputs) > 1:
+ validation_result.append_error(
+ yaml_path="mapping", message=f"Input {_input} has been linked to multiple outputs {outputs}."
+ )
+ return validation_result
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/fl_scatter_gather.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/fl_scatter_gather.py
new file mode 100644
index 00000000..0ad6b0e2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/fl_scatter_gather.py
@@ -0,0 +1,886 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+import re
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+from azure.ai.ml import Output
+from azure.ai.ml._schema import PathAwareSchema
+from azure.ai.ml._schema.pipeline.control_flow_job import FLScatterGatherSchema
+from azure.ai.ml.constants import JobType
+from azure.ai.ml.constants._common import AssetTypes
+from azure.ai.ml.dsl import pipeline
+from azure.ai.ml.dsl._do_while import do_while
+from azure.ai.ml.entities._assets.federated_learning_silo import FederatedLearningSilo
+from azure.ai.ml.entities._builders.control_flow_node import ControlFlowNode
+from azure.ai.ml.entities._builders.pipeline import Pipeline
+from azure.ai.ml.entities._component.command_component import CommandComponent
+from azure.ai.ml.entities._component.component import Component
+from azure.ai.ml.entities._inputs_outputs.input import Input
+from azure.ai.ml.entities._job.pipeline._io.mixin import NodeIOMixin
+from azure.ai.ml.entities._job.pipeline.pipeline_job import PipelineJob
+from azure.ai.ml.entities._util import convert_ordered_dict_to_dict
+from azure.ai.ml.entities._validation import MutableValidationResult
+
+from .subcomponents import create_scatter_output_table
+
+# TODO 2293610: add support for more types of outputs besides uri_folder and mltable
+# Likely types that ought to be mergeable: string, int, uri_file
+MERGE_COMPONENT_MAPPING = {
+ "mltable": create_scatter_output_table,
+ "uri_folder": create_scatter_output_table,
+}
+
+
+ANCHORABLE_OUTPUT_TYPES = {AssetTypes.MLTABLE, AssetTypes.URI_FOLDER}
+
+ANCHORING_PATH_ROOT = "root"
+
+
+# big TODO: For some reason, surfacing this file in __init__.py causes
+# a circular import exception on the first attempted import
+# In notebooks, the second import succeeds, but then causes a silent failure where the
+# MLDesigner component created by the subcomponents.create_scatter_output_table function
+# will produce a ComponentExecutor object instead of the actual component.
+# TODO 2293541: Add telemetry of some sort
+# pylint: disable=too-many-instance-attributes
+class FLScatterGather(ControlFlowNode, NodeIOMixin):
+ """A node which creates a federated learning scatter-gather loop as a pipeline subgraph.
+ Intended for use inside a pipeline job. This is initialized when calling
+ `dsl.fl_scatter_gather()` or when loading a serialized version of this node from YAML.
+ Please do not manually initialize this class.
+
+ :param silo_configs: List of federated learning silo configurations.
+ :type silo_configs: List[~azure.ai.ml.entities._assets.federated_learning_silo.FederatedLearningSilo]
+ :param silo_component: Component representing the silo for federated learning.
+ :type silo_component: ~azure.ai.ml.entities.Component
+ :param aggregation_component: Component representing the aggregation step.
+ :type aggregation_component: ~azure.ai.ml.entities.Component
+ :param aggregation_compute: The compute resource for the aggregation step.
+ :type aggregation_compute: str
+ :param aggregation_datastore: The datastore for the aggregation step.
+ :type aggregation_datastore: str
+ :param shared_silo_kwargs: Keyword arguments shared across all silos.
+ :type shared_silo_kwargs: dict
+ :param aggregation_kwargs: Keyword arguments specific to the aggregation step.
+ :type aggregation_kwargs: dict
+ :param silo_to_aggregation_argument_map: Mapping of silo to aggregation arguments.
+ :type silo_to_aggregation_argument_map: dict
+ :param aggregation_to_silo_argument_map: Mapping of aggregation to silo arguments.
+ :type aggregation_to_silo_argument_map: dict
+ :param max_iterations: The maximum number of iterations for the scatter-gather loop.
+ :type max_iterations: int
+ :param create_default_mappings_if_needed: Whether to create default argument mappings if needed.
+ :type create_default_mappings_if_needed: bool
+ """
+
+ # See node class for input descriptions, no point maintaining
+ # double descriptions between a wrapper its interior.
+ def __init__(
+ self,
+ *,
+ silo_configs: List[FederatedLearningSilo],
+ silo_component: Component,
+ aggregation_component: Component,
+ aggregation_compute: Optional[str] = None,
+ aggregation_datastore: Optional[str] = None,
+ shared_silo_kwargs: Optional[Dict] = None,
+ aggregation_kwargs: Optional[Dict] = None,
+ silo_to_aggregation_argument_map: Optional[Dict] = None,
+ aggregation_to_silo_argument_map: Optional[Dict] = None,
+ max_iterations: int = 1,
+ create_default_mappings_if_needed: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ # auto-create X_to_Y_argument_map values if allowed and needed.
+ if create_default_mappings_if_needed:
+ (
+ silo_to_aggregation_argument_map,
+ aggregation_to_silo_argument_map,
+ ) = FLScatterGather._try_create_default_mappings(
+ silo_component,
+ aggregation_component,
+ silo_to_aggregation_argument_map,
+ aggregation_to_silo_argument_map,
+ )
+
+ # input validation.
+ FLScatterGather.validate_inputs(
+ silo_configs=silo_configs,
+ silo_component=silo_component,
+ aggregation_component=aggregation_component,
+ shared_silo_kwargs=shared_silo_kwargs,
+ aggregation_compute=aggregation_compute,
+ aggregation_datastore=aggregation_datastore,
+ aggregation_kwargs=aggregation_kwargs,
+ silo_to_aggregation_argument_map=silo_to_aggregation_argument_map,
+ aggregation_to_silo_argument_map=aggregation_to_silo_argument_map,
+ max_iterations=max_iterations,
+ )
+
+ # store inputs
+ self.silo_configs = silo_configs
+ self.aggregation_compute = aggregation_compute
+ self.aggregation_datastore = aggregation_datastore
+ self.silo_component = silo_component
+ self.aggregation_component = aggregation_component
+ self.shared_silo_kwargs = shared_silo_kwargs
+ self.aggregation_kwargs = aggregation_kwargs
+ self.silo_to_aggregation_argument_map = silo_to_aggregation_argument_map
+ self.aggregation_to_silo_argument_map = aggregation_to_silo_argument_map
+ self.max_iterations = max_iterations
+ self._init = True # Needed by parent class to work properly
+
+ self.scatter_gather_graph = self.scatter_gather()
+
+ # set SG node flag for telemetry
+ self.scatter_gather_graph.properties["azureml.telemetry.attribution"] = "FederatedLearningSGJobFlag"
+ self.scatter_gather_graph._to_rest_object()
+
+ # set output to final aggregation step's output
+ self._outputs = self.scatter_gather_graph.outputs
+ super(FLScatterGather, self).__init__(
+ type=JobType.COMPONENT,
+ component=None,
+ inputs=None,
+ outputs=self.scatter_gather_graph.outputs,
+ name=None,
+ display_name=None,
+ description=None,
+ tags=None,
+ properties=None,
+ comment=None,
+ compute=None,
+ experiment_name=None,
+ )
+
+ def scatter_gather(self) -> PipelineJob:
+ """Executes the scatter-gather loop by creating and executing a pipeline subgraph.
+ Returns the outputs of the final aggregation step.
+
+ :return: Outputs of the final aggregation step.
+ :rtype: list[~azure.ai.ml.Output]
+ """
+
+ @pipeline(
+ func=None,
+ name="Scatter gather",
+ description="It includes all steps that need to be executed in silo and aggregation",
+ )
+ # pylint: disable-next=docstring-missing-return,docstring-missing-rtype
+ def scatter_gather_iteration_body(**silo_inputs: Input) -> PipelineJob:
+ """
+ Performs a scatter-gather iteration by running copies of the silo step on different
+ computes/datstores according to this node's silo configs. The outputs of these
+ silo components are then merged by an internal helper component. The merged values
+ are then inputted into the user-provided aggregation component. Returns the executed aggregation component.
+
+ Kwargs are a dictionary of names and Inputs to be injected into each executed silo step. This dictionary is
+ merged with silo-specific inputs before each executed.
+ """
+
+ silo_outputs = []
+ # TODO 2293586 replace this for-loop with a parallel-for node
+ for silo_config in self.silo_configs:
+ silo_inputs.update(silo_config.inputs)
+ executed_silo_component = self.silo_component(**silo_inputs)
+ for v, k in executed_silo_component.inputs.items():
+ if v in silo_config.inputs and k.type == "uri_folder":
+ k.mode = "ro_mount"
+ FLScatterGather._anchor_step(
+ pipeline_step=executed_silo_component,
+ compute=silo_config.compute,
+ internal_datastore=silo_config.datastore,
+ orchestrator_datastore=self.aggregation_datastore,
+ )
+ # add to silo outputs list
+ silo_outputs.append(executed_silo_component)
+
+ # produce internal argument-merging components and record them in local subgraph
+ merge_comp_mapping = self._inject_merge_components(silo_outputs)
+
+ # produce aggregate step inputs by merging static kwargs and mapped arguments from
+ # internal merge components
+ agg_inputs: Dict = {}
+ if self.aggregation_kwargs is not None:
+ agg_inputs.update(self.aggregation_kwargs)
+ internal_merge_outputs = {
+ self._get_aggregator_input_name(k): v.outputs.aggregated_output for k, v in merge_comp_mapping.items()
+ }
+ agg_inputs.update(internal_merge_outputs)
+
+ # run the user aggregation step
+ executed_aggregation_component = self.aggregation_component(**agg_inputs)
+ # Set mode of aggregated mltable inputs as eval mount to allow files referenced within the table
+ # to be accessible by the component
+ for name, agg_input in executed_aggregation_component.inputs.items():
+ if (
+ self.silo_to_aggregation_argument_map is not None
+ and name in self.silo_to_aggregation_argument_map.values()
+ and agg_input.type == "mltable"
+ ):
+ agg_input.mode = "eval_download"
+
+ # Anchor both the internal merge components and the user-supplied aggregation step
+ # to the aggregation compute and datastore
+ if self.aggregation_compute is not None and self.aggregation_datastore is not None:
+ # internal merge component is also siloed to wherever the aggregation component lives.
+ for executed_merge_component in merge_comp_mapping.values():
+ FLScatterGather._anchor_step(
+ pipeline_step=executed_merge_component,
+ compute=self.aggregation_compute,
+ internal_datastore=self.aggregation_datastore,
+ orchestrator_datastore=self.aggregation_datastore,
+ )
+ FLScatterGather._anchor_step(
+ pipeline_step=executed_aggregation_component,
+ compute=self.aggregation_compute,
+ internal_datastore=self.aggregation_datastore,
+ orchestrator_datastore=self.aggregation_datastore,
+ )
+ res: PipelineJob = executed_aggregation_component.outputs
+ return res
+
+ @pipeline(func=None, name="Scatter gather graph")
+ # pylint: disable-next=docstring-missing-return,docstring-missing-rtype
+ def create_scatter_gather_graph() -> PipelineJob:
+ """
+ Creates a scatter-gather graph by executing the scatter_gather_iteration_body
+ function in a do-while loop. The loop terminates when the user-supplied
+ termination condition is met.
+ """
+
+ silo_inputs: Dict = {}
+ if self.shared_silo_kwargs is not None:
+ # Start with static inputs
+ silo_inputs.update(self.shared_silo_kwargs)
+
+ # merge in inputs passed in from previous iteration's aggregate step)
+ if self.aggregation_to_silo_argument_map is not None:
+ silo_inputs.update({v: None for v in self.aggregation_to_silo_argument_map.values()})
+
+ scatter_gather_body = scatter_gather_iteration_body(**silo_inputs)
+
+ # map aggregation outputs to scatter inputs
+ if self.aggregation_to_silo_argument_map is not None:
+ do_while_mapping = {
+ k: getattr(scatter_gather_body.inputs, v) for k, v in self.aggregation_to_silo_argument_map.items()
+ }
+
+ do_while(
+ body=scatter_gather_body, # type: ignore[arg-type]
+ mapping=do_while_mapping, # pylint: disable=possibly-used-before-assignment
+ max_iteration_count=self.max_iterations,
+ )
+ res_scatter: PipelineJob = scatter_gather_body.outputs # type: ignore[assignment]
+ return res_scatter
+
+ res: PipelineJob = create_scatter_gather_graph()
+ return res
+
+ @classmethod
+ def _get_fl_datastore_path(
+ cls,
+ datastore_name: Optional[str],
+ output_name: str,
+ unique_id: str = "${{name}}",
+ iteration_num: Optional[int] = None,
+ ) -> str:
+ """Construct a path string using the inputted values. The important aspect is that this produces a
+ path with a specified datastore.
+
+ :param datastore_name: The datastore to use in the constructed path.
+ :type datastore_name: str
+ :param output_name: The name of the output value that this path is assumed to belong to.
+ Is injected into the path.
+ :type output_name: str
+ :param unique_id: An additional string to inject if needed. Defaults to ${{name}}, which is the
+ output name again.
+ :type unique_id: str
+ :param iteration_num: The iteration number of the current scatter-gather iteration.
+ If set, inject this into the resulting path string.
+ :type iteration_num: Optional[int]
+ :return: A data path string containing the various aforementioned inputs.
+ :rtype: str
+
+ """
+ data_path = f"azureml://datastores/{datastore_name}/paths/federated_learning/{output_name}/{unique_id}/"
+ if iteration_num:
+ data_path += f"iteration_{iteration_num}/"
+ return data_path
+
+ @classmethod
+ def _check_datastore(cls, path: str, expected_datastore: Optional[str]) -> bool:
+ """Perform a simple regex check to try determine if the datastore in the inputted path string
+ matches the expected_datastore.
+
+
+ :param path: An output pathstring.
+ :type path: str
+ :param expected_datastore: A datastore name.
+ :type expected_datastore: str
+ :return: Whether or not the expected_datastore was found in the path at the expected location.
+ :rtype: bool
+ """
+ match = re.match("(.*datastore/)([^/]*)(/.*)", path)
+ if match:
+ groups = match.groups()
+ if groups[1] == expected_datastore:
+ return True
+ return False
+
+ @classmethod
+ def _check_or_set_datastore(
+ cls,
+ name: str,
+ output: Output,
+ target_datastore: Optional[str],
+ iteration_num: Optional[int] = None,
+ ) -> MutableValidationResult:
+ """Tries to assign output.path to a value which includes the target_datastore if it's not already
+ set. If the output's path is already set, return a warning if it doesn't match the target_datastore.
+
+ :param name: The name of the output to modify
+ :type name: str
+ :param output: The output object to examine and potentially change the datastore of.
+ :type output: Output
+ :param target_datastore: The name of the datastore to try applying to the output
+ :type target_datastore: str
+ :param iteration_num: the current iteration in the scatter gather loop. If set, include this in the generated
+ path.
+ :type iteration_num: Optional[int]
+ :return: A validation result containing any problems that arose. Contains a warning if the examined output
+ already contains a datastore that does not match 'target_datastore'.
+ :rtype: MutableValidationResult
+ """
+ validation_result = cls._create_empty_validation_result()
+ if not hasattr(output, "path") or not output.path:
+ output.path = cls._get_fl_datastore_path(target_datastore, name, iteration_num=iteration_num)
+ # Double check the path's datastore leads to the target if it's already set.
+ elif not cls._check_datastore(output.path, target_datastore):
+ validation_result.append_warning(
+ yaml_path=name,
+ message=f"Output '{name}' has an undetermined datastore, or a datstore"
+ + f" that does not match the expected datastore for this output, which is '{target_datastore}'."
+ + " Make sure this is intended.",
+ )
+ return validation_result
+
+ # TODO 2293705: Add anchoring for more resource types.
+ @classmethod
+ def _anchor_step(
+ cls,
+ pipeline_step: Union[Pipeline, CommandComponent],
+ compute: str,
+ internal_datastore: str,
+ orchestrator_datastore: Optional[str],
+ iteration: Optional[int] = 0,
+ _path: str = "root",
+ ) -> MutableValidationResult:
+ """Take a pipeline step and recursively enforces the right compute/datastore config.
+
+ :param pipeline_step: a step to anchor
+ :type pipeline_step: Union[Pipeline, CommandComponent]
+ :param compute: name of the compute target
+ :type compute: str
+ :param internal_datastore: The name of the datastore that should be used for internal output anchoring.
+ :type internal_datastore: str
+ :param orchestrator_datastore: The name of the orchestrator/aggregation datastore that should be used for
+ 'real' output anchoring.
+ :type orchestrator_datastore: str
+ :param iteration: The current iteration number in the scatter gather loop. Defaults to 0.
+ :type iteration: Optional[int]
+ :param _path: for recursive anchoring, codes the "path" inside the pipeline for messaging
+ :type _path: str
+ :return: A validation result containing any issues that were uncovered during anchoring. This function adds
+ warnings when outputs already have assigned paths which don't contain the expected datastore.
+ :rtype: MutableValidationResult
+ """
+
+ validation_result = cls._create_empty_validation_result()
+
+ # Current step is a pipeline, which means we need to inspect its steps (jobs) and
+ # potentially anchor those as well.
+ if pipeline_step.type == "pipeline":
+ if hasattr(pipeline_step, "component"):
+ # Current step is probably not the root of the graph
+ # its outputs should be anchored to the internal_datastore.
+ for name, output in pipeline_step.outputs.items():
+ if not isinstance(output, str):
+ if output.type in ANCHORABLE_OUTPUT_TYPES:
+ validation_result.merge_with(
+ cls._check_or_set_datastore(
+ name=name,
+ output=output,
+ target_datastore=orchestrator_datastore,
+ iteration_num=iteration,
+ )
+ )
+
+ # then we need to anchor the internal component of this step
+ # The outputs of this sub-component are a deep copy of the outputs of this step
+ # This is dangerous, and we need to make sure they both use the same datastore,
+ # so we keep datastore types identical across this recursive call.
+ cls._anchor_step(
+ pipeline_step.component, # type: ignore
+ compute,
+ internal_datastore=internal_datastore,
+ orchestrator_datastore=orchestrator_datastore,
+ _path=f"{_path}.component",
+ )
+
+ else:
+ # This is a pipeline step with multiple jobs beneath it.
+ # Anchor its outputs...
+ for name, output in pipeline_step.outputs.items():
+ if not isinstance(output, str):
+ if output.type in ANCHORABLE_OUTPUT_TYPES:
+ validation_result.merge_with(
+ cls._check_or_set_datastore(
+ name=name,
+ output=output,
+ target_datastore=orchestrator_datastore,
+ iteration_num=iteration,
+ )
+ )
+ # ...then recursively anchor each job inside the pipeline
+ if not isinstance(pipeline_step, CommandComponent):
+ for job_key in pipeline_step.jobs:
+ job = pipeline_step.jobs[job_key]
+ # replace orchestrator with internal datastore, jobs components
+ # should either use the local datastore
+ # or have already had their outputs re-assigned.
+ cls._anchor_step(
+ job,
+ compute,
+ internal_datastore=internal_datastore,
+ orchestrator_datastore=internal_datastore,
+ _path=f"{_path}.jobs.{job_key}",
+ )
+
+ elif pipeline_step.type == "command":
+ # if the current step is a command component
+ # make sure the compute corresponds to the silo
+ if not isinstance(pipeline_step, CommandComponent) and pipeline_step.compute is None:
+ pipeline_step.compute = compute
+ # then anchor each of the job's outputs
+ for name, output in pipeline_step.outputs.items():
+ if not isinstance(output, str):
+ if output.type in ANCHORABLE_OUTPUT_TYPES:
+ validation_result.merge_with(
+ cls._check_or_set_datastore(
+ name=name,
+ output=output,
+ target_datastore=orchestrator_datastore,
+ iteration_num=iteration,
+ )
+ )
+ else:
+ # TODO revisit this and add support for anchoring more things
+ raise NotImplementedError(f"under path={_path}: step type={pipeline_step.type} is not supported")
+
+ return validation_result
+
+ # Making this a class method allows for easier, isolated testing, and allows careful
+ # users to call this as a pre-init step.
+ # TODO: Might be worth migrating this to a schema validation class, but out of scope for now.
+ # pylint: disable=too-many-statements,too-many-branches, too-many-locals
+ @classmethod
+ def validate_inputs(
+ cls,
+ *,
+ silo_configs: List[FederatedLearningSilo],
+ silo_component: Component,
+ aggregation_component: Component,
+ shared_silo_kwargs: Optional[Dict],
+ aggregation_compute: Optional[str],
+ aggregation_datastore: Optional[str],
+ aggregation_kwargs: Optional[Dict],
+ silo_to_aggregation_argument_map: Optional[Dict],
+ aggregation_to_silo_argument_map: Optional[Dict],
+ max_iterations: int,
+ raise_error: bool = False,
+ ) -> MutableValidationResult:
+ """Validates the inputs for the scatter-gather node.
+
+ :keyword silo_configs: List of federated learning silo configurations.
+ :paramtype silo_configs: List[~azure.ai.ml.entities._assets.federated_learning_silo.FederatedLearningSilo]
+ :keyword silo_component: Component representing the silo for federated learning.
+ :paramtype silo_component: ~azure.ai.ml.entities.Component
+ :keyword aggregation_component: Component representing the aggregation step.
+ :paramtype aggregation_component: ~azure.ai.ml.entities.Component
+ :keyword shared_silo_kwargs: Keyword arguments shared across all silos.
+ :paramtype shared_silo_kwargs: Dict
+ :keyword aggregation_compute: The compute resource for the aggregation step.
+ :paramtype aggregation_compute: str
+ :keyword aggregation_datastore: The datastore for the aggregation step.
+ :paramtype aggregation_datastore: str
+ :keyword aggregation_kwargs: Keyword arguments specific to the aggregation step.
+ :paramtype aggregation_kwargs: Dict
+ :keyword silo_to_aggregation_argument_map: Mapping of silo to aggregation arguments.
+ :paramtype silo_to_aggregation_argument_map: Dict
+ :keyword aggregation_to_silo_argument_map: Mapping of aggregation to silo arguments.
+ :paramtype aggregation_to_silo_argument_map: Dict
+ :keyword max_iterations: The maximum number of iterations for the scatter-gather loop.
+ :paramtype max_iterations: int
+ :keyword raise_error: Whether to raise an exception if validation fails. Defaults to False.
+ :paramtype raise_error: bool
+ :return: The validation result.
+ :rtype: ~azure.ai.ml.entities._validation.MutableValidationResult
+ """
+ validation_result = cls._create_empty_validation_result()
+
+ # saved values for validation later on
+ silo_inputs = None
+ silo_outputs = None
+ agg_inputs = None
+ agg_outputs = None
+ # validate silo component
+ if silo_component is None:
+ validation_result.append_error(
+ yaml_path="silo_component",
+ message="silo_component is a required argument for the scatter gather node.",
+ )
+ else:
+ # ensure that silo component has both inputs and outputs
+ if not hasattr(silo_component, "inputs"):
+ validation_result.append_error(
+ yaml_path="silo_component",
+ message="silo_component is missing 'inputs' attribute;"
+ + "it does not appear to be a valid component that can be used in a scatter-gather loop.",
+ )
+ else:
+ silo_inputs = silo_component.inputs
+ if not hasattr(silo_component, "outputs"):
+ validation_result.append_error(
+ yaml_path="silo_component",
+ message="silo_component is missing 'outputs' attribute;"
+ + "it does not appear to be a valid component that can be used in a scatter-gather loop.",
+ )
+ else:
+ silo_outputs = silo_component.outputs
+ # validate aggregation component
+ if aggregation_component is None:
+ validation_result.append_error(
+ yaml_path="aggregation_component",
+ message="aggregation_component is a required argument for the scatter gather node.",
+ )
+ else:
+ # ensure that aggregation component has both inputs and outputs
+ if not hasattr(aggregation_component, "inputs"):
+ validation_result.append_error(
+ yaml_path="aggregation_component",
+ message="aggregation_component is missing 'inputs' attribute;"
+ + "it does not appear to be a valid component that can be used in a scatter-gather loop.",
+ )
+ else:
+ agg_inputs = aggregation_component.inputs
+ if not hasattr(aggregation_component, "outputs"):
+ validation_result.append_error(
+ yaml_path="aggregation_component",
+ message="aggregation_component is missing 'outputs' attribute;"
+ + " it does not appear to be a valid component that can be used in a scatter-gather loop.",
+ )
+ else:
+ agg_outputs = aggregation_component.outputs
+
+ # validate silos configs
+ if silo_configs is None:
+ validation_result.append_error(
+ yaml_path="silo_configs",
+ message="silo_configs is a required argument for the scatter gather node.",
+ )
+ elif len(silo_configs) == 0:
+ validation_result.append_error(
+ yaml_path="silo_configs",
+ message="silo_configs cannot be an empty list.",
+ )
+ else:
+ first_silo = silo_configs[0]
+ expected_inputs: List = []
+ if hasattr(first_silo, "inputs"):
+ expected_inputs = first_silo.inputs.keys() # type: ignore
+ num_expected_inputs = len(expected_inputs)
+ # pylint: disable=consider-using-enumerate
+ for i in range(len(silo_configs)):
+ silo = silo_configs[i]
+ if not hasattr(silo, "compute"):
+ validation_result.append_error(
+ yaml_path="silo_configs",
+ message=f"Silo at index {i} in silo_configs is missing its compute value.",
+ )
+ if not hasattr(silo, "datastore"):
+ validation_result.append_error(
+ yaml_path="silo_configs",
+ message=f"Silo at index {i} in silo_configs is missing its datastore value.",
+ )
+ silo_input_len = 0
+ if hasattr(silo, "inputs"):
+ silo_input_len = len(silo.inputs)
+ # if inputs exist, make sure the inputs names are consistent across silo configs
+ for expected_input_name in expected_inputs:
+ if expected_input_name not in silo.inputs:
+ validation_result.append_error(
+ yaml_path="silo_configs",
+ message=f"Silo at index {i} has is missing inputs named '{expected_input_name}',"
+ + "which was listed in the first silo config. "
+ + "Silos must have consistent inputs names.",
+ )
+ if silo_input_len != num_expected_inputs:
+ validation_result.append_error(
+ yaml_path="silo_configs",
+ message=f"Silo at index {i} has {silo_input_len} inputs, but the first silo established that"
+ + f"each silo would have {num_expected_inputs} silo-specific inputs.",
+ )
+
+ # Make sure both aggregation overrides are set, or not
+ if aggregation_datastore is None and aggregation_compute is not None:
+ validation_result.append_error(
+ yaml_path="aggregation_datastore",
+ message="aggregation_datastore cannot be unset if aggregation_compute is set.",
+ )
+ elif aggregation_datastore is not None and aggregation_compute is None:
+ validation_result.append_error(
+ yaml_path="aggregation_compute",
+ message="aggregation_compute cannot be unset if aggregation_datastore is set.",
+ )
+
+ # validate component kwargs, ensuring that the relevant components contain the specified inputs
+ if shared_silo_kwargs is None:
+ validation_result.append_error(
+ yaml_path="shared_silo_kwargs",
+ message="shared_silo_kwargs should never be None. Input an empty dictionary instead.",
+ )
+ elif silo_inputs is not None:
+ for k in shared_silo_kwargs.keys():
+ if k not in silo_inputs:
+ validation_result.append_error(
+ yaml_path="shared_silo_kwargs",
+ message=f"shared_silo_kwargs keyword {k} not listed in silo_component's inputs",
+ )
+ if aggregation_kwargs is None:
+ validation_result.append_error(
+ yaml_path="aggregation_kwargs",
+ message="aggregation_kwargs should never be None. Input an empty dictionary instead.",
+ )
+ elif silo_inputs is not None:
+ for k in aggregation_kwargs.keys():
+ if agg_inputs is not None and k not in agg_inputs:
+ validation_result.append_error(
+ yaml_path="aggregation_kwargs",
+ message=f"aggregation_kwargs keyword {k} not listed in aggregation_component's inputs",
+ )
+
+ # validate that argument mappings leverage inputs and outputs that actually exist
+ if aggregation_to_silo_argument_map is None:
+ validation_result.append_error(
+ yaml_path="aggregation_to_silo_argument_map",
+ message="aggregation_to_silo_argument_map should never be None. Input an empty dictionary instead.",
+ )
+ elif silo_inputs is not None and agg_outputs is not None:
+ for k, v in aggregation_to_silo_argument_map.items():
+ if k not in agg_outputs:
+ validation_result.append_error(
+ yaml_path="aggregation_to_silo_argument_map",
+ message=f"aggregation_to_silo_argument_map key {k} "
+ + "is not a known output of the aggregation component.",
+ )
+ if v not in silo_inputs:
+ validation_result.append_error(
+ yaml_path="aggregation_to_silo_argument_map",
+ message=f"aggregation_to_silo_argument_map value {v} "
+ + "is not a known input of the silo component.",
+ )
+ # and check the other mapping
+ if silo_to_aggregation_argument_map is None:
+ validation_result.append_error(
+ yaml_path="silo_to_aggregation_argument_map",
+ message="silo_to_aggregation_argument_map should never be None. "
+ + "Input an empty dictionary instead.",
+ )
+ elif agg_inputs is not None and silo_outputs is not None:
+ for k, v in silo_to_aggregation_argument_map.items():
+ if k not in silo_outputs:
+ validation_result.append_error(
+ yaml_path="silo_to_aggregation_argument_map",
+ message=f"silo_to_aggregation_argument_map key {k }"
+ + " is not a known output of the silo component.",
+ )
+ if v not in agg_inputs:
+ validation_result.append_error(
+ yaml_path="silo_to_aggregation_argument_map",
+ message=f"silo_to_aggregation_argument_map value {v}"
+ + " is not a known input of the aggregation component.",
+ )
+
+ if max_iterations < 1:
+ validation_result.append_error(
+ yaml_path="max_iterations",
+ message=f"max_iterations must be a positive value, not '{max_iterations}'.",
+ )
+
+ return cls._try_raise(validation_result, raise_error=raise_error)
+
+ @classmethod
+ def _custom_fl_data_path(
+ cls,
+ datastore_name: str,
+ output_name: str,
+ unique_id: str = "${{name}}",
+ iteration_num: str = "${{iteration_num}}",
+ ) -> str:
+ """Produces a path to store the data during FL training.
+
+ :param datastore_name: name of the Azure ML datastore
+ :type datastore_name: str
+ :param output_name: a name unique to this output
+ :type output_name: str
+ :param unique_id: a unique id for the run (default: inject run id with ${{name}})
+ :type unique_id: str
+ :param iteration_num: an iteration number if relevant
+ :type iteration_num: str
+ :return: direct url to the data path to store the data
+ :rtype: str
+ """
+ data_path = f"azureml://datastores/{datastore_name}/paths/federated_learning/{output_name}/{unique_id}/"
+ if iteration_num is not None:
+ data_path += f"iteration_{iteration_num}/"
+
+ return data_path
+
+ def _get_aggregator_input_name(self, silo_output_name: str) -> Optional[str]:
+ """Retrieves the aggregator input name
+
+ :param silo_output_name: The silo output name
+ :type silo_output_name: str
+ :return:
+ * Returns aggregator input name that maps to silo_output.
+ * Returns None if silo_output_name not in silo_to_aggregation_argument_map
+ :rtype: Optional[str]
+ """
+ if self.silo_to_aggregation_argument_map is None:
+ return None
+
+ return self.silo_to_aggregation_argument_map.get(silo_output_name)
+
+ @classmethod
+ def _try_create_default_mappings(
+ cls,
+ silo_comp: Optional[Component],
+ agg_comp: Optional[Component],
+ silo_agg_map: Optional[Dict],
+ agg_silo_map: Optional[Dict],
+ ) -> Tuple[Optional[Dict], Optional[Dict]]:
+ """
+ This function tries to produce dictionaries that link the silo and aggregation
+ components' outputs to the other's inputs.
+ The mapping only occurs for inputted mappings that are None, otherwise
+ the inputted mapping is returned unchanged.
+ These auto-generated mappings are naive, and simply maps all outputs of a component that have a
+ identically-named input in the other component.
+
+ This function does nothing if either inputted component is None. This function will also do nothing
+ for a given mapping if either of the relevant inputs or outputs are None (but not empty).
+
+ Example inputs:
+ silo_comp.inputs = {"silo_input" : value }
+ silo_comp.outputs = {"c" : ..., "silo_output2" : ... }
+ agg_comp.inputs = {"silo_output1" : ... }
+ agg_comp.outputs = {"agg_output" : ... }
+ silo_agg_map = None
+ agg_silo_map = {}
+
+ Example returns:
+ {"silo_output1" : "silo_output1"}, {}
+
+ :param silo_comp: The silo component
+ :type silo_comp: Optional[Component]
+ :param agg_comp: The aggregation component
+ :type agg_comp: Optional[Component]
+ :param silo_agg_map: Mapping of silo to aggregation arguments.
+ :type silo_agg_map: Optional[Dict]
+ :param agg_silo_map: Mapping of aggregation to silo arguments.
+ :type agg_silo_map: Optional[Dict]
+ :return: Returns a tuple of the potentially modified silo to aggregation mapping, followed by the aggregation
+ to silo mapping.
+ :rtype: Tuple[Optional[Dict], Optional[Dict]]
+ """
+ if silo_comp is None or agg_comp is None:
+ return silo_agg_map, agg_silo_map
+ if silo_agg_map is None and silo_comp.outputs is not None and agg_comp.inputs is not None:
+ silo_agg_map = {output: output for output in silo_comp.outputs.keys() if output in agg_comp.inputs}
+ if agg_silo_map is None:
+ agg_silo_map = {output: output for output in agg_comp.outputs.keys() if output in silo_comp.inputs}
+ return silo_agg_map, agg_silo_map
+
+ @staticmethod
+ # pylint: disable-next=docstring-missing-rtype
+ def _get_merge_component(output_type: str) -> Any:
+ """Gets the merge component to be used based on type of output
+
+ :param output_type: The output type
+ :type output_type: str
+ :return: The merge component
+ """
+ return MERGE_COMPONENT_MAPPING[output_type]
+
+ def _inject_merge_components(self, executed_silo_components: Any) -> Dict:
+ """Add a merge component for each silo output in the silo_to_aggregation_argument_map.
+ These merge components act as a mediator between the user silo and aggregation steps, reducing
+ the variable number of silo outputs into a single input for the aggergation step.
+
+ :param executed_silo_components: A list of executed silo steps to extract outputs from.
+ :type executed_silo_components:
+ :return: A mapping from silo output names to the corresponding newly created and executed merge component
+ :rtype: dict
+ """
+ executed_component = executed_silo_components[0]
+
+ merge_comp_mapping = {}
+ if self.silo_to_aggregation_argument_map is not None:
+ for (
+ silo_output_argument_name,
+ _,
+ ) in self.silo_to_aggregation_argument_map.items():
+ merge_comp = self._get_merge_component(executed_component.outputs[silo_output_argument_name].type)
+ merge_component_inputs = {
+ silo_output_argument_name
+ + "_silo_"
+ + str(i): executed_silo_components[i].outputs[silo_output_argument_name]
+ for i in range(0, len(executed_silo_components))
+ }
+ executed_merge_component = merge_comp(**merge_component_inputs)
+ for input_obj in executed_merge_component.inputs.values():
+ input_obj.mode = "direct"
+ for output_obj in executed_merge_component.outputs.values():
+ output_obj.type = "mltable"
+ merge_comp_mapping.update({silo_output_argument_name: executed_merge_component})
+
+ return merge_comp_mapping
+
+ # boilerplate functions - largely copied from other node builders
+
+ @property
+ def outputs(self) -> Dict[str, Union[str, Output]]:
+ """Get the outputs of the scatter-gather node.
+
+ :return: The outputs of the scatter-gather node.
+ :rtype: Dict[str, Union[str, ~azure.ai.ml.Output]]
+ """
+ return self._outputs
+
+ @classmethod
+ def _create_schema_for_validation(cls, context: Any) -> PathAwareSchema:
+ return FLScatterGatherSchema(context=context)
+
+ def _to_rest_object(self, **kwargs: Any) -> dict:
+ """Convert self to a rest object for remote call.
+
+ :return: The rest object
+ :rtype: dict
+ """
+ rest_node = super(FLScatterGather, self)._to_rest_object(**kwargs)
+ rest_node.update({"outputs": self._to_rest_outputs()})
+ # TODO: Bug Item number: 2897665
+ res: dict = convert_ordered_dict_to_dict(rest_node) # type: ignore
+ return res
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/import_func.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/import_func.py
new file mode 100644
index 00000000..c9ecabd8
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/import_func.py
@@ -0,0 +1,93 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from typing import Any, Dict, Optional
+
+from azure.ai.ml.constants._component import ComponentSource
+from azure.ai.ml.entities._component.import_component import ImportComponent
+from azure.ai.ml.entities._inputs_outputs import Output
+from azure.ai.ml.entities._job.import_job import ImportSource
+
+from .command_func import _parse_input, _parse_inputs_outputs, _parse_output
+from .import_node import Import
+
+
+def import_job(
+ *,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ display_name: Optional[str] = None,
+ experiment_name: Optional[str] = None,
+ source: Optional[ImportSource] = None,
+ output: Optional[Output] = None,
+ is_deterministic: bool = True,
+ **kwargs: Any,
+) -> Import:
+ """Create an Import object which can be used inside dsl.pipeline as a function
+ and can also be created as a standalone import job.
+
+ :keyword name: Name of the import job or component created.
+ :paramtype name: str
+ :keyword description: A friendly description of the import.
+ :paramtype description: str
+ :keyword tags: Tags to be attached to this import.
+ :paramtype tags: Dict
+ :keyword display_name: A friendly name.
+ :paramtype display_name: str
+ :keyword experiment_name: Name of the experiment the job will be created under.
+ If None is provided, the default will be set to the current directory name.
+ Will be ignored as a pipeline step.
+ :paramtype experiment_name: str
+ :keyword source: Input source parameters used by this import.
+ :paramtype source: ~azure.ai.ml.entities._job.import_job.ImportSource
+ :keyword output: The output of this import.
+ :paramtype output: ~azure.ai.ml.entities.Output
+ :keyword is_deterministic: Specify whether the command will return the same output given the same input.
+ If a command (component) is deterministic, when used as a node/step in a pipeline,
+ it will reuse results from a previously submitted job in the current workspace
+ which has the same inputs and settings.
+ In this case, this step will not use any compute resource.
+ Defaults to True.
+ :paramtype is_deterministic: bool
+ :returns: The Import object.
+ :rtype: ~azure.ai.ml.entities._builders.import_node.Import
+ """
+ inputs = source._to_job_inputs() if source else kwargs.pop("inputs")
+ outputs = {"output": output} if output else kwargs.pop("outputs")
+ component_inputs, job_inputs = _parse_inputs_outputs(inputs, parse_func=_parse_input)
+ # job inputs can not be None
+ job_inputs = {k: v for k, v in job_inputs.items() if v is not None}
+ component_outputs, job_outputs = _parse_inputs_outputs(outputs, parse_func=_parse_output)
+
+ component = kwargs.pop("component", None)
+
+ if component is None:
+ component = ImportComponent(
+ name=name,
+ tags=tags,
+ display_name=display_name,
+ description=description,
+ source=component_inputs,
+ output=component_outputs["output"],
+ _source=ComponentSource.BUILDER,
+ is_deterministic=is_deterministic,
+ **kwargs,
+ )
+
+ import_obj = Import(
+ component=component,
+ name=name,
+ description=description,
+ tags=tags,
+ display_name=display_name,
+ experiment_name=experiment_name,
+ inputs=job_inputs,
+ outputs=job_outputs,
+ **kwargs,
+ )
+
+ return import_obj
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/import_node.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/import_node.py
new file mode 100644
index 00000000..144753d5
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/import_node.py
@@ -0,0 +1,205 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+# pylint: disable=protected-access
+
+import logging
+from typing import Any, Dict, List, Optional, Tuple, Type, Union
+
+from marshmallow import Schema
+
+from azure.ai.ml._restclient.v2022_02_01_preview.models import CommandJob as RestCommandJob
+from azure.ai.ml._restclient.v2022_02_01_preview.models import JobBaseData
+from azure.ai.ml._schema.job.import_job import ImportJobSchema
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+from azure.ai.ml.constants._component import ComponentSource, NodeType
+from azure.ai.ml.constants._compute import ComputeType
+from azure.ai.ml.entities._component.component import Component
+from azure.ai.ml.entities._component.import_component import ImportComponent
+from azure.ai.ml.entities._inputs_outputs import Output
+from azure.ai.ml.entities._job._input_output_helpers import from_rest_data_outputs, from_rest_inputs_to_dataset_literal
+from azure.ai.ml.entities._job.import_job import ImportJob, ImportSource
+from azure.ai.ml.exceptions import ErrorTarget, ValidationErrorType, ValidationException
+
+from ..._schema import PathAwareSchema
+from .._inputs_outputs import Output
+from .._util import convert_ordered_dict_to_dict, load_from_dict, validate_attribute_type
+from .base_node import BaseNode
+
+module_logger = logging.getLogger(__name__)
+
+
+class Import(BaseNode):
+ """Base class for import node, used for import component version consumption.
+
+ You should not instantiate this class directly. Instead, you should
+ create from a builder function.
+
+ :param component: Id or instance of the import component/job to be run for the step.
+ :type component: ~azure.ai.ml.entities._component.import_component.ImportComponent
+ :param inputs: Input parameters to the import.
+ :type inputs: Dict[str, str]
+ :param outputs: Mapping of output data bindings used in the job.
+ :type outputs: Dict[str, Union[str, ~azure.ai.ml.entities.Output]]
+ :param name: Name of the import.
+ :type name: str
+ :param description: Description of the import.
+ :type description: str
+ :param display_name: Display name of the job.
+ :type display_name: str
+ :param experiment_name: Name of the experiment the job will be created under,
+ if None is provided, the default will be set to the current directory name.
+ :type experiment_name: str
+ """
+
+ def __init__(
+ self,
+ *,
+ component: Union[str, ImportComponent],
+ inputs: Optional[Dict] = None,
+ outputs: Optional[Dict] = None,
+ **kwargs: Any,
+ ) -> None:
+ # validate init params are valid type
+ validate_attribute_type(attrs_to_check=locals(), attr_type_map=self._attr_type_map())
+
+ kwargs.pop("type", None)
+ kwargs.pop("compute", None)
+
+ self._parameters = kwargs.pop("parameters", {})
+ BaseNode.__init__(
+ self,
+ type=NodeType.IMPORT,
+ inputs=inputs,
+ outputs=outputs,
+ component=component,
+ compute=ComputeType.ADF,
+ **kwargs,
+ )
+
+ @classmethod
+ def _get_supported_inputs_types(cls) -> Type[str]:
+ # import source parameters type, connection, query, path are always str
+ return str
+
+ @classmethod
+ def _get_supported_outputs_types(cls) -> Tuple:
+ return str, Output
+
+ @property
+ def component(self) -> Union[str, ImportComponent]:
+ res: Union[str, ImportComponent] = self._component
+ return res
+
+ @classmethod
+ def _attr_type_map(cls) -> dict:
+ return {
+ "component": (str, ImportComponent),
+ }
+
+ def _to_job(self) -> ImportJob:
+ return ImportJob(
+ id=self.id,
+ name=self.name,
+ display_name=self.display_name,
+ description=self.description,
+ experiment_name=self.experiment_name,
+ status=self.status,
+ source=ImportSource._from_job_inputs(self._job_inputs),
+ output=self._job_outputs.get("output"),
+ creation_context=self.creation_context,
+ )
+
+ @classmethod
+ def _picked_fields_from_dict_to_rest_object(cls) -> List[str]:
+ return []
+
+ def _to_rest_object(self, **kwargs: Any) -> dict:
+ rest_obj: dict = super()._to_rest_object(**kwargs)
+ rest_obj.update(
+ convert_ordered_dict_to_dict(
+ {
+ "componentId": self._get_component_id(),
+ }
+ )
+ )
+ return rest_obj
+
+ @classmethod
+ def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "Import":
+ from .import_func import import_job
+
+ loaded_data = load_from_dict(ImportJobSchema, data, context, additional_message, **kwargs)
+
+ _import_job: Import = import_job(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_data)
+
+ return _import_job
+
+ @classmethod
+ def _load_from_rest_job(cls, obj: JobBaseData) -> "Import":
+ from .import_func import import_job
+
+ rest_command_job: RestCommandJob = obj.properties
+ inputs = from_rest_inputs_to_dataset_literal(rest_command_job.inputs)
+ outputs = from_rest_data_outputs(rest_command_job.outputs)
+
+ _import_job: Import = import_job(
+ name=obj.name,
+ display_name=rest_command_job.display_name,
+ description=rest_command_job.description,
+ experiment_name=rest_command_job.experiment_name,
+ status=rest_command_job.status,
+ creation_context=obj.system_data,
+ inputs=inputs,
+ output=outputs["output"] if "output" in outputs else None,
+ )
+ _import_job._id = obj.id
+ if isinstance(_import_job.component, ImportComponent):
+ _import_job.component._source = (
+ ComponentSource.REMOTE_WORKSPACE_JOB
+ ) # This is used by pipeline job telemetries.
+
+ return _import_job
+
+ @classmethod
+ def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]:
+ from azure.ai.ml._schema.pipeline import ImportSchema
+
+ return ImportSchema(context=context)
+
+ # pylint: disable-next=docstring-missing-param
+ def __call__(self, *args: Any, **kwargs: Any) -> "Import":
+ """Call Import as a function will return a new instance each time.
+
+ :return: An Import node.
+ :rtype: Import
+ """
+ if isinstance(self._component, Component):
+ # call this to validate inputs
+ node: Import = self._component(*args, **kwargs)
+ # merge inputs
+ for name, original_input in self.inputs.items():
+ if name not in kwargs:
+ # use setattr here to make sure owner of input won't change
+ setattr(node.inputs, name, original_input._data)
+ node._job_inputs[name] = original_input._data
+ # get outputs
+ for name, original_output in self.outputs.items():
+ # use setattr here to make sure owner of input won't change
+ if not isinstance(original_output, str):
+ setattr(node.outputs, name, original_output._data)
+ self._refine_optional_inputs_with_no_value(node, kwargs)
+ # set default values: compute, environment_variables, outputs
+ node._name = self.name
+ node.compute = self.compute
+ node.tags = self.tags
+ # Pass through the display name only if the display name is not system generated.
+ node.display_name = self.display_name if self.display_name != self.name else None
+ return node
+ msg = "Import can be called as a function only when referenced component is {}, currently got {}."
+ raise ValidationException(
+ message=msg.format(type(Component), self._component),
+ no_personal_data_message=msg.format(type(Component), "self._component"),
+ target=ErrorTarget.COMMAND_JOB,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/parallel.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/parallel.py
new file mode 100644
index 00000000..db1de797
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/parallel.py
@@ -0,0 +1,551 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import copy
+import json
+import logging
+import os
+import re
+from enum import Enum
+from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
+
+from marshmallow import INCLUDE, Schema
+
+from azure.ai.ml._schema.core.fields import NestedField, UnionField
+from azure.ai.ml._schema.job.identity import AMLTokenIdentitySchema, ManagedIdentitySchema, UserIdentitySchema
+from azure.ai.ml.entities._credentials import (
+ AmlTokenConfiguration,
+ ManagedIdentityConfiguration,
+ UserIdentityConfiguration,
+ _BaseJobIdentityConfiguration,
+)
+from azure.ai.ml.entities._job.job import Job
+from azure.ai.ml.entities._job.parallel.run_function import RunFunction
+from azure.ai.ml.entities._job.pipeline._io import NodeOutput
+from azure.ai.ml.exceptions import MlException
+
+from ..._schema import PathAwareSchema
+from ..._utils.utils import is_data_binding_expression
+from ...constants._common import ARM_ID_PREFIX
+from ...constants._component import NodeType
+from .._component.component import Component
+from .._component.flow import FlowComponent
+from .._component.parallel_component import ParallelComponent
+from .._inputs_outputs import Input, Output
+from .._job.job_resource_configuration import JobResourceConfiguration
+from .._job.parallel.parallel_job import ParallelJob
+from .._job.parallel.parallel_task import ParallelTask
+from .._job.parallel.retry_settings import RetrySettings
+from .._job.pipeline._io import NodeWithGroupInputMixin
+from .._util import convert_ordered_dict_to_dict, get_rest_dict_for_node_attrs, validate_attribute_type
+from .base_node import BaseNode
+
+module_logger = logging.getLogger(__name__)
+
+
+class Parallel(BaseNode, NodeWithGroupInputMixin): # pylint: disable=too-many-instance-attributes
+ """Base class for parallel node, used for parallel component version consumption.
+
+ You should not instantiate this class directly. Instead, you should
+ create from builder function: parallel.
+
+ :param component: Id or instance of the parallel component/job to be run for the step
+ :type component: ~azure.ai.ml.entities._component.parallel_component.parallelComponent
+ :param name: Name of the parallel
+ :type name: str
+ :param description: Description of the commad
+ :type description: str
+ :param tags: Tag dictionary. Tags can be added, removed, and updated
+ :type tags: dict[str, str]
+ :param properties: The job property dictionary
+ :type properties: dict[str, str]
+ :param display_name: Display name of the job
+ :type display_name: str
+ :param retry_settings: Parallel job run failed retry
+ :type retry_settings: BatchRetrySettings
+ :param logging_level: A string of the logging level name
+ :type logging_level: str
+ :param max_concurrency_per_instance: The max parallellism that each compute instance has
+ :type max_concurrency_per_instance: int
+ :param error_threshold: The number of item processing failures should be ignored
+ :type error_threshold: int
+ :param mini_batch_error_threshold: The number of mini batch processing failures should be ignored
+ :type mini_batch_error_threshold: int
+ :param task: The parallel task
+ :type task: ParallelTask
+ :param mini_batch_size: For FileDataset input, this field is the number of files
+ a user script can process in one run() call.
+ For TabularDataset input, this field is the approximate size of data
+ the user script can process in one run() call.
+ Example values are 1024, 1024KB, 10MB, and 1GB. (optional, default value is 10 files
+ for FileDataset and 1MB for TabularDataset.)
+ This value could be set through PipelineParameter
+ :type mini_batch_size: str
+ :param partition_keys: The keys used to partition dataset into mini-batches. If specified,
+ the data with the same key will be partitioned into the same mini-batch.
+ If both partition_keys and mini_batch_size are specified,
+ the partition keys will take effect.
+ The input(s) must be partitioned dataset(s),
+ and the partition_keys must be a subset of the keys of every input dataset for this to work.
+ :keyword identity: The identity that the command job will use while running on compute.
+ :paramtype identity: Optional[Union[
+ dict[str, str],
+ ~azure.ai.ml.entities.ManagedIdentityConfiguration,
+ ~azure.ai.ml.entities.AmlTokenConfiguration,
+ ~azure.ai.ml.entities.UserIdentityConfiguration]]
+ :type partition_keys: List
+ :param input_data: The input data
+ :type input_data: str
+ :param inputs: Inputs of the component/job
+ :type inputs: dict
+ :param outputs: Outputs of the component/job
+ :type outputs: dict
+ """
+
+ # pylint: disable=too-many-statements
+ def __init__(
+ self,
+ *,
+ component: Union[ParallelComponent, str],
+ compute: Optional[str] = None,
+ inputs: Optional[Dict[str, Union[NodeOutput, Input, str, bool, int, float, Enum]]] = None,
+ outputs: Optional[Dict[str, Union[str, Output, "Output"]]] = None,
+ retry_settings: Optional[Union[RetrySettings, Dict[str, str]]] = None,
+ logging_level: Optional[str] = None,
+ max_concurrency_per_instance: Optional[int] = None,
+ error_threshold: Optional[int] = None,
+ mini_batch_error_threshold: Optional[int] = None,
+ input_data: Optional[str] = None,
+ task: Optional[Union[ParallelTask, RunFunction, Dict]] = None,
+ partition_keys: Optional[List] = None,
+ mini_batch_size: Optional[Union[str, int]] = None,
+ resources: Optional[JobResourceConfiguration] = None,
+ environment_variables: Optional[Dict] = None,
+ identity: Optional[
+ Union[ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration, Dict]
+ ] = None,
+ **kwargs: Any,
+ ) -> None:
+ # validate init params are valid type
+ validate_attribute_type(attrs_to_check=locals(), attr_type_map=self._attr_type_map())
+ kwargs.pop("type", None)
+
+ if isinstance(component, FlowComponent):
+ # make input definition fit actual inputs for flow component
+ with component._inputs._fit_inputs(inputs): # type: ignore[attr-defined]
+ BaseNode.__init__(
+ self,
+ type=NodeType.PARALLEL,
+ component=component,
+ inputs=inputs,
+ outputs=outputs,
+ compute=compute,
+ **kwargs,
+ )
+ else:
+ BaseNode.__init__(
+ self,
+ type=NodeType.PARALLEL,
+ component=component,
+ inputs=inputs,
+ outputs=outputs,
+ compute=compute,
+ **kwargs,
+ )
+ # init mark for _AttrDict
+ self._init = True
+
+ self._task = task
+
+ if (
+ mini_batch_size is not None
+ and not isinstance(mini_batch_size, int)
+ and not is_data_binding_expression(mini_batch_size)
+ ):
+ """Convert str to int.""" # pylint: disable=pointless-string-statement
+ pattern = re.compile(r"^\d+([kKmMgG][bB])*$")
+ if not pattern.match(mini_batch_size):
+ raise ValueError(r"Parameter mini_batch_size must follow regex rule ^\d+([kKmMgG][bB])*$")
+
+ try:
+ mini_batch_size = int(mini_batch_size)
+ except ValueError as e:
+ if not isinstance(mini_batch_size, int):
+ unit = mini_batch_size[-2:].lower()
+ if unit == "kb":
+ mini_batch_size = int(mini_batch_size[0:-2]) * 1024
+ elif unit == "mb":
+ mini_batch_size = int(mini_batch_size[0:-2]) * 1024 * 1024
+ elif unit == "gb":
+ mini_batch_size = int(mini_batch_size[0:-2]) * 1024 * 1024 * 1024
+ else:
+ raise ValueError("mini_batch_size unit must be kb, mb or gb") from e
+
+ self.mini_batch_size = mini_batch_size
+ self.partition_keys = partition_keys
+ self.input_data = input_data
+ self._retry_settings = retry_settings
+ self.logging_level = logging_level
+ self.max_concurrency_per_instance = max_concurrency_per_instance
+ self.error_threshold = error_threshold
+ self.mini_batch_error_threshold = mini_batch_error_threshold
+ self._resources = resources
+ self.environment_variables = {} if environment_variables is None else environment_variables
+ self._identity = identity
+ if isinstance(self.component, ParallelComponent):
+ self.resources = cast(JobResourceConfiguration, self.resources) or cast(
+ JobResourceConfiguration, copy.deepcopy(self.component.resources)
+ )
+ # TODO: Bug Item number: 2897665
+ self.retry_settings = self.retry_settings or copy.deepcopy(self.component.retry_settings) # type: ignore
+ self.input_data = self.input_data or self.component.input_data
+ self.max_concurrency_per_instance = (
+ self.max_concurrency_per_instance or self.component.max_concurrency_per_instance
+ )
+ self.mini_batch_error_threshold = (
+ self.mini_batch_error_threshold or self.component.mini_batch_error_threshold
+ )
+ self.mini_batch_size = self.mini_batch_size or self.component.mini_batch_size
+ self.partition_keys = self.partition_keys or copy.deepcopy(self.component.partition_keys)
+
+ if not self.task:
+ self.task = self.component.task
+ # task.code is based on self.component.base_path
+ self._base_path = self.component.base_path
+
+ self._init = False
+
+ @classmethod
+ def _get_supported_outputs_types(cls) -> Tuple:
+ return str, Output
+
+ @property
+ def retry_settings(self) -> RetrySettings:
+ """Get the retry settings for the parallel job.
+
+ :return: The retry settings for the parallel job.
+ :rtype: ~azure.ai.ml.entities._job.parallel.retry_settings.RetrySettings
+ """
+ return self._retry_settings # type: ignore
+
+ @retry_settings.setter
+ def retry_settings(self, value: Union[RetrySettings, Dict]) -> None:
+ """Set the retry settings for the parallel job.
+
+ :param value: The retry settings for the parallel job.
+ :type value: ~azure.ai.ml.entities._job.parallel.retry_settings.RetrySettings or dict
+ """
+ if isinstance(value, dict):
+ value = RetrySettings(**value)
+ self._retry_settings = value
+
+ @property
+ def resources(self) -> Optional[JobResourceConfiguration]:
+ """Get the resource configuration for the parallel job.
+
+ :return: The resource configuration for the parallel job.
+ :rtype: ~azure.ai.ml.entities._job.job_resource_configuration.JobResourceConfiguration
+ """
+ return self._resources
+
+ @resources.setter
+ def resources(self, value: Union[JobResourceConfiguration, Dict]) -> None:
+ """Set the resource configuration for the parallel job.
+
+ :param value: The resource configuration for the parallel job.
+ :type value: ~azure.ai.ml.entities._job.job_resource_configuration.JobResourceConfiguration or dict
+ """
+ if isinstance(value, dict):
+ value = JobResourceConfiguration(**value)
+ self._resources = value
+
+ @property
+ def identity(
+ self,
+ ) -> Optional[Union[ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration, Dict]]:
+ """The identity that the job will use while running on compute.
+
+ :return: The identity that the job will use while running on compute.
+ :rtype: Optional[Union[~azure.ai.ml.ManagedIdentityConfiguration, ~azure.ai.ml.AmlTokenConfiguration,
+ ~azure.ai.ml.UserIdentityConfiguration]]
+ """
+ return self._identity
+
+ @identity.setter
+ def identity(
+ self,
+ value: Union[Dict, ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration, None],
+ ) -> None:
+ """Sets the identity that the job will use while running on compute.
+
+ :param value: The identity that the job will use while running on compute.
+ :type value: Union[dict[str, str], ~azure.ai.ml.ManagedIdentityConfiguration,
+ ~azure.ai.ml.AmlTokenConfiguration, ~azure.ai.ml.UserIdentityConfiguration]
+ """
+ if isinstance(value, dict):
+ identity_schema = UnionField(
+ [
+ NestedField(ManagedIdentitySchema, unknown=INCLUDE),
+ NestedField(AMLTokenIdentitySchema, unknown=INCLUDE),
+ NestedField(UserIdentitySchema, unknown=INCLUDE),
+ ]
+ )
+ value = identity_schema._deserialize(value=value, attr=None, data=None)
+ self._identity = value
+
+ @property
+ def component(self) -> Union[str, ParallelComponent]:
+ """Get the component of the parallel job.
+
+ :return: The component of the parallel job.
+ :rtype: str or ~azure.ai.ml.entities._component.parallel_component.ParallelComponent
+ """
+ res: Union[str, ParallelComponent] = self._component
+ return res
+
+ @property
+ def task(self) -> Optional[ParallelTask]:
+ """Get the parallel task.
+
+ :return: The parallel task.
+ :rtype: ~azure.ai.ml.entities._job.parallel.parallel_task.ParallelTask
+ """
+ return self._task # type: ignore
+
+ @task.setter
+ def task(self, value: Union[ParallelTask, Dict]) -> None:
+ """Set the parallel task.
+
+ :param value: The parallel task.
+ :type value: ~azure.ai.ml.entities._job.parallel.parallel_task.ParallelTask or dict
+ """
+ # base path should be reset if task is set via sdk
+ self._base_path: Optional[Union[str, os.PathLike]] = None
+ if isinstance(value, dict):
+ value = ParallelTask(**value)
+ self._task = value
+
+ def _set_base_path(self, base_path: Optional[Union[str, os.PathLike]]) -> None:
+ if self._base_path:
+ return
+ super(Parallel, self)._set_base_path(base_path)
+
+ def set_resources(
+ self,
+ *,
+ instance_type: Optional[Union[str, List[str]]] = None,
+ instance_count: Optional[int] = None,
+ properties: Optional[Dict] = None,
+ docker_args: Optional[str] = None,
+ shm_size: Optional[str] = None,
+ # pylint: disable=unused-argument
+ **kwargs: Any,
+ ) -> None:
+ """Set the resources for the parallel job.
+
+ :keyword instance_type: The instance type or a list of instance types used as supported by the compute target.
+ :paramtype instance_type: Union[str, List[str]]
+ :keyword instance_count: The number of instances or nodes used by the compute target.
+ :paramtype instance_count: int
+ :keyword properties: The property dictionary for the resources.
+ :paramtype properties: dict
+ :keyword docker_args: Extra arguments to pass to the Docker run command.
+ :paramtype docker_args: str
+ :keyword shm_size: Size of the Docker container's shared memory block.
+ :paramtype shm_size: str
+ """
+ if self.resources is None:
+ self.resources = JobResourceConfiguration()
+
+ if instance_type is not None:
+ self.resources.instance_type = instance_type
+ if instance_count is not None:
+ self.resources.instance_count = instance_count
+ if properties is not None:
+ self.resources.properties = properties
+ if docker_args is not None:
+ self.resources.docker_args = docker_args
+ if shm_size is not None:
+ self.resources.shm_size = shm_size
+
+ # Save the resources to internal component as well, otherwise calling sweep() will loose the settings
+ if isinstance(self.component, Component):
+ self.component.resources = self.resources
+
+ @classmethod
+ def _attr_type_map(cls) -> dict:
+ return {
+ "component": (str, ParallelComponent, FlowComponent),
+ "retry_settings": (dict, RetrySettings),
+ "resources": (dict, JobResourceConfiguration),
+ "task": (dict, ParallelTask),
+ "logging_level": str,
+ "max_concurrency_per_instance": (str, int),
+ "error_threshold": (str, int),
+ "mini_batch_error_threshold": (str, int),
+ "environment_variables": dict,
+ }
+
+ def _to_job(self) -> ParallelJob:
+ return ParallelJob(
+ name=self.name,
+ display_name=self.display_name,
+ description=self.description,
+ tags=self.tags,
+ properties=self.properties,
+ compute=self.compute,
+ resources=self.resources,
+ partition_keys=self.partition_keys,
+ mini_batch_size=self.mini_batch_size,
+ task=self.task,
+ retry_settings=self.retry_settings,
+ input_data=self.input_data,
+ logging_level=self.logging_level,
+ identity=self.identity,
+ max_concurrency_per_instance=self.max_concurrency_per_instance,
+ error_threshold=self.error_threshold,
+ mini_batch_error_threshold=self.mini_batch_error_threshold,
+ environment_variables=self.environment_variables,
+ inputs=self._job_inputs,
+ outputs=self._job_outputs,
+ )
+
+ def _parallel_attr_to_dict(self, attr: str, base_type: Type) -> dict:
+ # Convert parallel attribute to dict
+ rest_attr = {}
+ parallel_attr = getattr(self, attr)
+ if parallel_attr is not None:
+ if isinstance(parallel_attr, base_type):
+ rest_attr = parallel_attr._to_dict()
+ elif isinstance(parallel_attr, dict):
+ rest_attr = parallel_attr
+ else:
+ msg = f"Expecting {base_type} for {attr}, got {type(parallel_attr)} instead."
+ raise MlException(message=msg, no_personal_data_message=msg)
+ # TODO: Bug Item number: 2897665
+ res: dict = convert_ordered_dict_to_dict(rest_attr) # type: ignore
+ return res
+
+ @classmethod
+ def _picked_fields_from_dict_to_rest_object(cls) -> List[str]:
+ return [
+ "type",
+ "resources",
+ "error_threshold",
+ "mini_batch_error_threshold",
+ "environment_variables",
+ "max_concurrency_per_instance",
+ "task",
+ "input_data",
+ ]
+
+ def _to_rest_object(self, **kwargs: Any) -> dict:
+ rest_obj: Dict = super(Parallel, self)._to_rest_object(**kwargs)
+ rest_obj.update(
+ convert_ordered_dict_to_dict(
+ {
+ "componentId": self._get_component_id(),
+ "retry_settings": get_rest_dict_for_node_attrs(self.retry_settings),
+ "logging_level": self.logging_level,
+ "mini_batch_size": self.mini_batch_size,
+ "partition_keys": (
+ json.dumps(self.partition_keys) if self.partition_keys is not None else self.partition_keys
+ ),
+ "identity": get_rest_dict_for_node_attrs(self.identity),
+ "resources": get_rest_dict_for_node_attrs(self.resources),
+ }
+ )
+ )
+ return rest_obj
+
+ @classmethod
+ def _from_rest_object_to_init_params(cls, obj: dict) -> Dict:
+ obj = super()._from_rest_object_to_init_params(obj)
+ # retry_settings
+ if "retry_settings" in obj and obj["retry_settings"]:
+ obj["retry_settings"] = RetrySettings._from_dict(obj["retry_settings"])
+
+ if "task" in obj and obj["task"]:
+ obj["task"] = ParallelTask._from_dict(obj["task"])
+ task_code = obj["task"].code
+ task_env = obj["task"].environment
+ # remove azureml: prefix in code and environment which is added in _to_rest_object
+ if task_code and isinstance(task_code, str) and task_code.startswith(ARM_ID_PREFIX):
+ obj["task"].code = task_code[len(ARM_ID_PREFIX) :]
+ if task_env and isinstance(task_env, str) and task_env.startswith(ARM_ID_PREFIX):
+ obj["task"].environment = task_env[len(ARM_ID_PREFIX) :]
+
+ if "resources" in obj and obj["resources"]:
+ obj["resources"] = JobResourceConfiguration._from_rest_object(obj["resources"])
+
+ if "partition_keys" in obj and obj["partition_keys"]:
+ obj["partition_keys"] = json.dumps(obj["partition_keys"])
+ if "identity" in obj and obj["identity"]:
+ obj["identity"] = _BaseJobIdentityConfiguration._from_rest_object(obj["identity"])
+ return obj
+
+ def _build_inputs(self) -> Dict:
+ inputs = super(Parallel, self)._build_inputs()
+ built_inputs = {}
+ # Validate and remove non-specified inputs
+ for key, value in inputs.items():
+ if value is not None:
+ built_inputs[key] = value
+ return built_inputs
+
+ @classmethod
+ def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]:
+ from azure.ai.ml._schema.pipeline import ParallelSchema
+
+ return ParallelSchema(context=context)
+
+ # pylint: disable-next=docstring-missing-param
+ def __call__(self, *args: Any, **kwargs: Any) -> "Parallel":
+ """Call Parallel as a function will return a new instance each time.
+
+ :return: A Parallel node
+ :rtype: Parallel
+ """
+ if isinstance(self._component, Component):
+ # call this to validate inputs
+ node: Parallel = self._component(*args, **kwargs)
+ # merge inputs
+ for name, original_input in self.inputs.items():
+ if name not in kwargs:
+ # use setattr here to make sure owner of input won't change
+ setattr(node.inputs, name, original_input._data)
+ # get outputs
+ for name, original_output in self.outputs.items():
+ # use setattr here to make sure owner of input won't change
+ if not isinstance(original_output, str):
+ setattr(node.outputs, name, original_output._data)
+ self._refine_optional_inputs_with_no_value(node, kwargs)
+ # set default values: compute, environment_variables, outputs
+ node._name = self.name
+ node.compute = self.compute
+ node.tags = self.tags
+ node.display_name = self.display_name
+ node.mini_batch_size = self.mini_batch_size
+ node.partition_keys = self.partition_keys
+ node.logging_level = self.logging_level
+ node.max_concurrency_per_instance = self.max_concurrency_per_instance
+ node.error_threshold = self.error_threshold
+ # deep copy for complex object
+ node.retry_settings = copy.deepcopy(self.retry_settings)
+ node.input_data = self.input_data
+ node.task = copy.deepcopy(self.task)
+ node._base_path = self.base_path
+ node.resources = copy.deepcopy(self.resources)
+ node.environment_variables = copy.deepcopy(self.environment_variables)
+ node.identity = copy.deepcopy(self.identity)
+ return node
+ msg = f"Parallel can be called as a function only when referenced component is {type(Component)}, \
+ currently got {self._component}."
+ raise MlException(message=msg, no_personal_data_message=msg)
+
+ @classmethod
+ def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "Job":
+ raise NotImplementedError()
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/parallel_for.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/parallel_for.py
new file mode 100644
index 00000000..1e888f50
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/parallel_for.py
@@ -0,0 +1,362 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+import json
+import os
+from typing import Any, Dict, Optional, Union
+
+from azure.ai.ml import Input, Output
+from azure.ai.ml._schema import PathAwareSchema
+from azure.ai.ml._schema.pipeline.control_flow_job import ParallelForSchema
+from azure.ai.ml._utils.utils import is_data_binding_expression
+from azure.ai.ml.constants import AssetTypes
+from azure.ai.ml.constants._component import ComponentParameterTypes, ControlFlowType
+from azure.ai.ml.entities import Component, Pipeline
+from azure.ai.ml.entities._builders import BaseNode
+from azure.ai.ml.entities._builders.control_flow_node import LoopNode
+from azure.ai.ml.entities._job.pipeline._io import NodeOutput, PipelineInput
+from azure.ai.ml.entities._job.pipeline._io.mixin import NodeIOMixin
+from azure.ai.ml.entities._util import convert_ordered_dict_to_dict, validate_attribute_type
+from azure.ai.ml.entities._validation import MutableValidationResult
+from azure.ai.ml.exceptions import UserErrorException
+
+
+class ParallelFor(LoopNode, NodeIOMixin):
+ """Parallel for loop node in the pipeline job. By specifying the loop body and aggregated items, a job-level
+ parallel for loop can be implemented. It will be initialized when calling dsl.parallel_for or when loading the
+ pipeline yml containing parallel_for node. Please do not manually initialize this class.
+
+ :param body: Pipeline job for the parallel for loop body.
+ :type body: ~azure.ai.ml.entities.Pipeline
+ :param items: The loop body's input which will bind to the loop node.
+ :type items: typing.Union[list, dict, str, ~azure.ai.ml.entities._job.pipeline._io.NodeOutput,
+ ~azure.ai.ml.entities._job.pipeline._io.PipelineInput]
+ :param max_concurrency: Maximum number of concurrent iterations to run. All loop body nodes will be executed
+ in parallel if not specified.
+ :type max_concurrency: int
+ """
+
+ OUT_TYPE_MAPPING = {
+ AssetTypes.URI_FILE: AssetTypes.MLTABLE,
+ AssetTypes.URI_FOLDER: AssetTypes.MLTABLE,
+ AssetTypes.MLTABLE: AssetTypes.MLTABLE,
+ AssetTypes.MLFLOW_MODEL: AssetTypes.MLTABLE,
+ AssetTypes.TRITON_MODEL: AssetTypes.MLTABLE,
+ AssetTypes.CUSTOM_MODEL: AssetTypes.MLTABLE,
+ # legacy path support
+ "path": AssetTypes.MLTABLE,
+ ComponentParameterTypes.NUMBER: ComponentParameterTypes.STRING,
+ ComponentParameterTypes.STRING: ComponentParameterTypes.STRING,
+ ComponentParameterTypes.BOOLEAN: ComponentParameterTypes.STRING,
+ ComponentParameterTypes.INTEGER: ComponentParameterTypes.STRING,
+ }
+
+ def __init__(
+ self,
+ *,
+ body: "Pipeline",
+ items: Union[list, dict, str, PipelineInput, NodeOutput],
+ max_concurrency: Optional[int] = None,
+ **kwargs: Any,
+ ) -> None:
+ # validate init params are valid type
+ validate_attribute_type(attrs_to_check=locals(), attr_type_map=self._attr_type_map())
+
+ kwargs.pop("type", None)
+ super(ParallelFor, self).__init__(
+ type=ControlFlowType.PARALLEL_FOR,
+ body=body,
+ **kwargs,
+ )
+ # loop body is incomplete in submission time, so won't validate required inputs
+ self.body._validate_required_input_not_provided = False
+ self._outputs: dict = {}
+
+ actual_outputs = kwargs.get("outputs", {})
+ # parallel for node shares output meta with body
+ try:
+ outputs = self.body._component.outputs
+ # transform body outputs to aggregate types when available
+ self._outputs = self._build_outputs_dict(
+ outputs=actual_outputs, output_definition_dict=self._convert_output_meta(outputs)
+ )
+ except AttributeError:
+ # when body output not available, create default output builder without meta
+ self._outputs = self._build_outputs_dict(outputs=actual_outputs)
+
+ self._items = items
+
+ self.max_concurrency = max_concurrency
+
+ @property
+ def outputs(self) -> Dict[str, Union[str, Output]]:
+ """Get the outputs of the parallel for loop.
+
+ :return: The dictionary containing the outputs of the parallel for loop.
+ :rtype: dict[str, Union[str, ~azure.ai.ml.Output]]
+ """
+ return self._outputs
+
+ @property
+ def items(self) -> Union[list, dict, str, PipelineInput, NodeOutput]:
+ """Get the loop body's input which will bind to the loop node.
+
+ :return: The input for the loop body.
+ :rtype: typing.Union[list, dict, str, ~azure.ai.ml.entities._job.pipeline._io.NodeOutput,
+ ~azure.ai.ml.entities._job.pipeline._io.PipelineInput]
+ """
+ return self._items
+
+ @classmethod
+ def _create_schema_for_validation(cls, context: Any) -> PathAwareSchema:
+ return ParallelForSchema(context=context)
+
+ @classmethod
+ def _attr_type_map(cls) -> dict:
+ return {
+ **super(ParallelFor, cls)._attr_type_map(),
+ "items": (dict, list, str, PipelineInput, NodeOutput),
+ }
+
+ @classmethod
+ # pylint: disable-next=docstring-missing-param
+ def _to_rest_item(cls, item: dict) -> dict:
+ """Convert item to rest object.
+
+ :return: The rest object
+ :rtype: dict
+ """
+ primitive_inputs, asset_inputs = {}, {}
+ # validate item
+ for key, val in item.items():
+ if isinstance(val, Input):
+ asset_inputs[key] = val
+ elif isinstance(val, (PipelineInput, NodeOutput)):
+ # convert binding object to string
+ primitive_inputs[key] = str(val)
+ else:
+ primitive_inputs[key] = val
+ return {
+ # asset type inputs will be converted to JobInput dict:
+ # {"asset_param": {"uri": "xxx", "job_input_type": "uri_file"}}
+ **cls._input_entity_to_rest_inputs(input_entity=asset_inputs),
+ # primitive inputs has primitive type value like this
+ # {"int_param": 1}
+ **primitive_inputs,
+ }
+
+ @classmethod
+ # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype
+ def _to_rest_items(cls, items: Union[list, dict, str, NodeOutput, PipelineInput]) -> str:
+ """Convert items to rest object."""
+ # validate items.
+ cls._validate_items(items=items, raise_error=True, body_component=None)
+ result: str = ""
+ # convert items to rest object
+ if isinstance(items, list):
+ rest_items_list = [cls._to_rest_item(item=i) for i in items]
+ result = json.dumps(rest_items_list)
+ elif isinstance(items, dict):
+ rest_items_dict = {k: cls._to_rest_item(item=v) for k, v in items.items()}
+ result = json.dumps(rest_items_dict)
+ elif isinstance(items, (NodeOutput, PipelineInput)):
+ result = str(items)
+ elif isinstance(items, str):
+ result = items
+ else:
+ raise UserErrorException("Unsupported items type: {}".format(type(items)))
+ return result
+
+ def _to_rest_object(self, **kwargs: Any) -> dict:
+ """Convert self to a rest object for remote call.
+
+ :return: The rest object
+ :rtype: dict
+ """
+ rest_node = super(ParallelFor, self)._to_rest_object(**kwargs)
+ # convert items to rest object
+ rest_items = self._to_rest_items(items=self.items)
+ rest_node.update({"items": rest_items, "outputs": self._to_rest_outputs()})
+ # TODO: Bug Item number: 2897665
+ res: dict = convert_ordered_dict_to_dict(rest_node) # type: ignore
+ return res
+
+ @classmethod
+ # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype
+ def _from_rest_item(cls, rest_item: Any) -> Dict:
+ """Convert rest item to item."""
+ primitive_inputs, asset_inputs = {}, {}
+ for key, val in rest_item.items():
+ if isinstance(val, dict) and val.get("job_input_type"):
+ asset_inputs[key] = val
+ else:
+ primitive_inputs[key] = val
+ return {**cls._from_rest_inputs(inputs=asset_inputs), **primitive_inputs}
+
+ @classmethod
+ # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype
+ def _from_rest_items(cls, rest_items: str) -> Union[dict, list, str]:
+ """Convert items from rest object."""
+ try:
+ items = json.loads(rest_items)
+ except json.JSONDecodeError:
+ # return original items when failed to load
+ return rest_items
+ if isinstance(items, list):
+ return [cls._from_rest_item(rest_item=i) for i in items]
+ if isinstance(items, dict):
+ return {k: cls._from_rest_item(rest_item=v) for k, v in items.items()}
+ return rest_items
+
+ @classmethod
+ def _from_rest_object(cls, obj: dict, pipeline_jobs: dict) -> "ParallelFor":
+ # pylint: disable=protected-access
+ obj = BaseNode._from_rest_object_to_init_params(obj)
+ obj["items"] = cls._from_rest_items(rest_items=obj.get("items", ""))
+ return cls._create_instance_from_schema_dict(pipeline_jobs=pipeline_jobs, loaded_data=obj)
+
+ @classmethod
+ def _create_instance_from_schema_dict(cls, pipeline_jobs: Dict, loaded_data: Dict, **kwargs: Any) -> "ParallelFor":
+ body_name = cls._get_data_binding_expression_value(loaded_data.pop("body"), regex=r"\{\{.*\.jobs\.(.*)\}\}")
+
+ loaded_data["body"] = cls._get_body_from_pipeline_jobs(pipeline_jobs=pipeline_jobs, body_name=body_name)
+ return cls(**loaded_data, **kwargs)
+
+ def _convert_output_meta(self, outputs: Dict[str, Union[NodeOutput, Output]]) -> Dict[str, Output]:
+ """Convert output meta to aggregate types.
+
+ :param outputs: Output meta
+ :type outputs: Dict[str, Union[NodeOutput, Output]]
+ :return: Dictionary of aggregate types
+ :rtype: Dict[str, Output]
+ """
+ # pylint: disable=protected-access
+ aggregate_outputs = {}
+ for name, output in outputs.items():
+ if output.type in self.OUT_TYPE_MAPPING:
+ new_type = self.OUT_TYPE_MAPPING[output.type]
+ else:
+ # when loop body introduces some new output type, this will be raised as a reminder to support is in
+ # parallel for
+ raise UserErrorException(
+ "Referencing output with type {} is not supported in parallel_for node.".format(output.type)
+ )
+ if isinstance(output, NodeOutput):
+ output = output._to_job_output() # type: ignore
+ if isinstance(output, Output):
+ out_dict = output._to_dict()
+ out_dict["type"] = new_type
+ resolved_output = Output(**out_dict)
+ else:
+ resolved_output = Output(type=new_type)
+ aggregate_outputs[name] = resolved_output
+ return aggregate_outputs
+
+ def _customized_validate(self) -> MutableValidationResult:
+ """Customized validation for parallel for node.
+
+ :return: The validation result
+ :rtype: MutableValidationResult
+ """
+ # pylint: disable=protected-access
+ validation_result = self._validate_body()
+ validation_result.merge_with(
+ self._validate_items(items=self.items, raise_error=False, body_component=self.body._component)
+ )
+ return validation_result
+
+ @classmethod
+ def _validate_items(
+ cls,
+ items: Union[list, dict, str, NodeOutput, PipelineInput],
+ raise_error: bool = True,
+ body_component: Optional[Union[str, Component]] = None,
+ ) -> MutableValidationResult:
+ validation_result = cls._create_empty_validation_result()
+ if items is not None:
+ if isinstance(items, str):
+ # TODO: remove the validation
+ # try to deserialize str if it's a json string
+ try:
+ items = json.loads(items)
+ except json.JSONDecodeError as e:
+ if not is_data_binding_expression(items, ["parent"]):
+ validation_result.append_error(
+ yaml_path="items",
+ message=f"Items is neither a valid JSON string due to {e} or a binding string.",
+ )
+ if isinstance(items, dict):
+ # Validate dict keys
+ items = list(items.values())
+ if isinstance(items, list):
+ if len(items) > 0:
+ cls._validate_items_list(items, validation_result, body_component=body_component)
+ else:
+ validation_result.append_error(yaml_path="items", message="Items is an empty list/dict.")
+ else:
+ validation_result.append_error(
+ yaml_path="items",
+ message="Items is required for parallel_for node",
+ )
+ return cls._try_raise(validation_result, raise_error=raise_error)
+
+ @classmethod
+ def _validate_items_list(
+ cls,
+ items: list,
+ validation_result: MutableValidationResult,
+ body_component: Optional[Union[str, Component]] = None,
+ ) -> None:
+ meta: dict = {}
+ # all items have to be dict and have matched meta
+ for item in items:
+ # item has to be dict
+ # Note: item can be empty dict when loop_body don't have foreach inputs.
+ if not isinstance(item, dict):
+ validation_result.append_error(
+ yaml_path="items",
+ message=f"Items has to be list/dict of dict as value, " f"but got {type(item)} for {item}.",
+ )
+ else:
+ # item has to have matched meta
+ if meta.keys() != item.keys():
+ if not meta.keys():
+ meta = item
+ else:
+ msg = f"Items should have same keys with body inputs, but got {item.keys()} and {meta.keys()}."
+ validation_result.append_error(yaml_path="items", message=msg)
+ # items' keys should appear in body's inputs
+ if isinstance(body_component, Component) and (not item.keys() <= body_component.inputs.keys()):
+ msg = f"Item {item} got unmatched inputs with loop body component inputs {body_component.inputs}."
+ validation_result.append_error(yaml_path="items", message=msg)
+ # validate item value type
+ cls._validate_item_value_type(item=item, validation_result=validation_result)
+
+ @classmethod
+ def _validate_item_value_type(cls, item: dict, validation_result: MutableValidationResult) -> None:
+ supported_types = (Input, str, bool, int, float, PipelineInput)
+ for _, val in item.items():
+ if not isinstance(val, supported_types):
+ validation_result.append_error(
+ yaml_path="items",
+ message="Unsupported type {} in parallel_for items. Supported types are: {}".format(
+ type(val), supported_types
+ ),
+ )
+ if isinstance(val, Input):
+ cls._validate_input_item_value(entry=val, validation_result=validation_result)
+
+ @classmethod
+ def _validate_input_item_value(cls, entry: Input, validation_result: MutableValidationResult) -> None:
+ if not isinstance(entry, Input):
+ return
+ if not entry.path:
+ validation_result.append_error(
+ yaml_path="items",
+ message=f"Input path not provided for {entry}.",
+ )
+ if isinstance(entry.path, str) and os.path.exists(entry.path):
+ validation_result.append_error(
+ yaml_path="items",
+ message=f"Local file input {entry} is not supported, please create it as a dataset.",
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/parallel_func.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/parallel_func.py
new file mode 100644
index 00000000..a8f08d1e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/parallel_func.py
@@ -0,0 +1,285 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+import os
+from typing import Any, Dict, List, Optional, Union
+
+from azure.ai.ml.constants._component import ComponentSource
+from azure.ai.ml.entities._component.parallel_component import ParallelComponent
+from azure.ai.ml.entities._credentials import (
+ AmlTokenConfiguration,
+ ManagedIdentityConfiguration,
+ UserIdentityConfiguration,
+)
+from azure.ai.ml.entities._deployment.deployment_settings import BatchRetrySettings
+from azure.ai.ml.entities._job.parallel.run_function import RunFunction
+
+from .command_func import _parse_input, _parse_inputs_outputs, _parse_output
+from .parallel import Parallel
+
+
+def parallel_run_function(
+ *,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ properties: Optional[Dict] = None,
+ display_name: Optional[str] = None,
+ experiment_name: Optional[str] = None,
+ compute: Optional[str] = None,
+ retry_settings: Optional[BatchRetrySettings] = None,
+ environment_variables: Optional[Dict] = None,
+ logging_level: Optional[str] = None,
+ max_concurrency_per_instance: Optional[int] = None,
+ error_threshold: Optional[int] = None,
+ mini_batch_error_threshold: Optional[int] = None,
+ task: Optional[RunFunction] = None,
+ mini_batch_size: Optional[str] = None,
+ partition_keys: Optional[List] = None,
+ input_data: Optional[str] = None,
+ inputs: Optional[Dict] = None,
+ outputs: Optional[Dict] = None,
+ instance_count: Optional[int] = None,
+ instance_type: Optional[str] = None,
+ docker_args: Optional[str] = None,
+ shm_size: Optional[str] = None,
+ identity: Optional[Union[ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration]] = None,
+ is_deterministic: bool = True,
+ **kwargs: Any,
+) -> Parallel:
+ """Create a Parallel object which can be used inside dsl.pipeline as a function and can also be created as a
+ standalone parallel job.
+
+ For an example of using ParallelRunStep, see the notebook
+ https://aka.ms/parallel-example-notebook
+
+ .. note::
+
+ To use parallel_run_function:
+
+ * Create a :class:`azure.ai.ml.entities._builders.Parallel` object to specify how parallel run is performed,
+ with parameters to control batch size,number of nodes per compute target, and a
+ reference to your custom Python script.
+
+ * Build pipeline with the parallel object as a function. defines inputs and
+ outputs for the step.
+
+ * Sumbit the pipeline to run.
+
+ .. code:: python
+
+ from azure.ai.ml import Input, Output, parallel
+
+ parallel_run = parallel_run_function(
+ name="batch_score_with_tabular_input",
+ display_name="Batch Score with Tabular Dataset",
+ description="parallel component for batch score",
+ inputs=dict(
+ job_data_path=Input(
+ type=AssetTypes.MLTABLE,
+ description="The data to be split and scored in parallel",
+ ),
+ score_model=Input(
+ type=AssetTypes.URI_FOLDER, description="The model for batch score."
+ ),
+ ),
+ outputs=dict(job_output_path=Output(type=AssetTypes.MLTABLE)),
+ input_data="${{inputs.job_data_path}}",
+ max_concurrency_per_instance=2, # Optional, default is 1
+ mini_batch_size="100", # optional
+ mini_batch_error_threshold=5, # Optional, allowed failed count on mini batch items, default is -1
+ logging_level="DEBUG", # Optional, default is INFO
+ error_threshold=5, # Optional, allowed failed count totally, default is -1
+ retry_settings=dict(max_retries=2, timeout=60), # Optional
+ task=RunFunction(
+ code="./src",
+ entry_script="tabular_batch_inference.py",
+ environment=Environment(
+ image="mcr.microsoft.com/azureml/openmpi3.1.2-ubuntu18.04",
+ conda_file="./src/environment_parallel.yml",
+ ),
+ program_arguments="--model ${{inputs.score_model}}",
+ append_row_to="${{outputs.job_output_path}}", # Optional, if not set, summary_only
+ ),
+ )
+
+ :keyword name: Name of the parallel job or component created.
+ :paramtype name: str
+ :keyword description: A friendly description of the parallel.
+ :paramtype description: str
+ :keyword tags: Tags to be attached to this parallel.
+ :paramtype tags: Dict
+ :keyword properties: The asset property dictionary.
+ :paramtype properties: Dict
+ :keyword display_name: A friendly name.
+ :paramtype display_name: str
+ :keyword experiment_name: Name of the experiment the job will be created under,
+ if None is provided, default will be set to current directory name.
+ Will be ignored as a pipeline step.
+ :paramtype experiment_name: str
+ :keyword compute: The name of the compute where the parallel job is executed (will not be used
+ if the parallel is used as a component/function).
+ :paramtype compute: str
+ :keyword retry_settings: Parallel component run failed retry
+ :paramtype retry_settings: ~azure.ai.ml.entities._deployment.deployment_settings.BatchRetrySettings
+ :keyword environment_variables: A dictionary of environment variables names and values.
+ These environment variables are set on the process
+ where user script is being executed.
+ :paramtype environment_variables: Dict[str, str]
+ :keyword logging_level: A string of the logging level name, which is defined in 'logging'.
+ Possible values are 'WARNING', 'INFO', and 'DEBUG'. (optional, default value is 'INFO'.)
+ This value could be set through PipelineParameter.
+ :paramtype logging_level: str
+ :keyword max_concurrency_per_instance: The max parallellism that each compute instance has.
+ :paramtype max_concurrency_per_instance: int
+ :keyword error_threshold: The number of record failures for Tabular Dataset and file failures for File Dataset
+ that should be ignored during processing.
+ If the error count goes above this value, then the job will be aborted.
+ Error threshold is for the entire input rather
+ than the individual mini-batch sent to run() method.
+ The range is [-1, int.max]. -1 indicates ignore all failures during processing
+ :paramtype error_threshold: int
+ :keyword mini_batch_error_threshold: The number of mini batch processing failures should be ignored
+ :paramtype mini_batch_error_threshold: int
+ :keyword task: The parallel task
+ :paramtype task: ~azure.ai.ml.entities._job.parallel.run_function.RunFunction
+ :keyword mini_batch_size: For FileDataset input,
+ this field is the number of files a user script can process in one run() call.
+ For TabularDataset input, this field is the approximate size of data
+ the user script can process in one run() call.
+ Example values are 1024, 1024KB, 10MB, and 1GB.
+ (optional, default value is 10 files for FileDataset and 1MB for TabularDataset.)
+ This value could be set through PipelineParameter.
+ :paramtype mini_batch_size: str
+ :keyword partition_keys: The keys used to partition dataset into mini-batches. If specified,
+ the data with the same key will be partitioned into the same mini-batch.
+ If both partition_keys and mini_batch_size are specified,
+ the partition keys will take effect.
+ The input(s) must be partitioned dataset(s),
+ and the partition_keys must be a subset of the keys of every input dataset for this to work
+ :paramtype partition_keys: List
+ :keyword input_data: The input data.
+ :paramtype input_data: str
+ :keyword inputs: A dict of inputs used by this parallel.
+ :paramtype inputs: Dict
+ :keyword outputs: The outputs of this parallel
+ :paramtype outputs: Dict
+ :keyword instance_count: Optional number of instances or nodes used by the compute target.
+ Defaults to 1
+ :paramtype instance_count: int
+ :keyword instance_type: Optional type of VM used as supported by the compute target..
+ :paramtype instance_type: str
+ :keyword docker_args: Extra arguments to pass to the Docker run command.
+ This would override any parameters that have already been set by the system,
+ or in this section.
+ This parameter is only supported for Azure ML compute types.
+ :paramtype docker_args: str
+ :keyword shm_size: Size of the docker container's shared memory block.
+ This should be in the format of (number)(unit) where number as to be greater than 0
+ and the unit can be one of b(bytes), k(kilobytes), m(megabytes), or g(gigabytes).
+ :paramtype shm_size: str
+ :keyword identity: Identity that PRS job will use while running on compute.
+ :paramtype identity: Optional[Union[
+ ~azure.ai.ml.entities.ManagedIdentityConfiguration,
+ ~azure.ai.ml.entities.AmlTokenConfiguration,
+ ~azure.ai.ml.entities.UserIdentityConfiguration]]
+ :keyword is_deterministic: Specify whether the parallel will return same output given same input.
+ If a parallel (component) is deterministic, when use it as a node/step in a pipeline,
+ it will reuse results from a previous submitted job in current workspace
+ which has same inputs and settings.
+ In this case, this step will not use any compute resource. Defaults to True,
+ specify is_deterministic=False if you would like to avoid such reuse behavior,
+ defaults to True.
+ :paramtype is_deterministic: bool
+ :return: The parallel node
+ :rtype: ~azure.ai.ml._builders.parallel.Parallel
+ """
+ # pylint: disable=too-many-locals
+ inputs = inputs or {}
+ outputs = outputs or {}
+ component_inputs, job_inputs = _parse_inputs_outputs(inputs, parse_func=_parse_input)
+ # job inputs can not be None
+ job_inputs = {k: v for k, v in job_inputs.items() if v is not None}
+ component_outputs, job_outputs = _parse_inputs_outputs(outputs, parse_func=_parse_output)
+
+ component = kwargs.pop("component", None)
+
+ if component is None:
+ if task is None:
+ component = ParallelComponent(
+ base_path=os.getcwd(), # base path should be current folder
+ name=name,
+ tags=tags,
+ code=None,
+ display_name=display_name,
+ description=description,
+ inputs=component_inputs,
+ outputs=component_outputs,
+ retry_settings=retry_settings, # type: ignore[arg-type]
+ logging_level=logging_level,
+ max_concurrency_per_instance=max_concurrency_per_instance,
+ error_threshold=error_threshold,
+ mini_batch_error_threshold=mini_batch_error_threshold,
+ task=task,
+ mini_batch_size=mini_batch_size,
+ partition_keys=partition_keys,
+ input_data=input_data,
+ _source=ComponentSource.BUILDER,
+ is_deterministic=is_deterministic,
+ **kwargs,
+ )
+ else:
+ component = ParallelComponent(
+ base_path=os.getcwd(), # base path should be current folder
+ name=name,
+ tags=tags,
+ code=task.code,
+ display_name=display_name,
+ description=description,
+ inputs=component_inputs,
+ outputs=component_outputs,
+ retry_settings=retry_settings, # type: ignore[arg-type]
+ logging_level=logging_level,
+ max_concurrency_per_instance=max_concurrency_per_instance,
+ error_threshold=error_threshold,
+ mini_batch_error_threshold=mini_batch_error_threshold,
+ task=task,
+ mini_batch_size=mini_batch_size,
+ partition_keys=partition_keys,
+ input_data=input_data,
+ _source=ComponentSource.BUILDER,
+ is_deterministic=is_deterministic,
+ **kwargs,
+ )
+
+ parallel_obj = Parallel(
+ component=component,
+ name=name,
+ description=description,
+ tags=tags,
+ properties=properties,
+ display_name=display_name,
+ experiment_name=experiment_name,
+ compute=compute,
+ inputs=job_inputs,
+ outputs=job_outputs,
+ identity=identity,
+ environment_variables=environment_variables,
+ retry_settings=retry_settings, # type: ignore[arg-type]
+ logging_level=logging_level,
+ max_concurrency_per_instance=max_concurrency_per_instance,
+ error_threshold=error_threshold,
+ mini_batch_error_threshold=mini_batch_error_threshold,
+ task=task,
+ mini_batch_size=mini_batch_size,
+ partition_keys=partition_keys,
+ input_data=input_data,
+ **kwargs,
+ )
+
+ if instance_count is not None or instance_type is not None or docker_args is not None or shm_size is not None:
+ parallel_obj.set_resources(
+ instance_count=instance_count, instance_type=instance_type, docker_args=docker_args, shm_size=shm_size
+ )
+
+ return parallel_obj
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/pipeline.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/pipeline.py
new file mode 100644
index 00000000..188d9044
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/pipeline.py
@@ -0,0 +1,225 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+import logging
+from enum import Enum
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
+
+from marshmallow import Schema
+
+from azure.ai.ml.entities._component.component import Component, NodeType
+from azure.ai.ml.entities._inputs_outputs import Input, Output
+from azure.ai.ml.entities._job.job import Job
+from azure.ai.ml.entities._validation import MutableValidationResult
+
+from ..._schema import PathAwareSchema
+from .._job.pipeline.pipeline_job_settings import PipelineJobSettings
+from .._util import convert_ordered_dict_to_dict, copy_output_setting, validate_attribute_type
+from .base_node import BaseNode
+
+if TYPE_CHECKING:
+ from azure.ai.ml.entities._job.pipeline.pipeline_job import PipelineJob
+
+module_logger = logging.getLogger(__name__)
+
+
+class Pipeline(BaseNode):
+ """Base class for pipeline node, used for pipeline component version consumption. You should not instantiate this
+ class directly. Instead, you should use @pipeline decorator to create a pipeline node.
+
+ :param component: Id or instance of the pipeline component/job to be run for the step.
+ :type component: Union[~azure.ai.ml.entities.Component, str]
+ :param inputs: Inputs of the pipeline node.
+ :type inputs: Optional[Dict[str, Union[
+ ~azure.ai.ml.entities.Input,
+ str, bool, int, float, Enum, "Input"]]].
+ :param outputs: Outputs of the pipeline node.
+ :type outputs: Optional[Dict[str, Union[str, ~azure.ai.ml.entities.Output, "Output"]]]
+ :param settings: Setting of pipeline node, only taking effect for root pipeline job.
+ :type settings: Optional[~azure.ai.ml.entities._job.pipeline.pipeline_job_settings.PipelineJobSettings]
+ """
+
+ def __init__(
+ self,
+ *,
+ component: Union[Component, str],
+ inputs: Optional[
+ Dict[
+ str,
+ Union[
+ Input,
+ str,
+ bool,
+ int,
+ float,
+ Enum,
+ "Input",
+ ],
+ ]
+ ] = None,
+ outputs: Optional[Dict[str, Union[str, Output, "Output"]]] = None,
+ settings: Optional[PipelineJobSettings] = None,
+ **kwargs: Any,
+ ) -> None:
+ # validate init params are valid type
+ validate_attribute_type(attrs_to_check=locals(), attr_type_map=self._attr_type_map())
+ kwargs.pop("type", None)
+
+ BaseNode.__init__(
+ self,
+ type=NodeType.PIPELINE,
+ component=component,
+ inputs=inputs,
+ outputs=outputs,
+ **kwargs,
+ )
+ # copy pipeline component output's setting to node level
+ self._copy_pipeline_component_out_setting_to_node()
+ self._settings: Optional[PipelineJobSettings] = None
+ self.settings = settings
+
+ @property
+ def component(self) -> Optional[Union[str, Component]]:
+ """Id or instance of the pipeline component/job to be run for the step.
+
+ :return: Id or instance of the pipeline component/job.
+ :rtype: Union[str, ~azure.ai.ml.entities.Component]
+ """
+ res: Union[str, Component] = self._component
+ return res
+
+ @property
+ def settings(self) -> Optional[PipelineJobSettings]:
+ """Settings of the pipeline.
+
+ Note: settings is available only when create node as a job.
+ i.e. ml_client.jobs.create_or_update(node).
+
+ :return: Settings of the pipeline.
+ :rtype: ~azure.ai.ml.entities.PipelineJobSettings
+ """
+ if self._settings is None:
+ self._settings = PipelineJobSettings()
+ return self._settings
+
+ @settings.setter
+ def settings(self, value: Union[PipelineJobSettings, Dict]) -> None:
+ """Set the settings of the pipeline.
+
+ :param value: The settings of the pipeline.
+ :type value: Union[~azure.ai.ml.entities.PipelineJobSettings, dict]
+ :raises TypeError: If the value is not an instance of PipelineJobSettings or a dict.
+ """
+ if value is not None:
+ if isinstance(value, PipelineJobSettings):
+ # since PipelineJobSettings inherit _AttrDict, we need add this branch to distinguish with dict
+ pass
+ elif isinstance(value, dict):
+ value = PipelineJobSettings(**value)
+ else:
+ raise TypeError("settings must be PipelineJobSettings or dict but got {}".format(type(value)))
+ self._settings = value
+
+ @classmethod
+ def _get_supported_inputs_types(cls) -> None:
+ # Return None here to skip validation,
+ # as input could be custom class object(parameter group).
+ return None
+
+ @property
+ def _skip_required_compute_missing_validation(self) -> bool:
+ return True
+
+ @classmethod
+ def _get_skip_fields_in_schema_validation(cls) -> List[str]:
+ # pipeline component must be a file reference when loading from yaml,
+ # so the created object can't pass schema validation.
+ return ["component"]
+
+ @classmethod
+ def _attr_type_map(cls) -> dict:
+ # Use local import to avoid recursive reference as BaseNode is imported in PipelineComponent.
+ from azure.ai.ml.entities import PipelineComponent
+
+ return {
+ "component": (str, PipelineComponent),
+ }
+
+ def _to_job(self) -> "PipelineJob":
+ from azure.ai.ml.entities._job.pipeline.pipeline_job import PipelineJob
+
+ return PipelineJob(
+ name=self.name,
+ display_name=self.display_name,
+ description=self.description,
+ tags=self.tags,
+ properties=self.properties,
+ # Filter None out to avoid case below failed with conflict keys check:
+ # group: None (user not specified)
+ # group.xx: 1 (user specified
+ inputs={k: v for k, v in self._job_inputs.items() if v},
+ outputs=self._job_outputs,
+ component=self.component,
+ settings=self.settings,
+ )
+
+ def _customized_validate(self) -> MutableValidationResult:
+ """Check unsupported settings when use as a node.
+
+ :return: The validation result
+ :rtype: MutableValidationResult
+ """
+ # Note: settings is not supported on node,
+ # jobs.create_or_update(node) will call node._to_job() at first,
+ # thus won't reach here.
+ # pylint: disable=protected-access
+ from azure.ai.ml.entities import PipelineComponent
+
+ validation_result = super(Pipeline, self)._customized_validate()
+ ignored_keys = PipelineComponent._check_ignored_keys(self)
+ if ignored_keys:
+ validation_result.append_warning(message=f"{ignored_keys} ignored on node {self.name!r}.")
+ if isinstance(self.component, PipelineComponent):
+ validation_result.merge_with(self.component._customized_validate())
+ return validation_result
+
+ def _to_rest_object(self, **kwargs: Any) -> dict:
+ rest_obj: Dict = super()._to_rest_object(**kwargs)
+ rest_obj.update(
+ convert_ordered_dict_to_dict(
+ {
+ "componentId": self._get_component_id(),
+ }
+ )
+ )
+ return rest_obj
+
+ def _build_inputs(self) -> Dict:
+ inputs = super(Pipeline, self)._build_inputs()
+ built_inputs = {}
+ # Validate and remove non-specified inputs
+ for key, value in inputs.items():
+ if value is not None:
+ built_inputs[key] = value
+ return built_inputs
+
+ @classmethod
+ def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]:
+ from azure.ai.ml._schema.pipeline.pipeline_component import PipelineSchema
+
+ return PipelineSchema(context=context)
+
+ def _copy_pipeline_component_out_setting_to_node(self) -> None:
+ """Copy pipeline component output's setting to node level."""
+ from azure.ai.ml.entities import PipelineComponent
+ from azure.ai.ml.entities._job.pipeline._io import NodeOutput
+
+ if not isinstance(self.component, PipelineComponent):
+ return
+ for key, val in self.component.outputs.items():
+ node_output = cast(NodeOutput, self.outputs.get(key))
+ copy_output_setting(source=val, target=node_output)
+
+ @classmethod
+ def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "Job":
+ raise NotImplementedError()
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/spark.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/spark.py
new file mode 100644
index 00000000..e72f1334
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/spark.py
@@ -0,0 +1,663 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+# pylint: disable=protected-access, too-many-instance-attributes
+
+import copy
+import logging
+import re
+from enum import Enum
+from os import PathLike, path
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Tuple, Union, cast
+
+from marshmallow import INCLUDE, Schema
+
+from ..._restclient.v2023_04_01_preview.models import JobBase as JobBaseData
+from ..._restclient.v2023_04_01_preview.models import SparkJob as RestSparkJob
+from ..._schema import NestedField, PathAwareSchema, UnionField
+from ..._schema.job.identity import AMLTokenIdentitySchema, ManagedIdentitySchema, UserIdentitySchema
+from ..._schema.job.parameterized_spark import CONF_KEY_MAP
+from ..._schema.job.spark_job import SparkJobSchema
+from ..._utils.utils import is_url
+from ...constants._common import (
+ ARM_ID_PREFIX,
+ BASE_PATH_CONTEXT_KEY,
+ REGISTRY_URI_FORMAT,
+ SPARK_ENVIRONMENT_WARNING_MESSAGE,
+)
+from ...constants._component import NodeType
+from ...constants._job.job import SparkConfKey
+from ...entities._assets import Environment
+from ...entities._component.component import Component
+from ...entities._component.spark_component import SparkComponent
+from ...entities._credentials import (
+ AmlTokenConfiguration,
+ ManagedIdentityConfiguration,
+ UserIdentityConfiguration,
+ _BaseJobIdentityConfiguration,
+)
+from ...entities._inputs_outputs import Input, Output
+from ...entities._job._input_output_helpers import (
+ from_rest_data_outputs,
+ from_rest_inputs_to_dataset_literal,
+ validate_inputs_for_args,
+)
+from ...entities._job.spark_job import SparkJob
+from ...entities._job.spark_job_entry import SparkJobEntryType
+from ...entities._job.spark_resource_configuration import SparkResourceConfiguration
+from ...entities._validation import MutableValidationResult
+from ...exceptions import ErrorCategory, ErrorTarget, ValidationException
+from .._job.pipeline._io import NodeOutput
+from .._job.spark_helpers import (
+ _validate_compute_or_resources,
+ _validate_input_output_mode,
+ _validate_spark_configurations,
+)
+from .._job.spark_job_entry_mixin import SparkJobEntry, SparkJobEntryMixin
+from .._util import convert_ordered_dict_to_dict, get_rest_dict_for_node_attrs, load_from_dict, validate_attribute_type
+from .base_node import BaseNode
+
+module_logger = logging.getLogger(__name__)
+
+
+class Spark(BaseNode, SparkJobEntryMixin):
+ """Base class for spark node, used for spark component version consumption.
+
+ You should not instantiate this class directly. Instead, you should
+ create it from the builder function: spark.
+
+ :param component: The ID or instance of the Spark component or job to be run during the step.
+ :type component: Union[str, ~azure.ai.ml.entities.SparkComponent]
+ :param identity: The identity that the Spark job will use while running on compute.
+ :type identity: Union[Dict[str, str],
+ ~azure.ai.ml.entities.ManagedIdentityConfiguration,
+ ~azure.ai.ml.entities.AmlTokenConfiguration,
+ ~azure.ai.ml.entities.UserIdentityConfiguration
+
+ ]
+
+ :param driver_cores: The number of cores to use for the driver process, only in cluster mode.
+ :type driver_cores: int
+ :param driver_memory: The amount of memory to use for the driver process, formatted as strings with a size unit
+ suffix ("k", "m", "g" or "t") (e.g. "512m", "2g").
+ :type driver_memory: str
+ :param executor_cores: The number of cores to use on each executor.
+ :type executor_cores: int
+ :param executor_memory: The amount of memory to use per executor process, formatted as strings with a size unit
+ suffix ("k", "m", "g" or "t") (e.g. "512m", "2g").
+ :type executor_memory: str
+ :param executor_instances: The initial number of executors.
+ :type executor_instances: int
+ :param dynamic_allocation_enabled: Whether to use dynamic resource allocation, which scales the number of
+ executors registered with this application up and down based on the workload.
+ :type dynamic_allocation_enabled: bool
+ :param dynamic_allocation_min_executors: The lower bound for the number of executors if dynamic allocation
+ is enabled.
+ :type dynamic_allocation_min_executors: int
+ :param dynamic_allocation_max_executors: The upper bound for the number of executors if dynamic allocation
+ is enabled.
+ :type dynamic_allocation_max_executors: int
+ :param conf: A dictionary with pre-defined Spark configurations key and values.
+ :type conf: Dict[str, str]
+ :param inputs: A mapping of input names to input data sources used in the job.
+ :type inputs: Dict[str, Union[
+ str,
+ bool,
+ int,
+ float,
+ Enum,
+ ~azure.ai.ml.entities._job.pipeline._io.NodeOutput,
+ ~azure.ai.ml.Input
+
+ ]]
+
+ :param outputs: A mapping of output names to output data sources used in the job.
+ :type outputs: Dict[str, Union[str, ~azure.ai.ml.Output]]
+ :param args: The arguments for the job.
+ :type args: str
+ :param compute: The compute resource the job runs on.
+ :type compute: str
+ :param resources: The compute resource configuration for the job.
+ :type resources: Union[Dict, ~azure.ai.ml.entities.SparkResourceConfiguration]
+ :param entry: The file or class entry point.
+ :type entry: Dict[str, str]
+ :param py_files: The list of .zip, .egg or .py files to place on the PYTHONPATH for Python apps.
+ :type py_files: List[str]
+ :param jars: The list of .JAR files to include on the driver and executor classpaths.
+ :type jars: List[str]
+ :param files: The list of files to be placed in the working directory of each executor.
+ :type files: List[str]
+ :param archives: The list of archives to be extracted into the working directory of each executor.
+ :type archives: List[str]
+ """
+
+ def __init__(
+ self,
+ *,
+ component: Union[str, SparkComponent],
+ identity: Optional[
+ Union[Dict, ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration]
+ ] = None,
+ driver_cores: Optional[Union[int, str]] = None,
+ driver_memory: Optional[str] = None,
+ executor_cores: Optional[Union[int, str]] = None,
+ executor_memory: Optional[str] = None,
+ executor_instances: Optional[Union[int, str]] = None,
+ dynamic_allocation_enabled: Optional[Union[bool, str]] = None,
+ dynamic_allocation_min_executors: Optional[Union[int, str]] = None,
+ dynamic_allocation_max_executors: Optional[Union[int, str]] = None,
+ conf: Optional[Dict[str, str]] = None,
+ inputs: Optional[
+ Dict[
+ str,
+ Union[
+ NodeOutput,
+ Input,
+ str,
+ bool,
+ int,
+ float,
+ Enum,
+ "Input",
+ ],
+ ]
+ ] = None,
+ outputs: Optional[Dict[str, Union[str, Output, "Output"]]] = None,
+ compute: Optional[str] = None,
+ resources: Optional[Union[Dict, SparkResourceConfiguration]] = None,
+ entry: Union[Dict[str, str], SparkJobEntry, None] = None,
+ py_files: Optional[List[str]] = None,
+ jars: Optional[List[str]] = None,
+ files: Optional[List[str]] = None,
+ archives: Optional[List[str]] = None,
+ args: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ # validate init params are valid type
+ validate_attribute_type(attrs_to_check=locals(), attr_type_map=self._attr_type_map())
+ kwargs.pop("type", None)
+
+ BaseNode.__init__(
+ self, type=NodeType.SPARK, inputs=inputs, outputs=outputs, component=component, compute=compute, **kwargs
+ )
+
+ # init mark for _AttrDict
+ self._init = True
+ SparkJobEntryMixin.__init__(self, entry=entry)
+ self.conf = conf
+ self.driver_cores = driver_cores
+ self.driver_memory = driver_memory
+ self.executor_cores = executor_cores
+ self.executor_memory = executor_memory
+ self.executor_instances = executor_instances
+ self.dynamic_allocation_enabled = dynamic_allocation_enabled
+ self.dynamic_allocation_min_executors = dynamic_allocation_min_executors
+ self.dynamic_allocation_max_executors = dynamic_allocation_max_executors
+
+ is_spark_component = isinstance(component, SparkComponent)
+ if is_spark_component:
+ # conf is dict and we need copy component conf here, otherwise node conf setting will affect component
+ # setting
+ _component = cast(SparkComponent, component)
+ self.conf = self.conf or copy.copy(_component.conf)
+ self.driver_cores = self.driver_cores or _component.driver_cores
+ self.driver_memory = self.driver_memory or _component.driver_memory
+ self.executor_cores = self.executor_cores or _component.executor_cores
+ self.executor_memory = self.executor_memory or _component.executor_memory
+ self.executor_instances = self.executor_instances or _component.executor_instances
+ self.dynamic_allocation_enabled = self.dynamic_allocation_enabled or _component.dynamic_allocation_enabled
+ self.dynamic_allocation_min_executors = (
+ self.dynamic_allocation_min_executors or _component.dynamic_allocation_min_executors
+ )
+ self.dynamic_allocation_max_executors = (
+ self.dynamic_allocation_max_executors or _component.dynamic_allocation_max_executors
+ )
+ if self.executor_instances is None and str(self.dynamic_allocation_enabled).lower() == "true":
+ self.executor_instances = self.dynamic_allocation_min_executors
+ # When create standalone job or pipeline job, following fields will always get value from component or get
+ # default None, because we will not pass those fields to Spark. But in following cases, we expect to get
+ # correct value from spark._from_rest_object() and then following fields will get from their respective
+ # keyword arguments.
+ # 1. when we call regenerated_spark_node=Spark._from_rest_object(spark_node._to_rest_object()) in local test,
+ # we expect regenerated_spark_node and spark_node are identical.
+ # 2.when get created remote job through Job._from_rest_object(result) in job operation where component is an
+ # arm_id, we expect get remote returned values.
+ # 3.when we load a remote job, component now is an arm_id, we need get entry from node level returned from
+ # service
+ self.entry = _component.entry if is_spark_component else entry
+ self.py_files = _component.py_files if is_spark_component else py_files
+ self.jars = _component.jars if is_spark_component else jars
+ self.files = _component.files if is_spark_component else files
+ self.archives = _component.archives if is_spark_component else archives
+ self.args = _component.args if is_spark_component else args
+ self.environment: Any = _component.environment if is_spark_component else None
+
+ self.resources = resources
+ self.identity = identity
+ self._swept = False
+ self._init = False
+
+ @classmethod
+ def _get_supported_outputs_types(cls) -> Tuple:
+ return str, Output
+
+ @property
+ def component(self) -> Union[str, SparkComponent]:
+ """The ID or instance of the Spark component or job to be run during the step.
+
+ :rtype: ~azure.ai.ml.entities.SparkComponent
+ """
+ res: Union[str, SparkComponent] = self._component
+ return res
+
+ @property
+ def resources(self) -> Optional[Union[Dict, SparkResourceConfiguration]]:
+ """The compute resource configuration for the job.
+
+ :rtype: ~azure.ai.ml.entities.SparkResourceConfiguration
+ """
+ return self._resources # type: ignore
+
+ @resources.setter
+ def resources(self, value: Optional[Union[Dict, SparkResourceConfiguration]]) -> None:
+ """Sets the compute resource configuration for the job.
+
+ :param value: The compute resource configuration for the job.
+ :type value: Union[Dict[str, str], ~azure.ai.ml.entities.SparkResourceConfiguration]
+ """
+ if isinstance(value, dict):
+ value = SparkResourceConfiguration(**value)
+ self._resources = value
+
+ @property
+ def identity(
+ self,
+ ) -> Optional[Union[Dict, ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration]]:
+ """The identity that the Spark job will use while running on compute.
+
+ :rtype: Union[~azure.ai.ml.entities.ManagedIdentityConfiguration, ~azure.ai.ml.entities.AmlTokenConfiguration,
+ ~azure.ai.ml.entities.UserIdentityConfiguration]
+ """
+ # If there is no identity from CLI/SDK input: for jobs running on synapse compute (MLCompute Clusters), the
+ # managed identity is the default; for jobs running on clusterless, the user identity should be the default,
+ # otherwise use user input identity.
+ if self._identity is None:
+ if self.compute is not None:
+ return ManagedIdentityConfiguration()
+ if self.resources is not None:
+ return UserIdentityConfiguration()
+ return self._identity
+
+ @identity.setter
+ def identity(
+ self,
+ value: Union[Dict[str, str], ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration],
+ ) -> None:
+ """Sets the identity that the Spark job will use while running on compute.
+
+ :param value: The identity that the Spark job will use while running on compute.
+ :type value: Union[Dict[str, str], ~azure.ai.ml.entities.ManagedIdentityConfiguration,
+ ~azure.ai.ml.entities.AmlTokenConfiguration, ~azure.ai.ml.entities.UserIdentityConfiguration]
+ """
+ if isinstance(value, dict):
+ identify_schema = UnionField(
+ [
+ NestedField(ManagedIdentitySchema, unknown=INCLUDE),
+ NestedField(AMLTokenIdentitySchema, unknown=INCLUDE),
+ NestedField(UserIdentitySchema, unknown=INCLUDE),
+ ]
+ )
+ value = identify_schema._deserialize(value=value, attr=None, data=None)
+ self._identity = value
+
+ @property
+ def code(self) -> Optional[Union[str, PathLike]]:
+ """The local or remote path pointing at source code.
+
+ :rtype: Union[str, PathLike]
+ """
+ if isinstance(self.component, Component):
+ _code: Optional[Union[str, PathLike]] = self.component.code
+ return _code
+ return None
+
+ @code.setter
+ def code(self, value: str) -> None:
+ """Sets the source code to be used for the job.
+
+ :param value: The local or remote path pointing at source code.
+ :type value: Union[str, PathLike]
+ """
+ if isinstance(self.component, Component):
+ self.component.code = value
+ else:
+ msg = "Can't set code property for a registered component {}"
+ raise ValidationException(
+ message=msg.format(self.component),
+ no_personal_data_message=msg.format(self.component),
+ target=ErrorTarget.SPARK_JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ @classmethod
+ def _from_rest_object_to_init_params(cls, obj: dict) -> Dict:
+ obj = super()._from_rest_object_to_init_params(obj)
+
+ if "resources" in obj and obj["resources"]:
+ obj["resources"] = SparkResourceConfiguration._from_rest_object(obj["resources"])
+
+ if "identity" in obj and obj["identity"]:
+ obj["identity"] = _BaseJobIdentityConfiguration._from_rest_object(obj["identity"])
+
+ if "entry" in obj and obj["entry"]:
+ obj["entry"] = SparkJobEntry._from_rest_object(obj["entry"])
+ if "conf" in obj and obj["conf"]:
+ # get conf setting value from conf
+ for field_name, _ in CONF_KEY_MAP.items():
+ value = obj["conf"].get(field_name, None)
+ if value is not None:
+ obj[field_name] = value
+
+ return obj
+
+ @classmethod
+ def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "Spark":
+ from .spark_func import spark
+
+ loaded_data = load_from_dict(SparkJobSchema, data, context, additional_message, **kwargs)
+ spark_job: Spark = spark(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_data)
+
+ return spark_job
+
+ @classmethod
+ def _load_from_rest_job(cls, obj: JobBaseData) -> "Spark":
+ from .spark_func import spark
+
+ rest_spark_job: RestSparkJob = obj.properties
+ rest_spark_conf = copy.copy(rest_spark_job.conf) or {}
+
+ spark_job: Spark = spark(
+ name=obj.name,
+ id=obj.id,
+ entry=SparkJobEntry._from_rest_object(rest_spark_job.entry),
+ display_name=rest_spark_job.display_name,
+ description=rest_spark_job.description,
+ tags=rest_spark_job.tags,
+ properties=rest_spark_job.properties,
+ experiment_name=rest_spark_job.experiment_name,
+ services=rest_spark_job.services,
+ status=rest_spark_job.status,
+ creation_context=obj.system_data,
+ code=rest_spark_job.code_id,
+ compute=rest_spark_job.compute_id,
+ environment=rest_spark_job.environment_id,
+ identity=(
+ _BaseJobIdentityConfiguration._from_rest_object(rest_spark_job.identity)
+ if rest_spark_job.identity
+ else None
+ ),
+ args=rest_spark_job.args,
+ conf=rest_spark_conf,
+ driver_cores=rest_spark_conf.get(
+ SparkConfKey.DRIVER_CORES, None
+ ), # copy fields from conf into the promote attribute in spark
+ driver_memory=rest_spark_conf.get(SparkConfKey.DRIVER_MEMORY, None),
+ executor_cores=rest_spark_conf.get(SparkConfKey.EXECUTOR_CORES, None),
+ executor_memory=rest_spark_conf.get(SparkConfKey.EXECUTOR_MEMORY, None),
+ executor_instances=rest_spark_conf.get(SparkConfKey.EXECUTOR_INSTANCES, None),
+ dynamic_allocation_enabled=rest_spark_conf.get(SparkConfKey.DYNAMIC_ALLOCATION_ENABLED, None),
+ dynamic_allocation_min_executors=rest_spark_conf.get(SparkConfKey.DYNAMIC_ALLOCATION_MIN_EXECUTORS, None),
+ dynamic_allocation_max_executors=rest_spark_conf.get(SparkConfKey.DYNAMIC_ALLOCATION_MAX_EXECUTORS, None),
+ resources=SparkResourceConfiguration._from_rest_object(rest_spark_job.resources),
+ inputs=from_rest_inputs_to_dataset_literal(rest_spark_job.inputs),
+ outputs=from_rest_data_outputs(rest_spark_job.outputs),
+ )
+ return spark_job
+
+ @classmethod
+ def _attr_type_map(cls) -> dict:
+ return {
+ # hack: allow use InternalSparkComponent as component
+ # "component": (str, SparkComponent),
+ "environment": (str, Environment),
+ "resources": (dict, SparkResourceConfiguration),
+ "code": (str, PathLike),
+ }
+
+ @property
+ def _skip_required_compute_missing_validation(self) -> bool:
+ return self.resources is not None
+
+ def _to_job(self) -> SparkJob:
+ if isinstance(self.component, SparkComponent):
+ return SparkJob(
+ experiment_name=self.experiment_name,
+ name=self.name,
+ display_name=self.display_name,
+ description=self.description,
+ tags=self.tags,
+ code=self.component.code,
+ entry=self.entry,
+ py_files=self.py_files,
+ jars=self.jars,
+ files=self.files,
+ archives=self.archives,
+ identity=self.identity,
+ driver_cores=self.driver_cores,
+ driver_memory=self.driver_memory,
+ executor_cores=self.executor_cores,
+ executor_memory=self.executor_memory,
+ executor_instances=self.executor_instances,
+ dynamic_allocation_enabled=self.dynamic_allocation_enabled,
+ dynamic_allocation_min_executors=self.dynamic_allocation_min_executors,
+ dynamic_allocation_max_executors=self.dynamic_allocation_max_executors,
+ conf=self.conf,
+ environment=self.environment,
+ status=self.status,
+ inputs=self._job_inputs,
+ outputs=self._job_outputs,
+ services=self.services,
+ args=self.args,
+ compute=self.compute,
+ resources=self.resources,
+ )
+
+ return SparkJob(
+ experiment_name=self.experiment_name,
+ name=self.name,
+ display_name=self.display_name,
+ description=self.description,
+ tags=self.tags,
+ code=self.component,
+ entry=self.entry,
+ py_files=self.py_files,
+ jars=self.jars,
+ files=self.files,
+ archives=self.archives,
+ identity=self.identity,
+ driver_cores=self.driver_cores,
+ driver_memory=self.driver_memory,
+ executor_cores=self.executor_cores,
+ executor_memory=self.executor_memory,
+ executor_instances=self.executor_instances,
+ dynamic_allocation_enabled=self.dynamic_allocation_enabled,
+ dynamic_allocation_min_executors=self.dynamic_allocation_min_executors,
+ dynamic_allocation_max_executors=self.dynamic_allocation_max_executors,
+ conf=self.conf,
+ environment=self.environment,
+ status=self.status,
+ inputs=self._job_inputs,
+ outputs=self._job_outputs,
+ services=self.services,
+ args=self.args,
+ compute=self.compute,
+ resources=self.resources,
+ )
+
+ @classmethod
+ def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]:
+ from azure.ai.ml._schema.pipeline import SparkSchema
+
+ return SparkSchema(context=context)
+
+ @classmethod
+ def _picked_fields_from_dict_to_rest_object(cls) -> List[str]:
+ return [
+ "type",
+ "resources",
+ "py_files",
+ "jars",
+ "files",
+ "archives",
+ "identity",
+ "conf",
+ "args",
+ ]
+
+ def _to_rest_object(self, **kwargs: Any) -> dict:
+ rest_obj: dict = super()._to_rest_object(**kwargs)
+ rest_obj.update(
+ convert_ordered_dict_to_dict(
+ {
+ "componentId": self._get_component_id(),
+ "identity": get_rest_dict_for_node_attrs(self.identity),
+ "resources": get_rest_dict_for_node_attrs(self.resources),
+ "entry": get_rest_dict_for_node_attrs(self.entry),
+ }
+ )
+ )
+ return rest_obj
+
+ def _build_inputs(self) -> dict:
+ inputs = super(Spark, self)._build_inputs()
+ built_inputs = {}
+ # Validate and remove non-specified inputs
+ for key, value in inputs.items():
+ if value is not None:
+ built_inputs[key] = value
+ return built_inputs
+
+ def _customized_validate(self) -> MutableValidationResult:
+ result = super()._customized_validate()
+ if (
+ isinstance(self.component, SparkComponent)
+ and isinstance(self.component._environment, Environment)
+ and self.component._environment.image is not None
+ ):
+ result.append_warning(
+ yaml_path="environment.image",
+ message=SPARK_ENVIRONMENT_WARNING_MESSAGE,
+ )
+ result.merge_with(self._validate_entry_exist())
+ result.merge_with(self._validate_fields())
+ return result
+
+ def _validate_entry_exist(self) -> MutableValidationResult:
+ is_remote_code = isinstance(self.code, str) and (
+ self.code.startswith("git+")
+ or self.code.startswith(REGISTRY_URI_FORMAT)
+ or self.code.startswith(ARM_ID_PREFIX)
+ or is_url(self.code)
+ or bool(self.CODE_ID_RE_PATTERN.match(self.code))
+ )
+ validation_result = self._create_empty_validation_result()
+ # validate whether component entry exists to ensure code path is correct, especially when code is default value
+ if self.code is None or is_remote_code or not isinstance(self.entry, SparkJobEntry):
+ # skip validate when code is not a local path or code is None, or self.entry is not SparkJobEntry object
+ pass
+ else:
+ if not path.isabs(self.code):
+ _component: SparkComponent = self.component # type: ignore
+ code_path = Path(_component.base_path) / self.code
+ if code_path.exists():
+ code_path = code_path.resolve().absolute()
+ else:
+ validation_result.append_error(
+ message=f"Code path {code_path} doesn't exist.", yaml_path="component.code"
+ )
+ entry_path = code_path / self.entry.entry
+ else:
+ entry_path = Path(self.code) / self.entry.entry
+
+ if (
+ isinstance(self.entry, SparkJobEntry)
+ and self.entry.entry_type == SparkJobEntryType.SPARK_JOB_FILE_ENTRY
+ ):
+ if not entry_path.exists():
+ validation_result.append_error(
+ message=f"Entry {entry_path} doesn't exist.", yaml_path="component.entry"
+ )
+ return validation_result
+
+ def _validate_fields(self) -> MutableValidationResult:
+ validation_result = self._create_empty_validation_result()
+ try:
+ _validate_compute_or_resources(self.compute, self.resources)
+ except ValidationException as e:
+ validation_result.append_error(message=str(e), yaml_path="resources")
+ validation_result.append_error(message=str(e), yaml_path="compute")
+
+ try:
+ _validate_input_output_mode(self.inputs, self.outputs)
+ except ValidationException as e:
+ msg = str(e)
+ m = re.match(r"(Input|Output) '(\w+)'", msg)
+ if m:
+ io_type, io_name = m.groups()
+ if io_type == "Input":
+ validation_result.append_error(message=msg, yaml_path=f"inputs.{io_name}")
+ else:
+ validation_result.append_error(message=msg, yaml_path=f"outputs.{io_name}")
+
+ try:
+ _validate_spark_configurations(self)
+ except ValidationException as e:
+ validation_result.append_error(message=str(e), yaml_path="conf")
+
+ try:
+ self._validate_entry()
+ except ValidationException as e:
+ validation_result.append_error(message=str(e), yaml_path="entry")
+
+ if self.args:
+ try:
+ validate_inputs_for_args(self.args, self.inputs)
+ except ValidationException as e:
+ validation_result.append_error(message=str(e), yaml_path="args")
+ return validation_result
+
+ # pylint: disable-next=docstring-missing-param
+ def __call__(self, *args: Any, **kwargs: Any) -> "Spark":
+ """Call Spark as a function will return a new instance each time.
+
+ :return: A Spark object
+ :rtype: Spark
+ """
+ if isinstance(self._component, Component):
+ # call this to validate inputs
+ node: Spark = self._component(*args, **kwargs)
+ # merge inputs
+ for name, original_input in self.inputs.items():
+ if name not in kwargs:
+ # use setattr here to make sure owner of input won't change
+ setattr(node.inputs, name, original_input._data)
+ node._job_inputs[name] = original_input._data
+ # get outputs
+ for name, original_output in self.outputs.items():
+ # use setattr here to make sure owner of output won't change
+ if not isinstance(original_output, str):
+ setattr(node.outputs, name, original_output._data)
+ self._refine_optional_inputs_with_no_value(node, kwargs)
+ node._name = self.name
+ node.compute = self.compute
+ node.environment = copy.deepcopy(self.environment)
+ node.resources = copy.deepcopy(self.resources)
+ return node
+
+ msg = "Spark can be called as a function only when referenced component is {}, currently got {}."
+ raise ValidationException(
+ message=msg.format(type(Component), self._component),
+ no_personal_data_message=msg.format(type(Component), "self._component"),
+ target=ErrorTarget.SPARK_JOB,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/spark_func.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/spark_func.py
new file mode 100644
index 00000000..342f8c44
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/spark_func.py
@@ -0,0 +1,306 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+# pylint: disable=protected-access, too-many-locals
+
+import os
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import AmlToken, ManagedIdentity, UserIdentity
+from azure.ai.ml.constants._common import AssetTypes
+from azure.ai.ml.constants._component import ComponentSource
+from azure.ai.ml.entities import Environment
+from azure.ai.ml.entities._component.spark_component import SparkComponent
+from azure.ai.ml.entities._inputs_outputs import Input, Output
+from azure.ai.ml.entities._job.pipeline._component_translatable import ComponentTranslatableMixin
+from azure.ai.ml.entities._job.spark_job_entry import SparkJobEntry
+from azure.ai.ml.entities._job.spark_resource_configuration import SparkResourceConfiguration
+from azure.ai.ml.exceptions import ErrorTarget, ValidationException
+
+from .spark import Spark
+
+SUPPORTED_INPUTS = [AssetTypes.URI_FILE, AssetTypes.URI_FOLDER, AssetTypes.MLTABLE]
+
+
+def _parse_input(input_value: Union[Input, dict, str, bool, int, float]) -> Tuple:
+ component_input = None
+ job_input: Union[Input, dict, str, bool, int, float] = ""
+
+ if isinstance(input_value, Input):
+ component_input = Input(**input_value._to_dict())
+ input_type = input_value.type
+ if input_type in SUPPORTED_INPUTS:
+ job_input = Input(**input_value._to_dict())
+ elif isinstance(input_value, dict):
+ # if user provided dict, we try to parse it to Input.
+ # for job input, only parse for path type
+ input_type = input_value.get("type", None)
+ if input_type in SUPPORTED_INPUTS:
+ job_input = Input(**input_value)
+ component_input = Input(**input_value)
+ elif isinstance(input_value, (str, bool, int, float)):
+ # Input bindings are not supported
+ component_input = ComponentTranslatableMixin._to_input_builder_function(input_value)
+ job_input = input_value
+ else:
+ msg = f"Unsupported input type: {type(input_value)}, only Input, dict, str, bool, int and float are supported."
+ raise ValidationException(message=msg, no_personal_data_message=msg, target=ErrorTarget.JOB)
+ return component_input, job_input
+
+
+def _parse_output(output_value: Union[Output, dict]) -> Tuple:
+ component_output = None
+ job_output: Union[Output, dict] = {}
+
+ if isinstance(output_value, Output):
+ component_output = Output(**output_value._to_dict())
+ job_output = Output(**output_value._to_dict())
+ elif not output_value:
+ # output value can be None or empty dictionary
+ # None output value will be packed into a JobOutput object with mode = ReadWriteMount & type = UriFolder
+ component_output = ComponentTranslatableMixin._to_output(output_value)
+ job_output = output_value
+ elif isinstance(output_value, dict): # When output value is a non-empty dictionary
+ job_output = Output(**output_value)
+ component_output = Output(**output_value)
+ elif isinstance(output_value, str): # When output is passed in from pipeline job yaml
+ job_output = output_value
+ else:
+ msg = f"Unsupported output type: {type(output_value)}, only Output and dict are supported."
+ raise ValidationException(message=msg, no_personal_data_message=msg, target=ErrorTarget.JOB)
+ return component_output, job_output
+
+
+def _parse_inputs_outputs(io_dict: Dict, parse_func: Callable) -> Tuple[Dict, Dict]:
+ component_io_dict, job_io_dict = {}, {}
+ if io_dict:
+ for key, val in io_dict.items():
+ component_io, job_io = parse_func(val)
+ component_io_dict[key] = component_io
+ job_io_dict[key] = job_io
+ return component_io_dict, job_io_dict
+
+
+def spark(
+ *,
+ experiment_name: Optional[str] = None,
+ name: Optional[str] = None,
+ display_name: Optional[str] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ code: Optional[Union[str, os.PathLike]] = None,
+ entry: Union[Dict[str, str], SparkJobEntry, None] = None,
+ py_files: Optional[List[str]] = None,
+ jars: Optional[List[str]] = None,
+ files: Optional[List[str]] = None,
+ archives: Optional[List[str]] = None,
+ identity: Optional[Union[Dict[str, str], ManagedIdentity, AmlToken, UserIdentity]] = None,
+ driver_cores: Optional[int] = None,
+ driver_memory: Optional[str] = None,
+ executor_cores: Optional[int] = None,
+ executor_memory: Optional[str] = None,
+ executor_instances: Optional[int] = None,
+ dynamic_allocation_enabled: Optional[bool] = None,
+ dynamic_allocation_min_executors: Optional[int] = None,
+ dynamic_allocation_max_executors: Optional[int] = None,
+ conf: Optional[Dict[str, str]] = None,
+ environment: Optional[Union[str, Environment]] = None,
+ inputs: Optional[Dict] = None,
+ outputs: Optional[Dict] = None,
+ args: Optional[str] = None,
+ compute: Optional[str] = None,
+ resources: Optional[Union[Dict, SparkResourceConfiguration]] = None,
+ **kwargs: Any,
+) -> Spark:
+ """Creates a Spark object which can be used inside a dsl.pipeline function or used as a standalone Spark job.
+
+ :keyword experiment_name: The name of the experiment the job will be created under.
+ :paramtype experiment_name: Optional[str]
+ :keyword name: The name of the job.
+ :paramtype name: Optional[str]
+ :keyword display_name: The job display name.
+ :paramtype display_name: Optional[str]
+ :keyword description: The description of the job. Defaults to None.
+ :paramtype description: Optional[str]
+ :keyword tags: The dictionary of tags for the job. Tags can be added, removed, and updated. Defaults to None.
+ :paramtype tags: Optional[dict[str, str]]
+ :keyword code: The source code to run the job. Can be a local path or "http:", "https:", or "azureml:" url
+ pointing to a remote location.
+ :type code: Optional[Union[str, os.PathLike]]
+ :keyword entry: The file or class entry point.
+ :paramtype entry: Optional[Union[dict[str, str], ~azure.ai.ml.entities.SparkJobEntry]]
+ :keyword py_files: The list of .zip, .egg or .py files to place on the PYTHONPATH for Python apps.
+ Defaults to None.
+ :paramtype py_files: Optional[List[str]]
+ :keyword jars: The list of .JAR files to include on the driver and executor classpaths. Defaults to None.
+ :paramtype jars: Optional[List[str]]
+ :keyword files: The list of files to be placed in the working directory of each executor. Defaults to None.
+ :paramtype files: Optional[List[str]]
+ :keyword archives: The list of archives to be extracted into the working directory of each executor.
+ Defaults to None.
+ :paramtype archives: Optional[List[str]]
+ :keyword identity: The identity that the Spark job will use while running on compute.
+ :paramtype identity: Optional[Union[
+ dict[str, str],
+ ~azure.ai.ml.entities.ManagedIdentityConfiguration,
+ ~azure.ai.ml.entities.AmlTokenConfiguration,
+ ~azure.ai.ml.entities.UserIdentityConfiguration]]
+ :keyword driver_cores: The number of cores to use for the driver process, only in cluster mode.
+ :paramtype driver_cores: Optional[int]
+ :keyword driver_memory: The amount of memory to use for the driver process, formatted as strings with a size unit
+ suffix ("k", "m", "g" or "t") (e.g. "512m", "2g").
+ :paramtype driver_memory: Optional[str]
+ :keyword executor_cores: The number of cores to use on each executor.
+ :paramtype executor_cores: Optional[int]
+ :keyword executor_memory: The amount of memory to use per executor process, formatted as strings with a size unit
+ suffix ("k", "m", "g" or "t") (e.g. "512m", "2g").
+ :paramtype executor_memory: Optional[str]
+ :keyword executor_instances: The initial number of executors.
+ :paramtype executor_instances: Optional[int]
+ :keyword dynamic_allocation_enabled: Whether to use dynamic resource allocation, which scales the number of
+ executors registered with this application up and down based on the workload.
+ :paramtype dynamic_allocation_enabled: Optional[bool]
+ :keyword dynamic_allocation_min_executors: The lower bound for the number of executors if dynamic allocation is
+ enabled.
+ :paramtype dynamic_allocation_min_executors: Optional[int]
+ :keyword dynamic_allocation_max_executors: The upper bound for the number of executors if dynamic allocation is
+ enabled.
+ :paramtype dynamic_allocation_max_executors: Optional[int]
+ :keyword conf: A dictionary with pre-defined Spark configurations key and values. Defaults to None.
+ :paramtype conf: Optional[dict[str, str]]
+ :keyword environment: The Azure ML environment to run the job in.
+ :paramtype environment: Optional[Union[str, ~azure.ai.ml.entities.Environment]]
+ :keyword inputs: A mapping of input names to input data used in the job. Defaults to None.
+ :paramtype inputs: Optional[dict[str, ~azure.ai.ml.Input]]
+ :keyword outputs: A mapping of output names to output data used in the job. Defaults to None.
+ :paramtype outputs: Optional[dict[str, ~azure.ai.ml.Output]]
+ :keyword args: The arguments for the job.
+ :paramtype args: Optional[str]
+ :keyword compute: The compute resource the job runs on.
+ :paramtype compute: Optional[str]
+ :keyword resources: The compute resource configuration for the job.
+ :paramtype resources: Optional[Union[dict, ~azure.ai.ml.entities.SparkResourceConfiguration]]
+ :return: A Spark object.
+ :rtype: ~azure.ai.ml.entities.Spark
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_spark_configurations.py
+ :start-after: [START spark_function_configuration_1]
+ :end-before: [END spark_function_configuration_1]
+ :language: python
+ :dedent: 8
+ :caption: Configuring a SparkJob.
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_spark_configurations.py
+ :start-after: [START spark_function_configuration_2]
+ :end-before: [END spark_function_configuration_2]
+ :language: python
+ :dedent: 8
+ :caption: Configuring a SparkJob.
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_spark_configurations.py
+ :start-after: [START spark_dsl_pipeline]
+ :end-before: [END spark_dsl_pipeline]
+ :language: python
+ :dedent: 8
+ :caption: Building a Spark pipeline using the DSL pipeline decorator
+ """
+
+ inputs = inputs or {}
+ outputs = outputs or {}
+ component_inputs, job_inputs = _parse_inputs_outputs(inputs, parse_func=_parse_input)
+ # job inputs can not be None
+ job_inputs = {k: v for k, v in job_inputs.items() if v is not None}
+ component_outputs, job_outputs = _parse_inputs_outputs(outputs, parse_func=_parse_output)
+ component = kwargs.pop("component", None)
+
+ if component is None:
+ component = SparkComponent(
+ name=name,
+ display_name=display_name,
+ tags=tags,
+ description=description,
+ code=code,
+ entry=entry,
+ py_files=py_files,
+ jars=jars,
+ files=files,
+ archives=archives,
+ driver_cores=driver_cores,
+ driver_memory=driver_memory,
+ executor_cores=executor_cores,
+ executor_memory=executor_memory,
+ executor_instances=executor_instances,
+ dynamic_allocation_enabled=dynamic_allocation_enabled,
+ dynamic_allocation_min_executors=dynamic_allocation_min_executors,
+ dynamic_allocation_max_executors=dynamic_allocation_max_executors,
+ conf=conf,
+ environment=environment,
+ inputs=component_inputs,
+ outputs=component_outputs,
+ args=args,
+ _source=ComponentSource.BUILDER,
+ **kwargs,
+ )
+ if isinstance(component, SparkComponent):
+ spark_obj = Spark(
+ experiment_name=experiment_name,
+ name=name,
+ display_name=display_name,
+ tags=tags,
+ description=description,
+ component=component,
+ identity=identity,
+ driver_cores=driver_cores,
+ driver_memory=driver_memory,
+ executor_cores=executor_cores,
+ executor_memory=executor_memory,
+ executor_instances=executor_instances,
+ dynamic_allocation_enabled=dynamic_allocation_enabled,
+ dynamic_allocation_min_executors=dynamic_allocation_min_executors,
+ dynamic_allocation_max_executors=dynamic_allocation_max_executors,
+ conf=conf,
+ inputs=job_inputs,
+ outputs=job_outputs,
+ compute=compute,
+ resources=resources,
+ **kwargs,
+ )
+ else:
+ # when we load a remote job, component now is an arm_id, we need get entry from node level returned from
+ # service
+ spark_obj = Spark(
+ experiment_name=experiment_name,
+ name=name,
+ display_name=display_name,
+ tags=tags,
+ description=description,
+ component=component,
+ identity=identity,
+ driver_cores=driver_cores,
+ driver_memory=driver_memory,
+ executor_cores=executor_cores,
+ executor_memory=executor_memory,
+ executor_instances=executor_instances,
+ dynamic_allocation_enabled=dynamic_allocation_enabled,
+ dynamic_allocation_min_executors=dynamic_allocation_min_executors,
+ dynamic_allocation_max_executors=dynamic_allocation_max_executors,
+ conf=conf,
+ inputs=job_inputs,
+ outputs=job_outputs,
+ compute=compute,
+ resources=resources,
+ entry=entry,
+ py_files=py_files,
+ jars=jars,
+ files=files,
+ archives=archives,
+ args=args,
+ **kwargs,
+ )
+ return spark_obj
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/subcomponents.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/subcomponents.py
new file mode 100644
index 00000000..9b9ed5d2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/subcomponents.py
@@ -0,0 +1,59 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# This file contains mldesigner decorator-produced components
+# that are used within node constructors. Keep imports and
+# general complexity in this file to a minimum.
+
+from typing import List
+
+from mldesigner import Output, command_component
+
+from azure.ai.ml.constants._common import DefaultOpenEncoding
+
+
+def save_mltable_yaml(path: str, mltable_paths: List[str]) -> None:
+ """Save MLTable YAML.
+
+ :param path: The path to save the MLTable YAML file.
+ :type path: str
+ :param mltable_paths: List of paths to be included in the MLTable.
+ :type mltable_paths: List[str]
+ :raises ValueError: If the given path points to a file.
+ """
+ import os
+
+ path = os.path.abspath(path)
+
+ if os.path.isfile(path):
+ raise ValueError(f"The given path {path} points to a file.")
+
+ if not os.path.exists(path):
+ os.makedirs(path, exist_ok=True)
+
+ save_path = os.path.join(path, "MLTable")
+ # Do not touch - this is MLTable syntax that is needed to mount these paths
+ # To the MLTable's inputs
+ mltable_file_content = "\n".join(["paths:"] + [f"- folder : {path}" for path in mltable_paths])
+
+ with open(save_path, "w", encoding=DefaultOpenEncoding.WRITE) as f:
+ f.write(mltable_file_content)
+
+
+# TODO 2293610: add support for more types of outputs besides uri_folder and mltable
+@command_component()
+def create_scatter_output_table(aggregated_output: Output(type="mltable"), **kwargs: str) -> Output: # type: ignore
+ """Create scatter output table.
+
+ This function is used by the FL scatter gather node to reduce a dynamic number of silo outputs
+ into a single input for the user-supplied aggregation step.
+
+ :param aggregated_output: The aggregated output MLTable.
+ :type aggregated_output: ~mldesigner.Output(type="mltable")
+
+ Keyword arguments represent input names and URI folder paths.
+ """
+ # kwargs keys are inputs names (ex: silo_output_silo_1)
+ # values are uri_folder paths
+ save_mltable_yaml(aggregated_output, list(kwargs.values()))
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/sweep.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/sweep.py
new file mode 100644
index 00000000..603babbe
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/sweep.py
@@ -0,0 +1,454 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+# pylint: disable=protected-access
+
+import logging
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import pydash
+from marshmallow import EXCLUDE, Schema
+
+from azure.ai.ml._schema._sweep.sweep_fields_provider import EarlyTerminationField
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+from azure.ai.ml.constants._component import NodeType
+from azure.ai.ml.constants._job.sweep import SearchSpace
+from azure.ai.ml.entities._component.command_component import CommandComponent
+from azure.ai.ml.entities._credentials import (
+ AmlTokenConfiguration,
+ ManagedIdentityConfiguration,
+ UserIdentityConfiguration,
+)
+from azure.ai.ml.entities._inputs_outputs import Input, Output
+from azure.ai.ml.entities._job.job_limits import SweepJobLimits
+from azure.ai.ml.entities._job.job_resource_configuration import JobResourceConfiguration
+from azure.ai.ml.entities._job.pipeline._io import NodeInput
+from azure.ai.ml.entities._job.queue_settings import QueueSettings
+from azure.ai.ml.entities._job.sweep.early_termination_policy import (
+ BanditPolicy,
+ EarlyTerminationPolicy,
+ MedianStoppingPolicy,
+ TruncationSelectionPolicy,
+)
+from azure.ai.ml.entities._job.sweep.objective import Objective
+from azure.ai.ml.entities._job.sweep.parameterized_sweep import ParameterizedSweep
+from azure.ai.ml.entities._job.sweep.sampling_algorithm import SamplingAlgorithm
+from azure.ai.ml.entities._job.sweep.search_space import (
+ Choice,
+ LogNormal,
+ LogUniform,
+ Normal,
+ QLogNormal,
+ QLogUniform,
+ QNormal,
+ QUniform,
+ Randint,
+ SweepDistribution,
+ Uniform,
+)
+from azure.ai.ml.exceptions import ErrorTarget, UserErrorException, ValidationErrorType, ValidationException
+from azure.ai.ml.sweep import SweepJob
+
+from ..._restclient.v2022_10_01.models import ComponentVersion
+from ..._schema import PathAwareSchema
+from ..._schema._utils.data_binding_expression import support_data_binding_expression_for_fields
+from ..._utils.utils import camel_to_snake
+from .base_node import BaseNode
+
+module_logger = logging.getLogger(__name__)
+
+
+class Sweep(ParameterizedSweep, BaseNode):
+ """Base class for sweep node.
+
+ This class should not be instantiated directly. Instead, it should be created via the builder function: sweep.
+
+ :param trial: The ID or instance of the command component or job to be run for the step.
+ :type trial: Union[~azure.ai.ml.entities.CommandComponent, str]
+ :param compute: The compute definition containing the compute information for the step.
+ :type compute: str
+ :param limits: The limits for the sweep node.
+ :type limits: ~azure.ai.ml.sweep.SweepJobLimits
+ :param sampling_algorithm: The sampling algorithm to use to sample inside the search space.
+ Accepted values are: "random", "grid", or "bayesian".
+ :type sampling_algorithm: str
+ :param objective: The objective used to determine the target run with the local optimal
+ hyperparameter in search space.
+ :type objective: ~azure.ai.ml.sweep.Objective
+ :param early_termination_policy: The early termination policy of the sweep node.
+ :type early_termination_policy: Union[
+
+ ~azure.mgmt.machinelearningservices.models.BanditPolicy,
+ ~azure.mgmt.machinelearningservices.models.MedianStoppingPolicy,
+ ~azure.mgmt.machinelearningservices.models.TruncationSelectionPolicy
+
+ ]
+
+ :param search_space: The hyperparameter search space to run trials in.
+ :type search_space: Dict[str, Union[
+
+ ~azure.ai.ml.entities.Choice,
+ ~azure.ai.ml.entities.LogNormal,
+ ~azure.ai.ml.entities.LogUniform,
+ ~azure.ai.ml.entities.Normal,
+ ~azure.ai.ml.entities.QLogNormal,
+ ~azure.ai.ml.entities.QLogUniform,
+ ~azure.ai.ml.entities.QNormal,
+ ~azure.ai.ml.entities.QUniform,
+ ~azure.ai.ml.entities.Randint,
+ ~azure.ai.ml.entities.Uniform
+
+ ]]
+
+ :param inputs: Mapping of input data bindings used in the job.
+ :type inputs: Dict[str, Union[
+
+ ~azure.ai.ml.Input,
+
+ str,
+ bool,
+ int,
+ float
+
+ ]]
+
+ :param outputs: Mapping of output data bindings used in the job.
+ :type outputs: Dict[str, Union[str, ~azure.ai.ml.Output]]
+ :param identity: The identity that the training job will use while running on compute.
+ :type identity: Union[
+
+ ~azure.ai.ml.ManagedIdentityConfiguration,
+ ~azure.ai.ml.AmlTokenConfiguration,
+ ~azure.ai.ml.UserIdentityConfiguration
+
+ ]
+
+ :param queue_settings: The queue settings for the job.
+ :type queue_settings: ~azure.ai.ml.entities.QueueSettings
+ :param resources: Compute Resource configuration for the job.
+ :type resources: Optional[Union[dict, ~azure.ai.ml.entities.ResourceConfiguration]]
+ """
+
+ def __init__(
+ self,
+ *,
+ trial: Optional[Union[CommandComponent, str]] = None,
+ compute: Optional[str] = None,
+ limits: Optional[SweepJobLimits] = None,
+ sampling_algorithm: Optional[Union[str, SamplingAlgorithm]] = None,
+ objective: Optional[Objective] = None,
+ early_termination: Optional[
+ Union[BanditPolicy, MedianStoppingPolicy, TruncationSelectionPolicy, EarlyTerminationPolicy, str]
+ ] = None,
+ search_space: Optional[
+ Dict[
+ str,
+ Union[
+ Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform
+ ],
+ ]
+ ] = None,
+ inputs: Optional[Dict[str, Union[Input, str, bool, int, float]]] = None,
+ outputs: Optional[Dict[str, Union[str, Output]]] = None,
+ identity: Optional[
+ Union[Dict, ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration]
+ ] = None,
+ queue_settings: Optional[QueueSettings] = None,
+ resources: Optional[Union[dict, JobResourceConfiguration]] = None,
+ **kwargs: Any,
+ ) -> None:
+ # TODO: get rid of self._job_inputs, self._job_outputs once we have general Input
+ self._job_inputs, self._job_outputs = inputs, outputs
+
+ kwargs.pop("type", None)
+ BaseNode.__init__(
+ self,
+ type=NodeType.SWEEP,
+ component=trial,
+ inputs=inputs,
+ outputs=outputs,
+ compute=compute,
+ **kwargs,
+ )
+ # init mark for _AttrDict
+ self._init = True
+ ParameterizedSweep.__init__(
+ self,
+ sampling_algorithm=sampling_algorithm,
+ objective=objective,
+ limits=limits,
+ early_termination=early_termination,
+ search_space=search_space,
+ queue_settings=queue_settings,
+ resources=resources,
+ )
+
+ self.identity: Any = identity
+ self._init = False
+
+ @property
+ def trial(self) -> CommandComponent:
+ """The ID or instance of the command component or job to be run for the step.
+
+ :rtype: ~azure.ai.ml.entities.CommandComponent
+ """
+ res: CommandComponent = self._component
+ return res
+
+ @property
+ def search_space(
+ self,
+ ) -> Optional[
+ Dict[
+ str,
+ Union[Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform],
+ ]
+ ]:
+ """Dictionary of the hyperparameter search space.
+
+ Each key is the name of a hyperparameter and its value is the parameter expression.
+
+ :rtype: Dict[str, Union[~azure.ai.ml.entities.Choice, ~azure.ai.ml.entities.LogNormal,
+ ~azure.ai.ml.entities.LogUniform, ~azure.ai.ml.entities.Normal, ~azure.ai.ml.entities.QLogNormal,
+ ~azure.ai.ml.entities.QLogUniform, ~azure.ai.ml.entities.QNormal, ~azure.ai.ml.entities.QUniform,
+ ~azure.ai.ml.entities.Randint, ~azure.ai.ml.entities.Uniform]]
+ """
+ return self._search_space
+
+ @search_space.setter
+ def search_space(self, values: Dict[str, Dict[str, Union[str, int, float, dict]]]) -> None:
+ """Sets the search space for the sweep job.
+
+ :param values: The search space to set.
+ :type values: Dict[str, Dict[str, Union[str, int, float, dict]]]
+ """
+ search_space: Dict = {}
+ for name, value in values.items():
+ # If value is a SearchSpace object, directly pass it to job.search_space[name]
+ search_space[name] = self._value_type_to_class(value) if isinstance(value, dict) else value
+ self._search_space = search_space
+
+ @classmethod
+ def _value_type_to_class(cls, value: Any) -> Dict:
+ value_type = value["type"]
+ search_space_dict = {
+ SearchSpace.CHOICE: Choice,
+ SearchSpace.RANDINT: Randint,
+ SearchSpace.LOGNORMAL: LogNormal,
+ SearchSpace.NORMAL: Normal,
+ SearchSpace.LOGUNIFORM: LogUniform,
+ SearchSpace.UNIFORM: Uniform,
+ SearchSpace.QLOGNORMAL: QLogNormal,
+ SearchSpace.QNORMAL: QNormal,
+ SearchSpace.QLOGUNIFORM: QLogUniform,
+ SearchSpace.QUNIFORM: QUniform,
+ }
+
+ res: dict = search_space_dict[value_type](**value)
+ return res
+
+ @classmethod
+ def _get_supported_inputs_types(cls) -> Tuple:
+ supported_types = super()._get_supported_inputs_types() or ()
+ return (
+ SweepDistribution,
+ *supported_types,
+ )
+
+ @classmethod
+ def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "Sweep":
+ raise NotImplementedError("Sweep._load_from_dict is not supported")
+
+ @classmethod
+ def _picked_fields_from_dict_to_rest_object(cls) -> List[str]:
+ return [
+ "limits",
+ "sampling_algorithm",
+ "objective",
+ "early_termination",
+ "search_space",
+ "queue_settings",
+ "resources",
+ ]
+
+ def _to_rest_object(self, **kwargs: Any) -> dict:
+ rest_obj: dict = super(Sweep, self)._to_rest_object(**kwargs)
+ # hack: ParameterizedSweep.early_termination is not allowed to be None
+ for key in ["early_termination"]:
+ if key in rest_obj and rest_obj[key] is None:
+ del rest_obj[key]
+
+ # hack: only early termination policy does not follow yaml schema now, should be removed after server-side made
+ # the change
+ if "early_termination" in rest_obj:
+ _early_termination: EarlyTerminationPolicy = self.early_termination # type: ignore
+ rest_obj["early_termination"] = _early_termination._to_rest_object().as_dict()
+
+ rest_obj.update(
+ {
+ "type": self.type,
+ "trial": self._get_trial_component_rest_obj(),
+ }
+ )
+ return rest_obj
+
+ @classmethod
+ def _from_rest_object_to_init_params(cls, obj: dict) -> Dict:
+ obj = super()._from_rest_object_to_init_params(obj)
+
+ # hack: only early termination policy does not follow yaml schema now, should be removed after server-side made
+ # the change
+ if "early_termination" in obj and "policy_type" in obj["early_termination"]:
+ # can't use _from_rest_object here, because obj is a dict instead of an EarlyTerminationPolicy rest object
+ obj["early_termination"]["type"] = camel_to_snake(obj["early_termination"].pop("policy_type"))
+
+ # TODO: use cls._get_schema() to load from rest object
+ from azure.ai.ml._schema._sweep.parameterized_sweep import ParameterizedSweepSchema
+
+ schema = ParameterizedSweepSchema(context={BASE_PATH_CONTEXT_KEY: "./"})
+ support_data_binding_expression_for_fields(schema, ["type", "component", "trial"])
+
+ base_sweep = schema.load(obj, unknown=EXCLUDE, partial=True)
+ for key, value in base_sweep.items():
+ obj[key] = value
+
+ # trial
+ trial_component_id = pydash.get(obj, "trial.componentId", None)
+ obj["trial"] = trial_component_id # check this
+
+ return obj
+
+ def _get_trial_component_rest_obj(self) -> Union[Dict, ComponentVersion, None]:
+ # trial component to rest object is different from usual component
+ trial_component_id = self._get_component_id()
+ if trial_component_id is None:
+ return None
+ if isinstance(trial_component_id, str):
+ return {"componentId": trial_component_id}
+ if isinstance(trial_component_id, CommandComponent):
+ return trial_component_id._to_rest_object()
+ raise UserErrorException(f"invalid trial in sweep node {self.name}: {str(self.trial)}")
+
+ def _to_job(self) -> SweepJob:
+ command = self.trial.command
+ if self.search_space is not None:
+ for key, _ in self.search_space.items():
+ if command is not None:
+ # Double curly brackets to escape
+ command = command.replace(f"${{{{inputs.{key}}}}}", f"${{{{search_space.{key}}}}}")
+
+ # TODO: raise exception when the trial is a pre-registered component
+ if command != self.trial.command and isinstance(self.trial, CommandComponent):
+ self.trial.command = command
+
+ return SweepJob(
+ name=self.name,
+ display_name=self.display_name,
+ description=self.description,
+ properties=self.properties,
+ tags=self.tags,
+ experiment_name=self.experiment_name,
+ trial=self.trial,
+ compute=self.compute,
+ sampling_algorithm=self.sampling_algorithm,
+ search_space=self.search_space,
+ limits=self.limits,
+ early_termination=self.early_termination, # type: ignore[arg-type]
+ objective=self.objective,
+ inputs=self._job_inputs,
+ outputs=self._job_outputs,
+ identity=self.identity,
+ queue_settings=self.queue_settings,
+ resources=self.resources,
+ )
+
+ @classmethod
+ def _get_component_attr_name(cls) -> str:
+ return "trial"
+
+ def _build_inputs(self) -> Dict:
+ inputs = super(Sweep, self)._build_inputs()
+ built_inputs = {}
+ # Validate and remove non-specified inputs
+ for key, value in inputs.items():
+ if value is not None:
+ built_inputs[key] = value
+
+ return built_inputs
+
+ @classmethod
+ def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]:
+ from azure.ai.ml._schema.pipeline.component_job import SweepSchema
+
+ return SweepSchema(context=context)
+
+ @classmethod
+ def _get_origin_inputs_and_search_space(cls, built_inputs: Optional[Dict[str, NodeInput]]) -> Tuple:
+ """Separate mixed true inputs & search space definition from inputs of
+ this node and return them.
+
+ Input will be restored to Input/LiteralInput before returned.
+
+ :param built_inputs: The built inputs
+ :type built_inputs: Optional[Dict[str, NodeInput]]
+ :return: A tuple of the inputs and search space
+ :rtype: Tuple[
+ Dict[str, Union[Input, str, bool, int, float]],
+ Dict[str, SweepDistribution],
+ ]
+ """
+ search_space: Dict = {}
+ inputs: Dict = {}
+ if built_inputs is not None:
+ for input_name, input_obj in built_inputs.items():
+ if isinstance(input_obj, NodeInput):
+ if isinstance(input_obj._data, SweepDistribution):
+ search_space[input_name] = input_obj._data
+ else:
+ inputs[input_name] = input_obj._data
+ else:
+ msg = "unsupported built input type: {}: {}"
+ raise ValidationException(
+ message=msg.format(input_name, type(input_obj)),
+ no_personal_data_message=msg.format("[input_name]", type(input_obj)),
+ target=ErrorTarget.SWEEP_JOB,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ return inputs, search_space
+
+ def _is_input_set(self, input_name: str) -> bool:
+ if super(Sweep, self)._is_input_set(input_name):
+ return True
+ return self.search_space is not None and input_name in self.search_space
+
+ def __setattr__(self, key: Any, value: Any) -> None:
+ super(Sweep, self).__setattr__(key, value)
+ if key == "early_termination" and isinstance(self.early_termination, BanditPolicy):
+ # only one of slack_amount and slack_factor can be specified but default value is 0.0.
+ # Need to keep track of which one is null.
+ if self.early_termination.slack_amount == 0.0:
+ self.early_termination.slack_amount = None # type: ignore[assignment]
+ if self.early_termination.slack_factor == 0.0:
+ self.early_termination.slack_factor = None # type: ignore[assignment]
+
+ @property
+ def early_termination(self) -> Optional[Union[str, EarlyTerminationPolicy]]:
+ """The early termination policy for the sweep job.
+
+ :rtype: Union[str, ~azure.ai.ml.sweep.BanditPolicy, ~azure.ai.ml.sweep.MedianStoppingPolicy,
+ ~azure.ai.ml.sweep.TruncationSelectionPolicy]
+ """
+ return self._early_termination
+
+ @early_termination.setter
+ def early_termination(self, value: Optional[Union[str, EarlyTerminationPolicy]]) -> None:
+ """Sets the early termination policy for the sweep job.
+
+ :param value: The early termination policy for the sweep job.
+ :type value: Union[~azure.ai.ml.sweep.BanditPolicy, ~azure.ai.ml.sweep.MedianStoppingPolicy,
+ ~azure.ai.ml.sweep.TruncationSelectionPolicy, dict[str, Union[str, float, int, bool]]]
+ """
+ if isinstance(value, dict):
+ early_termination_schema = EarlyTerminationField()
+ value = early_termination_schema._deserialize(value=value, attr=None, data=None)
+ self._early_termination = value # type: ignore[assignment]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/__init__.py
new file mode 100644
index 00000000..fdf8caba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/_additional_includes.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/_additional_includes.py
new file mode 100644
index 00000000..85f609ca
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/_additional_includes.py
@@ -0,0 +1,541 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+import json
+import os
+import shutil
+import tempfile
+import zipfile
+from collections import defaultdict
+from concurrent.futures import ThreadPoolExecutor
+from contextlib import contextmanager
+from multiprocessing import cpu_count
+from pathlib import Path
+from typing import Any, Dict, Generator, List, Optional, Tuple, Union
+
+from azure.ai.ml.constants._common import AzureDevopsArtifactsType
+from azure.ai.ml.entities._validation import MutableValidationResult, ValidationResultBuilder
+
+from ..._utils._artifact_utils import ArtifactCache
+from ..._utils._asset_utils import IgnoreFile, get_upload_files_from_folder
+from ..._utils.utils import is_concurrent_component_registration_enabled, is_private_preview_enabled
+from ...entities._util import _general_copy
+from .._assets import Code
+from .code import ComponentCodeMixin, ComponentIgnoreFile
+
+PLACEHOLDER_FILE_NAME = "_placeholder_spec.yaml"
+
+
+class AdditionalIncludes:
+ """Initialize the AdditionalIncludes object.
+
+ :param origin_code_value: The origin code value.
+ :type origin_code_value: Optional[str]
+ :param base_path: The base path for origin code path and additional include configs.
+ :type base_path: Path
+ :param configs: The additional include configs.
+ :type configs: List[Union[str, dict]]
+ """
+
+ def __init__(
+ self,
+ *,
+ origin_code_value: Optional[str],
+ base_path: Path,
+ configs: Optional[List[Union[str, dict]]] = None,
+ ) -> None:
+ self._base_path = base_path
+ self._origin_code_value = origin_code_value
+ self._origin_configs = configs
+
+ @property
+ def origin_configs(self) -> List:
+ """The origin additional include configs.
+ Artifact additional include configs haven't been resolved in this property.
+
+ :return: The origin additional include configs.
+ :rtype: List[Union[str, dict]]
+ """
+ return self._origin_configs or []
+
+ @property
+ def resolved_code_path(self) -> Union[None, Path]:
+ """The resolved origin code path based on base path, if code path is not specified, return None.
+ We shouldn't change this property name given it's referenced in mldesigner.
+
+ :return: The resolved origin code path.
+ :rtype: Union[None, Path]
+ """
+ if self._origin_code_value is None:
+ return None
+ if os.path.isabs(self._origin_code_value):
+ return Path(self._origin_code_value)
+ return (self.base_path / self._origin_code_value).resolve()
+
+ @property
+ def base_path(self) -> Path:
+ """Base path for origin code path and additional include configs.
+
+ :return: The base path.
+ :rtype: Path
+ """
+ return self._base_path
+
+ @property
+ def with_includes(self) -> bool:
+ """Whether the additional include configs have been provided.
+
+ :return: True if additional include configs have been provided, False otherwise.
+ :rtype: bool
+ """
+ return len(self.origin_configs) != 0
+
+ @classmethod
+ def _get_artifacts_by_config(cls, artifact_config: Dict[str, str]) -> Union[str, os.PathLike]:
+ # config key existence has been validated in _validate_additional_include_config
+ res: Union[str, os.PathLike] = ArtifactCache().get(
+ organization=artifact_config.get("organization", None),
+ project=artifact_config.get("project", None),
+ feed=artifact_config["feed"],
+ name=artifact_config["name"],
+ version=artifact_config["version"],
+ scope=artifact_config.get("scope", "organization"),
+ resolve=True,
+ )
+ return res
+
+ def _validate_additional_include_config(
+ self, additional_include_config: Union[Dict, str]
+ ) -> MutableValidationResult:
+ validation_result = ValidationResultBuilder.success()
+ if (
+ isinstance(additional_include_config, dict)
+ and additional_include_config.get("type") == AzureDevopsArtifactsType.ARTIFACT
+ ):
+ # for artifact additional include, we validate the required fields in config but won't validate the
+ # artifact content to avoid downloading it in validation stage
+ # note that runtime error will be thrown when loading the artifact
+ for item in ["feed", "name", "version"]:
+ if item not in additional_include_config:
+ # TODO: add yaml path after we support list index in yaml path
+ validation_result.append_error(
+ "{} are required for artifacts config but got {}.".format(
+ item, json.dumps(additional_include_config)
+ )
+ )
+ elif isinstance(additional_include_config, str):
+ validation_result.merge_with(self._validate_local_additional_include_config(additional_include_config))
+ else:
+ validation_result.append_error(
+ message=f"Unexpected format in additional_includes, {additional_include_config}"
+ )
+ return validation_result
+
+ @classmethod
+ def _resolve_artifact_additional_include_config(
+ cls, artifact_additional_include_config: Dict[str, str]
+ ) -> List[Tuple[str, str]]:
+ """Resolve an artifact additional include config into a list of (local_path, config_info) tuples.
+
+ Configured artifact will be downloaded to local path first; the config_info will be in below format:
+ %name%:%version% in %feed%
+
+ :param artifact_additional_include_config: Additional include config for an artifact
+ :type artifact_additional_include_config: Dict[str, str]
+ :return: A list of 2-tuples of local_path and config_info
+ :rtype: List[Tuple[str, str]]
+ """
+ result = []
+ # Note that we don't validate the artifact config here, since it has already been validated in
+ # _validate_additional_include_config
+ artifact_path = cls._get_artifacts_by_config(artifact_additional_include_config)
+ for item in os.listdir(artifact_path):
+ config_info = (
+ f"{artifact_additional_include_config['name']}:{artifact_additional_include_config['version']} in "
+ f"{artifact_additional_include_config['feed']}"
+ )
+ result.append((os.path.join(artifact_path, item), config_info))
+ return result
+
+ def _resolve_artifact_additional_include_configs(
+ self, artifact_additional_includes_configs: List[Dict[str, str]]
+ ) -> List:
+ additional_include_info_tuples = []
+ # Unlike component registration, artifact downloading is a pure download progress; so we can use
+ # more threads to speed up the downloading process.
+ # We use 5 threads per CPU core plus 5 extra threads, and the max number of threads is 64.
+ num_threads = min(64, (int(cpu_count()) * 5) + 5)
+ if (
+ len(artifact_additional_includes_configs) > 1
+ and is_concurrent_component_registration_enabled()
+ and is_private_preview_enabled()
+ ):
+ with ThreadPoolExecutor(max_workers=num_threads) as executor:
+ all_artifact_pairs_itr = executor.map(
+ self._resolve_artifact_additional_include_config, artifact_additional_includes_configs
+ )
+
+ for artifact_pairs in all_artifact_pairs_itr:
+ additional_include_info_tuples.extend(artifact_pairs)
+ else:
+ all_artifact_pairs_list = list(
+ map(self._resolve_artifact_additional_include_config, artifact_additional_includes_configs)
+ )
+
+ for artifact_pairs in all_artifact_pairs_list:
+ additional_include_info_tuples.extend(artifact_pairs)
+
+ return additional_include_info_tuples
+
+ @staticmethod
+ def _copy(src: Path, dst: Path, *, ignore_file: Optional[Any] = None) -> None:
+ if ignore_file and ignore_file.is_file_excluded(src):
+ return
+ if not src.exists():
+ raise ValueError(f"Path {src} does not exist.")
+ if src.is_file():
+ _general_copy(src, dst)
+ if src.is_dir():
+ # TODO: should we cover empty folder?
+ # use os.walk to replace shutil.copytree, which may raise FileExistsError
+ # for same folder, the expected behavior is merging
+ # ignore will be also applied during this process
+ for name in src.glob("*"):
+ if ignore_file is not None:
+ AdditionalIncludes._copy(name, dst / name.name, ignore_file=ignore_file.merge(name))
+
+ @staticmethod
+ def _is_folder_to_compress(path: Path) -> bool:
+ """Check if the additional include needs to compress corresponding folder as a zip.
+
+ For example, given additional include /mnt/c/hello.zip
+ 1) if a file named /mnt/c/hello.zip already exists, return False (simply copy)
+ 2) if a folder named /mnt/c/hello exists, return True (compress as a zip and copy)
+
+ :param path: Given path in additional include.
+ :type path: Path
+ :return: If the path need to be compressed as a zip file.
+ :rtype: bool
+ """
+ if path.suffix != ".zip":
+ return False
+ # if zip file exists, simply copy as other additional includes
+ if path.exists():
+ return False
+ # remove .zip suffix and check whether the folder exists
+ stem_path = path.parent / path.stem
+ return stem_path.is_dir()
+
+ def _resolve_folder_to_compress(self, include: str, dst_path: Path, ignore_file: IgnoreFile) -> None:
+ """resolve the zip additional include, need to compress corresponding folder.
+
+ :param include: The path, relative to :attr:`AdditionalIncludes.base_path`, to zip
+ :type include: str
+ :param dst_path: The path to write the zipfile to
+ :type dst_path: Path
+ :param ignore_file: The ignore file to use to filter files
+ :type ignore_file: IgnoreFile
+ """
+ zip_additional_include = (self.base_path / include).resolve()
+ folder_to_zip = zip_additional_include.parent / zip_additional_include.stem
+ zip_file = dst_path / zip_additional_include.name
+ with zipfile.ZipFile(zip_file, "w") as zf:
+ zf.write(folder_to_zip, os.path.relpath(folder_to_zip, folder_to_zip.parent)) # write root in zip
+ paths = [path for path, _ in get_upload_files_from_folder(folder_to_zip, ignore_file=ignore_file)]
+ # sort the paths to make sure the zip file (namelist) is deterministic
+ for path in sorted(paths):
+ zf.write(path, os.path.relpath(path, folder_to_zip.parent))
+
+ def _get_resolved_additional_include_configs(self) -> List[str]:
+ """
+ Resolve additional include configs to a list of local_paths and return it.
+
+ Addition includes is a list of include files, including local paths and Azure Devops Artifacts.
+ Yaml format of additional_includes looks like below:
+ additional_includes:
+ - your/local/path
+ - type: artifact
+ organization: devops_organization
+ project: devops_project
+ feed: artifacts_feed_name
+ name: universal_package_name
+ version: package_version
+ scope: scope_type
+ The artifacts package will be downloaded from devops to the local in this function and transferred to
+ the local paths of downloaded artifacts;
+ The local paths will be returned directly.
+ If there are conflicts among artifacts, runtime error will be raised. Note that we won't check the
+ conflicts between artifacts and local paths and conflicts among local paths. Reasons are:
+ 1. There can be ignore_file in local paths, which makes it hard to check the conflict and may lead to breaking
+ changes;
+ 2. Conflicts among artifacts are more likely to happen, since user may refer to 2 artifacts of the same name
+ but with different version & feed.
+ 3. According to current design, folders in local paths will be merged; while artifact conflicts can be
+ identified by folder name conflicts and are not allowed.
+
+ :return additional_includes: Path list of additional_includes
+ :rtype additional_includes: List[str]
+ """
+ additional_include_configs_in_local_path = []
+
+ artifact_additional_include_configs = []
+ for additional_include_config in self.origin_configs:
+ if isinstance(additional_include_config, str):
+ # add local additional include configs directly
+ additional_include_configs_in_local_path.append(additional_include_config)
+ else:
+ # artifact additional include config will be downloaded and resolved to a local path later
+ # note that there is no more validation for artifact additional include config here, since it has
+ # already been validated in _validate_additional_include_config
+ artifact_additional_include_configs.append(additional_include_config)
+
+ artifact_additional_include_info_tuples = self._resolve_artifact_additional_include_configs(
+ artifact_additional_include_configs
+ )
+ additional_include_configs_in_local_path.extend(
+ local_path for local_path, _ in artifact_additional_include_info_tuples
+ )
+
+ # check file conflicts among artifact package
+ # given this is not in validate stage, we will raise error if there are conflict files
+ conflict_files: dict = defaultdict(set)
+ for local_path, config_info in artifact_additional_include_info_tuples:
+ file_name = Path(local_path).name
+ conflict_files[file_name].add(config_info)
+
+ conflict_files = {k: v for k, v in conflict_files.items() if len(v) > 1}
+ if conflict_files:
+ raise RuntimeError(f"There are conflict files in additional include: {conflict_files}")
+
+ return additional_include_configs_in_local_path
+
+ def _validate_local_additional_include_config(
+ self, local_path: str, config_info: Optional[str] = None
+ ) -> MutableValidationResult:
+ """Validate local additional include config.
+
+ Note that we will check the file conflicts between each local additional includes and origin code, but
+ won't check the file conflicts among local additional includes fo now.
+
+ :param local_path: The local path
+ :type local_path: str
+ :param config_info: The config info
+ :type config_info: Optional[str]
+ :return: The validation result.
+ :rtype: ~azure.ai.ml.entities._validation.MutableValidationResult
+ """
+ validation_result = ValidationResultBuilder.success()
+ include_path = self.base_path / local_path
+ # if additional include has not supported characters, resolve will fail and raise OSError
+ try:
+ src_path = include_path.resolve()
+ except OSError:
+ # no need to include potential yaml file name in error message as it will be covered by
+ # validation message construction.
+ error_msg = (
+ f"Failed to resolve additional include " f"{config_info or local_path} " f"based on {self.base_path}."
+ )
+ validation_result.append_error(message=error_msg)
+ return validation_result
+
+ if not src_path.exists() and not self._is_folder_to_compress(src_path):
+ error_msg = f"Unable to find additional include {config_info or local_path}"
+ validation_result.append_error(message=error_msg)
+ return validation_result
+
+ if len(src_path.parents) == 0:
+ error_msg = "Root directory is not supported for additional includes."
+ validation_result.append_error(message=error_msg)
+ return validation_result
+
+ dst_path = Path(self.resolved_code_path) / src_path.name if self.resolved_code_path else None
+ if dst_path:
+ if dst_path.is_symlink():
+ # if destination path is symbolic link, check if it points to the same file/folder as source path
+ if dst_path.resolve() != src_path.resolve():
+ error_msg = f"A symbolic link already exists for additional include {config_info or local_path}."
+ validation_result.append_error(message=error_msg)
+ return validation_result
+ elif dst_path.exists():
+ error_msg = f"A file already exists for additional include {config_info or local_path}."
+ validation_result.append_error(message=error_msg)
+ return validation_result
+
+ def validate(self) -> MutableValidationResult:
+ """Validate the AdditionalIncludes object.
+
+ :return: The validation result.
+ :rtype: ~azure.ai.ml.entities._validation.MutableValidationResult
+ """
+ validation_result = ValidationResultBuilder.success()
+ for additional_include_config in self.origin_configs:
+ validation_result.merge_with(self._validate_additional_include_config(additional_include_config))
+ return validation_result
+
+ def _copy_origin_code(self, target_path: Path) -> ComponentIgnoreFile:
+ """Copy origin code to target path.
+
+ :param target_path: The destination to copy to
+ :type target_path: Path
+ :return: The component ignore file for the origin path
+ :rtype: ComponentIgnoreFile
+ """
+ # code can be either file or folder, as additional includes exists, need to copy to temporary folder
+ if self.resolved_code_path is None:
+ # if additional include configs exist but no origin code path, return a dummy ignore file
+ return ComponentIgnoreFile(
+ self.base_path,
+ )
+
+ if Path(self.resolved_code_path).is_file():
+ # use a dummy ignore file to save base path
+ root_ignore_file = ComponentIgnoreFile(
+ Path(self.resolved_code_path).parent,
+ skip_ignore_file=True,
+ )
+ self._copy(
+ Path(self.resolved_code_path),
+ target_path / Path(self.resolved_code_path).name,
+ ignore_file=root_ignore_file,
+ )
+ else:
+ # current implementation of ignore file is based on absolute path, so it cannot be shared
+ root_ignore_file = ComponentIgnoreFile(self.resolved_code_path)
+ self._copy(self.resolved_code_path, target_path, ignore_file=root_ignore_file)
+ return root_ignore_file
+
+ @contextmanager
+ def merge_local_code_and_additional_includes(self) -> Generator:
+ """Merge code and potential additional includes into a temporary folder and return the absolute path of it.
+
+ If no additional includes are specified, just return the absolute path of the original code path.
+ If no original code path is specified, return None.
+
+ :return: The absolute path of the merged code and additional includes.
+ :rtype: Path
+ """
+ if not self.with_includes:
+ if self.resolved_code_path is None:
+ yield None
+ else:
+ yield self.resolved_code_path.absolute()
+ return
+
+ # for now, upload path of a code asset will include the folder name of the code path (name of folder or
+ # parent name of file). For example, if code path is /mnt/c/code-a, upload path will be xxx/code-a
+ # which means that the upload path will change every time as we will merge additional includes into a temp
+ # folder. To avoid this, we will copy the code path to a child folder with a fixed name under the temp folder,
+ # then the child folder will be used in upload path.
+ # This issue shouldn't impact users as there is a separate asset existence check before uploading.
+ # We still make this change as:
+ # 1. We will always need to record for twice as upload path will be changed for first time uploading
+ # 2. This will improve the stability of the code asset existence check - AssetNotChanged check in
+ # BlobStorageClient will be a backup check
+ tmp_folder_path = Path(tempfile.mkdtemp(), "code_with_additional_includes")
+ tmp_folder_path.mkdir(parents=True, exist_ok=True)
+
+ root_ignore_file = self._copy_origin_code(tmp_folder_path)
+
+ # resolve additional includes
+ base_path = self.base_path
+ # additional includes from artifact will be downloaded to a temp local path on calling
+ # self.includes, so no need to add specific logic for artifact
+
+ # TODO: skip ignored files defined in code when copying additional includes
+ # copy additional includes disregarding ignore files as current ignore file implementation
+ # is based on absolute path, which is not suitable for additional includes
+ for additional_include_local_path in self._get_resolved_additional_include_configs():
+ src_path = Path(additional_include_local_path)
+ if not src_path.is_absolute():
+ src_path = (base_path / additional_include_local_path).resolve()
+ dst_path = (tmp_folder_path / src_path.name).resolve()
+
+ root_ignore_file.rebase(src_path.parent)
+ if self._is_folder_to_compress(src_path):
+ self._resolve_folder_to_compress(
+ additional_include_local_path,
+ Path(tmp_folder_path),
+ # actual src path is without .zip suffix
+ ignore_file=root_ignore_file.merge(src_path.parent / src_path.stem),
+ )
+ # early continue as the folder is compressed as a zip file
+ continue
+
+ # no need to check if src_path exists as it is already validated
+ if src_path.is_file():
+ self._copy(src_path, dst_path, ignore_file=root_ignore_file)
+ elif src_path.is_dir():
+ self._copy(
+ src_path,
+ dst_path,
+ # root ignore file on parent + ignore file on src_path
+ ignore_file=root_ignore_file.merge(src_path),
+ )
+ else:
+ raise ValueError(f"Unable to find additional include {additional_include_local_path}.")
+ try:
+ yield tmp_folder_path.absolute()
+
+ finally:
+ # clean up tmp folder as it can be very disk space consuming
+ shutil.rmtree(tmp_folder_path, ignore_errors=True)
+
+
+class AdditionalIncludesMixin(ComponentCodeMixin):
+ @classmethod
+ def _get_additional_includes_field_name(cls) -> str:
+ """Get the field name for additional includes.
+
+ :return: The field name
+ :rtype: str
+ """
+ return "additional_includes"
+
+ def _get_all_additional_includes_configs(self) -> List:
+ return getattr(self, self._get_additional_includes_field_name(), [])
+
+ def _append_diagnostics_and_check_if_origin_code_reliable_for_local_path_validation(
+ self, base_validation_result: Optional[MutableValidationResult] = None
+ ) -> bool:
+ is_reliable: bool = super()._append_diagnostics_and_check_if_origin_code_reliable_for_local_path_validation(
+ base_validation_result
+ )
+ additional_includes_obj = self._generate_additional_includes_obj()
+
+ if base_validation_result is not None:
+ base_validation_result.merge_with(
+ additional_includes_obj.validate(), field_name=self._get_additional_includes_field_name()
+ )
+ # if additional includes is specified, origin code will be merged with additional includes into a temp folder
+ # before registered as a code asset, so origin code value is not reliable for local path validation
+ if additional_includes_obj.with_includes:
+ return False
+ return is_reliable
+
+ def _generate_additional_includes_obj(self) -> AdditionalIncludes:
+ return AdditionalIncludes(
+ base_path=self._get_base_path_for_code(),
+ configs=self._get_all_additional_includes_configs(),
+ origin_code_value=self._get_origin_code_in_str(),
+ )
+
+ @contextmanager
+ def _try_build_local_code(self) -> Generator:
+ """Build final code when origin code is a local code.
+
+ Will merge code path with additional includes into a temp folder if additional includes is specified.
+
+ :return: The built Code object
+ :rtype: Iterable[Optional[Code]]
+ """
+ # will try to merge code and additional includes even if code is None
+ tmp_code_dir: Any
+ with self._generate_additional_includes_obj().merge_local_code_and_additional_includes() as tmp_code_dir:
+ if tmp_code_dir is None:
+ yield None
+ else:
+ yield Code(
+ base_path=self._get_base_path_for_code(),
+ path=tmp_code_dir,
+ ignore_file=ComponentIgnoreFile(tmp_code_dir),
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/automl_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/automl_component.py
new file mode 100644
index 00000000..3e7be727
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/automl_component.py
@@ -0,0 +1,42 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Any, Optional
+
+from azure.ai.ml._schema import PathAwareSchema
+from azure.ai.ml._schema.component.automl_component import AutoMLComponentSchema
+from azure.ai.ml.constants._common import COMPONENT_TYPE
+from azure.ai.ml.constants._component import NodeType
+from azure.ai.ml.entities._component.component import Component
+
+
+class AutoMLComponent(Component):
+ """AutoML component entity, used to define an automl component.
+
+ AutoML Component will only be used "internally" for the mentioned scenarios that need it. AutoML Component schema is
+ not intended to be used by the end users and therefore it won't be provided to the end users and it won't have
+ public documentation for the users.
+
+ :param task: Task of the component.
+ :type task: str
+ """
+
+ def __init__(
+ self,
+ *,
+ task: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ kwargs[COMPONENT_TYPE] = NodeType.AUTOML
+ super(AutoMLComponent, self).__init__(**kwargs)
+ self._task = task
+
+ @property
+ def task(self) -> Optional[str]:
+ """Returns task of the component."""
+ return self._task
+
+ @classmethod
+ def _create_schema_for_validation(cls, context: Any) -> PathAwareSchema:
+ return AutoMLComponentSchema(context=context)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/code.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/code.py
new file mode 100644
index 00000000..1f838bec
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/code.py
@@ -0,0 +1,297 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+import os
+from contextlib import contextmanager
+from enum import Enum
+from pathlib import Path
+from typing import Any, Generator, List, Optional, Union
+
+from azure.ai.ml._utils._arm_id_utils import is_ARM_id_for_resource, is_registry_id_for_resource
+from azure.ai.ml._utils._asset_utils import IgnoreFile, get_ignore_file
+from azure.ai.ml._utils.utils import is_private_preview_enabled
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, AzureMLResourceType
+from azure.ai.ml.entities._assets import Code
+from azure.ai.ml.entities._validation import MutableValidationResult
+
+
+class ComponentIgnoreFile(IgnoreFile):
+ _COMPONENT_CODE_IGNORES = ["__pycache__"]
+ """Component-specific ignore file used for ignoring files in a component directory.
+
+ :param directory_path: The directory path for the ignore file.
+ :type directory_path: Union[str, Path]
+ :param additional_includes_file_name: Name of the additional includes file in the root directory to be ignored.
+ :type additional_includes_file_name: str
+ :param skip_ignore_file: Whether to skip the ignore file, defaults to False.
+ :type skip_ignore_file: bool
+ :param extra_ignore_list: List of additional ignore files to be considered during file exclusion.
+ :type extra_ignore_list: List[~azure.ai.ml._utils._asset_utils.IgnoreFile]
+ :raises ValueError: If additional include file is not found.
+ :return: The ComponentIgnoreFile object.
+ :rtype: ComponentIgnoreFile
+ """
+
+ def __init__(
+ self,
+ directory_path: Union[str, Path],
+ *,
+ additional_includes_file_name: Optional[str] = None,
+ skip_ignore_file: bool = False,
+ extra_ignore_list: Optional[List[IgnoreFile]] = None,
+ ):
+ self._base_path: Union[str, Path] = Path(directory_path)
+ self._extra_ignore_list: List[IgnoreFile] = extra_ignore_list or []
+ # only the additional include file in root directory is ignored
+ # additional include files in subdirectories are not processed so keep them
+ self._additional_includes_file_name = additional_includes_file_name
+ # note: the parameter changes to directory path in this class, rather than file path
+ file_path = None if skip_ignore_file else get_ignore_file(directory_path).path
+ super(ComponentIgnoreFile, self).__init__(file_path=file_path)
+
+ def exists(self) -> bool:
+ """Check if the ignore file exists.
+
+ :return: True
+ :rtype: bool
+ """
+ return True
+
+ @property
+ def base_path(self) -> Union[str, Path]:
+ """Get the base path of the ignore file.
+
+ :return: The base path.
+ :rtype: Path
+ """
+ # for component ignore file, the base path can be different from file.parent
+ return self._base_path
+
+ def rebase(self, directory_path: Union[str, Path]) -> "ComponentIgnoreFile":
+ """Rebase the ignore file to a new directory.
+
+ :param directory_path: The new directory path.
+ :type directory_path: Union[str, Path]
+ :return: The rebased ComponentIgnoreFile object.
+ :rtype: ComponentIgnoreFile
+ """
+ self._base_path = directory_path
+ return self
+
+ def is_file_excluded(self, file_path: Union[str, Path]) -> bool:
+ """Check if a file should be excluded based on the ignore file rules.
+
+ :param file_path: The file path.
+ :type file_path: Union[str, Path]
+ :return: True if the file should be excluded, False otherwise.
+ :rtype: bool
+ """
+ if self._additional_includes_file_name and self._get_rel_path(file_path) == self._additional_includes_file_name:
+ return True
+ for ignore_file in self._extra_ignore_list:
+ if ignore_file.is_file_excluded(file_path):
+ return True
+ res: bool = super(ComponentIgnoreFile, self).is_file_excluded(file_path)
+ return res
+
+ def merge(self, other_path: Path) -> "ComponentIgnoreFile":
+ """Merge the ignore list from another ComponentIgnoreFile object.
+
+ :param other_path: The path of the other ignore file.
+ :type other_path: Path
+ :return: The merged ComponentIgnoreFile object.
+ :rtype: ComponentIgnoreFile
+ """
+ if other_path.is_file():
+ return self
+ return ComponentIgnoreFile(other_path, extra_ignore_list=self._extra_ignore_list + [self])
+
+ def _get_ignore_list(self) -> List[str]:
+ """Retrieves the list of ignores from ignore file
+
+ Override to add custom ignores.
+
+ :return: The ignore rules
+ :rtype: List[str]
+ """
+ if not super(ComponentIgnoreFile, self).exists():
+ return self._COMPONENT_CODE_IGNORES
+ res: list = super(ComponentIgnoreFile, self)._get_ignore_list() + self._COMPONENT_CODE_IGNORES
+ return res
+
+
+class CodeType(Enum):
+ """Code type."""
+
+ LOCAL = "local"
+ NONE = "none"
+ GIT = "git"
+ ARM_ID = "arm_id"
+ UNKNOWN = "unknown"
+
+
+def _get_code_type(origin_code_value: Optional[str]) -> CodeType:
+ if origin_code_value is None:
+ return CodeType.NONE
+ if not isinstance(origin_code_value, str):
+ # note that:
+ # 1. Code & CodeOperation are not public for now
+ # 2. AnonymousCodeSchema is not within CodeField
+ # 3. Code will be returned as an arm id as an attribute of a component when getting a component from remote
+ # So origin_code_value should never be a Code object, or an exception will be raised
+ # in validation stage.
+ return CodeType.UNKNOWN
+ if is_ARM_id_for_resource(origin_code_value, AzureMLResourceType.CODE) or is_registry_id_for_resource(
+ origin_code_value
+ ):
+ return CodeType.ARM_ID
+ if origin_code_value.startswith("git+"):
+ return CodeType.GIT
+ return CodeType.LOCAL
+
+
+class ComponentCodeMixin:
+ """Mixin class for components with local files as part of the component. Those local files will be uploaded to
+ blob storage and further referenced as a code asset in arm id. In below docstring, we will refer to those local
+ files as "code".
+
+ The major interface of this mixin is self._customized_code_validate and self._build_code.
+ self._customized_code_validate will return a validation result indicating whether the code is valid.
+ self._build_code will return a temp Code object for server-side code asset creation.
+ """
+
+ def _get_base_path_for_code(self) -> Path:
+ """Get base path for additional includes.
+
+ :return: The base path
+ :rtype: Path
+ """
+ if hasattr(self, BASE_PATH_CONTEXT_KEY):
+ return Path(getattr(self, BASE_PATH_CONTEXT_KEY))
+ raise NotImplementedError(
+ "Component must have a base_path attribute to use ComponentCodeMixin. "
+ "Please set base_path in __init__ or override _get_base_path_for_code."
+ )
+
+ @classmethod
+ def _get_code_field_name(cls) -> str:
+ """Get the field name for code.
+
+ Will be used to get origin code value by default and will be used as field name of validation diagnostics.
+
+ :return: Code field name
+ :rtype: str
+ """
+ return "code"
+
+ def _get_origin_code_value(self) -> Union[str, os.PathLike, None]:
+ """Get origin code value.
+ Origin code value is either an absolute path or a relative path to base path if it's a local path.
+ Additional includes are only supported for component types with code attribute. Origin code path will be copied
+ to a temp folder along with additional includes to form a new code content.
+ """
+ return getattr(self, self._get_code_field_name(), None)
+
+ def _fill_back_code_value(self, value: str) -> None:
+ """Fill resolved code value back to the component.
+
+ :param value: resolved code value
+ :type value: str
+ :return: no return
+ :rtype: None
+ """
+ return setattr(self, self._get_code_field_name(), value)
+
+ def _get_origin_code_in_str(self) -> Optional[str]:
+ """Get origin code value in str to simplify following logic."""
+ origin_code_value = self._get_origin_code_value()
+ if origin_code_value is None:
+ return None
+ if isinstance(origin_code_value, Path):
+ return origin_code_value.as_posix()
+ return str(origin_code_value)
+
+ def _append_diagnostics_and_check_if_origin_code_reliable_for_local_path_validation(
+ self, base_validation_result: Optional[MutableValidationResult] = None
+ ) -> bool:
+ """Append diagnostics from customized validation logic to the base validation result and check if origin code
+ value is valid for path validation.
+
+ For customized validation logic, this method shouldn't cover the validation logic duplicated with schema
+ validation, like local code existence check.
+ For the check, as "code" includes file dependencies of a component, other fields may depend on those files.
+ However, the origin code value may not be reliable for validation of those fields. For example:
+ 1. origin code value can be a remote git path or an arm id of a code asset.
+ 2. some file operations may be done during build_code, which makes final code content different from what we can
+ get from origin code value.
+ So, we use this function to check if origin code value is reliable for further local path validation.
+
+ :param base_validation_result: base validation result to append diagnostics to.
+ :type base_validation_result: MutableValidationResult
+ :return: whether origin code value is reliable for further local path validation.
+ :rtype: bool
+ """
+ # If private features are enable and component has code value of type str we need to check
+ # that it is a valid git path case. Otherwise, we should throw a ValidationError
+ # saying that the code value is not valid
+ code_type = _get_code_type(self._get_origin_code_in_str())
+ if code_type == CodeType.GIT and not is_private_preview_enabled():
+ if base_validation_result is not None:
+ base_validation_result.append_error(
+ message="Not a valid code value: git paths are not supported.",
+ yaml_path=self._get_code_field_name(),
+ )
+ return code_type == CodeType.LOCAL
+
+ @contextmanager
+ def _build_code(self) -> Generator:
+ """Create a Code object if necessary based on origin code value and yield it.
+
+ :return: If built code is the same as its origin value, do nothing and yield None.
+ Otherwise, yield a Code object pointing to the code.
+ :rtype: Iterable[Optional[Code]]
+ """
+ origin_code_value = self._get_origin_code_in_str()
+ code_type = _get_code_type(origin_code_value)
+
+ if code_type == CodeType.GIT:
+ # git also need to be resolved into arm id
+ yield Code(path=origin_code_value)
+ elif code_type in [CodeType.LOCAL, CodeType.NONE]:
+ code: Any
+ # false-positive by pylint, hence disable it
+ # (https://github.com/pylint-dev/pylint/blob/main/doc/data/messages
+ # /c/contextmanager-generator-missing-cleanup/details.rst)
+ with self._try_build_local_code() as code: # pylint:disable=contextmanager-generator-missing-cleanup
+ yield code
+ else:
+ # arm id, None and unknown need no extra resolution
+ yield None
+
+ @contextmanager
+ def _try_build_local_code(self) -> Generator:
+ """Extract the logic of _build_code for local code for further override.
+
+ :return: The Code object if could be constructed, None otherwise
+ :rtype: Iterable[Optional[Code]]
+ """
+ origin_code_value = self._get_origin_code_in_str()
+ if origin_code_value is None:
+ yield None
+ else:
+ base_path = self._get_base_path_for_code()
+ absolute_path: Union[str, Path] = (
+ origin_code_value if os.path.isabs(origin_code_value) else base_path / origin_code_value
+ )
+
+ yield Code(
+ base_path=base_path,
+ path=origin_code_value,
+ ignore_file=ComponentIgnoreFile(absolute_path),
+ )
+
+ def _with_local_code(self) -> bool:
+ # TODO: remove this method after we have a better way to do this judge in cache_utils
+ origin_code_value = self._get_origin_code_in_str()
+ code_type = _get_code_type(origin_code_value)
+ return code_type == CodeType.LOCAL
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/command_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/command_component.py
new file mode 100644
index 00000000..9bdcd3d1
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/command_component.py
@@ -0,0 +1,300 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+import os
+from typing import Any, Dict, List, Optional, Union, cast
+
+from marshmallow import Schema
+
+from azure.ai.ml._schema.component.command_component import CommandComponentSchema
+from azure.ai.ml.constants._common import COMPONENT_TYPE
+from azure.ai.ml.constants._component import NodeType
+from azure.ai.ml.entities._assets import Environment
+from azure.ai.ml.entities._job.distribution import (
+ DistributionConfiguration,
+ MpiDistribution,
+ PyTorchDistribution,
+ RayDistribution,
+ TensorFlowDistribution,
+)
+from azure.ai.ml.entities._job.job_resource_configuration import JobResourceConfiguration
+from azure.ai.ml.entities._job.parameterized_command import ParameterizedCommand
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
+
+from ..._restclient.v2022_10_01.models import ComponentVersion
+from ..._schema import PathAwareSchema
+from ..._utils.utils import get_all_data_binding_expressions, parse_args_description_from_docstring
+from .._util import convert_ordered_dict_to_dict, validate_attribute_type
+from .._validation import MutableValidationResult
+from ._additional_includes import AdditionalIncludesMixin
+from .component import Component
+
+# pylint: disable=protected-access
+
+
+class CommandComponent(Component, ParameterizedCommand, AdditionalIncludesMixin):
+ """Command component version, used to define a Command Component or Job.
+
+ :keyword name: The name of the Command job or component.
+ :paramtype name: Optional[str]
+ :keyword version: The version of the Command job or component.
+ :paramtype version: Optional[str]
+ :keyword description: The description of the component. Defaults to None.
+ :paramtype description: Optional[str]
+ :keyword tags: Tag dictionary. Tags can be added, removed, and updated. Defaults to None.
+ :paramtype tags: Optional[dict]
+ :keyword display_name: The display name of the component.
+ :paramtype display_name: Optional[str]
+ :keyword command: The command to be executed.
+ :paramtype command: Optional[str]
+ :keyword code: The source code to run the job. Can be a local path or "http:", "https:", or "azureml:" url pointing
+ to a remote location.
+ :type code: Optional[str]
+ :keyword environment: The environment that the job will run in.
+ :paramtype environment: Optional[Union[str, ~azure.ai.ml.entities.Environment]]
+ :keyword distribution: The configuration for distributed jobs. Defaults to None.
+ :paramtype distribution: Optional[Union[~azure.ai.ml.PyTorchDistribution, ~azure.ai.ml.MpiDistribution,
+ ~azure.ai.ml.TensorFlowDistribution, ~azure.ai.ml.RayDistribution]]
+ :keyword resources: The compute resource configuration for the command.
+ :paramtype resources: Optional[~azure.ai.ml.entities.JobResourceConfiguration]
+ :keyword inputs: A mapping of input names to input data sources used in the job. Defaults to None.
+ :paramtype inputs: Optional[dict[str, Union[
+ ~azure.ai.ml.Input,
+ str,
+ bool,
+ int,
+ float,
+ Enum,
+ ]]]
+ :keyword outputs: A mapping of output names to output data sources used in the job. Defaults to None.
+ :paramtype outputs: Optional[dict[str, Union[str, ~azure.ai.ml.Output]]]
+ :keyword instance_count: The number of instances or nodes to be used by the compute target. Defaults to 1.
+ :paramtype instance_count: Optional[int]
+ :keyword is_deterministic: Specifies whether the Command will return the same output given the same input.
+ Defaults to True. When True, if a Command (component) is deterministic and has been run before in the
+ current workspace with the same input and settings, it will reuse results from a previous submitted job
+ when used as a node or step in a pipeline. In that scenario, no compute resources will be used.
+ :paramtype is_deterministic: Optional[bool]
+ :keyword additional_includes: A list of shared additional files to be included in the component. Defaults to None.
+ :paramtype additional_includes: Optional[List[str]]
+ :keyword properties: The job property dictionary. Defaults to None.
+ :paramtype properties: Optional[dict[str, str]]
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if CommandComponent cannot be successfully validated.
+ Details will be provided in the error message.
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_command_configurations.py
+ :start-after: [START command_component_definition]
+ :end-before: [END command_component_definition]
+ :language: python
+ :dedent: 8
+ :caption: Creating a CommandComponent.
+ """
+
+ def __init__(
+ self,
+ *,
+ name: Optional[str] = None,
+ version: Optional[str] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ display_name: Optional[str] = None,
+ command: Optional[str] = None,
+ code: Optional[Union[str, os.PathLike]] = None,
+ environment: Optional[Union[str, Environment]] = None,
+ distribution: Optional[
+ Union[
+ Dict,
+ MpiDistribution,
+ TensorFlowDistribution,
+ PyTorchDistribution,
+ RayDistribution,
+ DistributionConfiguration,
+ ]
+ ] = None,
+ resources: Optional[JobResourceConfiguration] = None,
+ inputs: Optional[Dict] = None,
+ outputs: Optional[Dict] = None,
+ instance_count: Optional[int] = None, # promoted property from resources.instance_count
+ is_deterministic: bool = True,
+ additional_includes: Optional[List] = None,
+ properties: Optional[Dict] = None,
+ **kwargs: Any,
+ ) -> None:
+ # validate init params are valid type
+ validate_attribute_type(attrs_to_check=locals(), attr_type_map=self._attr_type_map())
+
+ kwargs[COMPONENT_TYPE] = NodeType.COMMAND
+
+ # Component backend doesn't support environment_variables yet,
+ # this is to support the case of CommandComponent being the trial of
+ # a SweepJob, where environment_variables is stored as part of trial
+ environment_variables = kwargs.pop("environment_variables", None)
+ super().__init__(
+ name=name,
+ version=version,
+ description=description,
+ tags=tags,
+ display_name=display_name,
+ inputs=inputs,
+ outputs=outputs,
+ is_deterministic=is_deterministic,
+ properties=properties,
+ **kwargs,
+ )
+
+ # No validation on value passed here because in pipeline job, required code&environment maybe absent
+ # and fill in later with job defaults.
+ self.command = command
+ self.code = code
+ self.environment_variables = environment_variables
+ self.environment = environment
+ self.resources = resources # type: ignore[assignment]
+ self.distribution = distribution
+
+ # check mutual exclusivity of promoted properties
+ if self.resources is not None and instance_count is not None:
+ msg = "instance_count and resources are mutually exclusive"
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.COMPONENT,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ self.instance_count = instance_count
+ self.additional_includes = additional_includes or []
+
+ def _to_ordered_dict_for_yaml_dump(self) -> Dict:
+ """Dump the component content into a sorted yaml string.
+
+ :return: The ordered dict
+ :rtype: Dict
+ """
+
+ obj: dict = super()._to_ordered_dict_for_yaml_dump()
+ # dict dumped base on schema will transfer code to an absolute path, while we want to keep its original value
+ if self.code and isinstance(self.code, str):
+ obj["code"] = self.code
+ return obj
+
+ @property
+ def instance_count(self) -> Optional[int]:
+ """The number of instances or nodes to be used by the compute target.
+
+ :return: The number of instances or nodes.
+ :rtype: int
+ """
+ return self.resources.instance_count if self.resources and not isinstance(self.resources, dict) else None
+
+ @instance_count.setter
+ def instance_count(self, value: int) -> None:
+ """Sets the number of instances or nodes to be used by the compute target.
+
+ :param value: The number of instances of nodes to be used by the compute target. Defaults to 1.
+ :type value: int
+ """
+ if not value:
+ return
+ if not self.resources:
+ self.resources = JobResourceConfiguration(instance_count=value)
+ else:
+ if not isinstance(self.resources, dict):
+ self.resources.instance_count = value
+
+ @classmethod
+ def _attr_type_map(cls) -> dict:
+ return {
+ "environment": (str, Environment),
+ "environment_variables": dict,
+ "resources": (dict, JobResourceConfiguration),
+ "code": (str, os.PathLike),
+ }
+
+ def _to_dict(self) -> Dict:
+ return cast(
+ dict, convert_ordered_dict_to_dict({**self._other_parameter, **super(CommandComponent, self)._to_dict()})
+ )
+
+ @classmethod
+ def _from_rest_object_to_init_params(cls, obj: ComponentVersion) -> Dict:
+ # put it here as distribution is shared by some components, e.g. command
+ distribution = obj.properties.component_spec.pop("distribution", None)
+ init_kwargs: dict = super()._from_rest_object_to_init_params(obj)
+ if distribution:
+ init_kwargs["distribution"] = DistributionConfiguration._from_rest_object(distribution)
+ return init_kwargs
+
+ def _get_environment_id(self) -> Union[str, None]:
+ # Return environment id of environment
+ # handle case when environment is defined inline
+ if isinstance(self.environment, Environment):
+ _id: Optional[str] = self.environment.id
+ return _id
+ return self.environment
+
+ # region SchemaValidatableMixin
+ @classmethod
+ def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]:
+ return CommandComponentSchema(context=context)
+
+ def _customized_validate(self) -> MutableValidationResult:
+ validation_result = super(CommandComponent, self)._customized_validate()
+ self._append_diagnostics_and_check_if_origin_code_reliable_for_local_path_validation(validation_result)
+ validation_result.merge_with(self._validate_command())
+ validation_result.merge_with(self._validate_early_available_output())
+ return validation_result
+
+ def _validate_command(self) -> MutableValidationResult:
+ validation_result = self._create_empty_validation_result()
+ # command
+ if self.command:
+ invalid_expressions = []
+ for data_binding_expression in get_all_data_binding_expressions(self.command, is_singular=False):
+ if not self._is_valid_data_binding_expression(data_binding_expression):
+ invalid_expressions.append(data_binding_expression)
+
+ if invalid_expressions:
+ validation_result.append_error(
+ yaml_path="command",
+ message="Invalid data binding expression: {}".format(", ".join(invalid_expressions)),
+ )
+ return validation_result
+
+ def _validate_early_available_output(self) -> MutableValidationResult:
+ validation_result = self._create_empty_validation_result()
+ for name, output in self.outputs.items():
+ if output.early_available is True and output._is_primitive_type is not True:
+ msg = (
+ f"Early available output {name!r} requires output is primitive type, "
+ f"got {output._is_primitive_type!r}."
+ )
+ validation_result.append_error(message=msg, yaml_path=f"outputs.{name}")
+ return validation_result
+
+ def _is_valid_data_binding_expression(self, data_binding_expression: str) -> bool:
+ current_obj: Any = self
+ for item in data_binding_expression.split("."):
+ if hasattr(current_obj, item):
+ current_obj = getattr(current_obj, item)
+ else:
+ try:
+ current_obj = current_obj[item]
+ except Exception: # pylint: disable=W0718
+ return False
+ return True
+
+ # endregion
+
+ @classmethod
+ def _parse_args_description_from_docstring(cls, docstring: str) -> Dict:
+ res: dict = parse_args_description_from_docstring(docstring)
+ return res
+
+ def __str__(self) -> str:
+ try:
+ toYaml: str = self._to_yaml()
+ return toYaml
+ except BaseException: # pylint: disable=W0718
+ toStr: str = super(CommandComponent, self).__str__()
+ return toStr
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/component.py
new file mode 100644
index 00000000..c02a3a33
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/component.py
@@ -0,0 +1,641 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+import re
+import uuid
+from os import PathLike
+from pathlib import Path
+from typing import IO, TYPE_CHECKING, Any, AnyStr, Callable, Dict, Iterable, Optional, Tuple, Union
+
+from marshmallow import INCLUDE
+
+from ..._restclient.v2024_01_01_preview.models import (
+ ComponentContainer,
+ ComponentContainerProperties,
+ ComponentVersion,
+ ComponentVersionProperties,
+)
+from ..._schema import PathAwareSchema
+from ..._schema.component import ComponentSchema
+from ..._utils.utils import dump_yaml_to_file, hash_dict
+from ...constants._common import (
+ ANONYMOUS_COMPONENT_NAME,
+ BASE_PATH_CONTEXT_KEY,
+ PARAMS_OVERRIDE_KEY,
+ REGISTRY_URI_FORMAT,
+ SOURCE_PATH_CONTEXT_KEY,
+ CommonYamlFields,
+ SchemaUrl,
+)
+from ...constants._component import ComponentSource, IOConstants, NodeType
+from ...entities._assets.asset import Asset
+from ...entities._inputs_outputs import Input, Output
+from ...entities._mixins import LocalizableMixin, TelemetryMixin, YamlTranslatableMixin
+from ...entities._system_data import SystemData
+from ...entities._util import find_type_in_override
+from ...entities._validation import MutableValidationResult, PathAwareSchemaValidatableMixin, RemoteValidatableMixin
+from ...exceptions import ErrorCategory, ErrorTarget, ValidationException
+from .._inputs_outputs import GroupInput
+
+if TYPE_CHECKING:
+ from ...entities.builders import BaseNode
+# pylint: disable=protected-access, redefined-builtin
+# disable redefined-builtin to use id/type as argument name
+
+
+COMPONENT_PLACEHOLDER = "COMPONENT_PLACEHOLDER"
+
+
+class Component(
+ Asset,
+ RemoteValidatableMixin,
+ TelemetryMixin,
+ YamlTranslatableMixin,
+ PathAwareSchemaValidatableMixin,
+ LocalizableMixin,
+):
+ """Base class for component version, used to define a component. Can't be instantiated directly.
+
+ :param name: Name of the resource.
+ :type name: str
+ :param version: Version of the resource.
+ :type version: str
+ :param id: Global ID of the resource, Azure Resource Manager ID.
+ :type id: str
+ :param type: Type of the command, supported is 'command'.
+ :type type: str
+ :param description: Description of the resource.
+ :type description: str
+ :param tags: Tag dictionary. Tags can be added, removed, and updated.
+ :type tags: dict
+ :param properties: Internal use only.
+ :type properties: dict
+ :param display_name: Display name of the component.
+ :type display_name: str
+ :param is_deterministic: Whether the component is deterministic. Defaults to True.
+ :type is_deterministic: bool
+ :param inputs: Inputs of the component.
+ :type inputs: dict
+ :param outputs: Outputs of the component.
+ :type outputs: dict
+ :param yaml_str: The YAML string of the component.
+ :type yaml_str: str
+ :param _schema: Schema of the component.
+ :type _schema: str
+ :param creation_context: Creation metadata of the component.
+ :type creation_context: ~azure.ai.ml.entities.SystemData
+ :param kwargs: Additional parameters for the component.
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Component cannot be successfully validated.
+ Details will be provided in the error message.
+ """
+
+ # pylint: disable=too-many-instance-attributes
+ def __init__(
+ self,
+ *,
+ name: Optional[str] = None,
+ version: Optional[str] = None,
+ id: Optional[str] = None,
+ type: Optional[str] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ properties: Optional[Dict] = None,
+ display_name: Optional[str] = None,
+ is_deterministic: bool = True,
+ inputs: Optional[Dict] = None,
+ outputs: Optional[Dict] = None,
+ yaml_str: Optional[str] = None,
+ _schema: Optional[str] = None,
+ creation_context: Optional[SystemData] = None,
+ **kwargs: Any,
+ ) -> None:
+ self.latest_version = None
+ self._intellectual_property = kwargs.pop("intellectual_property", None)
+ # Setting this before super init because when asset init version, _auto_increment_version's value may change
+ self._auto_increment_version = kwargs.pop("auto_increment", False)
+ # Get source from id first, then kwargs.
+ self._source = (
+ self._resolve_component_source_from_id(id) if id else kwargs.pop("_source", ComponentSource.CLASS)
+ )
+ # use ANONYMOUS_COMPONENT_NAME instead of guid
+ is_anonymous = kwargs.pop("is_anonymous", False)
+ if not name and version is None:
+ name = ANONYMOUS_COMPONENT_NAME
+ version = "1"
+ is_anonymous = True
+
+ super().__init__(
+ name=name,
+ version=version,
+ id=id,
+ description=description,
+ tags=tags,
+ properties=properties,
+ creation_context=creation_context,
+ is_anonymous=is_anonymous,
+ base_path=kwargs.pop(BASE_PATH_CONTEXT_KEY, None),
+ source_path=kwargs.pop(SOURCE_PATH_CONTEXT_KEY, None),
+ )
+ # store kwargs to self._other_parameter instead of pop to super class to allow component have extra
+ # fields not defined in current schema.
+
+ inputs = inputs if inputs else {}
+ outputs = outputs if outputs else {}
+
+ self.name = name
+ self._schema = _schema
+ self._type = type
+ self._display_name = display_name
+ self._is_deterministic = is_deterministic
+ self._inputs = self._build_io(inputs, is_input=True)
+ self._outputs = self._build_io(outputs, is_input=False)
+ # Store original yaml
+ self._yaml_str = yaml_str
+ self._other_parameter = kwargs
+
+ @property
+ def _func(self) -> Callable[..., "BaseNode"]:
+ from azure.ai.ml.entities._job.pipeline._load_component import _generate_component_function
+
+ # validate input/output names before creating component function
+ validation_result = self._validate_io_names(self.inputs)
+ validation_result.merge_with(self._validate_io_names(self.outputs))
+ self._try_raise(validation_result)
+
+ res: Callable = _generate_component_function(self)
+ return res
+
+ @property
+ def type(self) -> Optional[str]:
+ """Type of the component, default is 'command'.
+
+ :return: Type of the component.
+ :rtype: str
+ """
+ return self._type
+
+ @property
+ def display_name(self) -> Optional[str]:
+ """Display name of the component.
+
+ :return: Display name of the component.
+ :rtype: str
+ """
+ return self._display_name
+
+ @display_name.setter
+ def display_name(self, custom_display_name: str) -> None:
+ """Set display_name of the component.
+
+ :param custom_display_name: The new display name
+ :type custom_display_name: str
+ """
+ self._display_name = custom_display_name
+
+ @property
+ def is_deterministic(self) -> Optional[bool]:
+ """Whether the component is deterministic.
+
+ :return: Whether the component is deterministic
+ :rtype: bool
+ """
+ return self._is_deterministic
+
+ @property
+ def inputs(self) -> Dict:
+ """Inputs of the component.
+
+ :return: Inputs of the component.
+ :rtype: dict
+ """
+ res: dict = self._inputs
+ return res
+
+ @property
+ def outputs(self) -> Dict:
+ """Outputs of the component.
+
+ :return: Outputs of the component.
+ :rtype: dict
+ """
+ return self._outputs
+
+ @property
+ def version(self) -> Optional[str]:
+ """Version of the component.
+
+ :return: Version of the component.
+ :rtype: str
+ """
+ return self._version
+
+ @version.setter
+ def version(self, value: str) -> None:
+ """Set the version of the component.
+
+ :param value: The version of the component.
+ :type value: str
+ """
+ if value:
+ if not isinstance(value, str):
+ msg = f"Component version must be a string, not type {type(value)}."
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.COMPONENT,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ self._version = value
+ self._auto_increment_version = self.name and not self._version
+
+ def dump(self, dest: Union[str, PathLike, IO[AnyStr]], **kwargs: Any) -> None:
+ """Dump the component content into a file in yaml format.
+
+ :param dest: The destination to receive this component's content.
+ Must be either a path to a local file, or an already-open file stream.
+ If dest is a file path, a new file will be created,
+ and an exception is raised if the file exists.
+ If dest is an open file, the file will be written to directly,
+ and an exception will be raised if the file is not writable.
+ :type dest: Union[PathLike, str, IO[AnyStr]]
+ """
+ path = kwargs.pop("path", None)
+ yaml_serialized = self._to_dict()
+ dump_yaml_to_file(dest, yaml_serialized, default_flow_style=False, path=path, **kwargs)
+
+ @staticmethod
+ def _resolve_component_source_from_id( # pylint: disable=docstring-type-do-not-use-class
+ id: Optional[Union["Component", str]],
+ ) -> Any:
+ """Resolve the component source from id.
+
+ :param id: The component ID
+ :type id: Optional[str]
+ :return: The component source
+ :rtype: Literal[
+ ComponentSource.CLASS,
+ ComponentSource.REMOTE_REGISTRY,
+ ComponentSource.REMOTE_WORKSPACE_COMPONENT
+
+ ]
+ """
+ if id is None:
+ return ComponentSource.CLASS
+ # Consider default is workspace source, as
+ # azureml: prefix will be removed for arm versioned id.
+ return (
+ ComponentSource.REMOTE_REGISTRY
+ if not isinstance(id, Component) and id.startswith(REGISTRY_URI_FORMAT)
+ else ComponentSource.REMOTE_WORKSPACE_COMPONENT
+ )
+
+ @classmethod
+ def _validate_io_names(cls, io_names: Iterable[str], raise_error: bool = False) -> MutableValidationResult:
+ """Validate input/output names, raise exception if invalid.
+
+ :param io_names: The names to validate
+ :type io_names: Iterable[str]
+ :param raise_error: Whether to raise if validation fails. Defaults to False
+ :type raise_error: bool
+ :return: The validation result
+ :rtype: MutableValidationResult
+ """
+ validation_result = cls._create_empty_validation_result()
+ lower2original_kwargs: dict = {}
+
+ for name in io_names:
+ if re.match(IOConstants.VALID_KEY_PATTERN, name) is None:
+ msg = "{!r} is not a valid parameter name, must be composed letters, numbers, and underscores."
+ validation_result.append_error(message=msg.format(name), yaml_path=f"inputs.{name}")
+ # validate name conflict
+ lower_key = name.lower()
+ if lower_key in lower2original_kwargs:
+ msg = "Invalid component input names {!r} and {!r}, which are equal ignore case."
+ validation_result.append_error(
+ message=msg.format(name, lower2original_kwargs[lower_key]), yaml_path=f"inputs.{name}"
+ )
+ else:
+ lower2original_kwargs[lower_key] = name
+ return cls._try_raise(validation_result, raise_error=raise_error)
+
+ @classmethod
+ def _build_io(cls, io_dict: Union[Dict, Input, Output], is_input: bool) -> Dict:
+ component_io: dict = {}
+ for name, port in io_dict.items():
+ if is_input:
+ component_io[name] = port if isinstance(port, Input) else Input(**port)
+ else:
+ component_io[name] = port if isinstance(port, Output) else Output(**port)
+
+ if is_input:
+ # Restore flattened parameters to group
+ res: dict = GroupInput.restore_flattened_inputs(component_io)
+ return res
+ return component_io
+
+ @classmethod
+ def _create_schema_for_validation(cls, context: Any) -> PathAwareSchema:
+ return ComponentSchema(context=context)
+
+ @classmethod
+ def _create_validation_error(cls, message: str, no_personal_data_message: str) -> ValidationException:
+ return ValidationException(
+ message=message,
+ no_personal_data_message=no_personal_data_message,
+ target=ErrorTarget.COMPONENT,
+ )
+
+ @classmethod
+ def _is_flow(cls, data: Any) -> bool:
+ _schema = data.get(CommonYamlFields.SCHEMA, None)
+
+ if _schema and _schema in [SchemaUrl.PROMPTFLOW_FLOW, SchemaUrl.PROMPTFLOW_RUN]:
+ return True
+ return False
+
+ @classmethod
+ def _load(
+ cls,
+ data: Optional[Dict] = None,
+ yaml_path: Optional[Union[PathLike, str]] = None,
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "Component":
+ data = data or {}
+ params_override = params_override or []
+ base_path = Path(yaml_path).parent if yaml_path else Path("./")
+
+ type_in_override = find_type_in_override(params_override)
+
+ # type_in_override > type_in_yaml > default (command)
+ if type_in_override is None:
+ type_in_override = data.get(CommonYamlFields.TYPE, None)
+ if type_in_override is None and cls._is_flow(data):
+ type_in_override = NodeType.FLOW_PARALLEL
+ if type_in_override is None:
+ type_in_override = NodeType.COMMAND
+ data[CommonYamlFields.TYPE] = type_in_override
+
+ from azure.ai.ml.entities._component.component_factory import component_factory
+
+ create_instance_func, _ = component_factory.get_create_funcs(
+ data,
+ for_load=True,
+ )
+ new_instance: Component = create_instance_func()
+ # specific keys must be popped before loading with schema using kwargs
+ init_kwargs = {
+ "yaml_str": kwargs.pop("yaml_str", None),
+ "_source": kwargs.pop("_source", ComponentSource.YAML_COMPONENT),
+ }
+ init_kwargs.update(
+ new_instance._load_with_schema( # pylint: disable=protected-access
+ data,
+ context={
+ BASE_PATH_CONTEXT_KEY: base_path,
+ SOURCE_PATH_CONTEXT_KEY: yaml_path,
+ PARAMS_OVERRIDE_KEY: params_override,
+ },
+ unknown=INCLUDE,
+ raise_original_exception=True,
+ **kwargs,
+ )
+ )
+ # Set base path separately to avoid doing this in post load, as return types of post load are not unified,
+ # could be object or dict.
+ # base_path in context can be changed in loading, so we use original base_path here.
+ init_kwargs[BASE_PATH_CONTEXT_KEY] = base_path.absolute()
+ if yaml_path:
+ init_kwargs[SOURCE_PATH_CONTEXT_KEY] = Path(yaml_path).absolute().as_posix()
+ # TODO: Bug Item number: 2883415
+ new_instance.__init__( # type: ignore
+ **init_kwargs,
+ )
+ return new_instance
+
+ @classmethod
+ def _from_container_rest_object(cls, component_container_rest_object: ComponentContainer) -> "Component":
+ component_container_details: ComponentContainerProperties = component_container_rest_object.properties
+ component = Component(
+ id=component_container_rest_object.id,
+ name=component_container_rest_object.name,
+ description=component_container_details.description,
+ creation_context=SystemData._from_rest_object(component_container_rest_object.system_data),
+ tags=component_container_details.tags,
+ properties=component_container_details.properties,
+ type=NodeType._CONTAINER,
+ # Set this field to None as it hold a default True in init.
+ is_deterministic=None, # type: ignore[arg-type]
+ )
+ component.latest_version = component_container_details.latest_version
+ return component
+
+ @classmethod
+ def _from_rest_object(cls, obj: ComponentVersion) -> "Component":
+ # TODO: Remove in PuP with native import job/component type support in MFE/Designer
+ # Convert command component back to import component private preview
+ component_spec = obj.properties.component_spec
+ if component_spec[CommonYamlFields.TYPE] == NodeType.COMMAND and component_spec["command"] == NodeType.IMPORT:
+ component_spec[CommonYamlFields.TYPE] = NodeType.IMPORT
+ component_spec["source"] = component_spec.pop("inputs")
+ component_spec["output"] = component_spec.pop("outputs")["output"]
+
+ # shouldn't block serialization when name is not valid
+ # maybe override serialization method for name field?
+ from azure.ai.ml.entities._component.component_factory import component_factory
+
+ create_instance_func, _ = component_factory.get_create_funcs(obj.properties.component_spec, for_load=True)
+
+ instance: Component = create_instance_func()
+ # TODO: Bug Item number: 2883415
+ instance.__init__(**instance._from_rest_object_to_init_params(obj)) # type: ignore
+ return instance
+
+ @classmethod
+ def _from_rest_object_to_init_params(cls, obj: ComponentVersion) -> Dict:
+ # Object got from rest data contain _source, we delete it.
+ if "_source" in obj.properties.component_spec:
+ del obj.properties.component_spec["_source"]
+
+ rest_component_version = obj.properties
+ _type = rest_component_version.component_spec[CommonYamlFields.TYPE]
+
+ # inputs/outputs will be parsed by instance._build_io in instance's __init__
+ inputs = rest_component_version.component_spec.pop("inputs", {})
+ # parse String -> string, Integer -> integer, etc
+ for _input in inputs.values():
+ _input["type"] = Input._map_from_rest_type(_input["type"])
+ outputs = rest_component_version.component_spec.pop("outputs", {})
+
+ origin_name = rest_component_version.component_spec[CommonYamlFields.NAME]
+ rest_component_version.component_spec[CommonYamlFields.NAME] = ANONYMOUS_COMPONENT_NAME
+ init_kwargs = cls._load_with_schema(
+ rest_component_version.component_spec, context={BASE_PATH_CONTEXT_KEY: Path.cwd()}, unknown=INCLUDE
+ )
+ init_kwargs.update(
+ {
+ "id": obj.id,
+ "is_anonymous": rest_component_version.is_anonymous,
+ "creation_context": obj.system_data,
+ "inputs": inputs,
+ "outputs": outputs,
+ "name": origin_name,
+ }
+ )
+
+ # remove empty values, because some property only works for specific component, eg: distribution for command
+ # note that there is an issue that environment == {} will always be true, so use isinstance here
+ return {k: v for k, v in init_kwargs.items() if v is not None and not (isinstance(v, dict) and not v)}
+
+ def _get_anonymous_hash(self) -> str:
+ """Return the hash of anonymous component.
+
+ Anonymous Components (same code and interface) will have same hash.
+
+ :return: The component hash
+ :rtype: str
+ """
+ # omit version since anonymous component's version is random guid
+ # omit name since name doesn't impact component's uniqueness
+ return self._get_component_hash(keys_to_omit=["name", "id", "version"])
+
+ def _get_component_hash(self, keys_to_omit: Optional[Iterable[str]] = None) -> str:
+ """Return the hash of component.
+
+ :param keys_to_omit: An iterable of keys to omit when computing the component hash
+ :type keys_to_omit: Optional[Iterable[str]]
+ :return: The component hash
+ :rtype: str
+ """
+ component_interface_dict = self._to_dict()
+ res: str = hash_dict(component_interface_dict, keys_to_omit=keys_to_omit)
+ return res
+
+ @classmethod
+ def _get_resource_type(cls) -> str:
+ return "Microsoft.MachineLearningServices/workspaces/components/versions"
+
+ def _get_resource_name_version(self) -> Tuple:
+ version: Optional[str] = None
+ if not self.version and not self._auto_increment_version:
+ version = str(uuid.uuid4())
+ else:
+ version = self.version
+ return self.name or ANONYMOUS_COMPONENT_NAME, version
+
+ def _validate(self, raise_error: Optional[bool] = False) -> MutableValidationResult:
+ origin_name = self.name
+ # skip name validation for anonymous component as ANONYMOUS_COMPONENT_NAME will be used in component creation
+ if self._is_anonymous:
+ self.name = ANONYMOUS_COMPONENT_NAME
+ try:
+ return super()._validate(raise_error)
+ finally:
+ self.name = origin_name
+
+ def _customized_validate(self) -> MutableValidationResult:
+ validation_result = super(Component, self)._customized_validate()
+
+ # validate inputs names
+ validation_result.merge_with(self._validate_io_names(self.inputs, raise_error=False))
+ validation_result.merge_with(self._validate_io_names(self.outputs, raise_error=False))
+
+ return validation_result
+
+ def _get_anonymous_component_name_version(self) -> Tuple:
+ return ANONYMOUS_COMPONENT_NAME, self._get_anonymous_hash()
+
+ def _get_rest_name_version(self) -> Tuple:
+ if self._is_anonymous:
+ return self._get_anonymous_component_name_version()
+ return self.name, self.version
+
+ def _to_rest_object(self) -> ComponentVersion:
+ component = self._to_dict()
+
+ # TODO: Remove in PuP with native import job/component type support in MFE/Designer
+ # Convert import component to command component private preview
+ if component.get(CommonYamlFields.TYPE, None) == NodeType.IMPORT:
+ component[CommonYamlFields.TYPE] = NodeType.COMMAND
+ component["inputs"] = component.pop("source")
+ component["outputs"] = dict({"output": component.pop("output")})
+ # method _to_dict() will remove empty keys
+ if "tags" not in component:
+ component["tags"] = {}
+ component["tags"]["component_type_overwrite"] = NodeType.IMPORT
+ component["command"] = NodeType.IMPORT
+
+ # add source type to component rest object
+ component["_source"] = self._source
+ if self._intellectual_property:
+ # hack while full pass through supported is worked on for IPP fields
+ component.pop("intellectual_property")
+ component["intellectualProperty"] = self._intellectual_property._to_rest_object().serialize()
+ properties = ComponentVersionProperties(
+ component_spec=component,
+ description=self.description,
+ is_anonymous=self._is_anonymous,
+ properties=dict(self.properties) if self.properties else {},
+ tags=self.tags,
+ )
+ result = ComponentVersion(properties=properties)
+ if self._is_anonymous:
+ result.name = ANONYMOUS_COMPONENT_NAME
+ else:
+ result.name = self.name
+ result.properties.properties["client_component_hash"] = self._get_component_hash(keys_to_omit=["version"])
+ return result
+
+ def _to_dict(self) -> Dict:
+ # Replace the name of $schema to schema.
+ component_schema_dict: dict = self._dump_for_validation()
+ component_schema_dict.pop(BASE_PATH_CONTEXT_KEY, None)
+
+ # TODO: handle other_parameters and remove override from subclass
+ return component_schema_dict
+
+ def _localize(self, base_path: str) -> None:
+ """Called on an asset got from service to clean up remote attributes like id, creation_context, etc. and update
+ base_path.
+
+ :param base_path: The base_path
+ :type base_path: str
+ """
+ if not getattr(self, "id", None):
+ raise ValueError("Only remote asset can be localize but got a {} without id.".format(type(self)))
+ self._id = None
+ self._creation_context = None
+ self._base_path = base_path
+
+ def _get_telemetry_values(self, *args: Any, **kwargs: Any) -> Dict:
+ # Note: the is_anonymous is not reliable here, create_or_update will log is_anonymous from parameter.
+ is_anonymous = self.name is None or ANONYMOUS_COMPONENT_NAME in self.name
+ return {"type": self.type, "source": self._source, "is_anonymous": is_anonymous}
+
+ # pylint: disable-next=docstring-missing-param
+ def __call__(self, *args: Any, **kwargs: Any) -> "BaseNode":
+ """Call ComponentVersion as a function and get a Component object.
+
+ :return: The component object
+ :rtype: BaseNode
+ """
+ if args:
+ # raise clear error message for unsupported positional args
+ if self._func._has_parameters: # type: ignore
+ _error = f"got {args} for {self.name}"
+ msg = (
+ f"Component function doesn't support positional arguments, {_error}. " # type: ignore
+ f"Please use keyword arguments like: {self._func._func_calling_example}."
+ )
+ else:
+ msg = (
+ "Component function doesn't has any parameters, "
+ f"please make sure component {self.name} has inputs. "
+ )
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.COMPONENT,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ return self._func(*args, **kwargs) # pylint: disable=not-callable
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/component_factory.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/component_factory.py
new file mode 100644
index 00000000..012dd260
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/component_factory.py
@@ -0,0 +1,171 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from typing import Any, Callable, Dict, Optional, Tuple
+
+from marshmallow import Schema
+
+from ..._restclient.v2022_10_01.models import ComponentVersion
+from ..._utils.utils import is_internal_component_data
+from ...constants._common import SOURCE_PATH_CONTEXT_KEY
+from ...constants._component import DataTransferTaskType, NodeType
+from ...entities._component.automl_component import AutoMLComponent
+from ...entities._component.command_component import CommandComponent
+from ...entities._component.component import Component
+from ...entities._component.datatransfer_component import (
+ DataTransferCopyComponent,
+ DataTransferExportComponent,
+ DataTransferImportComponent,
+)
+from ...entities._component.import_component import ImportComponent
+from ...entities._component.parallel_component import ParallelComponent
+from ...entities._component.pipeline_component import PipelineComponent
+from ...entities._component.spark_component import SparkComponent
+from ...entities._util import get_type_from_spec
+from .flow import FlowComponent
+
+
+class _ComponentFactory:
+ """A class to create component instances from yaml dict or rest objects without hard-coded type check."""
+
+ def __init__(self) -> None:
+ self._create_instance_funcs: Dict = {}
+ self._create_schema_funcs: Dict = {}
+
+ self.register_type(
+ _type=NodeType.PARALLEL,
+ create_instance_func=lambda: ParallelComponent.__new__(ParallelComponent),
+ create_schema_func=ParallelComponent._create_schema_for_validation,
+ )
+ self.register_type(
+ _type=NodeType.COMMAND,
+ create_instance_func=lambda: CommandComponent.__new__(CommandComponent),
+ create_schema_func=CommandComponent._create_schema_for_validation,
+ )
+ self.register_type(
+ _type=NodeType.IMPORT,
+ create_instance_func=lambda: ImportComponent.__new__(ImportComponent),
+ create_schema_func=ImportComponent._create_schema_for_validation,
+ )
+ self.register_type(
+ _type=NodeType.PIPELINE,
+ create_instance_func=lambda: PipelineComponent.__new__(PipelineComponent),
+ create_schema_func=PipelineComponent._create_schema_for_validation,
+ )
+ self.register_type(
+ _type=NodeType.AUTOML,
+ create_instance_func=lambda: AutoMLComponent.__new__(AutoMLComponent),
+ create_schema_func=AutoMLComponent._create_schema_for_validation,
+ )
+ self.register_type(
+ _type=NodeType.SPARK,
+ create_instance_func=lambda: SparkComponent.__new__(SparkComponent),
+ create_schema_func=SparkComponent._create_schema_for_validation,
+ )
+ self.register_type(
+ _type="_".join([NodeType.DATA_TRANSFER, DataTransferTaskType.COPY_DATA]),
+ create_instance_func=lambda: DataTransferCopyComponent.__new__(DataTransferCopyComponent),
+ create_schema_func=DataTransferCopyComponent._create_schema_for_validation,
+ )
+
+ self.register_type(
+ _type="_".join([NodeType.DATA_TRANSFER, DataTransferTaskType.IMPORT_DATA]),
+ create_instance_func=lambda: DataTransferImportComponent.__new__(DataTransferImportComponent),
+ create_schema_func=DataTransferImportComponent._create_schema_for_validation,
+ )
+
+ self.register_type(
+ _type="_".join([NodeType.DATA_TRANSFER, DataTransferTaskType.EXPORT_DATA]),
+ create_instance_func=lambda: DataTransferExportComponent.__new__(DataTransferExportComponent),
+ create_schema_func=DataTransferExportComponent._create_schema_for_validation,
+ )
+
+ self.register_type(
+ _type=NodeType.FLOW_PARALLEL,
+ create_instance_func=lambda: FlowComponent.__new__(FlowComponent),
+ create_schema_func=FlowComponent._create_schema_for_validation,
+ )
+
+ def get_create_funcs(
+ self, yaml_spec: dict, for_load: bool = False
+ ) -> Tuple[Callable[..., Component], Callable[[Any], Schema]]:
+ """Get registered functions to create an instance and its corresponding schema for the given type.
+
+ :param yaml_spec: The YAML specification.
+ :type yaml_spec: dict
+ :param for_load: Whether the function is called for loading a component. Defaults to False.
+ :type for_load: bool
+ :return: A tuple containing the create_instance_func and create_schema_func.
+ :rtype: tuple
+ """
+
+ _type = get_type_from_spec(yaml_spec, valid_keys=self._create_instance_funcs)
+ # SparkComponent and InternalSparkComponent share the same type name, but they are different types.
+ if for_load and is_internal_component_data(yaml_spec, raise_if_not_enabled=True) and _type == NodeType.SPARK:
+ from azure.ai.ml._internal._schema.node import NodeType as InternalNodeType
+
+ _type = InternalNodeType.SPARK
+
+ create_instance_func = self._create_instance_funcs[_type]
+ create_schema_func = self._create_schema_funcs[_type]
+ return create_instance_func, create_schema_func
+
+ def register_type(
+ self,
+ _type: str,
+ create_instance_func: Callable[..., Component],
+ create_schema_func: Callable[[Any], Schema],
+ ) -> None:
+ """Register a new component type.
+
+ :param _type: The type name of the component.
+ :type _type: str
+ :param create_instance_func: A function to create an instance of the component.
+ :type create_instance_func: Callable[..., ~azure.ai.ml.entities.Component]
+ :param create_schema_func: A function to create a schema for the component.
+ :type create_schema_func: Callable[[Any], Schema]
+ """
+ self._create_instance_funcs[_type] = create_instance_func
+ self._create_schema_funcs[_type] = create_schema_func
+
+ @classmethod
+ def load_from_dict(cls, *, data: Dict, context: Dict, _type: Optional[str] = None, **kwargs: Any) -> Component:
+ """Load a component from a YAML dict.
+
+ :keyword data: The YAML dict.
+ :paramtype data: dict
+ :keyword context: The context of the YAML dict.
+ :paramtype context: dict
+ :keyword _type: The type name of the component. When None, it will be inferred from the YAML dict.
+ :paramtype _type: str
+ :return: The loaded component.
+ :rtype: ~azure.ai.ml.entities.Component
+ """
+
+ return Component._load(
+ data=data,
+ yaml_path=context.get(SOURCE_PATH_CONTEXT_KEY, None),
+ params_override=[{"type": _type}] if _type is not None else [],
+ **kwargs,
+ )
+
+ @classmethod
+ def load_from_rest(cls, *, obj: ComponentVersion, _type: Optional[str] = None) -> Component:
+ """Load a component from a REST object.
+
+ :keyword obj: The REST object.
+ :paramtype obj: ComponentVersion
+ :keyword _type: The type name of the component. When None, it will be inferred from the REST object.
+ :paramtype _type: str
+ :return: The loaded component.
+ :rtype: ~azure.ai.ml.entities.Component
+ """
+ if _type is not None:
+ obj.properties.component_spec["type"] = _type
+ return Component._from_rest_object(obj)
+
+
+component_factory = _ComponentFactory()
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/datatransfer_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/datatransfer_component.py
new file mode 100644
index 00000000..e71712ab
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/datatransfer_component.py
@@ -0,0 +1,325 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from pathlib import Path
+from typing import Any, Dict, NoReturn, Optional, Union, cast
+
+from marshmallow import Schema
+
+from azure.ai.ml._schema.component.data_transfer_component import (
+ DataTransferCopyComponentSchema,
+ DataTransferExportComponentSchema,
+ DataTransferImportComponentSchema,
+)
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, COMPONENT_TYPE, AssetTypes
+from azure.ai.ml.constants._component import DataTransferTaskType, ExternalDataType, NodeType
+from azure.ai.ml.entities._inputs_outputs.external_data import Database, FileSystem
+from azure.ai.ml.entities._inputs_outputs.output import Output
+from azure.ai.ml.entities._validation.core import MutableValidationResult
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+
+from ..._schema import PathAwareSchema
+from .._util import convert_ordered_dict_to_dict, validate_attribute_type
+from .component import Component
+
+
+class DataTransferComponent(Component):
+ """DataTransfer component version, used to define a data transfer component.
+
+ :param task: Task type in the data transfer component. Possible values are "copy_data",
+ "import_data", and "export_data".
+ :type task: str
+ :param inputs: Mapping of input data bindings used in the job.
+ :type inputs: dict
+ :param outputs: Mapping of output data bindings used in the job.
+ :type outputs: dict
+ :param kwargs: Additional parameters for the data transfer component.
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if the component cannot be successfully validated.
+ Details will be provided in the error message.
+ """
+
+ def __init__(
+ self,
+ *,
+ task: Optional[str] = None,
+ inputs: Optional[Dict] = None,
+ outputs: Optional[Dict] = None,
+ **kwargs: Any,
+ ) -> None:
+ # validate init params are valid type
+ validate_attribute_type(attrs_to_check=locals(), attr_type_map=self._attr_type_map())
+
+ kwargs[COMPONENT_TYPE] = NodeType.DATA_TRANSFER
+ # Set default base path
+ if BASE_PATH_CONTEXT_KEY not in kwargs:
+ kwargs[BASE_PATH_CONTEXT_KEY] = Path(".")
+
+ super().__init__(
+ inputs=inputs,
+ outputs=outputs,
+ **kwargs,
+ )
+ self._task = task
+
+ @classmethod
+ def _attr_type_map(cls) -> dict:
+ return {}
+
+ @property
+ def task(self) -> Optional[str]:
+ """Task type of the component.
+
+ :return: Task type of the component.
+ :rtype: str
+ """
+ return self._task
+
+ def _to_dict(self) -> Dict:
+ return cast(
+ dict,
+ convert_ordered_dict_to_dict({**self._other_parameter, **super(DataTransferComponent, self)._to_dict()}),
+ )
+
+ def __str__(self) -> str:
+ try:
+ _toYaml: str = self._to_yaml()
+ return _toYaml
+ except BaseException: # pylint: disable=W0718
+ _toStr: str = super(DataTransferComponent, self).__str__()
+ return _toStr
+
+ @classmethod
+ def _build_source_sink(cls, io_dict: Union[Dict, Database, FileSystem]) -> Union[Database, FileSystem]:
+ component_io: Union[Database, FileSystem] = Database()
+
+ if isinstance(io_dict, Database):
+ component_io = Database()
+ elif isinstance(io_dict, FileSystem):
+ component_io = FileSystem()
+ else:
+ if isinstance(io_dict, dict):
+ data_type = io_dict.pop("type", None)
+ if data_type == ExternalDataType.DATABASE:
+ component_io = Database()
+ elif data_type == ExternalDataType.FILE_SYSTEM:
+ component_io = FileSystem()
+ else:
+ msg = "Type in source or sink only support {} and {}, currently got {}."
+ raise ValidationException(
+ message=msg.format(
+ ExternalDataType.DATABASE,
+ ExternalDataType.FILE_SYSTEM,
+ data_type,
+ ),
+ no_personal_data_message=msg.format(
+ ExternalDataType.DATABASE,
+ ExternalDataType.FILE_SYSTEM,
+ "data_type",
+ ),
+ target=ErrorTarget.COMPONENT,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ else:
+ msg = "Source or sink only support dict, Database and FileSystem"
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.COMPONENT,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+ return component_io
+
+
+@experimental
+class DataTransferCopyComponent(DataTransferComponent):
+ """DataTransfer copy component version, used to define a data transfer copy component.
+
+ :param data_copy_mode: Data copy mode in the copy task.
+ Possible values are "merge_with_overwrite" and "fail_if_conflict".
+ :type data_copy_mode: str
+ :param inputs: Mapping of input data bindings used in the job.
+ :type inputs: dict
+ :param outputs: Mapping of output data bindings used in the job.
+ :type outputs: dict
+ :param kwargs: Additional parameters for the data transfer copy component.
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if the component cannot be successfully validated.
+ Details will be provided in the error message.
+ """
+
+ def __init__(
+ self,
+ *,
+ data_copy_mode: Optional[str] = None,
+ inputs: Optional[Dict] = None,
+ outputs: Optional[Dict] = None,
+ **kwargs: Any,
+ ) -> None:
+ kwargs["task"] = DataTransferTaskType.COPY_DATA
+ super().__init__(
+ inputs=inputs,
+ outputs=outputs,
+ **kwargs,
+ )
+
+ self._data_copy_mode = data_copy_mode
+
+ @classmethod
+ def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]:
+ return DataTransferCopyComponentSchema(context=context)
+
+ @property
+ def data_copy_mode(self) -> Optional[str]:
+ """Data copy mode of the component.
+
+ :return: Data copy mode of the component.
+ :rtype: str
+ """
+ return self._data_copy_mode
+
+ def _customized_validate(self) -> MutableValidationResult:
+ validation_result = super(DataTransferCopyComponent, self)._customized_validate()
+ validation_result.merge_with(self._validate_input_output_mapping())
+ return validation_result
+
+ def _validate_input_output_mapping(self) -> MutableValidationResult:
+ validation_result = self._create_empty_validation_result()
+ inputs_count = len(self.inputs)
+ outputs_count = len(self.outputs)
+ if outputs_count != 1:
+ msg = "Only support single output in {}, but there're {} outputs."
+ validation_result.append_error(
+ message=msg.format(DataTransferTaskType.COPY_DATA, outputs_count),
+ yaml_path="outputs",
+ )
+ else:
+ input_type = None
+ output_type = None
+ if inputs_count == 1:
+ for _, input_data in self.inputs.items():
+ input_type = input_data.type
+ for _, output_data in self.outputs.items():
+ output_type = output_data.type
+ if input_type is None or output_type is None or input_type != output_type:
+ msg = "Input type {} doesn't exactly match with output type {} in task {}"
+ validation_result.append_error(
+ message=msg.format(input_type, output_type, DataTransferTaskType.COPY_DATA),
+ yaml_path="outputs",
+ )
+ elif inputs_count > 1:
+ for _, output_data in self.outputs.items():
+ output_type = output_data.type
+ if output_type is None or output_type != AssetTypes.URI_FOLDER:
+ msg = "output type {} need to be {} in task {}"
+ validation_result.append_error(
+ message=msg.format(
+ output_type,
+ AssetTypes.URI_FOLDER,
+ DataTransferTaskType.COPY_DATA,
+ ),
+ yaml_path="outputs",
+ )
+ else:
+ msg = "Inputs must be set in task {}."
+ validation_result.append_error(
+ message=msg.format(DataTransferTaskType.COPY_DATA),
+ yaml_path="inputs",
+ )
+ return validation_result
+
+
+@experimental
+class DataTransferImportComponent(DataTransferComponent):
+ """DataTransfer import component version, used to define a data transfer import component.
+
+ :param source: The data source of the file system or database.
+ :type source: dict
+ :param outputs: Mapping of output data bindings used in the job.
+ Default value is an output port with the key "sink" and the type "mltable".
+ :type outputs: dict
+ :param kwargs: Additional parameters for the data transfer import component.
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if the component cannot be successfully validated.
+ Details will be provided in the error message.
+ """
+
+ def __init__(
+ self,
+ *,
+ source: Optional[Dict] = None,
+ outputs: Optional[Dict] = None,
+ **kwargs: Any,
+ ) -> None:
+ outputs = outputs or {"sink": Output(type=AssetTypes.MLTABLE)}
+ kwargs["task"] = DataTransferTaskType.IMPORT_DATA
+ super().__init__(
+ outputs=outputs,
+ **kwargs,
+ )
+
+ source = source if source else {}
+ self.source = self._build_source_sink(source)
+
+ @classmethod
+ def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]:
+ return DataTransferImportComponentSchema(context=context)
+
+ # pylint: disable-next=docstring-missing-param
+ def __call__(self, *args: Any, **kwargs: Any) -> NoReturn:
+ """Call ComponentVersion as a function and get a Component object."""
+
+ msg = "DataTransfer component is not callable for import task."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.COMPONENT,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+
+@experimental
+class DataTransferExportComponent(DataTransferComponent):
+ """DataTransfer export component version, used to define a data transfer export component.
+
+ :param sink: The sink of external data and databases.
+ :type sink: Union[Dict, Database, FileSystem]
+ :param inputs: Mapping of input data bindings used in the job.
+ :type inputs: dict
+ :param kwargs: Additional parameters for the data transfer export component.
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if the component cannot be successfully validated.
+ Details will be provided in the error message.
+ """
+
+ def __init__(
+ self,
+ *,
+ inputs: Optional[Dict] = None,
+ sink: Optional[Dict] = None,
+ **kwargs: Any,
+ ) -> None:
+ kwargs["task"] = DataTransferTaskType.EXPORT_DATA
+ super().__init__(
+ inputs=inputs,
+ **kwargs,
+ )
+
+ sink = sink if sink else {}
+ self.sink = self._build_source_sink(sink)
+
+ @classmethod
+ def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]:
+ return DataTransferExportComponentSchema(context=context)
+
+ # pylint: disable-next=docstring-missing-param
+ def __call__(self, *args: Any, **kwargs: Any) -> NoReturn:
+ """Call ComponentVersion as a function and get a Component object."""
+
+ msg = "DataTransfer component is not callable for export task."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.COMPONENT,
+ error_category=ErrorCategory.USER_ERROR,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/flow.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/flow.py
new file mode 100644
index 00000000..e4ff06cc
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/flow.py
@@ -0,0 +1,553 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+import contextlib
+import json
+import os
+from collections import defaultdict
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union
+
+import yaml # type: ignore[import]
+from marshmallow import EXCLUDE, Schema, ValidationError
+
+from azure.ai.ml.constants._common import (
+ BASE_PATH_CONTEXT_KEY,
+ COMPONENT_TYPE,
+ PROMPTFLOW_AZUREML_OVERRIDE_KEY,
+ SOURCE_PATH_CONTEXT_KEY,
+ AssetTypes,
+ SchemaUrl,
+)
+from azure.ai.ml.constants._component import ComponentParameterTypes, NodeType
+
+from ..._restclient.v2022_10_01.models import ComponentVersion
+from ..._schema import PathAwareSchema
+from ..._schema.component.flow import FlowComponentSchema, FlowSchema, RunSchema
+from ...exceptions import ErrorCategory, ErrorTarget, ValidationException
+from .. import Environment
+from .._inputs_outputs import GroupInput, Input, Output
+from ._additional_includes import AdditionalIncludesMixin
+from .component import Component
+
+# avoid circular import error
+if TYPE_CHECKING:
+ from azure.ai.ml.entities._builders.parallel import Parallel
+
+# pylint: disable=protected-access
+
+
+class _FlowPortNames:
+ """Common yaml fields.
+
+ Common yaml fields are used to define the common fields in yaml files. It can be one of the following values: type,
+ name, $schema.
+ """
+
+ DATA = "data"
+ RUN_OUTPUTS = "run_outputs"
+ CONNECTIONS = "connections"
+
+ FLOW_OUTPUTS = "flow_outputs"
+ DEBUG_INFO = "debug_info"
+
+
+class _FlowComponentPortDict(dict):
+ def __init__(self, ports: Dict):
+ self._allow_update_item = True
+ super().__init__()
+ for input_port_name, input_port in ports.items():
+ self[input_port_name] = input_port
+ self._allow_update_item = False
+
+ def __setitem__(self, key: Any, value: Any) -> None:
+ if not self._allow_update_item:
+ raise RuntimeError("Ports of flow component are not editable.")
+ super().__setitem__(key, value)
+
+ def __delitem__(self, key: Any) -> None:
+ if not self._allow_update_item:
+ raise RuntimeError("Ports of flow component are not editable.")
+ super().__delitem__(key)
+
+
+class FlowComponentInputDict(_FlowComponentPortDict):
+ """Input port dictionary for FlowComponent, with fixed input ports."""
+
+ def __init__(self) -> None:
+ super().__init__(
+ {
+ _FlowPortNames.CONNECTIONS: GroupInput(values={}, _group_class=None),
+ _FlowPortNames.DATA: Input(type=AssetTypes.URI_FOLDER, optional=False),
+ _FlowPortNames.FLOW_OUTPUTS: Input(type=AssetTypes.URI_FOLDER, optional=True),
+ }
+ )
+
+ @contextlib.contextmanager
+ def _fit_inputs(self, inputs: Optional[Dict]) -> Generator:
+ """Add dynamic input ports to the input port dictionary.
+ Input ports of a flow component include:
+ 1. data: required major uri_folder input
+ 2. run_output: optional uri_folder input
+ 3. connections.xxx.xxx: group of string parameters, first layer key can be any node name,
+ but we won't resolve the exact keys in SDK
+ 4. xxx: input_mapping parameters, key can be any node name, but we won't resolve the exact keys in SDK
+
+ #3 will be grouped into connections, we make it a fixed group input port.
+ #4 are dynamic input ports, we will add them temporarily in this context manager and remove them
+ after the context manager is finished.
+
+ :param inputs: The dynamic input to fit.
+ :type inputs: Dict[str, Any]
+ :return: None
+ :rtype: None
+ """
+ dynamic_columns_mapping_keys = []
+ dynamic_connections_inputs = defaultdict(list)
+ from azure.ai.ml.entities._job.pipeline._io import _GroupAttrDict
+ from azure.ai.ml.entities._job.pipeline._io.mixin import flatten_dict
+
+ flattened_inputs = flatten_dict(inputs, _GroupAttrDict, allow_dict_fields=[_FlowPortNames.CONNECTIONS])
+
+ for flattened_input_key in flattened_inputs:
+ if flattened_input_key.startswith(f"{_FlowPortNames.CONNECTIONS}."):
+ if flattened_input_key.count(".") != 2:
+ raise ValidationException(
+ message="flattened connection input prot name must be "
+ "in the format of connections.<node_name>.<port_name>, "
+ "but got %s" % flattened_input_key,
+ no_personal_data_message="flattened connection input prot name must be in the format of "
+ "connections.<node_name>.<port_name>",
+ target=ErrorTarget.COMPONENT,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ _, node_name, param_name = flattened_input_key.split(".")
+ dynamic_connections_inputs[node_name].append(param_name)
+ continue
+ if flattened_input_key not in self:
+ dynamic_columns_mapping_keys.append(flattened_input_key)
+
+ self._allow_update_item = True
+ for flattened_input_key in dynamic_columns_mapping_keys:
+ self[flattened_input_key] = Input(type=ComponentParameterTypes.STRING, optional=True)
+ if dynamic_connections_inputs:
+ self[_FlowPortNames.CONNECTIONS] = GroupInput(
+ values={
+ node_name: GroupInput(
+ values={
+ parameter_name: Input(
+ type=ComponentParameterTypes.STRING,
+ )
+ for parameter_name in param_names
+ },
+ _group_class=None,
+ )
+ for node_name, param_names in dynamic_connections_inputs.items()
+ },
+ _group_class=None,
+ )
+ self._allow_update_item = False
+
+ yield
+
+ self._allow_update_item = True
+ for flattened_input_key in dynamic_columns_mapping_keys:
+ del self[flattened_input_key]
+ self[_FlowPortNames.CONNECTIONS] = GroupInput(values={}, _group_class=None)
+ self._allow_update_item = False
+
+
+class FlowComponentOutputDict(_FlowComponentPortDict):
+ """Output port dictionary for FlowComponent, with fixed output ports."""
+
+ def __init__(self) -> None:
+ super().__init__(
+ {
+ _FlowPortNames.FLOW_OUTPUTS: Output(type=AssetTypes.URI_FOLDER),
+ _FlowPortNames.DEBUG_INFO: Output(type=AssetTypes.URI_FOLDER),
+ }
+ )
+
+
+class FlowComponent(Component, AdditionalIncludesMixin):
+ """Flow component version, used to define a Flow Component or Job.
+
+ :keyword name: The name of the Flow job or component.
+ :type name: Optional[str]
+ :keyword version: The version of the Flow job or component.
+ :type version: Optional[str]
+ :keyword description: The description of the component. Defaults to None.
+ :type description: Optional[str]
+ :keyword tags: Tag dictionary. Tags can be added, removed, and updated. Defaults to None.
+ :type tags: Optional[dict]
+ :keyword display_name: The display name of the component.
+ :type display_name: Optional[str]
+ :keyword flow: The path to the flow directory or flow definition file. Defaults to None and base path of this
+ component will be used as flow directory.
+ :type flow: Optional[Union[str, Path]]
+ :keyword column_mappings: The column mapping for the flow. Defaults to None.
+ :type column_mapping: Optional[dict[str, str]]
+ :keyword variant: The variant of the flow. Defaults to None.
+ :type variant: Optional[str]
+ :keyword connections: The connections for the flow. Defaults to None.
+ :type connections: Optional[dict[str, dict[str, str]]]
+ :keyword environment_variables: The environment variables for the flow. Defaults to None.
+ :type environment_variables: Optional[dict[str, str]]
+ :keyword environment: The environment for the flow component. Defaults to None.
+ :type environment: Optional[Union[str, Environment])
+ :keyword is_deterministic: Specifies whether the Flow will return the same output given the same input.
+ Defaults to True. When True, if a Flow (component) is deterministic and has been run before in the
+ current workspace with the same input and settings, it will reuse results from a previous submitted job
+ when used as a node or step in a pipeline. In that scenario, no compute resources will be used.
+ :type is_deterministic: Optional[bool]
+ :keyword additional_includes: A list of shared additional files to be included in the component. Defaults to None.
+ :type additional_includes: Optional[list[str]]
+ :keyword properties: The job property dictionary. Defaults to None.
+ :type properties: Optional[dict[str, str]]
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if FlowComponent cannot be successfully validated.
+ Details will be provided in the error message.
+ """
+
+ def __init__(
+ self,
+ *,
+ name: Optional[str] = None,
+ version: Optional[str] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ display_name: Optional[str] = None,
+ flow: Optional[Union[str, Path]] = None,
+ column_mapping: Optional[Dict[str, str]] = None,
+ variant: Optional[str] = None,
+ connections: Optional[Dict[str, Dict[str, str]]] = None,
+ environment_variables: Optional[Dict[str, str]] = None,
+ environment: Optional[Union[str, Environment]] = None,
+ is_deterministic: bool = True,
+ additional_includes: Optional[List] = None,
+ properties: Optional[Dict] = None,
+ **kwargs: Any,
+ ) -> None:
+ # validate init params are valid type
+ kwargs[COMPONENT_TYPE] = NodeType.FLOW_PARALLEL
+
+ # always use flow directory as base path
+ # Note: we suppose that there is no relative path in run.yaml other than flow.
+ # If there are any, we will need to rebase them so that they have the same base path as attributes in
+ # flow.dag.yaml
+ flow_dir, self._flow = self._get_flow_definition(
+ flow=flow,
+ base_path=kwargs.pop(BASE_PATH_CONTEXT_KEY, Path.cwd()),
+ source_path=kwargs.get(SOURCE_PATH_CONTEXT_KEY, None),
+ )
+ kwargs[BASE_PATH_CONTEXT_KEY] = flow_dir
+
+ super().__init__(
+ name=name or self._normalize_component_name(flow_dir.name),
+ version=version or "1",
+ description=description,
+ tags=tags,
+ display_name=display_name,
+ inputs={},
+ outputs={},
+ is_deterministic=is_deterministic,
+ properties=properties,
+ **kwargs,
+ )
+ self._environment = environment
+ self._column_mapping = column_mapping or {}
+ self._variant = variant
+ self._connections = connections or {}
+
+ self._inputs = FlowComponentInputDict()
+ self._outputs = FlowComponentOutputDict()
+
+ if flow:
+ # file existence has been checked in _get_flow_definition
+ # we don't need to rebase additional_includes as we have updated base_path
+ with open(Path(self.base_path, self._flow), "r", encoding="utf-8") as f:
+ flow_content = yaml.safe_load(f.read())
+ additional_includes = flow_content.get("additional_includes", None)
+ # environment variables in run.yaml have higher priority than those in flow.dag.yaml
+ self._environment_variables = flow_content.get("environment_variables", {})
+ self._environment_variables.update(environment_variables or {})
+ else:
+ self._environment_variables = environment_variables or {}
+
+ self._additional_includes = additional_includes or []
+
+ # unlike other Component, code is a private property in FlowComponent and
+ # will be used to store the arm id of the created code before constructing rest object
+ # we haven't used self.flow directly as self.flow can be a path to the flow dag yaml file instead of a directory
+ self._code_arm_id: Optional[str] = None
+
+ # region valid properties
+ @property
+ def flow(self) -> str:
+ """The path to the flow definition file relative to the flow directory.
+
+ :rtype: str
+ """
+ return self._flow
+
+ @property
+ def environment(self) -> Optional[Union[str, Environment]]:
+ """The environment for the flow component. Defaults to None.
+
+ :rtype: Union[str, Environment])
+ """
+ return self._environment
+
+ @environment.setter
+ def environment(self, value: Union[str, Environment]) -> None:
+ """The environment for the flow component. Defaults to None.
+
+ :param value: The column mapping for the flow.
+ :type value: Union[str, Environment])
+ """
+ self._environment = value
+
+ @property
+ def column_mapping(self) -> Dict[str, str]:
+ """The column mapping for the flow. Defaults to None.
+
+ :rtype: Dict[str, str]
+ """
+ return self._column_mapping
+
+ @column_mapping.setter
+ def column_mapping(self, value: Optional[Dict[str, str]]) -> None:
+ """
+ The column mapping for the flow. Defaults to None.
+
+ :param value: The column mapping for the flow.
+ :type value: Optional[Dict[str, str]]
+ """
+ self._column_mapping = value or {}
+
+ @property
+ def variant(self) -> Optional[str]:
+ """The variant of the flow. Defaults to None.
+
+ :rtype: Optional[str]
+ """
+ return self._variant
+
+ @variant.setter
+ def variant(self, value: Optional[str]) -> None:
+ """The variant of the flow. Defaults to None.
+
+ :param value: The variant of the flow.
+ :type value: Optional[str]
+ """
+ self._variant = value
+
+ @property
+ def connections(self) -> Dict[str, Dict[str, str]]:
+ """The connections for the flow. Defaults to None.
+
+ :rtype: Dict[str, Dict[str, str]]
+ """
+ return self._connections
+
+ @connections.setter
+ def connections(self, value: Optional[Dict[str, Dict[str, str]]]) -> None:
+ """
+ The connections for the flow. Defaults to None.
+
+ :param value: The connections for the flow.
+ :type value: Optional[Dict[str, Dict[str, str]]]
+ """
+ self._connections = value or {}
+
+ @property
+ def environment_variables(self) -> Dict[str, str]:
+ """The environment variables for the flow. Defaults to None.
+
+ :rtype: Dict[str, str]
+ """
+ return self._environment_variables
+
+ @environment_variables.setter
+ def environment_variables(self, value: Optional[Dict[str, str]]) -> None:
+ """The environment variables for the flow. Defaults to None.
+
+ :param value: The environment variables for the flow.
+ :type value: Optional[Dict[str, str]]
+ """
+ self._environment_variables = value or {}
+
+ @property
+ def additional_includes(self) -> List:
+ """A list of shared additional files to be included in the component. Defaults to None.
+
+ :rtype: List
+ """
+ return self._additional_includes
+
+ @additional_includes.setter
+ def additional_includes(self, value: Optional[List]) -> None:
+ """A list of shared additional files to be included in the component. Defaults to None.
+ All local additional includes should be relative to the flow directory.
+
+ :param value: A list of shared additional files to be included in the component.
+ :type value: Optional[List]
+ """
+ self._additional_includes = value or []
+
+ # endregion
+
+ @classmethod
+ def _normalize_component_name(cls, value: str) -> str:
+ return value.replace("-", "_")
+
+ # region Component
+ @classmethod
+ def _from_rest_object_to_init_params(cls, obj: ComponentVersion) -> Dict:
+ raise RuntimeError("FlowComponent does not support loading from REST object.")
+
+ def _to_rest_object(self) -> ComponentVersion:
+ rest_obj = super()._to_rest_object()
+ rest_obj.properties.component_spec["code"] = self._code_arm_id
+ rest_obj.properties.component_spec["flow_file_name"] = self._flow
+ return rest_obj
+
+ def _func(self, **kwargs: Any) -> "Parallel": # pylint: disable=invalid-overridden-method
+ from azure.ai.ml.entities._builders.parallel import Parallel
+
+ with self._inputs._fit_inputs(kwargs): # type: ignore[attr-defined]
+ # pylint: disable=not-callable
+ return super()._func(**kwargs) # type: ignore
+
+ @classmethod
+ def _get_flow_definition(
+ cls,
+ base_path: Path,
+ *,
+ flow: Optional[Union[str, os.PathLike]] = None,
+ source_path: Optional[Union[str, os.PathLike]] = None,
+ ) -> Tuple[Path, str]:
+ """
+ Get the path to the flow directory and the file name of the flow dag yaml file.
+ If flow is not specified, we will assume that the source_path is the path to the flow dag yaml file.
+ If flow is specified, it can be either a path to the flow dag yaml file or a path to the flow directory.
+ If flow is a path to the flow directory, we will assume that the flow dag yaml file is named flow.dag.yaml.
+
+ :param base_path: The base path of the flow component.
+ :type base_path: Path
+ :keyword flow: The path to the flow directory or flow definition file. Defaults to None and base path of this
+ component will be used as flow directory.
+ :type flow: Optional[Union[str, Path]]
+ :keyword source_path: The source path of the flow component, should be path to the flow dag yaml file
+ if specified.
+ :type source_path: Optional[Union[str, os.PathLike]]
+ :return: The path to the flow directory and the file name of the flow dag yaml file.
+ :rtype: Tuple[Path, str]
+ """
+ flow_file_name = "flow.dag.yaml"
+
+ if flow is None and source_path is None:
+ raise cls._create_validation_error(
+ message="Either flow or source_path must be specified.",
+ no_personal_data_message="Either flow or source_path must be specified.",
+ )
+
+ if flow is None:
+ # Flow component must be created with a local yaml file, so no need to check if source_path exists
+ if isinstance(source_path, (os.PathLike, str)):
+ flow_file_name = os.path.basename(source_path)
+ return Path(base_path), flow_file_name
+
+ flow_path = Path(flow)
+ if not flow_path.is_absolute():
+ # if flow_path points to a symlink, we still use the parent of the symlink as origin code
+ flow_path = Path(base_path, flow)
+
+ if flow_path.is_dir() and (flow_path / flow_file_name).is_file():
+ return flow_path, flow_file_name
+
+ if flow_path.is_file():
+ return flow_path.parent, flow_path.name
+
+ raise cls._create_validation_error(
+ message="Flow path must be a directory containing flow.dag.yaml or a file, but got %s" % flow_path,
+ no_personal_data_message="Flow path must be a directory or a file",
+ )
+
+ # endregion
+
+ # region SchemaValidatableMixin
+ @classmethod
+ def _load_with_schema(
+ cls, data: Any, *, context: Optional[Any] = None, raise_original_exception: bool = False, **kwargs: Any
+ ) -> Any:
+ # FlowComponent should be loaded with FlowSchema or FlowRunSchema instead of FlowComponentSchema
+ context = context or {BASE_PATH_CONTEXT_KEY: Path.cwd()}
+ _schema = data.get("$schema", None)
+ if _schema == SchemaUrl.PROMPTFLOW_RUN:
+ schema = RunSchema(context=context)
+ elif _schema == SchemaUrl.PROMPTFLOW_FLOW:
+ schema = FlowSchema(context=context)
+ else:
+ raise cls._create_validation_error(
+ message="$schema must be specified correctly for loading component from flow, but got %s" % _schema,
+ no_personal_data_message="$schema must be specified for loading component from flow",
+ )
+
+ # unlike other component, we should ignore unknown fields in flow to keep init_params clean and avoid
+ # too much understanding of flow.dag.yaml & run.yaml
+ kwargs["unknown"] = EXCLUDE
+ try:
+ loaded_dict = schema.load(data, **kwargs)
+ except ValidationError as e:
+ if raise_original_exception:
+ raise e
+ msg = "Trying to load data with schema failed. Data:\n%s\nError: %s" % (
+ json.dumps(data, indent=4) if isinstance(data, dict) else data,
+ json.dumps(e.messages, indent=4),
+ )
+ raise cls._create_validation_error(
+ message=msg,
+ no_personal_data_message=str(e),
+ ) from e
+ loaded_dict.update(loaded_dict.pop(PROMPTFLOW_AZUREML_OVERRIDE_KEY, {}))
+ return loaded_dict
+
+ @classmethod
+ def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]:
+ return FlowComponentSchema(context=context)
+
+ # endregion
+
+ # region AdditionalIncludesMixin
+ def _get_origin_code_value(self) -> Union[str, os.PathLike, None]:
+ if self._code_arm_id:
+ return self._code_arm_id
+ res: Union[str, os.PathLike, None] = self.base_path
+ return res
+
+ def _fill_back_code_value(self, value: str) -> None:
+ self._code_arm_id = value
+
+ @contextlib.contextmanager
+ def _try_build_local_code(self) -> Generator:
+ # false-positive by pylint, hence disable it
+ # (https://github.com/pylint-dev/pylint/blob/main/doc/data/messages
+ # /c/contextmanager-generator-missing-cleanup/details.rst)
+ with super()._try_build_local_code() as code: # pylint:disable=contextmanager-generator-missing-cleanup
+ if not code or not code.path:
+ yield code
+ return
+
+ if not (Path(code.path) / ".promptflow" / "flow.tools.json").is_file():
+ raise self._create_validation_error(
+ message="Flow component must be created with a ./promptflow/flow.tools.json, "
+ "please run `pf flow validate` to generate it or skip it in your ignore file.",
+ no_personal_data_message="Flow component must be created with a ./promptflow/flow.tools.json, "
+ "please run `pf flow validate` to generate it or skip it in your ignore file.",
+ )
+ # TODO: should we remove additional includes from flow.dag.yaml? for now we suppose it will be removed
+ # by mldesigner compile if needed
+
+ yield code
+
+ # endregion
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/import_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/import_component.py
new file mode 100644
index 00000000..13464a06
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/import_component.py
@@ -0,0 +1,96 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from pathlib import Path
+from typing import Any, Dict, Optional, Union
+
+from marshmallow import Schema
+
+from azure.ai.ml._schema.component.import_component import ImportComponentSchema
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, COMPONENT_TYPE
+from azure.ai.ml.constants._component import NodeType
+
+from ..._schema import PathAwareSchema
+from ..._utils.utils import parse_args_description_from_docstring
+from .._util import convert_ordered_dict_to_dict
+from .component import Component
+
+
+class ImportComponent(Component):
+ """Import component version, used to define an import component.
+
+ :param name: Name of the component.
+ :type name: str
+ :param version: Version of the component.
+ :type version: str
+ :param description: Description of the component.
+ :type description: str
+ :param tags: Tag dictionary. Tags can be added, removed, and updated.
+ :type tags: dict
+ :param display_name: Display name of the component.
+ :type display_name: str
+ :param source: Input source parameters of the component.
+ :type source: dict
+ :param output: Output of the component.
+ :type output: dict
+ :param is_deterministic: Whether the command component is deterministic. Defaults to True.
+ :type is_deterministic: bool
+ :param kwargs: Additional parameters for the import component.
+ """
+
+ def __init__(
+ self,
+ *,
+ name: Optional[str] = None,
+ version: Optional[str] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ display_name: Optional[str] = None,
+ source: Optional[Dict] = None,
+ output: Optional[Dict] = None,
+ is_deterministic: bool = True,
+ **kwargs: Any,
+ ) -> None:
+ kwargs[COMPONENT_TYPE] = NodeType.IMPORT
+ # Set default base path
+ if BASE_PATH_CONTEXT_KEY not in kwargs:
+ kwargs[BASE_PATH_CONTEXT_KEY] = Path(".")
+
+ super().__init__(
+ name=name,
+ version=version,
+ description=description,
+ tags=tags,
+ display_name=display_name,
+ inputs=source,
+ outputs={"output": output} if output else None,
+ is_deterministic=is_deterministic,
+ **kwargs,
+ )
+
+ self.source = source
+ self.output = output
+
+ def _to_dict(self) -> Dict:
+ # TODO: Bug Item number: 2897665
+ res: Dict = convert_ordered_dict_to_dict( # type: ignore
+ {**self._other_parameter, **super(ImportComponent, self)._to_dict()}
+ )
+ return res
+
+ @classmethod
+ def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]:
+ return ImportComponentSchema(context=context)
+
+ @classmethod
+ def _parse_args_description_from_docstring(cls, docstring: str) -> Dict:
+ res: dict = parse_args_description_from_docstring(docstring)
+ return res
+
+ def __str__(self) -> str:
+ try:
+ toYaml: str = self._to_yaml()
+ return toYaml
+ except BaseException: # pylint: disable=W0718
+ toStr: str = super(ImportComponent, self).__str__()
+ return toStr
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/parallel_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/parallel_component.py
new file mode 100644
index 00000000..3f29b1e1
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/parallel_component.py
@@ -0,0 +1,305 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import json
+import os
+import re
+from typing import Any, Dict, List, Optional, Union, cast
+
+from marshmallow import Schema
+
+from azure.ai.ml._restclient.v2022_10_01.models import ComponentVersion
+from azure.ai.ml._schema.component.parallel_component import ParallelComponentSchema
+from azure.ai.ml.constants._common import COMPONENT_TYPE
+from azure.ai.ml.constants._component import NodeType
+from azure.ai.ml.entities._job.job_resource_configuration import JobResourceConfiguration
+from azure.ai.ml.entities._job.parallel.parallel_task import ParallelTask
+from azure.ai.ml.entities._job.parallel.parameterized_parallel import ParameterizedParallel
+from azure.ai.ml.entities._job.parallel.retry_settings import RetrySettings
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
+
+from ..._schema import PathAwareSchema
+from .._util import validate_attribute_type
+from .._validation import MutableValidationResult
+from .code import ComponentCodeMixin
+from .component import Component
+
+
+class ParallelComponent(
+ Component, ParameterizedParallel, ComponentCodeMixin
+): # pylint: disable=too-many-instance-attributes
+ """Parallel component version, used to define a parallel component.
+
+ :param name: Name of the component. Defaults to None
+ :type name: str
+ :param version: Version of the component. Defaults to None
+ :type version: str
+ :param description: Description of the component. Defaults to None
+ :type description: str
+ :param tags: Tag dictionary. Tags can be added, removed, and updated. Defaults to None
+ :type tags: dict
+ :param display_name: Display name of the component. Defaults to None
+ :type display_name: str
+ :param retry_settings: parallel component run failed retry. Defaults to None
+ :type retry_settings: BatchRetrySettings
+ :param logging_level: A string of the logging level name. Defaults to None
+ :type logging_level: str
+ :param max_concurrency_per_instance: The max parallellism that each compute instance has. Defaults to None
+ :type max_concurrency_per_instance: int
+ :param error_threshold: The number of item processing failures should be ignored. Defaults to None
+ :type error_threshold: int
+ :param mini_batch_error_threshold: The number of mini batch processing failures should be ignored. Defaults to None
+ :type mini_batch_error_threshold: int
+ :param task: The parallel task. Defaults to None
+ :type task: ParallelTask
+ :param mini_batch_size: For FileDataset input, this field is the number of files a user script can process
+ in one run() call. For TabularDataset input, this field is the approximate size of data the user script
+ can process in one run() call. Example values are 1024, 1024KB, 10MB, and 1GB.
+ (optional, default value is 10 files for FileDataset and 1MB for TabularDataset.) This value could be set
+ through PipelineParameter.
+ :type mini_batch_size: str
+ :param partition_keys: The keys used to partition dataset into mini-batches. Defaults to None
+ If specified, the data with the same key will be partitioned into the same mini-batch.
+ If both partition_keys and mini_batch_size are specified, partition_keys will take effect.
+ The input(s) must be partitioned dataset(s),
+ and the partition_keys must be a subset of the keys of every input dataset for this to work.
+ :type partition_keys: list
+ :param input_data: The input data. Defaults to None
+ :type input_data: str
+ :param resources: Compute Resource configuration for the component. Defaults to None
+ :type resources: Union[dict, ~azure.ai.ml.entities.JobResourceConfiguration]
+ :param inputs: Inputs of the component. Defaults to None
+ :type inputs: dict
+ :param outputs: Outputs of the component. Defaults to None
+ :type outputs: dict
+ :param code: promoted property from task.code
+ :type code: str
+ :param instance_count: promoted property from resources.instance_count. Defaults to None
+ :type instance_count: int
+ :param is_deterministic: Whether the parallel component is deterministic. Defaults to True
+ :type is_deterministic: bool
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if ParallelComponent cannot be successfully validated.
+ Details will be provided in the error message.
+ """
+
+ def __init__( # pylint: disable=too-many-locals
+ self,
+ *,
+ name: Optional[str] = None,
+ version: Optional[str] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict[str, Any]] = None,
+ display_name: Optional[str] = None,
+ retry_settings: Optional[RetrySettings] = None,
+ logging_level: Optional[str] = None,
+ max_concurrency_per_instance: Optional[int] = None,
+ error_threshold: Optional[int] = None,
+ mini_batch_error_threshold: Optional[int] = None,
+ task: Optional[ParallelTask] = None,
+ mini_batch_size: Optional[str] = None,
+ partition_keys: Optional[List] = None,
+ input_data: Optional[str] = None,
+ resources: Optional[JobResourceConfiguration] = None,
+ inputs: Optional[Dict] = None,
+ outputs: Optional[Dict] = None,
+ code: Optional[str] = None, # promoted property from task.code
+ instance_count: Optional[int] = None, # promoted property from resources.instance_count
+ is_deterministic: bool = True,
+ **kwargs: Any,
+ ):
+ # validate init params are valid type
+ validate_attribute_type(attrs_to_check=locals(), attr_type_map=self._attr_type_map())
+
+ kwargs[COMPONENT_TYPE] = NodeType.PARALLEL
+
+ super().__init__(
+ name=name,
+ version=version,
+ description=description,
+ tags=tags,
+ display_name=display_name,
+ inputs=inputs,
+ outputs=outputs,
+ is_deterministic=is_deterministic,
+ **kwargs,
+ )
+
+ # No validation on value passed here because in pipeline job, required code&environment maybe absent
+ # and fill in later with job defaults.
+ self.task = task
+ self.mini_batch_size: int = 0
+ self.partition_keys = partition_keys
+ self.input_data = input_data
+ self.retry_settings = retry_settings
+ self.logging_level = logging_level
+ self.max_concurrency_per_instance = max_concurrency_per_instance
+ self.error_threshold = error_threshold
+ self.mini_batch_error_threshold = mini_batch_error_threshold
+ self.resources = resources
+
+ # check mutual exclusivity of promoted properties
+ if self.resources is not None and instance_count is not None:
+ msg = "instance_count and resources are mutually exclusive"
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.COMPONENT,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ self.instance_count = instance_count
+ self.code = code
+
+ if mini_batch_size is not None:
+ # Convert str to int.
+ pattern = re.compile(r"^\d+([kKmMgG][bB])*$")
+ if not pattern.match(mini_batch_size):
+ raise ValueError(r"Parameter mini_batch_size must follow regex rule ^\d+([kKmMgG][bB])*$")
+
+ try:
+ self.mini_batch_size = int(mini_batch_size)
+ except ValueError as e:
+ unit = mini_batch_size[-2:].lower()
+ if unit == "kb":
+ self.mini_batch_size = int(mini_batch_size[0:-2]) * 1024
+ elif unit == "mb":
+ self.mini_batch_size = int(mini_batch_size[0:-2]) * 1024 * 1024
+ elif unit == "gb":
+ self.mini_batch_size = int(mini_batch_size[0:-2]) * 1024 * 1024 * 1024
+ else:
+ raise ValueError("mini_batch_size unit must be kb, mb or gb") from e
+
+ @property
+ def instance_count(self) -> Optional[int]:
+ """Return value of promoted property resources.instance_count.
+
+ :return: Value of resources.instance_count.
+ :rtype: Optional[int]
+ """
+ return self.resources.instance_count if self.resources and not isinstance(self.resources, dict) else None
+
+ @instance_count.setter
+ def instance_count(self, value: int) -> None:
+ """Set the value of the promoted property resources.instance_count.
+
+ :param value: The value to set for resources.instance_count.
+ :type value: int
+ """
+ if not value:
+ return
+ if not self.resources:
+ self.resources = JobResourceConfiguration(instance_count=value)
+ else:
+ if not isinstance(self.resources, dict):
+ self.resources.instance_count = value
+
+ @property
+ def code(self) -> Optional[str]:
+ """Return value of promoted property task.code, which is a local or
+ remote path pointing at source code.
+
+ :return: Value of task.code.
+ :rtype: Optional[str]
+ """
+ return self.task.code if self.task else None
+
+ @code.setter
+ def code(self, value: str) -> None:
+ """Set the value of the promoted property task.code.
+
+ :param value: The value to set for task.code.
+ :type value: str
+ """
+ if not value:
+ return
+ if not self.task:
+ self.task = ParallelTask(code=value)
+ else:
+ self.task.code = value
+
+ def _to_ordered_dict_for_yaml_dump(self) -> Dict:
+ """Dump the component content into a sorted yaml string.
+
+ :return: The ordered dict
+ :rtype: Dict
+ """
+
+ obj: dict = super()._to_ordered_dict_for_yaml_dump()
+ # dict dumped base on schema will transfer code to an absolute path, while we want to keep its original value
+ if self.code and isinstance(self.code, str):
+ obj["task"]["code"] = self.code
+ return obj
+
+ @property
+ def environment(self) -> Optional[str]:
+ """Return value of promoted property task.environment, indicate the
+ environment that training job will run in.
+
+ :return: Value of task.environment.
+ :rtype: Optional[Environment, str]
+ """
+ if self.task:
+ return cast(Optional[str], self.task.environment)
+ return None
+
+ @environment.setter
+ def environment(self, value: str) -> None:
+ """Set the value of the promoted property task.environment.
+
+ :param value: The value to set for task.environment.
+ :type value: str
+ """
+ if not value:
+ return
+ if not self.task:
+ self.task = ParallelTask(environment=value)
+ else:
+ self.task.environment = value
+
+ def _customized_validate(self) -> MutableValidationResult:
+ validation_result = super()._customized_validate()
+ self._append_diagnostics_and_check_if_origin_code_reliable_for_local_path_validation(validation_result)
+ return validation_result
+
+ @classmethod
+ def _attr_type_map(cls) -> dict:
+ return {
+ "retry_settings": (dict, RetrySettings),
+ "task": (dict, ParallelTask),
+ "logging_level": str,
+ "max_concurrency_per_instance": int,
+ "input_data": str,
+ "error_threshold": int,
+ "mini_batch_error_threshold": int,
+ "code": (str, os.PathLike),
+ "resources": (dict, JobResourceConfiguration),
+ }
+
+ def _to_rest_object(self) -> ComponentVersion:
+ rest_object = super()._to_rest_object()
+ # schema required list while backend accept json string
+ if self.partition_keys:
+ rest_object.properties.component_spec["partition_keys"] = json.dumps(self.partition_keys)
+ return rest_object
+
+ @classmethod
+ def _from_rest_object_to_init_params(cls, obj: ComponentVersion) -> Dict:
+ # schema required list while backend accept json string
+ # update rest obj as it will be
+ partition_keys = obj.properties.component_spec.get("partition_keys", None)
+ if partition_keys:
+ obj.properties.component_spec["partition_keys"] = json.loads(partition_keys)
+ res: dict = super()._from_rest_object_to_init_params(obj)
+ return res
+
+ @classmethod
+ def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]:
+ return ParallelComponentSchema(context=context)
+
+ def __str__(self) -> str:
+ try:
+ toYaml: str = self._to_yaml()
+ return toYaml
+ except BaseException: # pylint: disable=W0718
+ toStr: str = super(ParallelComponent, self).__str__()
+ return toStr
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/pipeline_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/pipeline_component.py
new file mode 100644
index 00000000..229b714d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/pipeline_component.py
@@ -0,0 +1,529 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import json
+import logging
+import os
+import re
+import time
+import typing
+from collections import Counter
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+from marshmallow import Schema
+
+from azure.ai.ml._restclient.v2022_10_01.models import ComponentVersion, ComponentVersionProperties
+from azure.ai.ml._schema import PathAwareSchema
+from azure.ai.ml._schema.pipeline.pipeline_component import PipelineComponentSchema
+from azure.ai.ml._utils._asset_utils import get_object_hash
+from azure.ai.ml._utils.utils import hash_dict, is_data_binding_expression
+from azure.ai.ml.constants._common import ARM_ID_PREFIX, ASSET_ARM_ID_REGEX_FORMAT, COMPONENT_TYPE
+from azure.ai.ml.constants._component import ComponentSource, NodeType
+from azure.ai.ml.constants._job.pipeline import ValidationErrorCode
+from azure.ai.ml.entities._builders import BaseNode, Command
+from azure.ai.ml.entities._builders.control_flow_node import ControlFlowNode, LoopNode
+from azure.ai.ml.entities._component.component import Component
+from azure.ai.ml.entities._inputs_outputs import GroupInput, Input
+from azure.ai.ml.entities._job.automl.automl_job import AutoMLJob
+from azure.ai.ml.entities._job.pipeline._attr_dict import has_attr_safe, try_get_non_arbitrary_attr
+from azure.ai.ml.entities._job.pipeline._pipeline_expression import PipelineExpression
+from azure.ai.ml.entities._validation import MutableValidationResult
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
+
+module_logger = logging.getLogger(__name__)
+
+
+class PipelineComponent(Component):
+ """Pipeline component, currently used to store components in an azure.ai.ml.dsl.pipeline.
+
+ :param name: Name of the component.
+ :type name: str
+ :param version: Version of the component.
+ :type version: str
+ :param description: Description of the component.
+ :type description: str
+ :param tags: Tag dictionary. Tags can be added, removed, and updated.
+ :type tags: dict
+ :param display_name: Display name of the component.
+ :type display_name: str
+ :param inputs: Component inputs.
+ :type inputs: dict
+ :param outputs: Component outputs.
+ :type outputs: dict
+ :param jobs: Id to components dict inside the pipeline definition.
+ :type jobs: Dict[str, ~azure.ai.ml.entities._builders.BaseNode]
+ :param is_deterministic: Whether the pipeline component is deterministic.
+ :type is_deterministic: bool
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if PipelineComponent cannot be successfully validated.
+ Details will be provided in the error message.
+ """
+
+ def __init__(
+ self,
+ *,
+ name: Optional[str] = None,
+ version: Optional[str] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ display_name: Optional[str] = None,
+ inputs: Optional[Dict] = None,
+ outputs: Optional[Dict] = None,
+ jobs: Optional[Dict[str, BaseNode]] = None,
+ is_deterministic: Optional[bool] = None,
+ **kwargs: Any,
+ ) -> None:
+ kwargs[COMPONENT_TYPE] = NodeType.PIPELINE
+ super().__init__(
+ name=name,
+ version=version,
+ description=description,
+ tags=tags,
+ display_name=display_name,
+ inputs=inputs,
+ outputs=outputs,
+ is_deterministic=is_deterministic, # type: ignore[arg-type]
+ **kwargs,
+ )
+ self._jobs = self._process_jobs(jobs) if jobs else {}
+ # for telemetry
+ self._job_types, self._job_sources = self._get_job_type_and_source()
+ # Private support: create pipeline component from pipeline job
+ self._source_job_id = kwargs.pop("source_job_id", None)
+ # TODO: set anonymous hash for reuse
+
+ def _process_jobs(self, jobs: Dict[str, BaseNode]) -> Dict[str, BaseNode]:
+ """Process and validate jobs.
+
+ :param jobs: A map of node name to node
+ :type jobs: Dict[str, BaseNode]
+ :return: The processed jobs
+ :rtype: Dict[str, BaseNode]
+ """
+ # Remove swept Command
+ node_names_to_skip = []
+ for node_name, job_instance in jobs.items():
+ if isinstance(job_instance, Command) and job_instance._swept is True:
+ node_names_to_skip.append(node_name)
+
+ for key in node_names_to_skip:
+ del jobs[key]
+
+ # Set path and validate node type.
+ for _, job_instance in jobs.items():
+ if isinstance(job_instance, BaseNode):
+ job_instance._set_base_path(self.base_path)
+
+ if not isinstance(job_instance, (BaseNode, AutoMLJob, ControlFlowNode)):
+ msg = f"Not supported pipeline job type: {type(job_instance)}"
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.PIPELINE,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ return jobs
+
+ def _customized_validate(self) -> MutableValidationResult:
+ """Validate pipeline component structure.
+
+ :return: The validation result
+ :rtype: MutableValidationResult
+ """
+ validation_result = super(PipelineComponent, self)._customized_validate()
+
+ # Validate inputs
+ for input_name, input_value in self.inputs.items():
+ if input_value.type is None:
+ validation_result.append_error(
+ yaml_path="inputs.{}".format(input_name),
+ message="Parameter type unknown, please add type annotation or specify input default value.",
+ error_code=ValidationErrorCode.PARAMETER_TYPE_UNKNOWN,
+ )
+
+ # Validate all nodes
+ for node_name, node in self.jobs.items():
+ if isinstance(node, BaseNode):
+ # Node inputs will be validated.
+ validation_result.merge_with(node._validate(), "jobs.{}".format(node_name))
+ if isinstance(node.component, Component):
+ # Validate binding if not remote resource.
+ validation_result.merge_with(self._validate_binding_inputs(node))
+ elif isinstance(node, AutoMLJob):
+ pass
+ elif isinstance(node, ControlFlowNode):
+ # Validate control flow node.
+ validation_result.merge_with(node._validate(), "jobs.{}".format(node_name))
+ else:
+ validation_result.append_error(
+ yaml_path="jobs.{}".format(node_name),
+ message=f"Not supported pipeline job type: {type(node)}",
+ )
+
+ return validation_result
+
+ def _validate_compute_is_set(self, *, parent_node_name: Optional[str] = None) -> MutableValidationResult:
+ """Validate compute in pipeline component.
+
+ This function will only be called from pipeline_job._validate_compute_is_set
+ when both of the pipeline_job.compute and pipeline_job.settings.default_compute is None.
+ Rules:
+ - For pipeline node: will call node._component._validate_compute_is_set to validate node compute in sub graph.
+ - For general node:
+ - If _skip_required_compute_missing_validation is True, validation will be skipped.
+ - All the rest of cases without compute will add compute not set error to validation result.
+
+ :keyword parent_node_name: The name of the parent node.
+ :type parent_node_name: Optional[str]
+ :return: The validation result
+ :rtype: MutableValidationResult
+ """
+
+ # Note: do not put this into customized validate, as we would like call
+ # this from pipeline_job._validate_compute_is_set
+ validation_result = self._create_empty_validation_result()
+ no_compute_nodes = []
+ parent_node_name = parent_node_name if parent_node_name else ""
+ for node_name, node in self.jobs.items():
+ full_node_name = f"{parent_node_name}{node_name}.jobs."
+ if node.type == NodeType.PIPELINE and isinstance(node._component, PipelineComponent):
+ validation_result.merge_with(node._component._validate_compute_is_set(parent_node_name=full_node_name))
+ continue
+ if isinstance(node, BaseNode) and node._skip_required_compute_missing_validation:
+ continue
+ if has_attr_safe(node, "compute") and node.compute is None:
+ no_compute_nodes.append(node_name)
+
+ for node_name in no_compute_nodes:
+ validation_result.append_error(
+ yaml_path=f"jobs.{parent_node_name}{node_name}.compute",
+ message="Compute not set",
+ )
+ return validation_result
+
+ def _get_input_binding_dict(self, node: BaseNode) -> Tuple[dict, dict]:
+ """Return the input binding dict for each node.
+
+ :param node: The node
+ :type node: BaseNode
+ :return: A 2-tuple of (binding_dict, optional_binding_in_expression_dict)
+ :rtype: Tuple[dict, dict]
+ """
+ # pylint: disable=too-many-nested-blocks
+ binding_inputs = node._build_inputs()
+ # Collect binding relation dict {'pipeline_input': ['node_input']}
+ binding_dict: dict = {}
+ optional_binding_in_expression_dict: dict = {}
+ for component_input_name, component_binding_input in binding_inputs.items():
+ if isinstance(component_binding_input, PipelineExpression):
+ for pipeline_input_name in component_binding_input._inputs.keys():
+ if pipeline_input_name not in self.inputs:
+ continue
+ if pipeline_input_name not in binding_dict:
+ binding_dict[pipeline_input_name] = []
+ binding_dict[pipeline_input_name].append(component_input_name)
+ if pipeline_input_name not in optional_binding_in_expression_dict:
+ optional_binding_in_expression_dict[pipeline_input_name] = []
+ optional_binding_in_expression_dict[pipeline_input_name].append(pipeline_input_name)
+ else:
+ if isinstance(component_binding_input, Input):
+ component_binding_input = component_binding_input.path
+ if is_data_binding_expression(component_binding_input, ["parent"]):
+ # data binding may have more than one PipelineInput now
+ for pipeline_input_name in PipelineExpression.parse_pipeline_inputs_from_data_binding(
+ component_binding_input
+ ):
+ if pipeline_input_name not in self.inputs:
+ continue
+ if pipeline_input_name not in binding_dict:
+ binding_dict[pipeline_input_name] = []
+ binding_dict[pipeline_input_name].append(component_input_name)
+ # for data binding expression "${{parent.inputs.pipeline_input}}", it should not be optional
+ if len(component_binding_input.replace("${{parent.inputs." + pipeline_input_name + "}}", "")):
+ if pipeline_input_name not in optional_binding_in_expression_dict:
+ optional_binding_in_expression_dict[pipeline_input_name] = []
+ optional_binding_in_expression_dict[pipeline_input_name].append(pipeline_input_name)
+ return binding_dict, optional_binding_in_expression_dict
+
+ def _validate_binding_inputs(self, node: BaseNode) -> MutableValidationResult:
+ """Validate pipeline binding inputs and return all used pipeline input names.
+
+ Mark input as optional if all binding is optional and optional not set. Raise error if pipeline input is
+ optional but link to required inputs.
+
+ :param node: The node to validate
+ :type node: BaseNode
+ :return: The validation result
+ :rtype: MutableValidationResult
+ """
+ component_definition_inputs = {}
+ # Add flattened group input into definition inputs.
+ # e.g. Add {'group_name.item': PipelineInput} for {'group_name': GroupInput}
+ for name, val in node.component.inputs.items():
+ if isinstance(val, GroupInput):
+ component_definition_inputs.update(val.flatten(group_parameter_name=name))
+ component_definition_inputs[name] = val
+ # Collect binding relation dict {'pipeline_input': ['node_input']}
+ validation_result = self._create_empty_validation_result()
+ binding_dict, optional_binding_in_expression_dict = self._get_input_binding_dict(node)
+
+ # Validate links required and optional
+ for pipeline_input_name, binding_inputs in binding_dict.items():
+ pipeline_input = self.inputs[pipeline_input_name]
+ required_bindings = []
+ for name in binding_inputs:
+ # not check optional/required for pipeline input used in pipeline expression
+ if name in optional_binding_in_expression_dict.get(pipeline_input_name, []):
+ continue
+ if name in component_definition_inputs and component_definition_inputs[name].optional is not True:
+ required_bindings.append(f"{node.name}.inputs.{name}")
+ if pipeline_input.optional is None and not required_bindings:
+ # Set input as optional if all binding is optional and optional not set.
+ pipeline_input.optional = True
+ pipeline_input._is_inferred_optional = True
+ elif pipeline_input.optional is True and required_bindings:
+ if pipeline_input._is_inferred_optional:
+ # Change optional=True to None if is inferred by us
+ pipeline_input.optional = None
+ else:
+ # Raise exception if pipeline input is optional set by user but link to required inputs.
+ validation_result.append_error(
+ yaml_path="inputs.{}".format(pipeline_input._port_name),
+ message=f"Pipeline optional Input binding to required inputs: {required_bindings}",
+ )
+ return validation_result
+
+ def _get_job_type_and_source(self) -> Tuple[Dict[str, int], Dict[str, int]]:
+ """Get job types and sources for telemetry.
+
+ :return: A 2-tuple of
+ * A map of job type to the number of occurrences
+ * A map of job source to the number of occurrences
+ :rtype: Tuple[Dict[str, int], Dict[str, int]]
+ """
+ job_types: list = []
+ job_sources = []
+ for job in self.jobs.values():
+ job_types.append(job.type)
+ if isinstance(job, BaseNode):
+ job_sources.append(job._source)
+ elif isinstance(job, AutoMLJob):
+ # Consider all automl_job has builder type for now,
+ # as it's not easy to distinguish their source(yaml/builder).
+ job_sources.append(ComponentSource.BUILDER)
+ else:
+ # Fall back to CLASS
+ job_sources.append(ComponentSource.CLASS)
+ return dict(Counter(job_types)), dict(Counter(job_sources))
+
+ @property
+ def jobs(self) -> Dict[str, BaseNode]:
+ """Return a dictionary from component variable name to component object.
+
+ :return: Dictionary mapping component variable names to component objects.
+ :rtype: Dict[str, ~azure.ai.ml.entities._builders.BaseNode]
+ """
+ return self._jobs
+
+ def _get_anonymous_hash(self) -> str:
+ """Get anonymous hash for pipeline component.
+
+ :return: The anonymous hash of the pipeline component
+ :rtype: str
+ """
+ # ideally we should always use rest object to generate hash as it's the same as
+ # what we send to server-side, but changing the hash function will break reuse of
+ # existing components except for command component (hash result is the same for
+ # command component), so we just use rest object to generate hash for pipeline component,
+ # which doesn't have reuse issue.
+ component_interface_dict = self._to_rest_object().properties.component_spec
+ # Hash local inputs in pipeline component jobs
+ for job_name, job in self.jobs.items():
+ if getattr(job, "inputs", None):
+ for input_name, input_value in job.inputs.items():
+ try:
+ if (
+ getattr(input_value, "_data", None)
+ and isinstance(input_value._data, Input)
+ and input_value.path
+ and os.path.exists(input_value.path)
+ ):
+ start_time = time.time()
+ component_interface_dict["jobs"][job_name]["inputs"][input_name]["content_hash"] = (
+ get_object_hash(input_value.path)
+ )
+ module_logger.debug(
+ "Takes %s seconds to calculate the content hash of local input %s",
+ time.time() - start_time,
+ input_value.path,
+ )
+ except ValidationException:
+ pass
+ hash_value: str = hash_dict(
+ component_interface_dict,
+ keys_to_omit=[
+ # omit name since anonymous component will have same name
+ "name",
+ # omit _source since it doesn't impact component's uniqueness
+ "_source",
+ # omit id since it will be set after component is registered
+ "id",
+ # omit version since it will be set to this hash later
+ "version",
+ ],
+ )
+ return hash_value
+
+ @classmethod
+ def _load_from_rest_pipeline_job(cls, data: Dict) -> "PipelineComponent":
+ # TODO: refine this?
+ # Set type as None here to avoid schema validation failed
+ definition_inputs = {p: {"type": None} for p in data.get("inputs", {}).keys()}
+ definition_outputs = {p: {"type": None} for p in data.get("outputs", {}).keys()}
+ return PipelineComponent(
+ display_name=data.get("display_name"),
+ description=data.get("description"),
+ inputs=definition_inputs,
+ outputs=definition_outputs,
+ jobs=data.get("jobs"),
+ _source=ComponentSource.REMOTE_WORKSPACE_JOB,
+ )
+
+ @classmethod
+ def _resolve_sub_nodes(cls, rest_jobs: Dict) -> Dict:
+ from azure.ai.ml.entities._job.pipeline._load_component import pipeline_node_factory
+
+ sub_nodes = {}
+ if rest_jobs is None:
+ return sub_nodes
+ for node_name, node in rest_jobs.items():
+ # TODO: Remove this ad-hoc fix after unified arm id format in object
+ component_id = node.get("componentId", "")
+ if isinstance(component_id, str) and re.match(ASSET_ARM_ID_REGEX_FORMAT, component_id):
+ node["componentId"] = component_id[len(ARM_ID_PREFIX) :]
+ if not LoopNode._is_loop_node_dict(node):
+ # skip resolve LoopNode first since it may reference other nodes
+ # use node factory instead of BaseNode._from_rest_object here as AutoMLJob is not a BaseNode
+ sub_nodes[node_name] = pipeline_node_factory.load_from_rest_object(obj=node)
+ for node_name, node in rest_jobs.items():
+ if LoopNode._is_loop_node_dict(node):
+ # resolve LoopNode after all other nodes are resolved
+ sub_nodes[node_name] = pipeline_node_factory.load_from_rest_object(obj=node, pipeline_jobs=sub_nodes)
+ return sub_nodes
+
+ @classmethod
+ def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]:
+ return PipelineComponentSchema(context=context)
+
+ @classmethod
+ def _get_skip_fields_in_schema_validation(cls) -> typing.List[str]:
+ # jobs validations are done in _customized_validate()
+ return ["jobs"]
+
+ @classmethod
+ def _check_ignored_keys(cls, obj: object) -> List[str]:
+ """Return ignored keys in obj as a pipeline component when its value be set.
+
+ :param obj: The object to examine
+ :type obj: object
+ :return: List of keys to ignore
+ :rtype: List[str]
+ """
+ examine_mapping = {
+ "compute": lambda val: val is not None,
+ "settings": lambda val: val is not None and any(v is not None for v in val._to_dict().values()),
+ }
+ # Avoid new attr added by use `try_get_non...` instead of `hasattr` or `getattr` directly.
+ return [k for k, has_set in examine_mapping.items() if has_set(try_get_non_arbitrary_attr(obj, k))]
+
+ def _get_telemetry_values(self, *args: Any, **kwargs: Any) -> Dict:
+ telemetry_values: dict = super()._get_telemetry_values()
+ telemetry_values.update(
+ {
+ "source": self._source,
+ "node_count": len(self.jobs),
+ "node_type": json.dumps(self._job_types),
+ "node_source": json.dumps(self._job_sources),
+ }
+ )
+ return telemetry_values
+
+ @classmethod
+ def _from_rest_object_to_init_params(cls, obj: ComponentVersion) -> Dict:
+ # Pop jobs to avoid it goes with schema load
+ jobs = obj.properties.component_spec.pop("jobs", None)
+ init_params_dict: dict = super()._from_rest_object_to_init_params(obj)
+ if jobs:
+ try:
+ init_params_dict["jobs"] = PipelineComponent._resolve_sub_nodes(jobs)
+ except Exception as e: # pylint: disable=W0718
+ # Skip parse jobs if error exists.
+ # TODO: https://msdata.visualstudio.com/Vienna/_workitems/edit/2052262
+ module_logger.debug("Parse pipeline component jobs failed with: %s", e)
+ return init_params_dict
+
+ def _to_dict(self) -> Dict:
+ return {**self._other_parameter, **super()._to_dict()}
+
+ def _build_rest_component_jobs(self) -> Dict[str, dict]:
+ """Build pipeline component jobs to rest.
+
+ :return: A map of job name to rest objects
+ :rtype: Dict[str, dict]
+ """
+ # Build the jobs to dict
+ rest_component_jobs = {}
+ for job_name, job in self.jobs.items():
+ if isinstance(job, (BaseNode, ControlFlowNode)):
+ rest_node_dict = job._to_rest_object()
+ elif isinstance(job, AutoMLJob):
+ rest_node_dict = json.loads(json.dumps(job._to_dict(inside_pipeline=True)))
+ else:
+ msg = f"Non supported job type in Pipeline jobs: {type(job)}"
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.PIPELINE,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ rest_component_jobs[job_name] = rest_node_dict
+ return rest_component_jobs
+
+ def _to_rest_object(self) -> ComponentVersion:
+ """Check ignored keys and return rest object.
+
+ :return: The component version
+ :rtype: ComponentVersion
+ """
+ ignored_keys = self._check_ignored_keys(self)
+ if ignored_keys:
+ module_logger.warning("%s ignored on pipeline component %r.", ignored_keys, self.name)
+ component = self._to_dict()
+ # add source type to component rest object
+ component["_source"] = self._source
+ component["jobs"] = self._build_rest_component_jobs()
+ component["sourceJobId"] = self._source_job_id
+ if self._intellectual_property:
+ # hack while full pass through supported is worked on for IPP fields
+ component.pop("intellectual_property")
+ component["intellectualProperty"] = self._intellectual_property._to_rest_object().serialize()
+ properties = ComponentVersionProperties(
+ component_spec=component,
+ description=self.description,
+ is_anonymous=self._is_anonymous,
+ properties=self.properties,
+ tags=self.tags,
+ )
+ result = ComponentVersion(properties=properties)
+ result.name = self.name
+ return result
+
+ def __str__(self) -> str:
+ try:
+ toYaml: str = self._to_yaml()
+ return toYaml
+ except BaseException: # pylint: disable=W0718
+ toStr: str = super(PipelineComponent, self).__str__()
+ return toStr
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/spark_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/spark_component.py
new file mode 100644
index 00000000..7da65fb6
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_component/spark_component.py
@@ -0,0 +1,211 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+import os
+from typing import Any, Dict, List, Optional, Union
+
+from marshmallow import Schema
+
+from azure.ai.ml._schema.component.spark_component import SparkComponentSchema
+from azure.ai.ml.constants._common import COMPONENT_TYPE
+from azure.ai.ml.constants._component import NodeType
+from azure.ai.ml.constants._job.job import RestSparkConfKey
+from azure.ai.ml.entities._assets import Environment
+from azure.ai.ml.entities._job.parameterized_spark import ParameterizedSpark
+
+from ..._schema import PathAwareSchema
+from .._job.spark_job_entry_mixin import SparkJobEntry, SparkJobEntryMixin
+from .._util import convert_ordered_dict_to_dict, validate_attribute_type
+from .._validation import MutableValidationResult
+from ._additional_includes import AdditionalIncludesMixin
+from .component import Component
+
+
+class SparkComponent(
+ Component, ParameterizedSpark, SparkJobEntryMixin, AdditionalIncludesMixin
+): # pylint: disable=too-many-instance-attributes
+ """Spark component version, used to define a Spark Component or Job.
+
+ :keyword code: The source code to run the job. Can be a local path or "http:", "https:", or "azureml:" url pointing
+ to a remote location. Defaults to ".", indicating the current directory.
+ :type code: Union[str, os.PathLike]
+ :keyword entry: The file or class entry point.
+ :paramtype entry: Optional[Union[dict[str, str], ~azure.ai.ml.entities.SparkJobEntry]]
+ :keyword py_files: The list of .zip, .egg or .py files to place on the PYTHONPATH for Python apps. Defaults to None.
+ :paramtype py_files: Optional[List[str]]
+ :keyword jars: The list of .JAR files to include on the driver and executor classpaths. Defaults to None.
+ :paramtype jars: Optional[List[str]]
+ :keyword files: The list of files to be placed in the working directory of each executor. Defaults to None.
+ :paramtype files: Optional[List[str]]
+ :keyword archives: The list of archives to be extracted into the working directory of each executor.
+ Defaults to None.
+ :paramtype archives: Optional[List[str]]
+ :keyword driver_cores: The number of cores to use for the driver process, only in cluster mode.
+ :paramtype driver_cores: Optional[int]
+ :keyword driver_memory: The amount of memory to use for the driver process, formatted as strings with a size unit
+ suffix ("k", "m", "g" or "t") (e.g. "512m", "2g").
+ :paramtype driver_memory: Optional[str]
+ :keyword executor_cores: The number of cores to use on each executor.
+ :paramtype executor_cores: Optional[int]
+ :keyword executor_memory: The amount of memory to use per executor process, formatted as strings with a size unit
+ suffix ("k", "m", "g" or "t") (e.g. "512m", "2g").
+ :paramtype executor_memory: Optional[str]
+ :keyword executor_instances: The initial number of executors.
+ :paramtype executor_instances: Optional[int]
+ :keyword dynamic_allocation_enabled: Whether to use dynamic resource allocation, which scales the number of
+ executors registered with this application up and down based on the workload. Defaults to False.
+ :paramtype dynamic_allocation_enabled: Optional[bool]
+ :keyword dynamic_allocation_min_executors: The lower bound for the number of executors if dynamic allocation is
+ enabled.
+ :paramtype dynamic_allocation_min_executors: Optional[int]
+ :keyword dynamic_allocation_max_executors: The upper bound for the number of executors if dynamic allocation is
+ enabled.
+ :paramtype dynamic_allocation_max_executors: Optional[int]
+ :keyword conf: A dictionary with pre-defined Spark configurations key and values. Defaults to None.
+ :paramtype conf: Optional[dict[str, str]]
+ :keyword environment: The Azure ML environment to run the job in.
+ :paramtype environment: Optional[Union[str, ~azure.ai.ml.entities.Environment]]
+ :keyword inputs: A mapping of input names to input data sources used in the job. Defaults to None.
+ :paramtype inputs: Optional[dict[str, Union[
+ ~azure.ai.ml.entities._job.pipeline._io.NodeOutput,
+ ~azure.ai.ml.Input,
+ str,
+ bool,
+ int,
+ float,
+ Enum,
+ ]]]
+ :keyword outputs: A mapping of output names to output data sources used in the job. Defaults to None.
+ :paramtype outputs: Optional[dict[str, Union[str, ~azure.ai.ml.Output]]]
+ :keyword args: The arguments for the job. Defaults to None.
+ :paramtype args: Optional[str]
+ :keyword additional_includes: A list of shared additional files to be included in the component. Defaults to None.
+ :paramtype additional_includes: Optional[List[str]]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_spark_configurations.py
+ :start-after: [START spark_component_definition]
+ :end-before: [END spark_component_definition]
+ :language: python
+ :dedent: 8
+ :caption: Creating SparkComponent.
+ """
+
+ def __init__(
+ self,
+ *,
+ code: Optional[Union[str, os.PathLike]] = ".",
+ entry: Optional[Union[Dict[str, str], SparkJobEntry]] = None,
+ py_files: Optional[List[str]] = None,
+ jars: Optional[List[str]] = None,
+ files: Optional[List[str]] = None,
+ archives: Optional[List[str]] = None,
+ driver_cores: Optional[Union[int, str]] = None,
+ driver_memory: Optional[str] = None,
+ executor_cores: Optional[Union[int, str]] = None,
+ executor_memory: Optional[str] = None,
+ executor_instances: Optional[Union[int, str]] = None,
+ dynamic_allocation_enabled: Optional[Union[bool, str]] = None,
+ dynamic_allocation_min_executors: Optional[Union[int, str]] = None,
+ dynamic_allocation_max_executors: Optional[Union[int, str]] = None,
+ conf: Optional[Dict[str, str]] = None,
+ environment: Optional[Union[str, Environment]] = None,
+ inputs: Optional[Dict] = None,
+ outputs: Optional[Dict] = None,
+ args: Optional[str] = None,
+ additional_includes: Optional[List] = None,
+ **kwargs: Any,
+ ) -> None:
+ # validate init params are valid type
+ validate_attribute_type(attrs_to_check=locals(), attr_type_map=self._attr_type_map())
+
+ kwargs[COMPONENT_TYPE] = NodeType.SPARK
+
+ super().__init__(
+ inputs=inputs,
+ outputs=outputs,
+ **kwargs,
+ )
+
+ self.code: Optional[Union[str, os.PathLike]] = code
+ self.entry = entry
+ self.py_files = py_files
+ self.jars = jars
+ self.files = files
+ self.archives = archives
+ self.conf = conf
+ self.environment = environment
+ self.args = args
+ self.additional_includes = additional_includes or []
+ # For pipeline spark job, we also allow user to set driver_cores, driver_memory and so on by setting conf.
+ # If root level fields are not set by user, we promote conf setting to root level to facilitate subsequent
+ # verification. This usually happens when we use to_component(SparkJob) or builder function spark() as a node
+ # in pipeline sdk
+ conf = conf or {}
+ self.driver_cores = driver_cores or conf.get(RestSparkConfKey.DRIVER_CORES, None)
+ self.driver_memory = driver_memory or conf.get(RestSparkConfKey.DRIVER_MEMORY, None)
+ self.executor_cores = executor_cores or conf.get(RestSparkConfKey.EXECUTOR_CORES, None)
+ self.executor_memory = executor_memory or conf.get(RestSparkConfKey.EXECUTOR_MEMORY, None)
+ self.executor_instances = executor_instances or conf.get(RestSparkConfKey.EXECUTOR_INSTANCES, None)
+ self.dynamic_allocation_enabled = dynamic_allocation_enabled or conf.get(
+ RestSparkConfKey.DYNAMIC_ALLOCATION_ENABLED, None
+ )
+ self.dynamic_allocation_min_executors = dynamic_allocation_min_executors or conf.get(
+ RestSparkConfKey.DYNAMIC_ALLOCATION_MIN_EXECUTORS, None
+ )
+ self.dynamic_allocation_max_executors = dynamic_allocation_max_executors or conf.get(
+ RestSparkConfKey.DYNAMIC_ALLOCATION_MAX_EXECUTORS, None
+ )
+
+ @classmethod
+ def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]:
+ return SparkComponentSchema(context=context)
+
+ @classmethod
+ def _attr_type_map(cls) -> dict:
+ return {
+ "environment": (str, Environment),
+ "code": (str, os.PathLike),
+ }
+
+ def _customized_validate(self) -> MutableValidationResult:
+ validation_result = super()._customized_validate()
+ self._append_diagnostics_and_check_if_origin_code_reliable_for_local_path_validation(validation_result)
+ return validation_result
+
+ def _to_dict(self) -> Dict:
+ # TODO: Bug Item number: 2897665
+ res: Dict = convert_ordered_dict_to_dict( # type: ignore
+ {**self._other_parameter, **super(SparkComponent, self)._to_dict()}
+ )
+ return res
+
+ def _to_ordered_dict_for_yaml_dump(self) -> Dict:
+ """Dump the component content into a sorted yaml string.
+
+ :return: The ordered dict
+ :rtype: Dict
+ """
+
+ obj: dict = super()._to_ordered_dict_for_yaml_dump()
+ # dict dumped base on schema will transfer code to an absolute path, while we want to keep its original value
+ if self.code and isinstance(self.code, str):
+ obj["code"] = self.code
+ return obj
+
+ def _get_environment_id(self) -> Union[str, None]:
+ # Return environment id of environment
+ # handle case when environment is defined inline
+ if isinstance(self.environment, Environment):
+ res: Optional[str] = self.environment.id
+ return res
+ return self.environment
+
+ def __str__(self) -> str:
+ try:
+ toYaml: str = self._to_yaml()
+ return toYaml
+ except BaseException: # pylint: disable=W0718
+ toStr: str = super(SparkComponent, self).__str__()
+ return toStr
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/__init__.py
new file mode 100644
index 00000000..fdf8caba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_aml_compute_node_info.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_aml_compute_node_info.py
new file mode 100644
index 00000000..823a89ca
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_aml_compute_node_info.py
@@ -0,0 +1,50 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Dict, Optional
+
+from azure.ai.ml._restclient.v2022_10_01_preview.models import AmlComputeNodeInformation
+from azure.ai.ml._schema.compute.aml_compute_node_info import AmlComputeNodeInfoSchema
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+
+
+class AmlComputeNodeInfo:
+ """Compute node information related to AmlCompute."""
+
+ def __init__(self) -> None:
+ self.node_id = None
+ self.private_ip_address = None
+ self.public_ip_address = None
+ self.port = None
+ self.node_state = None
+ self.run_id: Optional[str] = None
+
+ @property
+ def current_job_name(self) -> Optional[str]:
+ """The run ID of the current job.
+
+ :return: The run ID of the current job.
+ :rtype: str
+ """
+ return self.run_id
+
+ @current_job_name.setter
+ def current_job_name(self, value: str) -> None:
+ """Set the current job run ID.
+
+ :param value: The job run ID.
+ :type value: str
+ """
+ self.run_id = value
+
+ @classmethod
+ def _from_rest_object(cls, rest_obj: AmlComputeNodeInformation) -> "AmlComputeNodeInfo":
+ result = cls()
+ result.__dict__.update(rest_obj.as_dict())
+ return result
+
+ def _to_dict(self) -> Dict:
+ # pylint: disable=no-member
+ res: dict = AmlComputeNodeInfoSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_custom_applications.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_custom_applications.py
new file mode 100644
index 00000000..2ee65e7f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_custom_applications.py
@@ -0,0 +1,221 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+# pylint: disable=protected-access,redefined-builtin
+
+from typing import Any, Dict, List, Optional
+
+from azure.ai.ml._restclient.v2022_10_01_preview.models import CustomService, Docker
+from azure.ai.ml._restclient.v2022_10_01_preview.models import Endpoint as RestEndpoint
+from azure.ai.ml._restclient.v2022_10_01_preview.models import EnvironmentVariable as RestEnvironmentVariable
+from azure.ai.ml._restclient.v2022_10_01_preview.models import EnvironmentVariableType as RestEnvironmentVariableType
+from azure.ai.ml._restclient.v2022_10_01_preview.models import Image as RestImage
+from azure.ai.ml._restclient.v2022_10_01_preview.models import ImageType as RestImageType
+from azure.ai.ml._restclient.v2022_10_01_preview.models import Protocol
+from azure.ai.ml._restclient.v2022_10_01_preview.models import VolumeDefinition as RestVolumeDefinition
+from azure.ai.ml._restclient.v2022_10_01_preview.models import VolumeDefinitionType as RestVolumeDefinitionType
+from azure.ai.ml.constants._compute import DUPLICATE_APPLICATION_ERROR, INVALID_VALUE_ERROR, CustomApplicationDefaults
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
+
+
+class ImageSettings:
+ """Specifies an image configuration for a Custom Application.
+
+ :param reference: Image reference URL.
+ :type reference: str
+ """
+
+ def __init__(self, *, reference: str):
+ self.reference = reference
+
+ def _to_rest_object(self) -> RestImage:
+ return RestImage(type=RestImageType.DOCKER, reference=self.reference)
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestImage) -> "ImageSettings":
+ return ImageSettings(reference=obj.reference)
+
+
+class EndpointsSettings:
+ """Specifies an endpoint configuration for a Custom Application.
+
+ :param target: Application port inside the container.
+ :type target: int
+ :param published: Port over which the application is exposed from container.
+ :type published: int
+ """
+
+ def __init__(self, *, target: int, published: int):
+ EndpointsSettings._validate_endpoint_settings(target=target, published=published)
+ self.target = target
+ self.published = published
+
+ def _to_rest_object(self) -> RestEndpoint:
+ return RestEndpoint(
+ name=CustomApplicationDefaults.ENDPOINT_NAME,
+ target=self.target,
+ published=self.published,
+ protocol=Protocol.HTTP,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestEndpoint) -> "EndpointsSettings":
+ return EndpointsSettings(target=obj.target, published=obj.published)
+
+ @classmethod
+ def _validate_endpoint_settings(cls, target: int, published: int) -> None:
+ ports = {
+ CustomApplicationDefaults.TARGET_PORT: target,
+ CustomApplicationDefaults.PUBLISHED_PORT: published,
+ }
+ min_value = CustomApplicationDefaults.PORT_MIN_VALUE
+ max_value = CustomApplicationDefaults.PORT_MAX_VALUE
+
+ for port_name, port in ports.items():
+ message = INVALID_VALUE_ERROR.format(port_name, min_value, max_value)
+ if not min_value < port < max_value:
+ raise ValidationException(
+ message=message,
+ target=ErrorTarget.COMPUTE,
+ no_personal_data_message=message,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+
+class VolumeSettings:
+ """Specifies the Bind Mount settings for a Custom Application.
+
+ :param source: The host path of the mount.
+ :type source: str
+ :param target: The path in the container for the mount.
+ :type target: str
+ """
+
+ def __init__(self, *, source: str, target: str):
+ self.source = source
+ self.target = target
+
+ def _to_rest_object(self) -> RestVolumeDefinition:
+ return RestVolumeDefinition(
+ type=RestVolumeDefinitionType.BIND,
+ read_only=False,
+ source=self.source,
+ target=self.target,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestVolumeDefinition) -> "VolumeSettings":
+ return VolumeSettings(source=obj.source, target=obj.target)
+
+
+class CustomApplications:
+ """Specifies the custom service application configuration.
+
+ :param name: Name of the Custom Application.
+ :type name: str
+ :param image: Describes the Image Specifications.
+ :type image: ImageSettings
+ :param type: Type of the Custom Application.
+ :type type: Optional[str]
+ :param endpoints: Configuring the endpoints for the container.
+ :type endpoints: List[EndpointsSettings]
+ :param environment_variables: Environment Variables for the container.
+ :type environment_variables: Optional[Dict[str, str]]
+ :param bind_mounts: Configuration of the bind mounts for the container.
+ :type bind_mounts: Optional[List[VolumeSettings]]
+ """
+
+ def __init__(
+ self,
+ *,
+ name: str,
+ image: ImageSettings,
+ type: str = CustomApplicationDefaults.DOCKER,
+ endpoints: List[EndpointsSettings],
+ environment_variables: Optional[Dict] = None,
+ bind_mounts: Optional[List[VolumeSettings]] = None,
+ **kwargs: Any
+ ):
+ self.name = name
+ self.type = type
+ self.image = image
+ self.endpoints = endpoints
+ self.environment_variables = environment_variables
+ self.bind_mounts = bind_mounts
+ self.additional_properties = kwargs
+
+ def _to_rest_object(self) -> CustomService:
+ endpoints = None
+ if self.endpoints:
+ endpoints = [endpoint._to_rest_object() for endpoint in self.endpoints]
+
+ environment_variables = None
+ if self.environment_variables:
+ environment_variables = {
+ name: RestEnvironmentVariable(type=RestEnvironmentVariableType.LOCAL, value=value)
+ for name, value in self.environment_variables.items()
+ }
+
+ volumes = None
+ if self.bind_mounts:
+ volumes = [volume._to_rest_object() for volume in self.bind_mounts]
+
+ return CustomService(
+ name=self.name,
+ image=self.image._to_rest_object(),
+ endpoints=endpoints,
+ environment_variables=environment_variables,
+ volumes=volumes,
+ docker=Docker(privileged=True),
+ additional_properties={**{"type": self.type}, **self.additional_properties},
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: CustomService) -> "CustomApplications":
+ endpoints = []
+ for endpoint in obj.endpoints:
+ endpoints.append(EndpointsSettings._from_rest_object(endpoint))
+
+ environment_variables = (
+ {name: value.value for name, value in obj.environment_variables.items()}
+ if obj.environment_variables
+ else None
+ )
+
+ bind_mounts = []
+ if obj.volumes:
+ for volume in obj.volumes:
+ bind_mounts.append(VolumeSettings._from_rest_object(volume))
+
+ return CustomApplications(
+ name=obj.name,
+ image=ImageSettings._from_rest_object(obj.image),
+ endpoints=endpoints,
+ environment_variables=environment_variables,
+ bind_mounts=bind_mounts,
+ type=obj.additional_properties.pop("type", CustomApplicationDefaults.DOCKER),
+ **obj.additional_properties,
+ )
+
+
+def validate_custom_applications(custom_apps: List[CustomApplications]) -> None:
+ message = DUPLICATE_APPLICATION_ERROR
+
+ names = [app.name for app in custom_apps]
+ if len(set(names)) != len(names):
+ raise ValidationException(
+ message=message.format("application_name"),
+ target=ErrorTarget.COMPUTE,
+ no_personal_data_message=message.format("application_name"),
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ published_ports = [endpoint.published for app in custom_apps for endpoint in app.endpoints]
+
+ if len(set(published_ports)) != len(published_ports):
+ raise ValidationException(
+ message=message.format("published_port"),
+ target=ErrorTarget.COMPUTE,
+ no_personal_data_message=message.format("published_port"),
+ error_category=ErrorCategory.USER_ERROR,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_image_metadata.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_image_metadata.py
new file mode 100644
index 00000000..342e4a97
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_image_metadata.py
@@ -0,0 +1,63 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from typing import Optional
+
+
+class ImageMetadata:
+ """Metadata about the operating system image for the compute instance.
+
+ :param is_latest_os_image_version: Specifies if the compute instance is running on the latest OS image version.
+ :type is_latest_os_image_version: bool
+ :param current_image_version: Version of the current image.
+ :type current_image_version: str
+ :param latest_image_version: The latest image version.
+ :type latest_image_version: str
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_compute.py
+ :start-after: [START image_metadata]
+ :end-before: [END image_metadata]
+ :language: python
+ :dedent: 8
+ :caption: Creating a ImageMetadata object.
+ """
+
+ def __init__(
+ self,
+ *,
+ is_latest_os_image_version: Optional[bool],
+ current_image_version: Optional[str],
+ latest_image_version: Optional[str]
+ ) -> None:
+ self._is_latest_os_image_version = is_latest_os_image_version
+ self._current_image_version = current_image_version
+ self._latest_image_version = latest_image_version
+
+ @property
+ def is_latest_os_image_version(self) -> Optional[bool]:
+ """Whether or not a compute instance is running on the latest OS image version.
+
+ :return: Boolean indicating if the compute instance is running the latest OS image version.
+ :rtype: bool
+ """
+ return self._is_latest_os_image_version
+
+ @property
+ def current_image_version(self) -> Optional[str]:
+ """The current OS image version number.
+
+ :return: The current OS image version number.
+ :rtype: str
+ """
+ return self._current_image_version
+
+ @property
+ def latest_image_version(self) -> Optional[str]:
+ """The latest OS image version number.
+
+ :return: The latest OS image version number.
+ :rtype: str
+ """
+ return self._latest_image_version
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_schedule.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_schedule.py
new file mode 100644
index 00000000..3616a5cc
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_schedule.py
@@ -0,0 +1,153 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+# pylint: disable=protected-access
+from typing import Any, List, Optional, Union
+
+from azure.ai.ml._restclient.v2022_10_01_preview.models import ComputePowerAction
+from azure.ai.ml._restclient.v2022_10_01_preview.models import ComputeSchedules as RestComputeSchedules
+from azure.ai.ml._restclient.v2022_10_01_preview.models import ComputeStartStopSchedule as RestComputeStartStopSchedule
+from azure.ai.ml._restclient.v2022_10_01_preview.models import ScheduleStatus as ScheduleState
+from azure.ai.ml._restclient.v2022_10_01_preview.models import TriggerType
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+from .._schedule.trigger import CronTrigger, RecurrencePattern, RecurrenceTrigger
+
+
+class ComputeStartStopSchedule(RestTranslatableMixin):
+ """Schedules for compute start or stop scenario.
+
+ :param trigger: The trigger of the schedule.
+ :type trigger: Union[~azure.ai.ml.entities.CronTrigger, ~azure.ai.ml.entities.RecurrenceTrigger]
+ :param action: The compute power action.
+ :type action: ~azure.ai.ml.entities.ComputePowerAction
+ :param state: The state of the schedule.
+ :type state: ~azure.ai.ml.entities.ScheduleState
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: dict
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_compute.py
+ :start-after: [START compute_start_stop_schedule]
+ :end-before: [END compute_start_stop_schedule]
+ :language: python
+ :dedent: 8
+ :caption: Creating a ComputeStartStopSchedule object.
+ """
+
+ def __init__(
+ self,
+ *,
+ trigger: Optional[Union[CronTrigger, RecurrenceTrigger]] = None,
+ action: Optional[ComputePowerAction] = None,
+ state: ScheduleState = ScheduleState.ENABLED,
+ **kwargs: Any
+ ) -> None:
+ self.trigger = trigger
+ self.action = action
+ self.state = state
+ self._schedule_id: Optional[str] = kwargs.pop("schedule_id", None)
+ self._provisioning_state: Optional[str] = kwargs.pop("provisioning_state", None)
+
+ @property
+ def schedule_id(self) -> Optional[str]:
+ """The schedule ID.
+
+ :return: The schedule ID.
+ :rtype: Optional[str]
+ """
+ return self._schedule_id
+
+ @property
+ def provisioning_state(self) -> Optional[str]:
+ """The schedule provisioning state.
+
+ :return: The schedule provisioning state.
+ :rtype: Optional[str]
+ """
+ return self._provisioning_state
+
+ def _to_rest_object(self) -> RestComputeStartStopSchedule:
+ rest_object = RestComputeStartStopSchedule(
+ action=self.action,
+ status=self.state,
+ )
+
+ if isinstance(self.trigger, CronTrigger):
+ rest_object.trigger_type = TriggerType.CRON
+ rest_object.cron = self.trigger._to_rest_compute_cron_object()
+ elif isinstance(self.trigger, RecurrenceTrigger):
+ rest_object.trigger_type = TriggerType.RECURRENCE
+ rest_object.recurrence = self.trigger._to_rest_compute_recurrence_object()
+
+ return rest_object
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestComputeStartStopSchedule) -> "ComputeStartStopSchedule":
+ schedule = ComputeStartStopSchedule(
+ action=obj.action,
+ state=obj.status,
+ schedule_id=obj.id,
+ provisioning_state=obj.provisioning_status,
+ )
+
+ if obj.trigger_type == TriggerType.CRON:
+ schedule.trigger = CronTrigger(
+ start_time=obj.cron.start_time,
+ time_zone=obj.cron.time_zone,
+ expression=obj.cron.expression,
+ )
+ elif obj.trigger_type == TriggerType.RECURRENCE:
+ schedule.trigger = RecurrenceTrigger(
+ start_time=obj.recurrence.start_time,
+ time_zone=obj.recurrence.time_zone,
+ frequency=obj.recurrence.frequency,
+ interval=obj.recurrence.interval,
+ schedule=RecurrencePattern._from_rest_object(obj.recurrence.schedule),
+ )
+
+ return schedule
+
+
+class ComputeSchedules(RestTranslatableMixin):
+ """Compute schedules.
+
+ :param compute_start_stop: Compute start or stop schedules.
+ :type compute_start_stop: List[~azure.ai.ml.entities.ComputeStartStopSchedule]
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: dict
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_compute.py
+ :start-after: [START compute_start_stop_schedule]
+ :end-before: [END compute_start_stop_schedule]
+ :language: python
+ :dedent: 8
+ :caption: Creating a ComputeSchedules object.
+ """
+
+ def __init__(self, *, compute_start_stop: Optional[List[ComputeStartStopSchedule]] = None) -> None:
+ self.compute_start_stop = compute_start_stop
+
+ def _to_rest_object(self) -> RestComputeSchedules:
+ rest_schedules: List[RestComputeStartStopSchedule] = []
+ if self.compute_start_stop:
+ for schedule in self.compute_start_stop:
+ rest_schedules.append(schedule._to_rest_object())
+
+ return RestComputeSchedules(
+ compute_start_stop=rest_schedules,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestComputeSchedules) -> "ComputeSchedules":
+ schedules: List[ComputeStartStopSchedule] = []
+ if obj.compute_start_stop:
+ for schedule in obj.compute_start_stop:
+ schedules.append(ComputeStartStopSchedule._from_rest_object(schedule))
+
+ return ComputeSchedules(
+ compute_start_stop=schedules,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_setup_scripts.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_setup_scripts.py
new file mode 100644
index 00000000..d2e12fd4
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_setup_scripts.py
@@ -0,0 +1,90 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+# pylint: disable=protected-access
+
+import re
+from typing import Optional, cast
+
+from azure.ai.ml._restclient.v2022_10_01_preview.models import ScriptReference as RestScriptReference
+from azure.ai.ml._restclient.v2022_10_01_preview.models import ScriptsToExecute as RestScriptsToExecute
+from azure.ai.ml._restclient.v2022_10_01_preview.models import SetupScripts as RestSetupScripts
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+
+class ScriptReference(RestTranslatableMixin):
+ """Script reference.
+
+ :keyword path: The location of scripts in workspace storage.
+ :paramtype path: Optional[str]
+ :keyword command: Command line arguments passed to the script to run.
+ :paramtype command: Optional[str]
+ :keyword timeout_minutes: Timeout, in minutes, for the script to run.
+ :paramtype timeout_minutes: Optional[int]
+ """
+
+ def __init__(
+ self, *, path: Optional[str] = None, command: Optional[str] = None, timeout_minutes: Optional[int] = None
+ ) -> None:
+ self.path = path
+ self.command = command
+ self.timeout_minutes = timeout_minutes
+
+ def _to_rest_object(self) -> RestScriptReference:
+ return RestScriptReference(
+ script_source="workspaceStorage",
+ script_data=self.path,
+ script_arguments=self.command,
+ timeout=f"{self.timeout_minutes}m",
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestScriptReference) -> Optional["ScriptReference"]:
+ if obj is None:
+ return obj
+ timeout_match = re.match(r"(\d+)m", obj.timeout) if obj.timeout else None
+ timeout_minutes = timeout_match.group(1) if timeout_match else None
+ script_reference = ScriptReference(
+ path=obj.script_data if obj.script_data else None,
+ command=obj.script_arguments if obj.script_arguments else None,
+ timeout_minutes=cast(Optional[int], timeout_minutes),
+ )
+ return script_reference
+
+
+class SetupScripts(RestTranslatableMixin):
+ """Customized setup scripts.
+
+ :keyword startup_script: The script to be run every time the compute is started.
+ :paramtype startup_script: Optional[~azure.ai.ml.entities.ScriptReference]
+ :keyword creation_script: The script to be run only when the compute is created.
+ :paramtype creation_script: Optional[~azure.ai.ml.entities.ScriptReference]
+ """
+
+ def __init__(
+ self, *, startup_script: Optional[ScriptReference] = None, creation_script: Optional[ScriptReference] = None
+ ) -> None:
+ self.startup_script = startup_script
+ self.creation_script = creation_script
+
+ def _to_rest_object(self) -> RestScriptsToExecute:
+ scripts_to_execute = RestScriptsToExecute(
+ startup_script=self.startup_script._to_rest_object() if self.startup_script else None,
+ creation_script=self.creation_script._to_rest_object() if self.creation_script else None,
+ )
+ return RestSetupScripts(scripts=scripts_to_execute)
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestSetupScripts) -> Optional["SetupScripts"]:
+ if obj is None or obj.scripts is None:
+ return None
+ scripts = obj.scripts
+ setup_scripts = SetupScripts(
+ startup_script=ScriptReference._from_rest_object(
+ scripts.startup_script if scripts.startup_script else None
+ ),
+ creation_script=ScriptReference._from_rest_object(
+ scripts.creation_script if scripts.creation_script else None
+ ),
+ )
+ return setup_scripts
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_usage.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_usage.py
new file mode 100644
index 00000000..6702382e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_usage.py
@@ -0,0 +1,100 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from abc import abstractmethod
+from os import PathLike
+from typing import IO, Any, AnyStr, Dict, Optional, Union
+
+from azure.ai.ml._restclient.v2022_10_01_preview.models import Usage as RestUsage
+from azure.ai.ml._restclient.v2022_10_01_preview.models import UsageUnit
+from azure.ai.ml._schema.compute.usage import UsageSchema
+from azure.ai.ml._utils.utils import dump_yaml_to_file
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+
+class UsageName:
+ def __init__(self, *, value: Optional[str] = None, localized_value: Optional[str] = None) -> None:
+ """The usage name.
+
+ :param value: The name of the resource.
+ :type value: Optional[str]
+ :param localized_value: The localized name of the resource.
+ :type localized_value: Optional[str]
+ """
+ self.value = value
+ self.localized_value = localized_value
+
+
+class Usage(RestTranslatableMixin):
+ """AzureML resource usage.
+
+ :param id: The resource ID.
+ :type id: Optional[str]
+ :param aml_workspace_location: The region of the AzureML workspace specified by the ID.
+ :type aml_workspace_location: Optional[str]
+ :param type: The resource type.
+ :type type: Optional[str]
+ :param unit: The unit of measurement for usage. Accepted value is "Count".
+ :type unit: Optional[Union[str, ~azure.ai.ml.entities.UsageUnit]]
+ :param current_value: The current usage of the resource.
+ :type current_value: Optional[int]
+ :param limit: The maximum permitted usage for the resource.
+ :type limit: Optional[int]
+ :param name: The name of the usage type.
+ :type name: Optional[~azure.ai.ml.entities.UsageName]
+ """
+
+ def __init__(
+ self,
+ id: Optional[str] = None, # pylint: disable=redefined-builtin
+ aml_workspace_location: Optional[str] = None,
+ type: Optional[str] = None, # pylint: disable=redefined-builtin
+ unit: Optional[Union[str, UsageUnit]] = None, # enum
+ current_value: Optional[int] = None,
+ limit: Optional[int] = None,
+ name: Optional[UsageName] = None,
+ ) -> None:
+ self.id = id
+ self.aml_workspace_location = aml_workspace_location
+ self.type = type
+ self.unit = unit
+ self.current_value = current_value
+ self.limit = limit
+ self.name = name
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestUsage) -> "Usage":
+ result = cls()
+ result.__dict__.update(obj.as_dict())
+ return result
+
+ def dump(self, dest: Union[str, PathLike, IO[AnyStr]], **kwargs: Any) -> None:
+ """Dumps the job content into a file in YAML format.
+
+ :param dest: The local path or file stream to write the YAML content to.
+ If dest is a file path, a new file will be created.
+ If dest is an open file, the file will be written to directly.
+ :type dest: Union[PathLike, str, IO[AnyStr]]
+ :raises: FileExistsError if dest is a file path and the file already exists.
+ :raises: IOError if dest is an open file and the file is not writable.
+ """
+ path = kwargs.pop("path", None)
+ yaml_serialized = self._to_dict()
+ dump_yaml_to_file(dest, yaml_serialized, default_flow_style=False, path=path, **kwargs)
+
+ def _to_dict(self) -> Dict:
+ # pylint: disable=no-member
+ res: dict = UsageSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
+
+ @classmethod
+ @abstractmethod
+ def _load(
+ cls,
+ path: Union[PathLike, str],
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "Usage":
+ pass
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_vm_size.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_vm_size.py
new file mode 100644
index 00000000..2f0049f0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/_vm_size.py
@@ -0,0 +1,104 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from abc import abstractmethod
+from os import PathLike
+from typing import IO, Any, AnyStr, Dict, List, Optional, Union
+
+from azure.ai.ml._restclient.v2022_10_01_preview.models import VirtualMachineSize
+from azure.ai.ml._schema.compute.vm_size import VmSizeSchema
+from azure.ai.ml._utils.utils import dump_yaml_to_file
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+
+class VmSize(RestTranslatableMixin):
+ """Virtual Machine Size.
+
+ :param name: The virtual machine size name.
+ :type name: Optional[str]
+ :param family: The virtual machine size family name.
+ :type family: Optional[str]
+ :param v_cp_us: The number of vCPUs supported by the virtual machine size.
+ :type v_cp_us: Optional[int]
+ :param gpus: The number of GPUs supported by the virtual machine size.
+ :type gpus: Optional[int]
+ :param os_vhd_size_mb: The OS VHD disk size, in MB, allowed by the virtual machine size.
+ :type os_vhd_size_mb: Optional[int]
+ :param max_resource_volume_mb: The resource volume size, in MB, allowed by the virtual machine
+ size.
+ :type max_resource_volume_mb: Optional[int]
+ :param memory_gb: The amount of memory, in GB, supported by the virtual machine size.
+ :type memory_gb: Optional[float]
+ :param low_priority_capable: Specifies if the virtual machine size supports low priority VMs.
+ :type low_priority_capable: Optional[bool]
+ :param premium_io: Specifies if the virtual machine size supports premium IO.
+ :type premium_io: Optional[bool]
+ :param estimated_vm_prices: The estimated price information for using a VM.
+ :type estimated_vm_prices: ~azure.mgmt.machinelearningservices.models.EstimatedVMPrices
+ :param supported_compute_types: Specifies the compute types supported by the virtual machine
+ size.
+ :type supported_compute_types: Optional[list[str]]
+ """
+
+ def __init__(
+ self,
+ name: Optional[str] = None,
+ family: Optional[str] = None,
+ v_cp_us: Optional[int] = None,
+ gpus: Optional[int] = None,
+ os_vhd_size_mb: Optional[int] = None,
+ max_resource_volume_mb: Optional[int] = None,
+ memory_gb: Optional[float] = None,
+ low_priority_capable: Optional[bool] = None,
+ premium_io: Optional[bool] = None,
+ supported_compute_types: Optional[List[str]] = None,
+ ) -> None:
+ self.name = name
+ self.family = family
+ self.v_cp_us = v_cp_us
+ self.gpus = gpus
+ self.os_vhd_size_mb = os_vhd_size_mb
+ self.max_resource_volume_mb = max_resource_volume_mb
+ self.memory_gb = memory_gb
+ self.low_priority_capable = low_priority_capable
+ self.premium_io = premium_io
+ self.supported_compute_types = ",".join(map(str, supported_compute_types)) if supported_compute_types else None
+
+ @classmethod
+ def _from_rest_object(cls, obj: VirtualMachineSize) -> "VmSize":
+ result = cls()
+ result.__dict__.update(obj.as_dict())
+ return result
+
+ def dump(self, dest: Union[str, PathLike, IO[AnyStr]], **kwargs: Any) -> None:
+ """Dump the virtual machine size content into a file in yaml format.
+
+ :param dest: The destination to receive this virtual machine size's content.
+ Must be either a path to a local file, or an already-open file stream.
+ If dest is a file path, a new file will be created,
+ and an exception is raised if the file exists.
+ If dest is an open file, the file will be written to directly,
+ and an exception will be raised if the file is not writable.
+ :type dest: Union[PathLike, str, IO[AnyStr]]
+ """
+
+ path = kwargs.pop("path", None)
+ yaml_serialized = self._to_dict()
+ dump_yaml_to_file(dest, yaml_serialized, default_flow_style=False, path=path, **kwargs)
+
+ def _to_dict(self) -> Dict:
+ # pylint: disable=no-member
+ res: dict = VmSizeSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
+
+ @classmethod
+ @abstractmethod
+ def _load(
+ cls,
+ path: Union[PathLike, str],
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "VmSize":
+ pass
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/aml_compute.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/aml_compute.py
new file mode 100644
index 00000000..3ec7c10f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/aml_compute.py
@@ -0,0 +1,281 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access,too-many-instance-attributes
+
+from typing import Any, Dict, Optional
+
+from azure.ai.ml._restclient.v2022_12_01_preview.models import (
+ AmlCompute as AmlComputeRest,
+)
+from azure.ai.ml._restclient.v2022_12_01_preview.models import (
+ AmlComputeProperties,
+ ComputeResource,
+ ResourceId,
+ ScaleSettings,
+ UserAccountCredentials,
+)
+from azure.ai.ml._schema._utils.utils import get_subnet_str
+from azure.ai.ml._schema.compute.aml_compute import AmlComputeSchema
+from azure.ai.ml._utils.utils import (
+ camel_to_snake,
+ snake_to_pascal,
+ to_iso_duration_format,
+)
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE
+from azure.ai.ml.constants._compute import ComputeDefaults, ComputeType
+from azure.ai.ml.entities._credentials import IdentityConfiguration
+from azure.ai.ml.entities._util import load_from_dict
+
+from .compute import Compute, NetworkSettings
+
+
+class AmlComputeSshSettings:
+ """SSH settings to access a AML compute target.
+
+ :param admin_username: SSH user name.
+ :type admin_username: str
+ :param admin_password: SSH user password. Defaults to None.
+ :type admin_password: str
+ :param ssh_key_value: The SSH RSA private key. Use "ssh-keygen -t
+ rsa -b 2048" to generate your SSH key pairs. Defaults to None.
+ :type ssh_key_value: Optional[str]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_compute.py
+ :start-after: [START aml_compute_ssh_settings]
+ :end-before: [END aml_compute_ssh_settings]
+ :language: python
+ :dedent: 8
+ :caption: Configuring an AmlComputeSshSettings object.
+ """
+
+ def __init__(
+ self,
+ *,
+ admin_username: str,
+ admin_password: Optional[str] = None,
+ ssh_key_value: Optional[str] = None,
+ ) -> None:
+ self.admin_username = admin_username
+ self.admin_password = admin_password
+ self.ssh_key_value = ssh_key_value
+
+ def _to_user_account_credentials(self) -> UserAccountCredentials:
+ return UserAccountCredentials(
+ admin_user_name=self.admin_username,
+ admin_user_password=self.admin_password,
+ admin_user_ssh_public_key=self.ssh_key_value,
+ )
+
+ @classmethod
+ def _from_user_account_credentials(cls, credentials: UserAccountCredentials) -> "AmlComputeSshSettings":
+ return cls(
+ admin_username=credentials.admin_user_name,
+ admin_password=credentials.admin_user_password,
+ ssh_key_value=credentials.admin_user_ssh_public_key,
+ )
+
+
+class AmlCompute(Compute):
+ """AzureML Compute resource.
+
+ :param name: Name of the compute resource.
+ :type name: str
+ :param description: Description of the compute resource.
+ :type description: Optional[str]
+ :param size: Size of the compute. Defaults to None.
+ :type size: Optional[str]
+ :param tags: A set of tags. Contains resource tags defined as key/value pairs.
+ :type tags: Optional[dict[str, str]]
+ :param ssh_settings: SSH settings to access the AzureML compute cluster.
+ :type ssh_settings: Optional[~azure.ai.ml.entities.AmlComputeSshSettings]
+ :param network_settings: Virtual network settings for the AzureML compute cluster.
+ :type network_settings: Optional[~azure.ai.ml.entities.NetworkSettings]
+ :param idle_time_before_scale_down: Node idle time before scaling down. Defaults to None.
+ :type idle_time_before_scale_down: Optional[int]
+ :param identity: The identities that are associated with the compute cluster.
+ :type identity: Optional[~azure.ai.ml.entities.IdentityConfiguration]
+ :param tier: Virtual Machine tier. Accepted values include: "Dedicated", "LowPriority". Defaults to None.
+ :type tier: Optional[str]
+ :param min_instances: Minimum number of instances. Defaults to None.
+ :type min_instances: Optional[int]
+ :param max_instances: Maximum number of instances. Defaults to None.
+ :type max_instances: Optional[int]
+ :param ssh_public_access_enabled: State of the public SSH port. Accepted values are:
+ * False - Indicates that the public SSH port is closed on all nodes of the cluster.
+ * True - Indicates that the public SSH port is open on all nodes of the cluster.
+ * None - Indicates that the public SSH port is closed on all nodes of the cluster if VNet is defined,
+ else is open all public nodes.
+ It can be None only during cluster creation time. After creation it will be either True or False.
+ Defaults to None.
+ :type ssh_public_access_enabled: Optional[bool]
+ :param enable_node_public_ip: Enable or disable node public IP address provisioning. Accepted values are:
+ * True - Indicates that the compute nodes will have public IPs provisioned.
+ * False - Indicates that the compute nodes will have a private endpoint and no public IPs.
+ Defaults to True.
+ :type enable_node_public_ip: bool
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_compute.py
+ :start-after: [START amlcompute]
+ :end-before: [END amlcompute]
+ :language: python
+ :dedent: 8
+ :caption: Creating an AmlCompute object.
+ """
+
+ def __init__(
+ self,
+ *,
+ name: str,
+ description: Optional[str] = None,
+ size: Optional[str] = None,
+ tags: Optional[dict] = None,
+ ssh_public_access_enabled: Optional[bool] = None,
+ ssh_settings: Optional[AmlComputeSshSettings] = None,
+ min_instances: Optional[int] = None,
+ max_instances: Optional[int] = None,
+ network_settings: Optional[NetworkSettings] = None,
+ idle_time_before_scale_down: Optional[int] = None,
+ identity: Optional[IdentityConfiguration] = None,
+ tier: Optional[str] = None,
+ enable_node_public_ip: bool = True,
+ **kwargs: Any,
+ ) -> None:
+ kwargs[TYPE] = ComputeType.AMLCOMPUTE
+ super().__init__(
+ name=name,
+ description=description,
+ location=kwargs.pop("location", None),
+ tags=tags,
+ **kwargs,
+ )
+ self.size = size
+ self.min_instances = min_instances or 0
+ self.max_instances = max_instances or 1
+ self.idle_time_before_scale_down = idle_time_before_scale_down
+ self.identity = identity
+ self.ssh_public_access_enabled = ssh_public_access_enabled
+ self.ssh_settings = ssh_settings
+ self.network_settings = network_settings
+ self.tier = tier
+ self.enable_node_public_ip = enable_node_public_ip
+ self.subnet = None
+
+ @classmethod
+ def _load_from_rest(cls, rest_obj: ComputeResource) -> "AmlCompute":
+ prop = rest_obj.properties
+
+ network_settings = None
+ if prop.properties.subnet or (prop.properties.enable_node_public_ip is not None):
+ network_settings = NetworkSettings(
+ subnet=prop.properties.subnet.id if prop.properties.subnet else None,
+ )
+
+ ssh_settings = (
+ AmlComputeSshSettings._from_user_account_credentials(prop.properties.user_account_credentials)
+ if prop.properties.user_account_credentials
+ else None
+ )
+
+ response = AmlCompute(
+ name=rest_obj.name,
+ id=rest_obj.id,
+ description=prop.description,
+ location=(prop.compute_location if prop.compute_location else rest_obj.location),
+ tags=rest_obj.tags if rest_obj.tags else None,
+ provisioning_state=prop.provisioning_state,
+ provisioning_errors=(
+ prop.provisioning_errors[0].error.code
+ if (prop.provisioning_errors and len(prop.provisioning_errors) > 0)
+ else None
+ ),
+ size=prop.properties.vm_size,
+ tier=camel_to_snake(prop.properties.vm_priority),
+ min_instances=(prop.properties.scale_settings.min_node_count if prop.properties.scale_settings else None),
+ max_instances=(prop.properties.scale_settings.max_node_count if prop.properties.scale_settings else None),
+ network_settings=network_settings or None,
+ ssh_settings=ssh_settings,
+ ssh_public_access_enabled=(prop.properties.remote_login_port_public_access == "Enabled"),
+ idle_time_before_scale_down=(
+ prop.properties.scale_settings.node_idle_time_before_scale_down.total_seconds()
+ if prop.properties.scale_settings and prop.properties.scale_settings.node_idle_time_before_scale_down
+ else None
+ ),
+ identity=(
+ IdentityConfiguration._from_compute_rest_object(rest_obj.identity) if rest_obj.identity else None
+ ),
+ created_on=prop.additional_properties.get("createdOn", None),
+ enable_node_public_ip=(
+ prop.properties.enable_node_public_ip if prop.properties.enable_node_public_ip is not None else True
+ ),
+ )
+ return response
+
+ def _set_full_subnet_name(self, subscription_id: str, rg: str) -> None:
+ if self.network_settings:
+ self.subnet = get_subnet_str(
+ self.network_settings.vnet_name,
+ self.network_settings.subnet,
+ subscription_id,
+ rg,
+ )
+
+ def _to_dict(self) -> Dict:
+ res: dict = AmlComputeSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
+
+ @classmethod
+ def _load_from_dict(cls, data: Dict, context: Dict, **kwargs: Any) -> "AmlCompute":
+ loaded_data = load_from_dict(AmlComputeSchema, data, context, **kwargs)
+ return AmlCompute(**loaded_data)
+
+ def _to_rest_object(self) -> ComputeResource:
+ if self.network_settings and self.network_settings.subnet:
+ subnet_resource = ResourceId(id=self.subnet)
+ else:
+ subnet_resource = None
+
+ # Scale settings is required when creating an AzureML compute cluster.
+ scale_settings = ScaleSettings(
+ max_node_count=self.max_instances,
+ min_node_count=self.min_instances,
+ node_idle_time_before_scale_down=(
+ to_iso_duration_format(int(self.idle_time_before_scale_down))
+ if self.idle_time_before_scale_down
+ else None
+ ),
+ )
+ remote_login_public_access = "Enabled"
+ disableLocalAuth = not (self.ssh_public_access_enabled and self.ssh_settings is not None)
+ if self.ssh_public_access_enabled is not None:
+ remote_login_public_access = "Enabled" if self.ssh_public_access_enabled else "Disabled"
+
+ else:
+ remote_login_public_access = "NotSpecified"
+ aml_prop = AmlComputeProperties(
+ vm_size=self.size if self.size else ComputeDefaults.VMSIZE,
+ vm_priority=snake_to_pascal(self.tier),
+ user_account_credentials=(self.ssh_settings._to_user_account_credentials() if self.ssh_settings else None),
+ scale_settings=scale_settings,
+ subnet=subnet_resource,
+ remote_login_port_public_access=remote_login_public_access,
+ enable_node_public_ip=self.enable_node_public_ip,
+ )
+
+ aml_comp = AmlComputeRest(
+ description=self.description,
+ compute_type=self.type,
+ properties=aml_prop,
+ disable_local_auth=disableLocalAuth,
+ )
+ return ComputeResource(
+ location=self.location,
+ properties=aml_comp,
+ identity=(self.identity._to_compute_rest_object() if self.identity else None),
+ tags=self.tags,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/compute.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/compute.py
new file mode 100644
index 00000000..de18da5a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/compute.py
@@ -0,0 +1,261 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from abc import abstractmethod
+from os import PathLike
+from pathlib import Path
+from typing import IO, Any, AnyStr, Dict, Optional, Union, cast
+
+from azure.ai.ml._restclient.v2022_10_01_preview.models import ComputeResource
+from azure.ai.ml._schema.compute.compute import ComputeSchema
+from azure.ai.ml._utils.utils import dump_yaml_to_file
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY, CommonYamlFields
+from azure.ai.ml.constants._compute import ComputeType
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+from azure.ai.ml.entities._resource import Resource
+from azure.ai.ml.entities._util import find_type_in_override
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
+
+
+class Compute(Resource, RestTranslatableMixin):
+ """Base class for compute resources.
+
+ This class should not be instantiated directly. Instead, use one of its subclasses.
+
+ :param type: The compute type. Accepted values are "amlcompute", "computeinstance",
+ "virtualmachine", "kubernetes", and "synapsespark".
+ :type type: str
+ :param name: Name of the compute resource.
+ :type name: str
+ :param location: The resource location. Defaults to workspace location.
+ :type location: Optional[str]
+ :param description: Description of the resource. Defaults to None.
+ :type description: Optional[str]
+ :param resource_id: ARM resource id of the underlying compute. Defaults to None.
+ :type resource_id: Optional[str]
+ :param tags: A set of tags. Contains resource tags defined as key/value pairs.
+ :type tags: Optional[dict[str, str]]
+ """
+
+ def __init__(
+ self,
+ name: str,
+ location: Optional[str] = None,
+ description: Optional[str] = None,
+ resource_id: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ **kwargs: Any,
+ ) -> None:
+ self._type: Optional[str] = kwargs.pop("type", None)
+ if self._type:
+ self._type = self._type.lower()
+
+ self._created_on: Optional[str] = kwargs.pop("created_on", None)
+ self._provisioning_state: Optional[str] = kwargs.pop("provisioning_state", None)
+ self._provisioning_errors: Optional[str] = kwargs.pop("provisioning_errors", None)
+
+ super().__init__(name=name, description=description, **kwargs)
+ self.resource_id = resource_id
+ self.location = location
+ self.tags = tags
+
+ @property
+ def type(self) -> Optional[str]:
+ """The compute type.
+
+ :return: The compute type.
+ :rtype: Optional[str]
+ """
+ return self._type
+
+ @property
+ def created_on(self) -> Optional[str]:
+ """The compute resource creation timestamp.
+
+ :return: The compute resource creation timestamp.
+ :rtype: Optional[str]
+ """
+ return self._created_on
+
+ @property
+ def provisioning_state(self) -> Optional[str]:
+ """The compute resource's provisioning state.
+
+ :return: The compute resource's provisioning state.
+ :rtype: Optional[str]
+ """
+ return self._provisioning_state
+
+ @property
+ def provisioning_errors(self) -> Optional[str]:
+ """The compute resource provisioning errors.
+
+ :return: The compute resource provisioning errors.
+ :rtype: Optional[str]
+ """
+ return self._provisioning_errors
+
+ def _to_rest_object(self) -> ComputeResource:
+ pass
+
+ @classmethod
+ def _from_rest_object(cls, obj: ComputeResource) -> "Compute":
+ from azure.ai.ml.entities import (
+ AmlCompute,
+ ComputeInstance,
+ KubernetesCompute,
+ SynapseSparkCompute,
+ UnsupportedCompute,
+ VirtualMachineCompute,
+ )
+
+ mapping = {
+ ComputeType.AMLCOMPUTE.lower(): AmlCompute,
+ ComputeType.COMPUTEINSTANCE.lower(): ComputeInstance,
+ ComputeType.VIRTUALMACHINE.lower(): VirtualMachineCompute,
+ ComputeType.KUBERNETES.lower(): KubernetesCompute,
+ ComputeType.SYNAPSESPARK.lower(): SynapseSparkCompute,
+ }
+ compute_type = obj.properties.compute_type.lower() if obj.properties.compute_type else None
+
+ class_type = cast(
+ Optional[Union[AmlCompute, ComputeInstance, VirtualMachineCompute, KubernetesCompute, SynapseSparkCompute]],
+ mapping.get(compute_type, None), # type: ignore
+ )
+ if class_type:
+ return class_type._load_from_rest(obj)
+ _unsupported_from_rest: Compute = UnsupportedCompute._load_from_rest(obj)
+ return _unsupported_from_rest
+
+ @classmethod
+ @abstractmethod
+ def _load_from_rest(cls, rest_obj: ComputeResource) -> "Compute":
+ pass
+
+ def _set_full_subnet_name(self, subscription_id: str, rg: str) -> None:
+ pass
+
+ def dump(self, dest: Union[str, PathLike, IO[AnyStr]], **kwargs: Any) -> None:
+ """Dump the compute content into a file in yaml format.
+
+ :param dest: The destination to receive this compute's content.
+ Must be either a path to a local file, or an already-open file stream.
+ If dest is a file path, a new file will be created,
+ and an exception is raised if the file exists.
+ If dest is an open file, the file will be written to directly,
+ and an exception will be raised if the file is not writable.'.
+ :type dest: Union[PathLike, str, IO[AnyStr]]
+ """
+ path = kwargs.pop("path", None)
+ yaml_serialized = self._to_dict()
+ dump_yaml_to_file(dest, yaml_serialized, default_flow_style=False, path=path, **kwargs)
+
+ def _to_dict(self) -> Dict:
+ res: dict = ComputeSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
+
+ @classmethod
+ def _load(
+ cls,
+ data: Optional[Dict] = None,
+ yaml_path: Optional[Union[PathLike, str]] = None,
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "Compute":
+ data = data or {}
+ params_override = params_override or []
+ context = {
+ BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"),
+ PARAMS_OVERRIDE_KEY: params_override,
+ }
+ from azure.ai.ml.entities import (
+ AmlCompute,
+ ComputeInstance,
+ KubernetesCompute,
+ SynapseSparkCompute,
+ VirtualMachineCompute,
+ )
+
+ type_in_override = find_type_in_override(params_override) if params_override else None
+ compute_type = type_in_override or data.get(CommonYamlFields.TYPE, None) # override takes the priority
+ if compute_type:
+ if compute_type.lower() == ComputeType.VIRTUALMACHINE:
+ _vm_load_from_dict: Compute = VirtualMachineCompute._load_from_dict(data, context, **kwargs)
+ return _vm_load_from_dict
+ if compute_type.lower() == ComputeType.AMLCOMPUTE:
+ _aml_load_from_dict: Compute = AmlCompute._load_from_dict(data, context, **kwargs)
+ return _aml_load_from_dict
+ if compute_type.lower() == ComputeType.COMPUTEINSTANCE:
+ _compute_instance_load_from_dict: Compute = ComputeInstance._load_from_dict(data, context, **kwargs)
+ return _compute_instance_load_from_dict
+ if compute_type.lower() == ComputeType.KUBERNETES:
+ _kub_load_from_dict: Compute = KubernetesCompute._load_from_dict(data, context, **kwargs)
+ return _kub_load_from_dict
+ if compute_type.lower() == ComputeType.SYNAPSESPARK:
+ _synapse_spark_load_from_dict: Compute = SynapseSparkCompute._load_from_dict(data, context, **kwargs)
+ return _synapse_spark_load_from_dict
+ msg = f"Unknown compute type: {compute_type}"
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.COMPUTE,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ @classmethod
+ @abstractmethod
+ def _load_from_dict(cls, data: Dict, context: Dict, **kwargs: Any) -> "Compute":
+ pass
+
+
+class NetworkSettings:
+ """Network settings for a compute resource. If the workspace and VNet are in different resource groups,
+ please provide the full URI for subnet and leave vnet_name as None.
+
+ :param vnet_name: The virtual network name.
+ :type vnet_name: Optional[str]
+ :param subnet: The subnet name.
+ :type subnet: Optional[str]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_compute.py
+ :start-after: [START network_settings]
+ :end-before: [END network_settings]
+ :language: python
+ :dedent: 8
+ :caption: Configuring NetworkSettings for an AmlCompute object.
+ """
+
+ def __init__(
+ self,
+ *,
+ vnet_name: Optional[str] = None,
+ subnet: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ self.vnet_name = vnet_name
+ self.subnet = subnet
+ self._public_ip_address: str = kwargs.pop("public_ip_address", None)
+ self._private_ip_address: str = kwargs.pop("private_ip_address", None)
+
+ @property
+ def public_ip_address(self) -> str:
+ """Public IP address of the compute instance.
+
+ :return: Public IP address.
+ :rtype: str
+ """
+ return self._public_ip_address
+
+ @property
+ def private_ip_address(self) -> str:
+ """Private IP address of the compute instance.
+
+ :return: Private IP address.
+ :rtype: str
+ """
+ return self._private_ip_address
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/compute_instance.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/compute_instance.py
new file mode 100644
index 00000000..9cbb2528
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/compute_instance.py
@@ -0,0 +1,511 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access,too-many-instance-attributes
+
+import logging
+import re
+import warnings
+from typing import Any, Dict, List, Optional
+
+from azure.ai.ml._restclient.v2022_10_01_preview.models import AssignedUser
+from azure.ai.ml._restclient.v2023_08_01_preview.models import ComputeInstance as CIRest
+from azure.ai.ml._restclient.v2023_08_01_preview.models import ComputeInstanceProperties
+from azure.ai.ml._restclient.v2023_08_01_preview.models import ComputeInstanceSshSettings as CiSShSettings
+from azure.ai.ml._restclient.v2023_08_01_preview.models import (
+ ComputeResource,
+ PersonalComputeInstanceSettings,
+ ResourceId,
+)
+from azure.ai.ml._schema._utils.utils import get_subnet_str
+from azure.ai.ml._schema.compute.compute_instance import ComputeInstanceSchema
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE
+from azure.ai.ml.constants._compute import ComputeDefaults, ComputeType
+from azure.ai.ml.entities._compute.compute import Compute, NetworkSettings
+from azure.ai.ml.entities._credentials import IdentityConfiguration
+from azure.ai.ml.entities._mixins import DictMixin
+from azure.ai.ml.entities._util import load_from_dict
+
+from ._custom_applications import CustomApplications, validate_custom_applications
+from ._image_metadata import ImageMetadata
+from ._schedule import ComputeSchedules
+from ._setup_scripts import SetupScripts
+
+module_logger = logging.getLogger(__name__)
+
+
+class ComputeInstanceSshSettings:
+ """Credentials for an administrator user account to SSH into the compute node.
+
+ Can only be configured if `ssh_public_access_enabled` is set to true on compute
+ resource.
+
+ :param ssh_key_value: The SSH public key of the administrator user account.
+ :type ssh_key_value: Optional[str]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_compute.py
+ :start-after: [START compute_instance_ssh_settings]
+ :end-before: [END compute_instance_ssh_settings]
+ :language: python
+ :dedent: 8
+ :caption: Configuring ComputeInstanceSshSettings object.
+ """
+
+ def __init__(
+ self,
+ *,
+ ssh_key_value: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ self.ssh_key_value = ssh_key_value
+ self._ssh_port: str = kwargs.pop("ssh_port", None)
+ self._admin_username: str = kwargs.pop("admin_username", None)
+
+ @property
+ def admin_username(self) -> str:
+ """The name of the administrator user account which can be used to SSH into nodes.
+
+ :return: The name of the administrator user account.
+ :rtype: str
+ """
+ return self._admin_username
+
+ @property
+ def ssh_port(self) -> str:
+ """SSH port.
+
+ :return: SSH port.
+ :rtype: str
+ """
+ return self._ssh_port
+
+
+class AssignedUserConfiguration(DictMixin):
+ """Settings to create a compute resource on behalf of another user.
+
+ :param user_tenant_id: Tenant ID of the user to assign the compute target to.
+ :type user_tenant_id: str
+ :param user_object_id: Object ID of the user to assign the compute target to.
+ :type user_object_id: str
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_compute.py
+ :start-after: [START assigned_user_configuration]
+ :end-before: [END assigned_user_configuration]
+ :language: python
+ :dedent: 8
+ :caption: Creating an AssignedUserConfiguration.
+ """
+
+ def __init__(self, *, user_tenant_id: str, user_object_id: str) -> None:
+ self.user_tenant_id = user_tenant_id
+ self.user_object_id = user_object_id
+
+
+class ComputeInstance(Compute):
+ """Compute Instance resource.
+
+ :param name: Name of the compute.
+ :type name: str
+ :param location: The resource location.
+ :type location: Optional[str]
+ :param description: Description of the resource.
+ :type description: Optional[str]
+ :param size: Compute size.
+ :type size: Optional[str]
+ :param tags: A set of tags. Contains resource tags defined as key/value pairs.
+ :type tags: Optional[dict[str, str]]
+ :param create_on_behalf_of: Configuration to create resource on behalf of another user. Defaults to None.
+ :type create_on_behalf_of: Optional[~azure.ai.ml.entities.AssignedUserConfiguration]
+ :ivar state: State of the resource.
+ :type state: Optional[str]
+ :ivar last_operation: The last operation.
+ :type last_operation: Optional[Dict[str, str]]
+ :ivar applications: Applications associated with the compute instance.
+ :type applications: Optional[List[Dict[str, str]]]
+ :param network_settings: Network settings for the compute instance.
+ :type network_settings: Optional[~azure.ai.ml.entities.NetworkSettings]
+ :param ssh_settings: SSH settings for the compute instance.
+ :type ssh_settings: Optional[~azure.ai.ml.entities.ComputeInstanceSshSettings]
+ :param ssh_public_access_enabled: State of the public SSH port. Defaults to None.
+ Possible values are:
+
+ * False - Indicates that the public ssh port is closed on all nodes of the cluster.
+ * True - Indicates that the public ssh port is open on all nodes of the cluster.
+ * None -Indicates that the public ssh port is closed on all nodes of the cluster if VNet is defined,
+ else is open all public nodes. It can be default only during cluster creation time, after
+ creation it will be either True or False.
+
+ :type ssh_public_access_enabled: Optional[bool]
+ :param schedules: Compute instance schedules. Defaults to None.
+ :type schedules: Optional[~azure.ai.ml.entities.ComputeSchedules]
+ :param identity: The identities that are associated with the compute cluster.
+ :type identity: ~azure.ai.ml.entities.IdentityConfiguration
+ :param idle_time_before_shutdown: Deprecated. Use the `idle_time_before_shutdown_minutes` parameter instead.
+ Stops compute instance after user defined period of inactivity.
+ Time is defined in ISO8601 format. Minimum is 15 minutes, maximum is 3 days.
+ :type idle_time_before_shutdown: Optional[str]
+ :param idle_time_before_shutdown_minutes: Stops compute instance after a user defined period of
+ inactivity in minutes. Minimum is 15 minutes, maximum is 3 days.
+ :type idle_time_before_shutdown_minutes: Optional[int]
+ :param enable_node_public_ip: Enable or disable node public IP address provisioning. Defaults to True.
+ Possible values are:
+
+ * True - Indicates that the compute nodes will have public IPs provisioned.
+ * False - Indicates that the compute nodes will have a private endpoint and no public IPs.
+
+ :type enable_node_public_ip: Optional[bool]
+ :param setup_scripts: Details of customized scripts to execute for setting up the cluster.
+ :type setup_scripts: Optional[~azure.ai.ml.entities.SetupScripts]
+ :param custom_applications: List of custom applications and their endpoints for the compute instance.
+ :type custom_applications: Optional[List[~azure.ai.ml.entities.CustomApplications]]
+ :param enable_sso: Enable or disable single sign-on. Defaults to True.
+ :type enable_sso: bool
+ :param enable_root_access: Enable or disable root access. Defaults to True.
+ :type enable_root_access: bool
+ :param release_quota_on_stop: Release quota on stop for the compute instance. Defaults to False.
+ :type release_quota_on_stop: bool
+ :param enable_os_patching: Enable or disable OS patching for the compute instance. Defaults to False.
+ :type enable_os_patching: bool
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_compute.py
+ :start-after: [START compute_instance]
+ :end-before: [END compute_instance]
+ :language: python
+ :dedent: 8
+ :caption: Creating a ComputeInstance object.
+ """
+
+ def __init__(
+ self,
+ *,
+ name: str,
+ description: Optional[str] = None,
+ size: Optional[str] = None,
+ tags: Optional[dict] = None,
+ ssh_public_access_enabled: Optional[bool] = None,
+ create_on_behalf_of: Optional[AssignedUserConfiguration] = None,
+ network_settings: Optional[NetworkSettings] = None,
+ ssh_settings: Optional[ComputeInstanceSshSettings] = None,
+ schedules: Optional[ComputeSchedules] = None,
+ identity: Optional[IdentityConfiguration] = None,
+ idle_time_before_shutdown: Optional[str] = None,
+ idle_time_before_shutdown_minutes: Optional[int] = None,
+ setup_scripts: Optional[SetupScripts] = None,
+ enable_node_public_ip: bool = True,
+ custom_applications: Optional[List[CustomApplications]] = None,
+ enable_sso: bool = True,
+ enable_root_access: bool = True,
+ release_quota_on_stop: bool = False,
+ enable_os_patching: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ kwargs[TYPE] = ComputeType.COMPUTEINSTANCE
+ self._state: str = kwargs.pop("state", None)
+ self._last_operation: dict = kwargs.pop("last_operation", None)
+ self._os_image_metadata: ImageMetadata = kwargs.pop("os_image_metadata", None)
+ self._services: list = kwargs.pop("services", None)
+ super().__init__(
+ name=name,
+ location=kwargs.pop("location", None),
+ resource_id=kwargs.pop("resource_id", None),
+ description=description,
+ tags=tags,
+ **kwargs,
+ )
+ self.size = size
+ self.ssh_public_access_enabled = ssh_public_access_enabled
+ self.create_on_behalf_of = create_on_behalf_of
+ self.network_settings = network_settings
+ self.ssh_settings = ssh_settings
+ self.schedules = schedules
+ self.identity = identity
+ self.idle_time_before_shutdown = idle_time_before_shutdown
+ self.idle_time_before_shutdown_minutes = idle_time_before_shutdown_minutes
+ self.setup_scripts = setup_scripts
+ self.enable_node_public_ip = enable_node_public_ip
+ self.enable_sso = enable_sso
+ self.enable_root_access = enable_root_access
+ self.release_quota_on_stop = release_quota_on_stop
+ self.enable_os_patching = enable_os_patching
+ self.custom_applications = custom_applications
+ self.subnet = None
+
+ @property
+ def services(self) -> List[Dict[str, str]]:
+ """The compute instance's services.
+
+ :return: The compute instance's services.
+ :rtype: List[Dict[str, str]]
+ """
+ return self._services
+
+ @property
+ def last_operation(self) -> Dict[str, str]:
+ """The last operation.
+
+ :return: The last operation.
+ :rtype: str
+ """
+ return self._last_operation
+
+ @property
+ def state(self) -> str:
+ """The state of the compute.
+
+ :return: The state of the compute.
+ :rtype: str
+ """
+ return self._state
+
+ @property
+ def os_image_metadata(self) -> ImageMetadata:
+ """Metadata about the operating system image for this compute instance.
+
+ :return: Operating system image metadata.
+ :rtype: ~azure.ai.ml.entities.ImageMetadata
+ """
+ return self._os_image_metadata
+
+ def _to_rest_object(self) -> ComputeResource:
+ if self.network_settings and self.network_settings.subnet:
+ subnet_resource = ResourceId(id=self.subnet)
+ else:
+ subnet_resource = None
+
+ ssh_settings = None
+ if self.ssh_public_access_enabled is not None or self.ssh_settings is not None:
+ ssh_settings = CiSShSettings()
+ ssh_settings.ssh_public_access = "Enabled" if self.ssh_public_access_enabled else "Disabled"
+ ssh_settings.admin_public_key = (
+ self.ssh_settings.ssh_key_value if self.ssh_settings and self.ssh_settings.ssh_key_value else None
+ )
+
+ personal_compute_instance_settings = None
+ if self.create_on_behalf_of:
+ personal_compute_instance_settings = PersonalComputeInstanceSettings(
+ assigned_user=AssignedUser(
+ object_id=self.create_on_behalf_of.user_object_id,
+ tenant_id=self.create_on_behalf_of.user_tenant_id,
+ )
+ )
+
+ idle_time_before_shutdown = None
+ if self.idle_time_before_shutdown_minutes:
+ idle_time_before_shutdown = f"PT{self.idle_time_before_shutdown_minutes}M"
+ elif self.idle_time_before_shutdown:
+ warnings.warn(
+ """ The property 'idle_time_before_shutdown' is deprecated.
+ Please use'idle_time_before_shutdown_minutes' instead.""",
+ DeprecationWarning,
+ )
+ idle_time_before_shutdown = self.idle_time_before_shutdown
+
+ compute_instance_prop = ComputeInstanceProperties(
+ vm_size=self.size if self.size else ComputeDefaults.VMSIZE,
+ subnet=subnet_resource,
+ ssh_settings=ssh_settings,
+ personal_compute_instance_settings=personal_compute_instance_settings,
+ idle_time_before_shutdown=idle_time_before_shutdown,
+ enable_node_public_ip=self.enable_node_public_ip,
+ enable_sso=self.enable_sso,
+ enable_root_access=self.enable_root_access,
+ release_quota_on_stop=self.release_quota_on_stop,
+ enable_os_patching=self.enable_os_patching,
+ )
+ compute_instance_prop.schedules = self.schedules._to_rest_object() if self.schedules else None
+ compute_instance_prop.setup_scripts = self.setup_scripts._to_rest_object() if self.setup_scripts else None
+ if self.custom_applications:
+ validate_custom_applications(self.custom_applications)
+ compute_instance_prop.custom_services = []
+ for app in self.custom_applications:
+ compute_instance_prop.custom_services.append(app._to_rest_object())
+ compute_instance = CIRest(
+ description=self.description,
+ compute_type=self.type,
+ properties=compute_instance_prop,
+ )
+ return ComputeResource(
+ location=self.location,
+ properties=compute_instance,
+ identity=(self.identity._to_compute_rest_object() if self.identity else None),
+ tags=self.tags,
+ )
+
+ def _to_dict(self) -> Dict:
+ res: dict = ComputeInstanceSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
+
+ def _set_full_subnet_name(self, subscription_id: str, rg: str) -> None:
+ if self.network_settings and (self.network_settings.vnet_name or self.network_settings.subnet):
+ self.subnet = get_subnet_str(
+ self.network_settings.vnet_name,
+ self.network_settings.subnet,
+ subscription_id,
+ rg,
+ )
+
+ @classmethod
+ def _load_from_rest(cls, rest_obj: ComputeResource) -> "ComputeInstance":
+ prop = rest_obj.properties
+ create_on_behalf_of = None
+ if prop.properties and prop.properties.personal_compute_instance_settings:
+ create_on_behalf_of = AssignedUserConfiguration(
+ user_tenant_id=prop.properties.personal_compute_instance_settings.assigned_user.tenant_id,
+ user_object_id=prop.properties.personal_compute_instance_settings.assigned_user.object_id,
+ )
+ ssh_settings = None
+ if prop.properties and prop.properties.ssh_settings:
+ ssh_settings = ComputeInstanceSshSettings(
+ ssh_key_value=prop.properties.ssh_settings.admin_public_key,
+ ssh_port=prop.properties.ssh_settings.ssh_port,
+ admin_username=prop.properties.ssh_settings.admin_user_name,
+ )
+
+ network_settings = None
+ if prop.properties and (
+ prop.properties.subnet
+ or (
+ prop.properties.connectivity_endpoints
+ and (
+ prop.properties.connectivity_endpoints.private_ip_address
+ or prop.properties.connectivity_endpoints.public_ip_address
+ )
+ )
+ ):
+ network_settings = NetworkSettings(
+ subnet=prop.properties.subnet.id if prop.properties.subnet else None,
+ public_ip_address=(
+ prop.properties.connectivity_endpoints.public_ip_address
+ if prop.properties.connectivity_endpoints
+ and prop.properties.connectivity_endpoints.public_ip_address
+ else None
+ ),
+ private_ip_address=(
+ prop.properties.connectivity_endpoints.private_ip_address
+ if prop.properties.connectivity_endpoints
+ and prop.properties.connectivity_endpoints.private_ip_address
+ else None
+ ),
+ )
+ os_image_metadata = None
+ if prop.properties and prop.properties.os_image_metadata:
+ metadata = prop.properties.os_image_metadata
+ os_image_metadata = ImageMetadata(
+ is_latest_os_image_version=(
+ metadata.is_latest_os_image_version if metadata.is_latest_os_image_version is not None else None
+ ),
+ current_image_version=metadata.current_image_version if metadata.current_image_version else None,
+ latest_image_version=metadata.latest_image_version if metadata.latest_image_version else None,
+ )
+
+ idle_time_before_shutdown = None
+ idle_time_before_shutdown_minutes = None
+ idle_time_before_shutdown_pattern = r"PT([0-9]+)M"
+ if prop.properties and prop.properties.idle_time_before_shutdown:
+ idle_time_before_shutdown = prop.properties.idle_time_before_shutdown
+ idle_time_match = re.match(
+ pattern=idle_time_before_shutdown_pattern,
+ string=idle_time_before_shutdown,
+ )
+ idle_time_before_shutdown_minutes = int(idle_time_match[1]) if idle_time_match else None
+ custom_applications = None
+ if prop.properties and prop.properties.custom_services:
+ custom_applications = []
+ for app in prop.properties.custom_services:
+ custom_applications.append(CustomApplications._from_rest_object(app))
+ response = ComputeInstance(
+ name=rest_obj.name,
+ id=rest_obj.id,
+ description=prop.description,
+ location=rest_obj.location,
+ resource_id=prop.resource_id,
+ tags=rest_obj.tags if rest_obj.tags else None,
+ provisioning_state=prop.provisioning_state,
+ provisioning_errors=(
+ prop.provisioning_errors[0].error.code
+ if (prop.provisioning_errors and len(prop.provisioning_errors) > 0)
+ else None
+ ),
+ size=prop.properties.vm_size if prop.properties else None,
+ state=prop.properties.state if prop.properties else None,
+ last_operation=(
+ prop.properties.last_operation.as_dict() if prop.properties and prop.properties.last_operation else None
+ ),
+ services=(
+ [app.as_dict() for app in prop.properties.applications]
+ if prop.properties and prop.properties.applications
+ else None
+ ),
+ created_on=(
+ rest_obj.properties.created_on.strftime("%Y-%m-%dT%H:%M:%S.%f%z")
+ if rest_obj.properties and rest_obj.properties.created_on is not None
+ else None
+ ),
+ create_on_behalf_of=create_on_behalf_of,
+ network_settings=network_settings,
+ ssh_settings=ssh_settings,
+ ssh_public_access_enabled=(
+ _ssh_public_access_to_bool(prop.properties.ssh_settings.ssh_public_access)
+ if (prop.properties and prop.properties.ssh_settings and prop.properties.ssh_settings.ssh_public_access)
+ else None
+ ),
+ schedules=(
+ ComputeSchedules._from_rest_object(prop.properties.schedules)
+ if prop.properties and prop.properties.schedules and prop.properties.schedules.compute_start_stop
+ else None
+ ),
+ identity=IdentityConfiguration._from_compute_rest_object(rest_obj.identity) if rest_obj.identity else None,
+ setup_scripts=(
+ SetupScripts._from_rest_object(prop.properties.setup_scripts)
+ if prop.properties and prop.properties.setup_scripts
+ else None
+ ),
+ idle_time_before_shutdown=idle_time_before_shutdown,
+ idle_time_before_shutdown_minutes=idle_time_before_shutdown_minutes,
+ os_image_metadata=os_image_metadata,
+ enable_node_public_ip=(
+ prop.properties.enable_node_public_ip
+ if (prop.properties and prop.properties.enable_node_public_ip is not None)
+ else True
+ ),
+ custom_applications=custom_applications,
+ enable_sso=(
+ prop.properties.enable_sso if (prop.properties and prop.properties.enable_sso is not None) else True
+ ),
+ enable_root_access=(
+ prop.properties.enable_root_access
+ if (prop.properties and prop.properties.enable_root_access is not None)
+ else True
+ ),
+ release_quota_on_stop=(
+ prop.properties.release_quota_on_stop
+ if (prop.properties and prop.properties.release_quota_on_stop is not None)
+ else False
+ ),
+ enable_os_patching=(
+ prop.properties.enable_os_patching
+ if (prop.properties and prop.properties.enable_os_patching is not None)
+ else False
+ ),
+ )
+ return response
+
+ @classmethod
+ def _load_from_dict(cls, data: Dict, context: Dict, **kwargs: Any) -> "ComputeInstance":
+ loaded_data = load_from_dict(ComputeInstanceSchema, data, context, **kwargs)
+ return ComputeInstance(**loaded_data)
+
+
+def _ssh_public_access_to_bool(value: str) -> Optional[bool]:
+ if value.lower() == "disabled":
+ return False
+ if value.lower() == "enabled":
+ return True
+ return None
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/kubernetes_compute.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/kubernetes_compute.py
new file mode 100644
index 00000000..bc8c2c28
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/kubernetes_compute.py
@@ -0,0 +1,105 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from typing import Any, Dict, Optional
+
+from azure.ai.ml._restclient.v2022_10_01_preview.models import ComputeResource, Kubernetes, KubernetesProperties
+from azure.ai.ml._schema.compute.kubernetes_compute import KubernetesComputeSchema
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE
+from azure.ai.ml.constants._compute import ComputeType
+from azure.ai.ml.entities._compute.compute import Compute
+from azure.ai.ml.entities._credentials import IdentityConfiguration
+from azure.ai.ml.entities._util import load_from_dict
+
+
+class KubernetesCompute(Compute):
+ """Kubernetes Compute resource.
+
+ :param namespace: The namespace of the KubernetesCompute. Defaults to "default".
+ :type namespace: Optional[str]
+ :param properties: The properties of the Kubernetes compute resource.
+ :type properties: Optional[Dict]
+ :param identity: The identities that are associated with the compute cluster.
+ :type identity: ~azure.ai.ml.entities.IdentityConfiguration
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_compute.py
+ :start-after: [START kubernetes_compute]
+ :end-before: [END kubernetes_compute]
+ :language: python
+ :dedent: 8
+ :caption: Creating a KubernetesCompute object.
+ """
+
+ def __init__(
+ self,
+ *,
+ namespace: str = "default",
+ properties: Optional[Dict[str, Any]] = None,
+ identity: Optional[IdentityConfiguration] = None,
+ **kwargs: Any,
+ ) -> None:
+ kwargs[TYPE] = ComputeType.KUBERNETES
+ super().__init__(**kwargs)
+ self.namespace = namespace
+ self.properties = properties if properties else {}
+ if "properties" in self.properties:
+ self.properties["properties"]["namespace"] = namespace
+ self.identity = identity
+
+ @classmethod
+ def _load_from_rest(cls, rest_obj: ComputeResource) -> "KubernetesCompute":
+ prop = rest_obj.properties
+ return KubernetesCompute(
+ name=rest_obj.name,
+ id=rest_obj.id,
+ description=prop.description,
+ location=rest_obj.location,
+ resource_id=prop.resource_id,
+ tags=rest_obj.tags if rest_obj.tags else None,
+ provisioning_state=prop.provisioning_state,
+ provisioning_errors=(
+ prop.provisioning_errors[0].error.code
+ if (prop.provisioning_errors and len(prop.provisioning_errors) > 0)
+ else None
+ ),
+ created_on=prop.additional_properties.get("createdOn", None),
+ properties=prop.properties.as_dict() if prop.properties else None,
+ namespace=prop.properties.namespace,
+ identity=IdentityConfiguration._from_compute_rest_object(rest_obj.identity) if rest_obj.identity else None,
+ )
+
+ def _to_dict(self) -> Dict:
+ res: dict = KubernetesComputeSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
+
+ @classmethod
+ def _load_from_dict(cls, data: Dict, context: Dict, **kwargs: Any) -> "KubernetesCompute":
+ if not data:
+ data = {"namespace": "default"}
+ if "namespace" not in data:
+ data["namespace"] = "default"
+
+ loaded_data = load_from_dict(KubernetesComputeSchema, data, context, **kwargs)
+ return KubernetesCompute(**loaded_data)
+
+ def _to_rest_object(self) -> ComputeResource:
+ kubernetes_prop = KubernetesProperties.from_dict(self.properties)
+ kubernetes_prop.namespace = self.namespace
+ kubernetes_comp = Kubernetes(
+ resource_id=self.resource_id,
+ compute_location=self.location,
+ description=self.description,
+ properties=kubernetes_prop,
+ )
+ return ComputeResource(
+ location=self.location,
+ properties=kubernetes_comp,
+ name=self.name,
+ identity=(self.identity._to_compute_rest_object() if self.identity else None),
+ tags=self.tags,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/synapsespark_compute.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/synapsespark_compute.py
new file mode 100644
index 00000000..99b366cb
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/synapsespark_compute.py
@@ -0,0 +1,234 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from typing import Any, Dict, Optional
+
+from azure.ai.ml._restclient.v2022_10_01_preview.models import (
+ AutoPauseProperties,
+ AutoScaleProperties,
+ ComputeResource,
+ SynapseSpark,
+)
+from azure.ai.ml._schema.compute.synapsespark_compute import SynapseSparkComputeSchema
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE
+from azure.ai.ml.constants._compute import ComputeType
+from azure.ai.ml.entities import Compute
+from azure.ai.ml.entities._credentials import IdentityConfiguration
+from azure.ai.ml.entities._util import load_from_dict
+
+
+class AutoScaleSettings:
+ """Auto-scale settings for Synapse Spark compute.
+
+ :keyword min_node_count: The minimum compute node count.
+ :paramtype min_node_count: Optional[int]
+ :keyword max_node_count: The maximum compute node count.
+ :paramtype max_node_count: Optional[int]
+ :keyword enabled: Specifies if auto-scale is enabled.
+ :paramtype enabled: Optional[bool]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_spark_configurations.py
+ :start-after: [START synapse_spark_compute_configuration]
+ :end-before: [END synapse_spark_compute_configuration]
+ :language: python
+ :dedent: 8
+ :caption: Configuring AutoScaleSettings on SynapseSparkCompute.
+ """
+
+ def __init__(
+ self,
+ *,
+ min_node_count: Optional[int] = None,
+ max_node_count: Optional[int] = None,
+ enabled: Optional[bool] = None,
+ ) -> None:
+ self.min_node_count = min_node_count
+ self.max_node_count = max_node_count
+ self.auto_scale_enabled = enabled
+
+ def _to_auto_scale_settings(self) -> AutoScaleProperties:
+ return AutoScaleProperties(
+ min_node_count=self.min_node_count,
+ max_node_count=self.max_node_count,
+ auto_scale_enabled=self.auto_scale_enabled,
+ )
+
+ @classmethod
+ def _from_auto_scale_settings(cls, autoscaleprops: AutoScaleProperties) -> "AutoScaleSettings":
+ return cls(
+ min_node_count=autoscaleprops.min_node_count,
+ max_node_count=autoscaleprops.max_node_count,
+ enabled=autoscaleprops.enabled,
+ )
+
+
+class AutoPauseSettings:
+ """Auto pause settings for Synapse Spark compute.
+
+ :keyword delay_in_minutes: The time delay in minutes before pausing cluster.
+ :paramtype delay_in_minutes: Optional[int]
+ :keyword enabled: Specifies if auto-pause is enabled.
+ :paramtype enabled: Optional[bool]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_spark_configurations.py
+ :start-after: [START synapse_spark_compute_configuration]
+ :end-before: [END synapse_spark_compute_configuration]
+ :language: python
+ :dedent: 8
+ :caption: Configuring AutoPauseSettings on SynapseSparkCompute.
+ """
+
+ def __init__(self, *, delay_in_minutes: Optional[int] = None, enabled: Optional[bool] = None) -> None:
+ self.delay_in_minutes = delay_in_minutes
+ self.auto_pause_enabled = enabled
+
+ def _to_auto_pause_settings(self) -> AutoPauseProperties:
+ return AutoPauseProperties(
+ delay_in_minutes=self.delay_in_minutes,
+ auto_pause_enabled=self.auto_pause_enabled,
+ )
+
+ @classmethod
+ def _from_auto_pause_settings(cls, autopauseprops: AutoPauseProperties) -> "AutoPauseSettings":
+ return cls(
+ delay_in_minutes=autopauseprops.delay_in_minutes,
+ enabled=autopauseprops.enabled,
+ )
+
+
+@experimental
+class SynapseSparkCompute(Compute):
+ """SynapseSpark Compute resource.
+
+ :keyword name: The name of the compute.
+ :paramtype name: str
+ :keyword description: The description of the resource. Defaults to None.
+ :paramtype description: Optional[str]
+ :keyword tags: The set of resource tags defined as key/value pairs. Defaults to None.
+ :paramtype tags: Optional[[dict[str, str]]
+ :keyword node_count: The number of nodes in the compute.
+ :paramtype node_count: Optional[int]
+ :keyword node_family: The node family of the compute.
+ :paramtype node_family: Optional[str]
+ :keyword node_size: The size of the node.
+ :paramtype node_size: Optional[str]
+ :keyword spark_version: The version of Spark to use.
+ :paramtype spark_version: Optional[str]
+ :keyword identity: The configuration of identities that are associated with the compute cluster.
+ :paramtype identity: Optional[~azure.ai.ml.entities.IdentityConfiguration]
+ :keyword scale_settings: The scale settings for the compute.
+ :paramtype scale_settings: Optional[~azure.ai.ml.entities.AutoScaleSettings]
+ :keyword auto_pause_settings: The auto pause settings for the compute.
+ :paramtype auto_pause_settings: Optional[~azure.ai.ml.entities.AutoPauseSettings]
+ :keyword kwargs: Additional keyword arguments passed to the parent class.
+ :paramtype kwargs: Optional[dict]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_spark_configurations.py
+ :start-after: [START synapse_spark_compute_configuration]
+ :end-before: [END synapse_spark_compute_configuration]
+ :language: python
+ :dedent: 8
+ :caption: Creating Synapse Spark compute.
+ """
+
+ def __init__(
+ self,
+ *,
+ name: str,
+ description: Optional[str] = None,
+ tags: Optional[Dict[str, str]] = None,
+ node_count: Optional[int] = None,
+ node_family: Optional[str] = None,
+ node_size: Optional[str] = None,
+ spark_version: Optional[str] = None,
+ identity: Optional[IdentityConfiguration] = None,
+ scale_settings: Optional[AutoScaleSettings] = None,
+ auto_pause_settings: Optional[AutoPauseSettings] = None,
+ **kwargs: Any,
+ ) -> None:
+ kwargs[TYPE] = ComputeType.SYNAPSESPARK
+ super().__init__(name=name, description=description, location=kwargs.pop("location", None), tags=tags, **kwargs)
+ self.identity = identity
+ self.node_count = node_count
+ self.node_family = node_family
+ self.node_size = node_size
+ self.spark_version = spark_version
+ self.scale_settings = scale_settings
+ self.auto_pause_settings = auto_pause_settings
+
+ @classmethod
+ def _load_from_rest(cls, rest_obj: ComputeResource) -> "SynapseSparkCompute":
+ prop = rest_obj.properties
+ scale_settings = (
+ # pylint: disable=protected-access
+ AutoScaleSettings._from_auto_scale_settings(prop.properties.auto_scale_properties)
+ if prop.properties.auto_scale_properties
+ else None
+ )
+
+ auto_pause_settings = (
+ # pylint: disable=protected-access
+ AutoPauseSettings._from_auto_pause_settings(prop.properties.auto_pause_properties)
+ if prop.properties.auto_pause_properties
+ else None
+ )
+
+ return SynapseSparkCompute(
+ name=rest_obj.name,
+ id=rest_obj.id,
+ description=prop.description,
+ location=rest_obj.location,
+ resource_id=prop.resource_id,
+ tags=rest_obj.tags if rest_obj.tags else None,
+ created_on=prop.created_on if prop.properties else None,
+ node_count=prop.properties.node_count if prop.properties else None,
+ node_family=prop.properties.node_size_family if prop.properties else None,
+ node_size=prop.properties.node_size if prop.properties else None,
+ spark_version=prop.properties.spark_version if prop.properties else None,
+ # pylint: disable=protected-access
+ identity=IdentityConfiguration._from_compute_rest_object(rest_obj.identity) if rest_obj.identity else None,
+ scale_settings=scale_settings,
+ auto_pause_settings=auto_pause_settings,
+ provisioning_state=prop.provisioning_state,
+ provisioning_errors=(
+ prop.provisioning_errors[0].error.code
+ if (prop.provisioning_errors and len(prop.provisioning_errors) > 0)
+ else None
+ ),
+ )
+
+ def _to_dict(self) -> Dict:
+ res: dict = SynapseSparkComputeSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
+
+ @classmethod
+ def _load_from_dict(cls, data: Dict, context: Dict, **kwargs: Any) -> "SynapseSparkCompute":
+ loaded_data = load_from_dict(SynapseSparkComputeSchema, data, context, **kwargs)
+ return SynapseSparkCompute(**loaded_data)
+
+ def _to_rest_object(self) -> ComputeResource:
+ synapsespark_comp = SynapseSpark(
+ name=self.name,
+ compute_type=self.type,
+ resource_id=self.resource_id,
+ description=self.description,
+ )
+ return ComputeResource(
+ location=self.location,
+ properties=synapsespark_comp,
+ name=self.name,
+ identity=(
+ # pylint: disable=protected-access
+ self.identity._to_compute_rest_object()
+ if self.identity
+ else None
+ ),
+ tags=self.tags,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/unsupported_compute.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/unsupported_compute.py
new file mode 100644
index 00000000..258fbf6b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/unsupported_compute.py
@@ -0,0 +1,62 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from typing import Any, Dict
+
+from azure.ai.ml._restclient.v2022_10_01_preview.models import ComputeResource
+from azure.ai.ml.constants._common import TYPE
+from azure.ai.ml.entities._compute.compute import Compute
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
+
+
+class UnsupportedCompute(Compute):
+ """Unsupported compute resource.
+
+ Only used for displaying compute properties for resources not fully supported in the SDK.
+ """
+
+ def __init__(
+ self,
+ **kwargs: Any,
+ ) -> None:
+ kwargs[TYPE] = "*** Unsupported Compute Type ***"
+ super().__init__(**kwargs)
+
+ @classmethod
+ def _load_from_rest(cls, rest_obj: ComputeResource) -> "UnsupportedCompute":
+ prop = rest_obj.properties
+ if hasattr(rest_obj, "tags"):
+ # TODO(2294131): remove this when DataFactory object has no tags got fixed
+ tags = rest_obj.tags
+ else:
+ tags = None
+ response = UnsupportedCompute(
+ name=rest_obj.name,
+ id=rest_obj.id,
+ description=prop.description,
+ location=rest_obj.location,
+ resource_id=prop.resource_id,
+ tags=tags,
+ provisioning_state=prop.provisioning_state,
+ created_on=prop.additional_properties.get("createdOn", None),
+ )
+ return response
+
+ @classmethod
+ def _load_from_dict(cls, data: Dict, context: Dict, **kwargs: Any) -> "UnsupportedCompute":
+ msg = "Cannot create unsupported compute type."
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.COMPUTE,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ def _to_rest_object(self) -> ComputeResource:
+ msg = "Cannot create unsupported compute type."
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.COMPUTE,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/virtual_machine_compute.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/virtual_machine_compute.py
new file mode 100644
index 00000000..90c3ec63
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/virtual_machine_compute.py
@@ -0,0 +1,172 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+from azure.ai.ml._restclient.v2022_10_01_preview.models import ComputeResource
+from azure.ai.ml._restclient.v2022_10_01_preview.models import VirtualMachine as VMResource
+from azure.ai.ml._restclient.v2022_10_01_preview.models import (
+ VirtualMachineSchemaProperties,
+ VirtualMachineSshCredentials,
+)
+from azure.ai.ml._schema.compute.virtual_machine_compute import VirtualMachineComputeSchema
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE, DefaultOpenEncoding
+from azure.ai.ml.constants._compute import ComputeType
+from azure.ai.ml.entities._compute.compute import Compute
+from azure.ai.ml.entities._util import load_from_dict
+
+
+class VirtualMachineSshSettings:
+ """SSH settings for a virtual machine.
+
+ :param admin_username: The admin user name. Defaults to None.
+ :type admin_username: str
+ :param admin_password: The admin user password. Defaults to None.
+ Required if `ssh_private_key_file` is not specified.
+ :type admin_password: Optional[str]
+ :param ssh_port: The ssh port number. Default is 22.
+ :type ssh_port: int
+ :param ssh_private_key_file: Path to the file containing the SSH rsa private key.
+ Use "ssh-keygen -t rsa -b 2048" to generate your SSH key pairs.
+ Required if admin_password is not specified.
+ :type ssh_private_key_file: Optional[str]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_compute.py
+ :start-after: [START vm_ssh_settings]
+ :end-before: [END vm_ssh_settings]
+ :language: python
+ :dedent: 8
+ :caption: Configuring a VirtualMachineSshSettings object.
+ """
+
+ def __init__(
+ self,
+ *,
+ admin_username: Optional[str],
+ admin_password: Optional[str] = None,
+ ssh_port: Optional[int] = 22,
+ ssh_private_key_file: Optional[str] = None,
+ ) -> None:
+ self.admin_username = admin_username
+ self.admin_password = admin_password
+ self.ssh_port = ssh_port
+ self.ssh_private_key_file = ssh_private_key_file
+
+
+class VirtualMachineCompute(Compute):
+ """Virtual Machine Compute resource.
+
+ :param name: Name of the compute resource.
+ :type name: str
+ :param description: Description of the resource. Defaults to None.
+ :type description: Optional[str]
+ :param resource_id: ARM resource ID of the underlying compute resource.
+ :type resource_id: str
+ :param tags: A set of tags. Contains resource tags defined as key/value pairs.
+ :type tags: Optional[dict]
+ :param ssh_settings: SSH settings. Defaults to None.
+ :type ssh_settings: Optional[~azure.ai.ml.entities.VirtualMachineSshSettings]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_compute.py
+ :start-after: [START vm_compute]
+ :end-before: [END vm_compute]
+ :language: python
+ :dedent: 8
+ :caption: Configuring a VirtualMachineCompute object.
+ """
+
+ def __init__(
+ self,
+ *,
+ name: str,
+ description: Optional[str] = None,
+ resource_id: str,
+ tags: Optional[dict] = None,
+ ssh_settings: Optional[VirtualMachineSshSettings] = None,
+ **kwargs: Any,
+ ) -> None:
+ kwargs[TYPE] = ComputeType.VIRTUALMACHINE
+ self._public_key_data: str = kwargs.pop("public_key_data", None)
+ super().__init__(
+ name=name,
+ location=kwargs.pop("location", None),
+ description=description,
+ resource_id=resource_id,
+ tags=tags,
+ **kwargs,
+ )
+ self.ssh_settings = ssh_settings
+
+ @property
+ def public_key_data(self) -> str:
+ """Public key data.
+
+ :return: Public key data.
+ :rtype: str
+ """
+ return self._public_key_data
+
+ @classmethod
+ def _load_from_rest(cls, rest_obj: ComputeResource) -> "VirtualMachineCompute":
+ prop = rest_obj.properties
+ credentials = prop.properties.administrator_account if prop.properties else None
+ ssh_settings_param = None
+ if credentials or (prop.properties and prop.properties.ssh_port):
+ ssh_settings_param = VirtualMachineSshSettings(
+ admin_username=credentials.username if credentials else None,
+ admin_password=credentials.password if credentials else None,
+ ssh_port=prop.properties.ssh_port if prop.properties else None,
+ )
+ response = VirtualMachineCompute(
+ name=rest_obj.name,
+ id=rest_obj.id,
+ description=prop.description,
+ location=rest_obj.location,
+ resource_id=prop.resource_id,
+ tags=rest_obj.tags if rest_obj.tags else None,
+ public_key_data=credentials.public_key_data if credentials else None,
+ provisioning_state=prop.provisioning_state,
+ provisioning_errors=(
+ prop.provisioning_errors[0].error.code
+ if (prop.provisioning_errors and len(prop.provisioning_errors) > 0)
+ else None
+ ),
+ ssh_settings=ssh_settings_param,
+ )
+ return response
+
+ def _to_dict(self) -> Dict:
+ res: dict = VirtualMachineComputeSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
+
+ @classmethod
+ def _load_from_dict(cls, data: Dict, context: Dict, **kwargs: Any) -> "VirtualMachineCompute":
+ loaded_data = load_from_dict(VirtualMachineComputeSchema, data, context, **kwargs)
+ return VirtualMachineCompute(**loaded_data)
+
+ def _to_rest_object(self) -> ComputeResource:
+ ssh_key_value = None
+ if self.ssh_settings and self.ssh_settings.ssh_private_key_file:
+ ssh_key_value = Path(self.ssh_settings.ssh_private_key_file).read_text(encoding=DefaultOpenEncoding.READ)
+ credentials = VirtualMachineSshCredentials(
+ username=self.ssh_settings.admin_username if self.ssh_settings else None,
+ password=self.ssh_settings.admin_password if self.ssh_settings else None,
+ public_key_data=self.public_key_data,
+ private_key_data=ssh_key_value,
+ )
+ if self.ssh_settings is not None:
+ properties = VirtualMachineSchemaProperties(
+ ssh_port=self.ssh_settings.ssh_port, administrator_account=credentials
+ )
+ vm_compute = VMResource(
+ properties=properties, # pylint: disable=possibly-used-before-assignment
+ resource_id=self.resource_id,
+ description=self.description,
+ )
+ resource = ComputeResource(name=self.name, location=self.location, tags=self.tags, properties=vm_compute)
+ return resource
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_credentials.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_credentials.py
new file mode 100644
index 00000000..b4d8e01d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_credentials.py
@@ -0,0 +1,964 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access,redefined-builtin
+
+from abc import ABC
+from typing import Any, Dict, List, Optional, Type, Union
+
+from azure.ai.ml._azure_environments import _get_active_directory_url_from_metadata
+from azure.ai.ml._restclient.v2022_01_01_preview.models import Identity as RestIdentityConfiguration
+from azure.ai.ml._restclient.v2022_01_01_preview.models import ManagedIdentity as RestWorkspaceConnectionManagedIdentity
+from azure.ai.ml._restclient.v2022_01_01_preview.models import (
+ PersonalAccessToken as RestWorkspaceConnectionPersonalAccessToken,
+)
+from azure.ai.ml._restclient.v2022_01_01_preview.models import (
+ ServicePrincipal as RestWorkspaceConnectionServicePrincipal,
+)
+from azure.ai.ml._restclient.v2022_01_01_preview.models import (
+ SharedAccessSignature as RestWorkspaceConnectionSharedAccessSignature,
+)
+from azure.ai.ml._restclient.v2022_01_01_preview.models import UserAssignedIdentity as RestUserAssignedIdentity
+from azure.ai.ml._restclient.v2022_01_01_preview.models import (
+ UsernamePassword as RestWorkspaceConnectionUsernamePassword,
+)
+from azure.ai.ml._restclient.v2022_05_01.models import ManagedServiceIdentity as RestManagedServiceIdentityConfiguration
+from azure.ai.ml._restclient.v2022_05_01.models import UserAssignedIdentity as RestUserAssignedIdentityConfiguration
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ AccountKeyDatastoreCredentials as RestAccountKeyDatastoreCredentials,
+)
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ AccountKeyDatastoreSecrets as RestAccountKeyDatastoreSecrets,
+)
+from azure.ai.ml._restclient.v2023_04_01_preview.models import AmlToken as RestAmlToken
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ CertificateDatastoreCredentials as RestCertificateDatastoreCredentials,
+)
+from azure.ai.ml._restclient.v2023_04_01_preview.models import CertificateDatastoreSecrets, CredentialsType
+from azure.ai.ml._restclient.v2023_04_01_preview.models import IdentityConfiguration as RestJobIdentityConfiguration
+from azure.ai.ml._restclient.v2023_04_01_preview.models import IdentityConfigurationType
+from azure.ai.ml._restclient.v2023_04_01_preview.models import ManagedIdentity as RestJobManagedIdentity
+from azure.ai.ml._restclient.v2023_04_01_preview.models import ManagedServiceIdentity as RestRegistryManagedIdentity
+from azure.ai.ml._restclient.v2023_04_01_preview.models import NoneDatastoreCredentials as RestNoneDatastoreCredentials
+from azure.ai.ml._restclient.v2023_04_01_preview.models import SasDatastoreCredentials as RestSasDatastoreCredentials
+from azure.ai.ml._restclient.v2023_04_01_preview.models import SasDatastoreSecrets as RestSasDatastoreSecrets
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ ServicePrincipalDatastoreCredentials as RestServicePrincipalDatastoreCredentials,
+)
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ ServicePrincipalDatastoreSecrets as RestServicePrincipalDatastoreSecrets,
+)
+from azure.ai.ml._restclient.v2023_04_01_preview.models import UserIdentity as RestUserIdentity
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ WorkspaceConnectionAccessKey as RestWorkspaceConnectionAccessKey,
+)
+from azure.ai.ml._restclient.v2023_06_01_preview.models import (
+ WorkspaceConnectionApiKey as RestWorkspaceConnectionApiKey,
+)
+
+# Note, this import needs to match the restclient that's imported by the
+# Connection class, otherwise some unit tests will start failing
+# Due to the mismatch between expected and received classes in WC rest conversions.
+from azure.ai.ml._restclient.v2024_04_01_preview.models import (
+ AADAuthTypeWorkspaceConnectionProperties,
+ AccessKeyAuthTypeWorkspaceConnectionProperties,
+ AccountKeyAuthTypeWorkspaceConnectionProperties,
+ ApiKeyAuthWorkspaceConnectionProperties,
+ ConnectionAuthType,
+ ManagedIdentityAuthTypeWorkspaceConnectionProperties,
+ NoneAuthTypeWorkspaceConnectionProperties,
+ PATAuthTypeWorkspaceConnectionProperties,
+ SASAuthTypeWorkspaceConnectionProperties,
+ ServicePrincipalAuthTypeWorkspaceConnectionProperties,
+ UsernamePasswordAuthTypeWorkspaceConnectionProperties,
+)
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml._utils.utils import _snake_to_camel, camel_to_snake, snake_to_pascal
+from azure.ai.ml.constants._common import CommonYamlFields, IdentityType
+from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin, YamlTranslatableMixin
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, JobException, ValidationErrorType, ValidationException
+
+
+class _BaseIdentityConfiguration(ABC, DictMixin, RestTranslatableMixin):
+ def __init__(self) -> None:
+ self.type: Any = None
+
+ @classmethod
+ def _get_credential_class_from_rest_type(cls, auth_type: str) -> Type:
+ # Defined in this file instead of in constants file to avoid risking
+ # circular imports. This map links rest enums to the corresponding client classes.
+ # Enums are all lower-cased because rest enums aren't always consistent with their
+ # camel casing rules.
+ # Defined in this class because I didn't want this at the bottom of the file,
+ # but the classes aren't visible to the interpreter at the start of the file.
+ # Technically most of these classes aren't child of _BaseIdentityConfiguration, but
+ # I don't care.
+ REST_CREDENTIAL_TYPE_TO_CLIENT_CLASS_MAP = {
+ ConnectionAuthType.SAS.lower(): SasTokenConfiguration,
+ ConnectionAuthType.PAT.lower(): PatTokenConfiguration,
+ ConnectionAuthType.ACCESS_KEY.lower(): AccessKeyConfiguration,
+ ConnectionAuthType.USERNAME_PASSWORD.lower(): UsernamePasswordConfiguration,
+ ConnectionAuthType.SERVICE_PRINCIPAL.lower(): ServicePrincipalConfiguration,
+ ConnectionAuthType.MANAGED_IDENTITY.lower(): ManagedIdentityConfiguration,
+ ConnectionAuthType.API_KEY.lower(): ApiKeyConfiguration,
+ ConnectionAuthType.ACCOUNT_KEY.lower(): AccountKeyConfiguration,
+ ConnectionAuthType.AAD.lower(): AadCredentialConfiguration,
+ }
+ if not auth_type:
+ return NoneCredentialConfiguration
+ return REST_CREDENTIAL_TYPE_TO_CLIENT_CLASS_MAP.get(
+ _snake_to_camel(auth_type).lower(), NoneCredentialConfiguration
+ )
+
+
+class AccountKeyConfiguration(RestTranslatableMixin, DictMixin):
+ def __init__(
+ self,
+ *,
+ account_key: Optional[str],
+ ) -> None:
+ self.type = camel_to_snake(CredentialsType.ACCOUNT_KEY)
+ self.account_key = account_key
+
+ def _to_datastore_rest_object(self) -> RestAccountKeyDatastoreCredentials:
+ secrets = RestAccountKeyDatastoreSecrets(key=self.account_key)
+ return RestAccountKeyDatastoreCredentials(secrets=secrets)
+
+ @classmethod
+ def _from_datastore_rest_object(cls, obj: RestAccountKeyDatastoreCredentials) -> "AccountKeyConfiguration":
+ return cls(account_key=obj.secrets.key if obj.secrets else None)
+
+ @classmethod
+ def _from_workspace_connection_rest_object(
+ cls, obj: Optional[RestWorkspaceConnectionSharedAccessSignature]
+ ) -> "AccountKeyConfiguration":
+ # As far as I can tell, account key configs use the name underlying
+ # rest object as sas token configs
+ return cls(account_key=obj.sas if obj is not None and obj.sas else None)
+
+ def _to_workspace_connection_rest_object(self) -> RestWorkspaceConnectionSharedAccessSignature:
+ return RestWorkspaceConnectionSharedAccessSignature(sas=self.account_key)
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, AccountKeyConfiguration):
+ return NotImplemented
+ return self.account_key == other.account_key
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
+
+ @classmethod
+ def _get_rest_properties_class(cls) -> Type:
+ return AccountKeyAuthTypeWorkspaceConnectionProperties
+
+
+class SasTokenConfiguration(RestTranslatableMixin, DictMixin):
+ def __init__(
+ self,
+ *,
+ sas_token: Optional[str],
+ ) -> None:
+ super().__init__()
+ self.type = camel_to_snake(CredentialsType.SAS)
+ self.sas_token = sas_token
+
+ def _to_datastore_rest_object(self) -> RestSasDatastoreCredentials:
+ secrets = RestSasDatastoreSecrets(sas_token=self.sas_token)
+ return RestSasDatastoreCredentials(secrets=secrets)
+
+ @classmethod
+ def _from_datastore_rest_object(cls, obj: RestSasDatastoreCredentials) -> "SasTokenConfiguration":
+ return cls(sas_token=obj.secrets.sas_token if obj.secrets else None)
+
+ def _to_workspace_connection_rest_object(self) -> RestWorkspaceConnectionSharedAccessSignature:
+ return RestWorkspaceConnectionSharedAccessSignature(sas=self.sas_token)
+
+ @classmethod
+ def _from_workspace_connection_rest_object(
+ cls, obj: Optional[RestWorkspaceConnectionSharedAccessSignature]
+ ) -> "SasTokenConfiguration":
+ return cls(sas_token=obj.sas if obj is not None and obj.sas else None)
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, SasTokenConfiguration):
+ return NotImplemented
+ return self.sas_token == other.sas_token
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
+
+ @classmethod
+ def _get_rest_properties_class(cls) -> Type:
+ return SASAuthTypeWorkspaceConnectionProperties
+
+
+class PatTokenConfiguration(RestTranslatableMixin, DictMixin):
+ """Personal access token credentials.
+
+ :param pat: Personal access token.
+ :type pat: str
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START personal_access_token_configuration]
+ :end-before: [END personal_access_token_configuration]
+ :language: python
+ :dedent: 8
+ :caption: Configuring a personal access token configuration for a WorkspaceConnection.
+ """
+
+ def __init__(self, *, pat: Optional[str]) -> None:
+ super().__init__()
+ self.type = camel_to_snake(ConnectionAuthType.PAT)
+ self.pat = pat
+
+ def _to_workspace_connection_rest_object(self) -> RestWorkspaceConnectionPersonalAccessToken:
+ return RestWorkspaceConnectionPersonalAccessToken(pat=self.pat)
+
+ @classmethod
+ def _from_workspace_connection_rest_object(
+ cls, obj: Optional[RestWorkspaceConnectionPersonalAccessToken]
+ ) -> "PatTokenConfiguration":
+ return cls(pat=obj.pat if obj is not None and obj.pat else None)
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, PatTokenConfiguration):
+ return NotImplemented
+ return self.pat == other.pat
+
+ @classmethod
+ def _get_rest_properties_class(cls) -> Type:
+ return PATAuthTypeWorkspaceConnectionProperties
+
+
+class UsernamePasswordConfiguration(RestTranslatableMixin, DictMixin):
+ """Username and password credentials.
+
+ :param username: The username, value should be url-encoded.
+ :type username: str
+ :param password: The password, value should be url-encoded.
+ :type password: str
+ """
+
+ def __init__(
+ self,
+ *,
+ username: Optional[str],
+ password: Optional[str],
+ ) -> None:
+ super().__init__()
+ self.type = camel_to_snake(ConnectionAuthType.USERNAME_PASSWORD)
+ self.username = username
+ self.password = password
+
+ def _to_workspace_connection_rest_object(self) -> RestWorkspaceConnectionUsernamePassword:
+ return RestWorkspaceConnectionUsernamePassword(username=self.username, password=self.password)
+
+ @classmethod
+ def _from_workspace_connection_rest_object(
+ cls, obj: Optional[RestWorkspaceConnectionUsernamePassword]
+ ) -> "UsernamePasswordConfiguration":
+ return cls(
+ username=obj.username if obj is not None and obj.username else None,
+ password=obj.password if obj is not None and obj.password else None,
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, UsernamePasswordConfiguration):
+ return NotImplemented
+ return self.username == other.username and self.password == other.password
+
+ @classmethod
+ def _get_rest_properties_class(cls) -> Type:
+ return UsernamePasswordAuthTypeWorkspaceConnectionProperties
+
+
+class BaseTenantCredentials(RestTranslatableMixin, DictMixin, ABC):
+ """Base class for tenant credentials.
+
+ This class should not be instantiated directly. Instead, use one of its subclasses.
+
+ :param authority_url: The authority URL. If None specified, a URL will be retrieved from the metadata in the cloud.
+ :type authority_url: Optional[str]
+ :param resource_url: The resource URL.
+ :type resource_url: Optional[str]
+ :param tenant_id: The tenant ID.
+ :type tenant_id: Optional[str]
+ :param client_id: The client ID.
+ :type client_id: Optional[str]
+ """
+
+ def __init__(
+ self,
+ authority_url: str = _get_active_directory_url_from_metadata(),
+ resource_url: Optional[str] = None,
+ tenant_id: Optional[str] = None,
+ client_id: Optional[str] = None,
+ ) -> None:
+ super().__init__()
+ self.authority_url = authority_url
+ self.resource_url = resource_url
+ self.tenant_id = tenant_id
+ self.client_id = client_id
+
+
+class ServicePrincipalConfiguration(BaseTenantCredentials):
+ """Service Principal credentials configuration.
+
+ :param client_secret: The client secret.
+ :type client_secret: str
+ :keyword kwargs: Additional arguments to pass to the parent class.
+ :paramtype kwargs: Optional[dict]
+ """
+
+ def __init__(
+ self,
+ *,
+ client_secret: Optional[str],
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.type = camel_to_snake(CredentialsType.SERVICE_PRINCIPAL)
+ self.client_secret = client_secret
+
+ def _to_datastore_rest_object(self) -> RestServicePrincipalDatastoreCredentials:
+ secrets = RestServicePrincipalDatastoreSecrets(client_secret=self.client_secret)
+ return RestServicePrincipalDatastoreCredentials(
+ authority_url=self.authority_url,
+ resource_url=self.resource_url,
+ tenant_id=self.tenant_id,
+ client_id=self.client_id,
+ secrets=secrets,
+ )
+
+ @classmethod
+ def _from_datastore_rest_object(
+ cls, obj: RestServicePrincipalDatastoreCredentials
+ ) -> "ServicePrincipalConfiguration":
+ return cls(
+ authority_url=obj.authority_url,
+ resource_url=obj.resource_url,
+ tenant_id=obj.tenant_id,
+ client_id=obj.client_id,
+ client_secret=obj.secrets.client_secret if obj.secrets else None,
+ )
+
+ def _to_workspace_connection_rest_object(self) -> RestWorkspaceConnectionServicePrincipal:
+ return RestWorkspaceConnectionServicePrincipal(
+ client_id=self.client_id,
+ client_secret=self.client_secret,
+ tenant_id=self.tenant_id,
+ )
+
+ @classmethod
+ def _from_workspace_connection_rest_object(
+ cls, obj: Optional[RestWorkspaceConnectionServicePrincipal]
+ ) -> "ServicePrincipalConfiguration":
+ return cls(
+ client_id=obj.client_id if obj is not None and obj.client_id else None,
+ client_secret=obj.client_secret if obj is not None and obj.client_secret else None,
+ tenant_id=obj.tenant_id if obj is not None and obj.tenant_id else None,
+ authority_url="",
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, ServicePrincipalConfiguration):
+ return NotImplemented
+ return (
+ self.authority_url == other.authority_url
+ and self.resource_url == other.resource_url
+ and self.tenant_id == other.tenant_id
+ and self.client_id == other.client_id
+ and self.client_secret == other.client_secret
+ )
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
+
+ @classmethod
+ def _get_rest_properties_class(cls) -> Type:
+ return ServicePrincipalAuthTypeWorkspaceConnectionProperties
+
+
+class CertificateConfiguration(BaseTenantCredentials):
+ def __init__(
+ self,
+ certificate: Optional[str] = None,
+ thumbprint: Optional[str] = None,
+ **kwargs: str,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.type = CredentialsType.CERTIFICATE
+ self.certificate = certificate
+ self.thumbprint = thumbprint
+
+ def _to_datastore_rest_object(self) -> RestCertificateDatastoreCredentials:
+ secrets = CertificateDatastoreSecrets(certificate=self.certificate)
+ return RestCertificateDatastoreCredentials(
+ authority_url=self.authority_url,
+ resource_uri=self.resource_url,
+ tenant_id=self.tenant_id,
+ client_id=self.client_id,
+ thumbprint=self.thumbprint,
+ secrets=secrets,
+ )
+
+ @classmethod
+ def _from_datastore_rest_object(cls, obj: RestCertificateDatastoreCredentials) -> "CertificateConfiguration":
+ return cls(
+ authority_url=obj.authority_url,
+ resource_url=obj.resource_uri,
+ tenant_id=obj.tenant_id,
+ client_id=obj.client_id,
+ thumbprint=obj.thumbprint,
+ certificate=obj.secrets.certificate if obj.secrets else None,
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, CertificateConfiguration):
+ return NotImplemented
+ return (
+ self.authority_url == other.authority_url
+ and self.resource_url == other.resource_url
+ and self.tenant_id == other.tenant_id
+ and self.client_id == other.client_id
+ and self.thumbprint == other.thumbprint
+ and self.certificate == other.certificate
+ )
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
+
+
+class _BaseJobIdentityConfiguration(ABC, RestTranslatableMixin, DictMixin, YamlTranslatableMixin):
+ def __init__(self) -> None:
+ self.type = None
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestJobIdentityConfiguration) -> "RestIdentityConfiguration":
+ if obj is None:
+ return None
+ mapping = {
+ IdentityConfigurationType.AML_TOKEN: AmlTokenConfiguration,
+ IdentityConfigurationType.MANAGED: ManagedIdentityConfiguration,
+ IdentityConfigurationType.USER_IDENTITY: UserIdentityConfiguration,
+ }
+
+ if isinstance(obj, dict):
+ # TODO: support data binding expression
+ obj = RestJobIdentityConfiguration.from_dict(obj)
+
+ identity_class = mapping.get(obj.identity_type, None)
+ if identity_class:
+ if obj.identity_type == IdentityConfigurationType.AML_TOKEN:
+ return AmlTokenConfiguration._from_job_rest_object(obj)
+
+ if obj.identity_type == IdentityConfigurationType.MANAGED:
+ return ManagedIdentityConfiguration._from_job_rest_object(obj)
+
+ if obj.identity_type == IdentityConfigurationType.USER_IDENTITY:
+ return UserIdentityConfiguration._from_job_rest_object(obj)
+
+ msg = f"Unknown identity type: {obj.identity_type}"
+ raise JobException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.IDENTITY,
+ error_category=ErrorCategory.SYSTEM_ERROR,
+ )
+
+ @classmethod
+ def _load(
+ cls,
+ data: Dict,
+ ) -> Union["ManagedIdentityConfiguration", "UserIdentityConfiguration", "AmlTokenConfiguration"]:
+ type_str = data.get(CommonYamlFields.TYPE)
+ if type_str == IdentityType.MANAGED_IDENTITY:
+ return ManagedIdentityConfiguration._load_from_dict(data)
+
+ if type_str == IdentityType.USER_IDENTITY:
+ return UserIdentityConfiguration._load_from_dict(data)
+
+ if type_str == IdentityType.AML_TOKEN:
+ return AmlTokenConfiguration._load_from_dict(data)
+
+ msg = f"Unsupported identity type: {type_str}."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.IDENTITY,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+
+class ManagedIdentityConfiguration(_BaseIdentityConfiguration):
+ """Managed Identity credential configuration.
+
+ :keyword client_id: The client ID of the managed identity.
+ :paramtype client_id: Optional[str]
+ :keyword resource_id: The resource ID of the managed identity.
+ :paramtype resource_id: Optional[str]
+ :keyword object_id: The object ID.
+ :paramtype object_id: Optional[str]
+ :keyword principal_id: The principal ID.
+ :paramtype principal_id: Optional[str]
+ """
+
+ def __init__(
+ self,
+ *,
+ client_id: Optional[str] = None,
+ resource_id: Optional[str] = None,
+ object_id: Optional[str] = None,
+ principal_id: Optional[str] = None,
+ ) -> None:
+ super().__init__()
+ self.type = IdentityType.MANAGED_IDENTITY
+ self.client_id = client_id
+ # TODO: Check if both client_id and resource_id are required
+ self.resource_id = resource_id
+ self.object_id = object_id
+ self.principal_id = principal_id
+
+ def _to_workspace_connection_rest_object(self) -> RestWorkspaceConnectionManagedIdentity:
+ return RestWorkspaceConnectionManagedIdentity(client_id=self.client_id, resource_id=self.resource_id)
+
+ @classmethod
+ def _from_workspace_connection_rest_object(
+ cls, obj: Optional[RestWorkspaceConnectionManagedIdentity]
+ ) -> "ManagedIdentityConfiguration":
+ return cls(
+ client_id=obj.client_id if obj is not None and obj.client_id else None,
+ resource_id=obj.resource_id if obj is not None and obj.client_id else None,
+ )
+
+ def _to_job_rest_object(self) -> RestJobManagedIdentity:
+ return RestJobManagedIdentity(
+ client_id=self.client_id,
+ object_id=self.object_id,
+ resource_id=self.resource_id,
+ )
+
+ @classmethod
+ def _from_job_rest_object(cls, obj: RestJobManagedIdentity) -> "ManagedIdentityConfiguration":
+ return cls(
+ client_id=obj.client_id,
+ object_id=obj.client_id,
+ resource_id=obj.resource_id,
+ )
+
+ def _to_identity_configuration_rest_object(self) -> RestUserAssignedIdentity:
+ return RestUserAssignedIdentity()
+
+ @classmethod
+ def _from_identity_configuration_rest_object(
+ cls, rest_obj: RestUserAssignedIdentity, **kwargs: Optional[str]
+ ) -> "ManagedIdentityConfiguration":
+ _rid: Optional[str] = kwargs["resource_id"]
+ result = cls(resource_id=_rid)
+ result.__dict__.update(rest_obj.as_dict())
+ return result
+
+ def _to_online_endpoint_rest_object(self) -> RestUserAssignedIdentityConfiguration:
+ return RestUserAssignedIdentityConfiguration()
+
+ def _to_workspace_rest_object(self) -> RestUserAssignedIdentityConfiguration:
+ return RestUserAssignedIdentityConfiguration(
+ principal_id=self.principal_id,
+ client_id=self.client_id,
+ )
+
+ @classmethod
+ def _from_workspace_rest_object(cls, obj: RestUserAssignedIdentityConfiguration) -> "ManagedIdentityConfiguration":
+ return cls(
+ principal_id=obj.principal_id,
+ client_id=obj.client_id,
+ )
+
+ def _to_dict(self) -> Dict:
+ # pylint: disable=no-member
+ from azure.ai.ml._schema.job.identity import ManagedIdentitySchema
+
+ _dict: Dict = ManagedIdentitySchema().dump(self)
+ return _dict
+
+ @classmethod
+ def _load_from_dict(cls, data: Dict) -> "ManagedIdentityConfiguration":
+ # pylint: disable=no-member
+ from azure.ai.ml._schema.job.identity import ManagedIdentitySchema
+
+ _data: ManagedIdentityConfiguration = ManagedIdentitySchema().load(data)
+ return _data
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, ManagedIdentityConfiguration):
+ return NotImplemented
+ return self.client_id == other.client_id and self.resource_id == other.resource_id
+
+ @classmethod
+ def _get_rest_properties_class(cls) -> Type:
+ return ManagedIdentityAuthTypeWorkspaceConnectionProperties
+
+
+class UserIdentityConfiguration(_BaseIdentityConfiguration):
+ """User identity configuration.
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_authentication.py
+ :start-after: [START user_identity_configuration]
+ :end-before: [END user_identity_configuration]
+ :language: python
+ :dedent: 8
+ :caption: Configuring a UserIdentityConfiguration for a command().
+ """
+
+ def __init__(self) -> None:
+ super().__init__()
+ self.type = IdentityType.USER_IDENTITY
+
+ def _to_job_rest_object(self) -> RestUserIdentity:
+ return RestUserIdentity()
+
+ @classmethod
+ # pylint: disable=unused-argument
+ def _from_job_rest_object(cls, obj: RestUserIdentity) -> "RestUserIdentity":
+ return cls()
+
+ def _to_dict(self) -> Dict:
+ # pylint: disable=no-member
+ from azure.ai.ml._schema.job.identity import UserIdentitySchema
+
+ _dict: Dict = UserIdentitySchema().dump(self)
+ return _dict
+
+ @classmethod
+ def _load_from_dict(cls, data: Dict) -> "UserIdentityConfiguration":
+ # pylint: disable=no-member
+ from azure.ai.ml._schema.job.identity import UserIdentitySchema
+
+ _data: UserIdentityConfiguration = UserIdentitySchema().load(data)
+ return _data
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, UserIdentityConfiguration):
+ return NotImplemented
+ res: bool = self._to_job_rest_object() == other._to_job_rest_object()
+ return res
+
+
+class AmlTokenConfiguration(_BaseIdentityConfiguration):
+ """AzureML Token identity configuration.
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_authentication.py
+ :start-after: [START aml_token_configuration]
+ :end-before: [END aml_token_configuration]
+ :language: python
+ :dedent: 8
+ :caption: Configuring an AmlTokenConfiguration for a command().
+ """
+
+ def __init__(self) -> None:
+ super().__init__()
+ self.type = IdentityType.AML_TOKEN
+
+ def _to_job_rest_object(self) -> RestAmlToken:
+ return RestAmlToken()
+
+ def _to_dict(self) -> Dict:
+ # pylint: disable=no-member
+ from azure.ai.ml._schema.job.identity import AMLTokenIdentitySchema
+
+ _dict: Dict = AMLTokenIdentitySchema().dump(self)
+ return _dict
+
+ @classmethod
+ def _load_from_dict(cls, data: Dict) -> "AmlTokenConfiguration":
+ # pylint: disable=no-member
+ from azure.ai.ml._schema.job.identity import AMLTokenIdentitySchema
+
+ _data: AmlTokenConfiguration = AMLTokenIdentitySchema().load(data)
+ return _data
+
+ @classmethod
+ # pylint: disable=unused-argument
+ def _from_job_rest_object(cls, obj: RestAmlToken) -> "AmlTokenConfiguration":
+ return cls()
+
+
+# This class will be used to represent Identity property on compute, endpoint, and registry
+class IdentityConfiguration(RestTranslatableMixin):
+ """Identity configuration used to represent identity property on compute, endpoint, and registry resources.
+
+ :param type: The type of managed identity.
+ :type type: str
+ :param user_assigned_identities: A list of ManagedIdentityConfiguration objects.
+ :type user_assigned_identities: Optional[list[~azure.ai.ml.entities.ManagedIdentityConfiguration]]
+ """
+
+ def __init__(
+ self,
+ *,
+ type: str,
+ user_assigned_identities: Optional[List[ManagedIdentityConfiguration]] = None,
+ **kwargs: dict,
+ ) -> None:
+ self.type = type
+ self.user_assigned_identities = user_assigned_identities
+ self.principal_id = kwargs.pop("principal_id", None)
+ self.tenant_id = kwargs.pop("tenant_id", None)
+
+ def _to_compute_rest_object(self) -> RestIdentityConfiguration:
+ rest_user_assigned_identities = (
+ {uai.resource_id: uai._to_identity_configuration_rest_object() for uai in self.user_assigned_identities}
+ if self.user_assigned_identities
+ else None
+ )
+ return RestIdentityConfiguration(
+ type=snake_to_pascal(self.type), user_assigned_identities=rest_user_assigned_identities
+ )
+
+ @classmethod
+ def _from_compute_rest_object(cls, obj: RestIdentityConfiguration) -> "IdentityConfiguration":
+ from_rest_user_assigned_identities = (
+ [
+ ManagedIdentityConfiguration._from_identity_configuration_rest_object(uai, resource_id=resource_id)
+ for (resource_id, uai) in obj.user_assigned_identities.items()
+ ]
+ if obj.user_assigned_identities
+ else None
+ )
+ result = cls(
+ type=camel_to_snake(obj.type),
+ user_assigned_identities=from_rest_user_assigned_identities,
+ )
+ result.principal_id = obj.principal_id
+ result.tenant_id = obj.tenant_id
+ return result
+
+ def _to_online_endpoint_rest_object(self) -> RestManagedServiceIdentityConfiguration:
+ rest_user_assigned_identities = (
+ {uai.resource_id: uai._to_online_endpoint_rest_object() for uai in self.user_assigned_identities}
+ if self.user_assigned_identities
+ else None
+ )
+
+ return RestManagedServiceIdentityConfiguration(
+ type=snake_to_pascal(self.type),
+ principal_id=self.principal_id,
+ tenant_id=self.tenant_id,
+ user_assigned_identities=rest_user_assigned_identities,
+ )
+
+ @classmethod
+ def _from_online_endpoint_rest_object(cls, obj: RestManagedServiceIdentityConfiguration) -> "IdentityConfiguration":
+ from_rest_user_assigned_identities = (
+ [
+ ManagedIdentityConfiguration._from_identity_configuration_rest_object(uai, resource_id=resource_id)
+ for (resource_id, uai) in obj.user_assigned_identities.items()
+ ]
+ if obj.user_assigned_identities
+ else None
+ )
+ result = cls(
+ type=camel_to_snake(obj.type),
+ user_assigned_identities=from_rest_user_assigned_identities,
+ )
+ result.principal_id = obj.principal_id
+ result.tenant_id = obj.tenant_id
+ return result
+
+ @classmethod
+ def _from_workspace_rest_object(cls, obj: RestManagedServiceIdentityConfiguration) -> "IdentityConfiguration":
+ from_rest_user_assigned_identities = (
+ [
+ ManagedIdentityConfiguration._from_identity_configuration_rest_object(uai, resource_id=resource_id)
+ for (resource_id, uai) in obj.user_assigned_identities.items()
+ ]
+ if obj.user_assigned_identities
+ else None
+ )
+ result = cls(
+ type=camel_to_snake(obj.type),
+ user_assigned_identities=from_rest_user_assigned_identities,
+ )
+ result.principal_id = obj.principal_id
+ result.tenant_id = obj.tenant_id
+ return result
+
+ def _to_workspace_rest_object(self) -> RestManagedServiceIdentityConfiguration:
+ rest_user_assigned_identities = (
+ {uai.resource_id: uai._to_workspace_rest_object() for uai in self.user_assigned_identities}
+ if self.user_assigned_identities
+ else None
+ )
+ return RestManagedServiceIdentityConfiguration(
+ type=snake_to_pascal(self.type), user_assigned_identities=rest_user_assigned_identities
+ )
+
+ def _to_rest_object(self) -> RestRegistryManagedIdentity:
+ return RestRegistryManagedIdentity(
+ type=self.type,
+ principal_id=self.principal_id,
+ tenant_id=self.tenant_id,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestRegistryManagedIdentity) -> "IdentityConfiguration":
+ result = cls(
+ type=obj.type,
+ user_assigned_identities=None,
+ )
+ result.principal_id = obj.principal_id
+ result.tenant_id = obj.tenant_id
+ return result
+
+
+class NoneCredentialConfiguration(RestTranslatableMixin):
+ """None Credential Configuration. In many uses cases, the presence of
+ this credential configuration indicates that the user's Entra ID will be
+ implicitly used instead of any other form of authentication."""
+
+ def __init__(self) -> None:
+ self.type = CredentialsType.NONE
+
+ def _to_datastore_rest_object(self) -> RestNoneDatastoreCredentials:
+ return RestNoneDatastoreCredentials()
+
+ @classmethod
+ # pylint: disable=unused-argument
+ def _from_datastore_rest_object(cls, obj: RestNoneDatastoreCredentials) -> "NoneCredentialConfiguration":
+ return cls()
+
+ def _to_workspace_connection_rest_object(self) -> None:
+ return None
+
+ def __eq__(self, other: object) -> bool:
+ if isinstance(other, NoneCredentialConfiguration):
+ return True
+ return False
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
+
+ @classmethod
+ def _get_rest_properties_class(cls) -> Type:
+ return NoneAuthTypeWorkspaceConnectionProperties
+
+
+class AadCredentialConfiguration(RestTranslatableMixin):
+ """Azure Active Directory Credential Configuration"""
+
+ def __init__(self) -> None:
+ self.type = camel_to_snake(ConnectionAuthType.AAD)
+
+ def _to_datastore_rest_object(self) -> RestNoneDatastoreCredentials:
+ return RestNoneDatastoreCredentials()
+
+ @classmethod
+ # pylint: disable=unused-argument
+ def _from_datastore_rest_object(cls, obj: RestNoneDatastoreCredentials) -> "AadCredentialConfiguration":
+ return cls()
+
+ # Has no credential object, just a property bag class.
+ def _to_workspace_connection_rest_object(self) -> None:
+ return None
+
+ def __eq__(self, other: object) -> bool:
+ if isinstance(other, AadCredentialConfiguration):
+ return True
+ return False
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
+
+ @classmethod
+ def _get_rest_properties_class(cls) -> Type:
+ return AADAuthTypeWorkspaceConnectionProperties
+
+
+class AccessKeyConfiguration(RestTranslatableMixin, DictMixin):
+ """Access Key Credentials.
+
+ :param access_key_id: The access key ID.
+ :type access_key_id: str
+ :param secret_access_key: The secret access key.
+ :type secret_access_key: str
+ """
+
+ def __init__(
+ self,
+ *,
+ access_key_id: Optional[str],
+ secret_access_key: Optional[str],
+ ) -> None:
+ super().__init__()
+ self.type = camel_to_snake(ConnectionAuthType.ACCESS_KEY)
+ self.access_key_id = access_key_id
+ self.secret_access_key = secret_access_key
+
+ def _to_workspace_connection_rest_object(self) -> RestWorkspaceConnectionAccessKey:
+ return RestWorkspaceConnectionAccessKey(
+ access_key_id=self.access_key_id, secret_access_key=self.secret_access_key
+ )
+
+ @classmethod
+ def _from_workspace_connection_rest_object(
+ cls, obj: Optional[RestWorkspaceConnectionAccessKey]
+ ) -> "AccessKeyConfiguration":
+ return cls(
+ access_key_id=obj.access_key_id if obj is not None and obj.access_key_id else None,
+ secret_access_key=obj.secret_access_key if obj is not None and obj.secret_access_key else None,
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, AccessKeyConfiguration):
+ return NotImplemented
+ return self.access_key_id == other.access_key_id and self.secret_access_key == other.secret_access_key
+
+ def _get_rest_properties_class(self):
+ return AccessKeyAuthTypeWorkspaceConnectionProperties
+
+
+@experimental
+class ApiKeyConfiguration(RestTranslatableMixin, DictMixin):
+ """Api Key Credentials.
+
+ :param key: API key id
+ :type key: str
+ """
+
+ def __init__(
+ self,
+ *,
+ key: Optional[str],
+ ):
+ super().__init__()
+ self.type = camel_to_snake(ConnectionAuthType.API_KEY)
+ self.key = key
+
+ def _to_workspace_connection_rest_object(self) -> RestWorkspaceConnectionApiKey:
+ return RestWorkspaceConnectionApiKey(
+ key=self.key,
+ )
+
+ @classmethod
+ def _from_workspace_connection_rest_object(
+ cls, obj: Optional[RestWorkspaceConnectionApiKey]
+ ) -> "ApiKeyConfiguration":
+ return cls(
+ key=obj.key if obj is not None and obj.key else None,
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, ApiKeyConfiguration):
+ return NotImplemented
+ return bool(self.key == other.key)
+
+ def _get_rest_properties_class(self):
+ return ApiKeyAuthWorkspaceConnectionProperties
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_data/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_data/__init__.py
new file mode 100644
index 00000000..fdf8caba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_data/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_data/mltable_metadata.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_data/mltable_metadata.py
new file mode 100644
index 00000000..452b2e53
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_data/mltable_metadata.py
@@ -0,0 +1,92 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from os import PathLike
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Union
+
+from marshmallow import INCLUDE
+
+from azure.ai.ml._schema._data.mltable_metadata_schema import MLTableMetadataSchema
+from azure.ai.ml._utils.utils import load_yaml
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+from azure.ai.ml.entities._util import load_from_dict
+
+
+class MLTableMetadataPath:
+ type: str # Literal["pattern", "file", "folder"]
+ value: Optional[str]
+
+ def __init__(self, *, pathDict: Dict):
+ if pathDict.get("pattern", None):
+ self.type = "pattern"
+ self.value = pathDict.get("pattern")
+ if pathDict.get("file", None):
+ self.type = "file"
+ self.value = pathDict.get("file")
+ if pathDict.get("folder", None):
+ self.type = "folder"
+ self.value = pathDict.get("folder")
+
+
+class MLTableMetadata:
+ """MLTableMetadata for data assets.
+
+ :param paths: List of paths which the MLTableMetadata refers to.
+ :type paths: List[MLTableMetadataPath]
+ :param transformations: Any transformations to be applied to the data referenced in paths.
+ :type transformations: List[Any]
+ :param base_path: Base path to resolve relative paths from.
+ :type base_path: str
+ """
+
+ def __init__(
+ self,
+ *,
+ paths: List[MLTableMetadataPath],
+ transformations: Optional[List[Any]] = None,
+ base_path: str,
+ **_kwargs: Any,
+ ):
+ self.base_path = base_path
+ self.paths = paths
+ self.transformations = transformations
+
+ @classmethod
+ def load(
+ cls,
+ yaml_path: Union[PathLike, str],
+ **kwargs: Any,
+ ) -> "MLTableMetadata":
+ """Construct an MLTable object from yaml file.
+
+ :param yaml_path: Path to a local file as the source.
+ :type yaml_path: PathLike | str
+
+ :return: Constructed MLTable object.
+ :rtype: MLTable
+ """
+ yaml_dict = load_yaml(yaml_path)
+ return cls._load(yaml_data=yaml_dict, yaml_path=yaml_path, **kwargs)
+
+ @classmethod
+ def _load(
+ cls,
+ yaml_data: Optional[Dict],
+ yaml_path: Optional[Union[PathLike, str]],
+ **kwargs: Any,
+ ) -> "MLTableMetadata":
+ yaml_data = yaml_data or {}
+ context = {
+ BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"),
+ }
+ res: MLTableMetadata = load_from_dict(MLTableMetadataSchema, yaml_data, context, "", unknown=INCLUDE, **kwargs)
+ return res
+
+ def _to_dict(self) -> Dict:
+ res: dict = MLTableMetadataSchema(context={BASE_PATH_CONTEXT_KEY: "./"}, unknown=INCLUDE).dump(self)
+ return res
+
+ def referenced_uris(self) -> List[Optional[str]]:
+ return [path.value for path in self.paths]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_data_import/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_data_import/__init__.py
new file mode 100644
index 00000000..fdf8caba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_data_import/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_data_import/data_import.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_data_import/data_import.py
new file mode 100644
index 00000000..028d431c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_data_import/data_import.py
@@ -0,0 +1,130 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from os import PathLike
+from pathlib import Path
+from typing import Any, Dict, Optional, Union
+
+from azure.ai.ml._restclient.v2023_06_01_preview.models import DatabaseSource as RestDatabaseSource
+from azure.ai.ml._restclient.v2023_06_01_preview.models import DataImport as RestDataImport
+from azure.ai.ml._restclient.v2023_06_01_preview.models import FileSystemSource as RestFileSystemSource
+from azure.ai.ml._schema import DataImportSchema
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY, AssetTypes
+from azure.ai.ml.data_transfer import Database, FileSystem
+from azure.ai.ml.entities._assets import Data
+from azure.ai.ml.entities._util import load_from_dict
+
+
+@experimental
+class DataImport(Data):
+ """Data asset with a creating data import job.
+
+ :param name: Name of the asset.
+ :type name: str
+ :param path: The path to the asset being created by data import job.
+ :type path: str
+ :param source: The source of the asset data being copied from.
+ :type source: Union[Database, FileSystem]
+ :param version: Version of the resource.
+ :type version: str
+ :param description: Description of the resource.
+ :type description: str
+ :param tags: Tag dictionary. Tags can be added, removed, and updated.
+ :type tags: dict[str, str]
+ :param properties: The asset property dictionary.
+ :type properties: dict[str, str]
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: dict
+ """
+
+ def __init__(
+ self,
+ *,
+ name: str,
+ path: str,
+ source: Union[Database, FileSystem],
+ version: Optional[str] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ properties: Optional[Dict] = None,
+ **kwargs: Any,
+ ):
+ super().__init__(
+ name=name,
+ version=version,
+ description=description,
+ tags=tags,
+ properties=properties,
+ path=path,
+ **kwargs,
+ )
+ self.source = source
+
+ @classmethod
+ def _load(
+ cls,
+ data: Optional[Dict] = None,
+ yaml_path: Optional[Union[PathLike, str]] = None,
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "DataImport":
+ data = data or {}
+ params_override = params_override or []
+ context = {
+ BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"),
+ PARAMS_OVERRIDE_KEY: params_override,
+ }
+ res: DataImport = load_from_dict(DataImportSchema, data, context, **kwargs)
+ return res
+
+ def _to_rest_object(self) -> RestDataImport:
+ if isinstance(self.source, Database):
+ source = RestDatabaseSource(
+ connection=self.source.connection,
+ query=self.source.query,
+ )
+ else:
+ source = RestFileSystemSource(
+ connection=self.source.connection,
+ path=self.source.path,
+ )
+
+ return RestDataImport(
+ description=self.description,
+ properties=self.properties,
+ tags=self.tags,
+ data_type=self.type,
+ data_uri=self.path,
+ asset_name=self.name,
+ source=source,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, data_rest_object: RestDataImport) -> "DataImport":
+ source: Any = None
+ if isinstance(data_rest_object.source, RestDatabaseSource):
+ source = Database(
+ connection=data_rest_object.source.connection,
+ query=data_rest_object.source.query,
+ )
+ data_type = AssetTypes.MLTABLE
+ else:
+ source = FileSystem(
+ connection=data_rest_object.source.connection,
+ path=data_rest_object.source.path,
+ )
+ data_type = AssetTypes.URI_FOLDER
+
+ data_import = cls(
+ name=data_rest_object.asset_name,
+ path=data_rest_object.data_uri,
+ source=source,
+ description=data_rest_object.description,
+ tags=data_rest_object.tags,
+ properties=data_rest_object.properties,
+ type=data_type,
+ is_anonymous=data_rest_object.is_anonymous,
+ )
+ return data_import
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_data_import/schedule.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_data_import/schedule.py
new file mode 100644
index 00000000..6a51878a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_data_import/schedule.py
@@ -0,0 +1,115 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+# pylint: disable=protected-access
+from os import PathLike
+from pathlib import Path
+from typing import Any, Dict, Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import ImportDataAction
+from azure.ai.ml._restclient.v2023_04_01_preview.models import Schedule as RestSchedule
+from azure.ai.ml._restclient.v2023_04_01_preview.models import ScheduleProperties
+from azure.ai.ml._schema._data_import.schedule import ImportDataScheduleSchema
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY, ScheduleType
+from azure.ai.ml.entities._data_import.data_import import DataImport
+from azure.ai.ml.entities._schedule.schedule import Schedule
+from azure.ai.ml.entities._schedule.trigger import CronTrigger, RecurrenceTrigger, TriggerBase
+from azure.ai.ml.entities._system_data import SystemData
+from azure.ai.ml.entities._util import load_from_dict
+
+
+@experimental
+class ImportDataSchedule(Schedule):
+ """ImportDataSchedule object.
+
+ :param name: Name of the schedule.
+ :type name: str
+ :param trigger: Trigger of the schedule.
+ :type trigger: Union[CronTrigger, RecurrenceTrigger]
+ :param import_data: The schedule action data import definition.
+ :type import_data: DataImport
+ :param display_name: Display name of the schedule.
+ :type display_name: str
+ :param description: Description of the schedule, defaults to None
+ :type description: str
+ :param tags: Tag dictionary. Tags can be added, removed, and updated.
+ :type tags: dict[str, str]
+ :param properties: The data import property dictionary.
+ :type properties: dict[str, str]
+ """
+
+ def __init__(
+ self,
+ *,
+ name: str,
+ trigger: Optional[Union[CronTrigger, RecurrenceTrigger]],
+ import_data: DataImport,
+ display_name: Optional[str] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ properties: Optional[Dict] = None,
+ **kwargs: Any,
+ ):
+ super().__init__(
+ name=name,
+ trigger=trigger,
+ display_name=display_name,
+ description=description,
+ tags=tags,
+ properties=properties,
+ **kwargs,
+ )
+ self.import_data = import_data
+ self._type = ScheduleType.DATA_IMPORT
+
+ @classmethod
+ def _load(
+ cls,
+ data: Optional[Dict] = None,
+ yaml_path: Optional[Union[PathLike, str]] = None,
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "ImportDataSchedule":
+ data = data or {}
+ params_override = params_override or []
+ context = {
+ BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"),
+ PARAMS_OVERRIDE_KEY: params_override,
+ }
+ return ImportDataSchedule(
+ base_path=context[BASE_PATH_CONTEXT_KEY],
+ **load_from_dict(ImportDataScheduleSchema, data, context, **kwargs),
+ )
+
+ @classmethod
+ def _create_schema_for_validation(cls, context: Any) -> ImportDataScheduleSchema:
+ return ImportDataScheduleSchema(context=context)
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestSchedule) -> "ImportDataSchedule":
+ return cls(
+ trigger=TriggerBase._from_rest_object(obj.properties.trigger),
+ import_data=DataImport._from_rest_object(obj.properties.action.data_import_definition),
+ name=obj.name,
+ display_name=obj.properties.display_name,
+ description=obj.properties.description,
+ tags=obj.properties.tags,
+ properties=obj.properties.properties,
+ provisioning_state=obj.properties.provisioning_state,
+ is_enabled=obj.properties.is_enabled,
+ creation_context=SystemData._from_rest_object(obj.system_data),
+ )
+
+ def _to_rest_object(self) -> RestSchedule:
+ return RestSchedule(
+ properties=ScheduleProperties(
+ description=self.description,
+ properties=self.properties,
+ tags=self.tags,
+ action=ImportDataAction(data_import_definition=self.import_data._to_rest_object()),
+ display_name=self.display_name,
+ is_enabled=self._is_enabled,
+ trigger=self.trigger._to_rest_object() if self.trigger is not None else None,
+ )
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/__init__.py
new file mode 100644
index 00000000..fdf8caba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/_constants.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/_constants.py
new file mode 100644
index 00000000..97a257ab
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/_constants.py
@@ -0,0 +1,8 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# Miscellaneous
+HTTPS = "https"
+HTTP = "http"
+WORKSPACE_BLOB_STORE = "workspaceblobstore"
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/_on_prem.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/_on_prem.py
new file mode 100644
index 00000000..e6c0dc3f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/_on_prem.py
@@ -0,0 +1,121 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from base64 import b64encode
+from pathlib import Path
+from typing import Any, Dict, Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import Datastore as DatastoreData
+from azure.ai.ml._restclient.v2023_04_01_preview.models import DatastoreType
+from azure.ai.ml._restclient.v2023_04_01_preview.models import HdfsDatastore as RestHdfsDatastore
+from azure.ai.ml._schema._datastore._on_prem import HdfsSchema
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE
+from azure.ai.ml.entities._datastore.datastore import Datastore
+from azure.ai.ml.entities._datastore.utils import _from_rest_datastore_credentials_preview
+from azure.ai.ml.entities._util import load_from_dict
+
+from ._constants import HTTP
+from ._on_prem_credentials import KerberosKeytabCredentials, KerberosPasswordCredentials
+
+
+@experimental
+class HdfsDatastore(Datastore):
+ """HDFS datastore that is linked to an Azure ML workspace.
+
+ :param name: Name of the datastore.
+ :type name: str
+ :param name_node_address: IP Address or DNS HostName.
+ :type name_node_address: str
+ :param hdfs_server_certificate: The TLS cert of the HDFS server (optional).
+ Needs to be a local path on create and will be a base64 encoded string on get.
+ :type hdfs_server_certificate: str
+ :param protocol: http or https
+ :type protocol: str
+ :param description: Description of the resource.
+ :type description: str
+ :param tags: Tag dictionary. Tags can be added, removed, and updated.
+ :type tags: dict[str, str]
+ :param properties: The asset property dictionary.
+ :type properties: dict[str, str]
+ :param credentials: Credentials to use for Azure ML workspace to connect to the storage.
+ :type credentials: Union[KerberosKeytabCredentials, KerberosPasswordCredentials]
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: dict
+ """
+
+ def __init__(
+ self,
+ *,
+ name: str,
+ name_node_address: str,
+ hdfs_server_certificate: Optional[str] = None,
+ protocol: str = HTTP,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ properties: Optional[Dict] = None,
+ credentials: Optional[Union[KerberosKeytabCredentials, KerberosPasswordCredentials]],
+ **kwargs: Any
+ ):
+ kwargs[TYPE] = DatastoreType.HDFS
+ super().__init__(
+ name=name, description=description, tags=tags, properties=properties, credentials=credentials, **kwargs
+ )
+
+ self.hdfs_server_certificate = hdfs_server_certificate
+ self.name_node_address = name_node_address
+ self.protocol = protocol
+
+ def _to_rest_object(self) -> DatastoreData:
+ use_this_cert = None
+ if self.hdfs_server_certificate:
+ with open(self.hdfs_server_certificate, "rb") as f:
+ use_this_cert = b64encode(f.read()).decode("utf-8")
+ hdfs_ds = RestHdfsDatastore(
+ credentials=self.credentials._to_rest_object(),
+ hdfs_server_certificate=use_this_cert,
+ name_node_address=self.name_node_address,
+ protocol=self.protocol,
+ description=self.description,
+ tags=self.tags,
+ )
+ return DatastoreData(properties=hdfs_ds)
+
+ @classmethod
+ def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "HdfsDatastore":
+ res: HdfsDatastore = load_from_dict(HdfsSchema, data, context, additional_message)
+ return res
+
+ @classmethod
+ def _from_rest_object(cls, datastore_resource: DatastoreData) -> "HdfsDatastore":
+ properties: RestHdfsDatastore = datastore_resource.properties
+ return HdfsDatastore(
+ name=datastore_resource.name,
+ id=datastore_resource.id,
+ credentials=_from_rest_datastore_credentials_preview(properties.credentials),
+ hdfs_server_certificate=properties.hdfs_server_certificate,
+ name_node_address=properties.name_node_address,
+ protocol=properties.protocol,
+ description=properties.description,
+ tags=properties.tags,
+ )
+
+ def __eq__(self, other: Any) -> bool:
+ res: bool = (
+ super().__eq__(other)
+ and self.hdfs_server_certificate == other.hdfs_server_certificate
+ and self.name_node_address == other.name_node_address
+ and self.protocol == other.protocol
+ )
+ return res
+
+ def __ne__(self, other: Any) -> bool:
+ return not self.__eq__(other)
+
+ def _to_dict(self) -> Dict:
+ context = {BASE_PATH_CONTEXT_KEY: Path(".").parent}
+ res: dict = HdfsSchema(context=context).dump(self)
+ return res
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/_on_prem_credentials.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/_on_prem_credentials.py
new file mode 100644
index 00000000..b658851a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/_on_prem_credentials.py
@@ -0,0 +1,128 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from base64 import b64encode
+from typing import Any, Optional
+
+from azure.ai.ml._restclient.v2023_04_01_preview import models as model_preview
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.entities._credentials import NoneCredentialConfiguration
+
+
+# TODO: Move classes in this file to azure.ai.ml.entities._credentials
+@experimental
+class BaseKerberosCredentials(NoneCredentialConfiguration):
+ def __init__(self, kerberos_realm: str, kerberos_kdc_address: str, kerberos_principal: str):
+ super().__init__()
+ self.kerberos_realm = kerberos_realm
+ self.kerberos_kdc_address = kerberos_kdc_address
+ self.kerberos_principal = kerberos_principal
+
+
+@experimental
+class KerberosKeytabCredentials(BaseKerberosCredentials):
+ def __init__(
+ self,
+ *,
+ kerberos_realm: str,
+ kerberos_kdc_address: str,
+ kerberos_principal: str,
+ kerberos_keytab: Optional[str],
+ **kwargs: Any,
+ ):
+ super().__init__(
+ kerberos_realm=kerberos_realm,
+ kerberos_kdc_address=kerberos_kdc_address,
+ kerberos_principal=kerberos_principal,
+ **kwargs,
+ )
+ self.type = model_preview.CredentialsType.KERBEROS_KEYTAB
+ self.kerberos_keytab = kerberos_keytab
+
+ def _to_rest_object(self) -> model_preview.KerberosKeytabCredentials:
+ use_this_keytab = None
+ if self.kerberos_keytab:
+ with open(self.kerberos_keytab, "rb") as f:
+ use_this_keytab = b64encode(f.read()).decode("utf-8")
+ secrets = model_preview.KerberosKeytabSecrets(kerberos_keytab=use_this_keytab)
+ return model_preview.KerberosKeytabCredentials(
+ kerberos_kdc_address=self.kerberos_kdc_address,
+ kerberos_principal=self.kerberos_principal,
+ kerberos_realm=self.kerberos_realm,
+ secrets=secrets,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: model_preview.KerberosKeytabCredentials) -> "KerberosKeytabCredentials":
+ return cls(
+ kerberos_kdc_address=obj.kerberos_kdc_address,
+ kerberos_principal=obj.kerberos_principal,
+ kerberos_realm=obj.kerberos_realm,
+ kerberos_keytab=obj.secrets.kerberos_keytab if obj.secrets else None,
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, KerberosKeytabCredentials):
+ return NotImplemented
+ return (
+ self.kerberos_kdc_address == other.kerberos_kdc_address
+ and self.kerberos_principal == other.kerberos_principal
+ and self.kerberos_realm == other.kerberos_realm
+ and self.kerberos_keytab == other.kerberos_keytab
+ )
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
+
+
+@experimental
+class KerberosPasswordCredentials(BaseKerberosCredentials):
+ def __init__(
+ self,
+ *,
+ kerberos_realm: str,
+ kerberos_kdc_address: str,
+ kerberos_principal: str,
+ kerberos_password: Optional[str],
+ **kwargs: Any,
+ ):
+ super().__init__(
+ kerberos_realm=kerberos_realm,
+ kerberos_kdc_address=kerberos_kdc_address,
+ kerberos_principal=kerberos_principal,
+ **kwargs,
+ )
+ self.type = model_preview.CredentialsType.KERBEROS_PASSWORD
+ self.kerberos_password = kerberos_password
+
+ def _to_rest_object(self) -> model_preview.KerberosPasswordCredentials:
+ secrets = model_preview.KerberosPasswordSecrets(kerberos_password=self.kerberos_password)
+ return model_preview.KerberosPasswordCredentials(
+ kerberos_kdc_address=self.kerberos_kdc_address,
+ kerberos_principal=self.kerberos_principal,
+ kerberos_realm=self.kerberos_realm,
+ secrets=secrets,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: model_preview.KerberosPasswordCredentials) -> "KerberosPasswordCredentials":
+ return cls(
+ kerberos_kdc_address=obj.kerberos_kdc_address,
+ kerberos_principal=obj.kerberos_principal,
+ kerberos_realm=obj.kerberos_realm,
+ kerberos_password=obj.secrets.kerberos_password if obj.secrets else None,
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, KerberosPasswordCredentials):
+ return NotImplemented
+ return (
+ self.kerberos_kdc_address == other.kerberos_kdc_address
+ and self.kerberos_principal == other.kerberos_principal
+ and self.kerberos_realm == other.kerberos_realm
+ and self.kerberos_password == other.kerberos_password
+ )
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/adls_gen1.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/adls_gen1.py
new file mode 100644
index 00000000..c2610703
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/adls_gen1.py
@@ -0,0 +1,106 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from pathlib import Path
+from typing import Any, Dict, Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ AzureDataLakeGen1Datastore as RestAzureDatalakeGen1Datastore,
+)
+from azure.ai.ml._restclient.v2023_04_01_preview.models import Datastore as DatastoreData
+from azure.ai.ml._restclient.v2023_04_01_preview.models import DatastoreType
+from azure.ai.ml._schema._datastore.adls_gen1 import AzureDataLakeGen1Schema
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE
+from azure.ai.ml.entities._credentials import CertificateConfiguration, ServicePrincipalConfiguration
+from azure.ai.ml.entities._datastore.datastore import Datastore
+from azure.ai.ml.entities._datastore.utils import from_rest_datastore_credentials
+from azure.ai.ml.entities._util import load_from_dict
+
+
+class AzureDataLakeGen1Datastore(Datastore):
+ """Azure Data Lake aka Gen 1 datastore that is linked to an Azure ML workspace.
+
+ :param name: Name of the datastore.
+ :type name: str
+ :param store_name: Name of the Azure storage resource.
+ :type store_name: str
+ :param description: Description of the resource.
+ :type description: str
+ :param tags: Tag dictionary. Tags can be added, removed, and updated.
+ :type tags: dict[str, str]
+ :param properties: The asset property dictionary.
+ :type properties: dict[str, str]
+ :param credentials: Credentials to use for Azure ML workspace to connect to the storage.
+ :type credentials: Union[ServicePrincipalSection, CertificateSection]
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: dict
+ """
+
+ def __init__(
+ self,
+ *,
+ name: str,
+ store_name: str,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ properties: Optional[Dict] = None,
+ credentials: Optional[Union[CertificateConfiguration, ServicePrincipalConfiguration]] = None,
+ **kwargs: Any
+ ):
+ kwargs[TYPE] = DatastoreType.AZURE_DATA_LAKE_GEN1
+ super().__init__(
+ name=name, description=description, tags=tags, properties=properties, credentials=credentials, **kwargs
+ )
+
+ self.store_name = store_name
+
+ def _to_rest_object(self) -> DatastoreData:
+ gen1_ds = RestAzureDatalakeGen1Datastore(
+ credentials=self.credentials._to_datastore_rest_object(),
+ store_name=self.store_name,
+ description=self.description,
+ tags=self.tags,
+ )
+ return DatastoreData(properties=gen1_ds)
+
+ @classmethod
+ def _load_from_dict(
+ cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any
+ ) -> "AzureDataLakeGen1Datastore":
+ res: AzureDataLakeGen1Datastore = load_from_dict(
+ AzureDataLakeGen1Schema, data, context, additional_message, **kwargs
+ )
+ return res
+
+ @classmethod
+ def _from_rest_object(cls, datastore_resource: DatastoreData) -> "AzureDataLakeGen1Datastore":
+ properties: RestAzureDatalakeGen1Datastore = datastore_resource.properties
+ return AzureDataLakeGen1Datastore(
+ id=datastore_resource.id,
+ name=datastore_resource.name,
+ store_name=properties.store_name,
+ credentials=from_rest_datastore_credentials(properties.credentials), # type: ignore[arg-type]
+ description=properties.description,
+ tags=properties.tags,
+ )
+
+ def __eq__(self, other: Any) -> bool:
+ res: bool = (
+ super().__eq__(other)
+ and self.name == other.name
+ and self.type == other.type
+ and self.store_name == other.store_name
+ and self.credentials == other.credentials
+ )
+ return res
+
+ def __ne__(self, other: Any) -> bool:
+ return not self.__eq__(other)
+
+ def _to_dict(self) -> Dict:
+ context = {BASE_PATH_CONTEXT_KEY: Path(".").parent}
+ res: dict = AzureDataLakeGen1Schema(context=context).dump(self)
+ return res
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/azure_storage.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/azure_storage.py
new file mode 100644
index 00000000..0fff1925
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/azure_storage.py
@@ -0,0 +1,337 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from pathlib import Path
+from typing import Any, Dict, Optional, Union
+
+from azure.ai.ml._azure_environments import _get_storage_endpoint_from_metadata
+from azure.ai.ml._restclient.v2023_04_01_preview.models import AzureBlobDatastore as RestAzureBlobDatastore
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ AzureDataLakeGen2Datastore as RestAzureDataLakeGen2Datastore,
+)
+from azure.ai.ml._restclient.v2023_04_01_preview.models import AzureFileDatastore as RestAzureFileDatastore
+from azure.ai.ml._restclient.v2023_04_01_preview.models import Datastore as DatastoreData
+from azure.ai.ml._restclient.v2023_04_01_preview.models import DatastoreType
+from azure.ai.ml._schema._datastore import AzureBlobSchema, AzureDataLakeGen2Schema, AzureFileSchema
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE
+from azure.ai.ml.entities._credentials import (
+ AccountKeyConfiguration,
+ CertificateConfiguration,
+ SasTokenConfiguration,
+ ServicePrincipalConfiguration,
+)
+from azure.ai.ml.entities._datastore.datastore import Datastore
+from azure.ai.ml.entities._datastore.utils import from_rest_datastore_credentials
+from azure.ai.ml.entities._util import load_from_dict
+
+from ._constants import HTTPS
+
+
+class AzureFileDatastore(Datastore):
+ """Azure file share that is linked to an Azure ML workspace.
+
+ :param name: Name of the datastore.
+ :type name: str
+ :param account_name: Name of the Azure storage account.
+ :type account_name: str
+ :param file_share_name: Name of the file share.
+ :type file_share_name: str
+ :param description: Description of the resource.
+ :type description: str
+ :param tags: Tag dictionary. Tags can be added, removed, and updated.
+ :type tags: dict[str, str]
+ :param endpoint: Endpoint to use to connect with the Azure storage account
+ :type endpoint: str
+ :param protocol: Protocol to use to connect with the Azure storage account
+ :type protocol: str
+ :param properties: The asset property dictionary.
+ :type properties: dict[str, str]
+ :param credentials: Credentials to use for Azure ML workspace to connect to the storage. Defaults to None.
+ :type credentials: Union[~azure.ai.ml.entities.AccountKeyConfiguration,
+ ~azure.ai.ml.entities.SasTokenConfiguration]
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: dict
+ """
+
+ def __init__(
+ self,
+ *,
+ name: str,
+ account_name: str,
+ file_share_name: str,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ endpoint: str = _get_storage_endpoint_from_metadata(),
+ protocol: str = HTTPS,
+ properties: Optional[Dict] = None,
+ credentials: Optional[Union[AccountKeyConfiguration, SasTokenConfiguration]] = None,
+ **kwargs: Any
+ ):
+ kwargs[TYPE] = DatastoreType.AZURE_FILE
+ super().__init__(
+ name=name, description=description, tags=tags, properties=properties, credentials=credentials, **kwargs
+ )
+ self.file_share_name = file_share_name
+ self.account_name = account_name
+ self.endpoint = endpoint
+ self.protocol = protocol
+
+ def _to_rest_object(self) -> DatastoreData:
+ file_ds = RestAzureFileDatastore(
+ account_name=self.account_name,
+ file_share_name=self.file_share_name,
+ credentials=self.credentials._to_datastore_rest_object(),
+ endpoint=self.endpoint,
+ protocol=self.protocol,
+ description=self.description,
+ tags=self.tags,
+ )
+ return DatastoreData(properties=file_ds)
+
+ @classmethod
+ def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "AzureFileDatastore":
+ res: AzureFileDatastore = load_from_dict(AzureFileSchema, data, context, additional_message)
+ return res
+
+ @classmethod
+ def _from_rest_object(cls, datastore_resource: DatastoreData) -> "AzureFileDatastore":
+ properties: RestAzureFileDatastore = datastore_resource.properties
+ return AzureFileDatastore(
+ name=datastore_resource.name,
+ id=datastore_resource.id,
+ account_name=properties.account_name,
+ credentials=from_rest_datastore_credentials(properties.credentials), # type: ignore[arg-type]
+ endpoint=properties.endpoint,
+ protocol=properties.protocol,
+ file_share_name=properties.file_share_name,
+ description=properties.description,
+ tags=properties.tags,
+ )
+
+ def __eq__(self, other: Any) -> bool:
+ res: bool = (
+ super().__eq__(other)
+ and self.file_share_name == other.file_share_name
+ and self.account_name == other.account_name
+ and self.endpoint == other.endpoint
+ and self.protocol == other.protocol
+ )
+ return res
+
+ def __ne__(self, other: Any) -> bool:
+ return not self.__eq__(other)
+
+ def _to_dict(self) -> Dict:
+ context = {BASE_PATH_CONTEXT_KEY: Path(".").parent}
+ res: dict = AzureFileSchema(context=context).dump(self)
+ return res
+
+
+class AzureBlobDatastore(Datastore):
+ """Azure blob storage that is linked to an Azure ML workspace.
+
+ :param name: Name of the datastore.
+ :type name: str
+ :param account_name: Name of the Azure storage account.
+ :type account_name: str
+ :param container_name: Name of the container.
+ :type container_name: str
+ :param description: Description of the resource.
+ :type description: str
+ :param tags: Tag dictionary. Tags can be added, removed, and updated.
+ :type tags: dict[str, str]
+ :param endpoint: Endpoint to use to connect with the Azure storage account.
+ :type endpoint: str
+ :param protocol: Protocol to use to connect with the Azure storage account.
+ :type protocol: str
+ :param properties: The asset property dictionary.
+ :type properties: dict[str, str]
+ :param credentials: Credentials to use for Azure ML workspace to connect to the storage.
+ :type credentials: Union[~azure.ai.ml.entities.AccountKeyConfiguration,
+ ~azure.ai.ml.entities.SasTokenConfiguration]
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: dict
+ """
+
+ def __init__(
+ self,
+ *,
+ name: str,
+ account_name: str,
+ container_name: str,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ endpoint: Optional[str] = None,
+ protocol: str = HTTPS,
+ properties: Optional[Dict] = None,
+ credentials: Optional[Union[AccountKeyConfiguration, SasTokenConfiguration]] = None,
+ **kwargs: Any
+ ):
+ kwargs[TYPE] = DatastoreType.AZURE_BLOB
+ super().__init__(
+ name=name, description=description, tags=tags, properties=properties, credentials=credentials, **kwargs
+ )
+
+ self.container_name = container_name
+ self.account_name = account_name
+ self.endpoint = endpoint if endpoint else _get_storage_endpoint_from_metadata()
+ self.protocol = protocol
+
+ def _to_rest_object(self) -> DatastoreData:
+ blob_ds = RestAzureBlobDatastore(
+ account_name=self.account_name,
+ container_name=self.container_name,
+ credentials=self.credentials._to_datastore_rest_object(),
+ endpoint=self.endpoint,
+ protocol=self.protocol,
+ tags=self.tags,
+ description=self.description,
+ )
+ return DatastoreData(properties=blob_ds)
+
+ @classmethod
+ def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "AzureBlobDatastore":
+ res: AzureBlobDatastore = load_from_dict(AzureBlobSchema, data, context, additional_message)
+ return res
+
+ @classmethod
+ def _from_rest_object(cls, datastore_resource: DatastoreData) -> "AzureBlobDatastore":
+ properties: RestAzureBlobDatastore = datastore_resource.properties
+ return AzureBlobDatastore(
+ name=datastore_resource.name,
+ id=datastore_resource.id,
+ account_name=properties.account_name,
+ credentials=from_rest_datastore_credentials(properties.credentials), # type: ignore[arg-type]
+ endpoint=properties.endpoint,
+ protocol=properties.protocol,
+ container_name=properties.container_name,
+ description=properties.description,
+ tags=properties.tags,
+ )
+
+ def __eq__(self, other: Any) -> bool:
+ res: bool = (
+ super().__eq__(other)
+ and self.container_name == other.container_name
+ and self.account_name == other.account_name
+ and self.endpoint == other.endpoint
+ and self.protocol == other.protocol
+ )
+ return res
+
+ def __ne__(self, other: Any) -> bool:
+ return not self.__eq__(other)
+
+ def _to_dict(self) -> Dict:
+ context = {BASE_PATH_CONTEXT_KEY: Path(".").parent}
+ res: dict = AzureBlobSchema(context=context).dump(self)
+ return res
+
+
+class AzureDataLakeGen2Datastore(Datastore):
+ """Azure data lake gen 2 that is linked to an Azure ML workspace.
+
+ :param name: Name of the datastore.
+ :type name: str
+ :param account_name: Name of the Azure storage account.
+ :type account_name: str
+ :param filesystem: The name of the Data Lake Gen2 filesystem.
+ :type filesystem: str
+ :param description: Description of the resource.
+ :type description: str
+ :param tags: Tag dictionary. Tags can be added, removed, and updated.
+ :type tags: dict[str, str]
+ :param endpoint: Endpoint to use to connect with the Azure storage account
+ :type endpoint: str
+ :param protocol: Protocol to use to connect with the Azure storage account
+ :type protocol: str
+ :param credentials: Credentials to use for Azure ML workspace to connect to the storage.
+ :type credentials: Union[
+ ~azure.ai.ml.entities.ServicePrincipalConfiguration,
+ ~azure.ai.ml.entities.CertificateConfiguration
+
+ ]
+ :param properties: The asset property dictionary.
+ :type properties: dict[str, str]
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: dict
+ """
+
+ def __init__(
+ self,
+ *,
+ name: str,
+ account_name: str,
+ filesystem: str,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ endpoint: str = _get_storage_endpoint_from_metadata(),
+ protocol: str = HTTPS,
+ properties: Optional[Dict] = None,
+ credentials: Optional[Union[ServicePrincipalConfiguration, CertificateConfiguration]] = None,
+ **kwargs: Any
+ ):
+ kwargs[TYPE] = DatastoreType.AZURE_DATA_LAKE_GEN2
+ super().__init__(
+ name=name, description=description, tags=tags, properties=properties, credentials=credentials, **kwargs
+ )
+
+ self.account_name = account_name
+ self.filesystem = filesystem
+ self.endpoint = endpoint
+ self.protocol = protocol
+
+ def _to_rest_object(self) -> DatastoreData:
+ gen2_ds = RestAzureDataLakeGen2Datastore(
+ account_name=self.account_name,
+ filesystem=self.filesystem,
+ credentials=self.credentials._to_datastore_rest_object(),
+ endpoint=self.endpoint,
+ protocol=self.protocol,
+ description=self.description,
+ tags=self.tags,
+ )
+ return DatastoreData(properties=gen2_ds)
+
+ @classmethod
+ def _load_from_dict(
+ cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any
+ ) -> "AzureDataLakeGen2Datastore":
+ res: AzureDataLakeGen2Datastore = load_from_dict(AzureDataLakeGen2Schema, data, context, additional_message)
+ return res
+
+ @classmethod
+ def _from_rest_object(cls, datastore_resource: DatastoreData) -> "AzureDataLakeGen2Datastore":
+ properties: RestAzureDataLakeGen2Datastore = datastore_resource.properties
+ return AzureDataLakeGen2Datastore(
+ name=datastore_resource.name,
+ id=datastore_resource.id,
+ account_name=properties.account_name,
+ credentials=from_rest_datastore_credentials(properties.credentials), # type: ignore[arg-type]
+ endpoint=properties.endpoint,
+ protocol=properties.protocol,
+ filesystem=properties.filesystem,
+ description=properties.description,
+ tags=properties.tags,
+ )
+
+ def __eq__(self, other: Any) -> bool:
+ res: bool = (
+ super().__eq__(other)
+ and self.filesystem == other.filesystem
+ and self.account_name == other.account_name
+ and self.endpoint == other.endpoint
+ and self.protocol == other.protocol
+ )
+ return res
+
+ def __ne__(self, other: Any) -> bool:
+ return not self.__eq__(other)
+
+ def _to_dict(self) -> Dict:
+ context = {BASE_PATH_CONTEXT_KEY: Path(".").parent}
+ res: dict = AzureDataLakeGen2Schema(context=context).dump(self)
+ return res
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/datastore.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/datastore.py
new file mode 100644
index 00000000..bc933cfb
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/datastore.py
@@ -0,0 +1,221 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access,redefined-builtin,arguments-renamed
+
+from abc import ABC, abstractmethod
+from os import PathLike
+from pathlib import Path
+from typing import IO, Any, AnyStr, Dict, Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import Datastore as DatastoreData
+from azure.ai.ml._restclient.v2023_04_01_preview.models import DatastoreType
+from azure.ai.ml._utils.utils import camel_to_snake, dump_yaml_to_file
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY, CommonYamlFields
+from azure.ai.ml.entities._credentials import (
+ AccountKeyConfiguration,
+ CertificateConfiguration,
+ NoneCredentialConfiguration,
+ SasTokenConfiguration,
+ ServicePrincipalConfiguration,
+)
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+from azure.ai.ml.entities._resource import Resource
+from azure.ai.ml.entities._util import find_type_in_override
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+
+
+class Datastore(Resource, RestTranslatableMixin, ABC):
+ """Datastore of an Azure ML workspace, abstract class.
+
+ :param name: Name of the datastore.
+ :type name: str
+ :param description: Description of the resource.
+ :type description: str
+ :param credentials: Credentials to use for Azure ML workspace to connect to the storage.
+ :type credentials: Optional[Union[
+ ~azure.ai.ml.entities.ServicePrincipalConfiguration,
+ ~azure.ai.ml.entities.CertificateConfiguration,
+ ~azure.ai.ml.entities.NoneCredentialConfiguration,
+ ~azure.ai.ml.entities.AccountKeyConfiguration,
+ ~azure.ai.ml.entities.SasTokenConfiguration
+
+ ]]
+ :param tags: Tag dictionary. Tags can be added, removed, and updated.
+ :type tags: dict[str, str]
+ :param properties: The asset property dictionary.
+ :type properties: dict[str, str]
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: dict
+ """
+
+ def __init__(
+ self,
+ credentials: Optional[
+ Union[
+ ServicePrincipalConfiguration,
+ CertificateConfiguration,
+ NoneCredentialConfiguration,
+ AccountKeyConfiguration,
+ SasTokenConfiguration,
+ ]
+ ],
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ properties: Optional[Dict] = None,
+ **kwargs: Any,
+ ):
+ self._type: str = kwargs.pop("type", None)
+ super().__init__(
+ name=name,
+ description=description,
+ tags=tags,
+ properties=properties,
+ **kwargs,
+ )
+
+ self.credentials = NoneCredentialConfiguration() if credentials is None else credentials
+
+ @property
+ def type(self) -> str:
+ return self._type
+
+ def dump(self, dest: Union[str, PathLike, IO[AnyStr]], **kwargs: Any) -> None:
+ """Dump the datastore content into a file in yaml format.
+
+ :param dest: The destination to receive this datastore's content.
+ Must be either a path to a local file, or an already-open file stream.
+ If dest is a file path, a new file will be created,
+ and an exception is raised if the file exists.
+ If dest is an open file, the file will be written to directly,
+ and an exception will be raised if the file is not writable.
+ :type dest: Union[PathLike, str, IO[AnyStr]]
+ """
+ yaml_serialized = self._to_dict()
+ dump_yaml_to_file(dest, yaml_serialized, default_flow_style=False, **kwargs)
+
+ @abstractmethod
+ def _to_dict(self) -> Dict:
+ pass
+
+ @classmethod
+ def _load(
+ cls,
+ data: Optional[Dict] = None,
+ yaml_path: Optional[Union[PathLike, str]] = None,
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "Datastore":
+ data = data or {}
+ params_override = params_override or []
+
+ context = {
+ BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"),
+ PARAMS_OVERRIDE_KEY: params_override,
+ }
+
+ from azure.ai.ml.entities import (
+ AzureBlobDatastore,
+ AzureDataLakeGen1Datastore,
+ AzureDataLakeGen2Datastore,
+ AzureFileDatastore,
+ OneLakeDatastore,
+ )
+
+ # from azure.ai.ml.entities._datastore._on_prem import (
+ # HdfsDatastore
+ # )
+
+ ds_type: Any = None
+ type_in_override = find_type_in_override(params_override)
+ type = type_in_override or data.get(
+ CommonYamlFields.TYPE, DatastoreType.AZURE_BLOB
+ ) # override takes the priority
+
+ # yaml expects snake casing, while service side constants are camel casing
+ if type == camel_to_snake(DatastoreType.AZURE_BLOB):
+ ds_type = AzureBlobDatastore
+ elif type == camel_to_snake(DatastoreType.AZURE_FILE):
+ ds_type = AzureFileDatastore
+ elif type == camel_to_snake(DatastoreType.AZURE_DATA_LAKE_GEN1):
+ ds_type = AzureDataLakeGen1Datastore
+ elif type == camel_to_snake(DatastoreType.AZURE_DATA_LAKE_GEN2):
+ ds_type = AzureDataLakeGen2Datastore
+ elif type == camel_to_snake(DatastoreType.ONE_LAKE):
+ ds_type = OneLakeDatastore
+ # disable unless preview release
+ # elif type == camel_to_snake(DatastoreTypePreview.HDFS):
+ # ds_type = HdfsDatastore
+ else:
+ msg = f"Unsupported datastore type: {type}."
+ raise ValidationException(
+ message=msg,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ target=ErrorTarget.DATASTORE,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ res: Datastore = ds_type._load_from_dict(
+ data=data,
+ context=context,
+ additional_message="If the datastore type is incorrect, change the 'type' property.",
+ **kwargs,
+ )
+ return res
+
+ @classmethod
+ def _from_rest_object(cls, datastore_resource: DatastoreData) -> "Datastore":
+ from azure.ai.ml.entities import (
+ AzureBlobDatastore,
+ AzureDataLakeGen1Datastore,
+ AzureDataLakeGen2Datastore,
+ AzureFileDatastore,
+ OneLakeDatastore,
+ )
+
+ # from azure.ai.ml.entities._datastore._on_prem import (
+ # HdfsDatastore
+ # )
+
+ datastore_type = datastore_resource.properties.datastore_type
+ if datastore_type == DatastoreType.AZURE_DATA_LAKE_GEN1:
+ res_adl_gen1: Datastore = AzureDataLakeGen1Datastore._from_rest_object(datastore_resource)
+ return res_adl_gen1
+ if datastore_type == DatastoreType.AZURE_DATA_LAKE_GEN2:
+ res_adl_gen2: Datastore = AzureDataLakeGen2Datastore._from_rest_object(datastore_resource)
+ return res_adl_gen2
+ if datastore_type == DatastoreType.AZURE_BLOB:
+ res_abd: Datastore = AzureBlobDatastore._from_rest_object(datastore_resource)
+ return res_abd
+ if datastore_type == DatastoreType.AZURE_FILE:
+ res_afd: Datastore = AzureFileDatastore._from_rest_object(datastore_resource)
+ return res_afd
+ if datastore_type == DatastoreType.ONE_LAKE:
+ res_old: Datastore = OneLakeDatastore._from_rest_object(datastore_resource)
+ return res_old
+ # disable unless preview release
+ # elif datastore_type == DatastoreTypePreview.HDFS:
+ # return HdfsDatastore._from_rest_object(datastore_resource)
+ msg = f"Unsupported datastore type {datastore_resource.properties.contents.type}"
+ raise ValidationException(
+ message=msg,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ target=ErrorTarget.DATASTORE,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.SYSTEM_ERROR,
+ )
+
+ @classmethod
+ @abstractmethod
+ def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "Datastore":
+ pass
+
+ def __eq__(self, other: Any) -> bool:
+ res: bool = self.name == other.name and self.type == other.type and self.credentials == other.credentials
+ return res
+
+ def __ne__(self, other: Any) -> bool:
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/one_lake.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/one_lake.py
new file mode 100644
index 00000000..9bc06d92
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/one_lake.py
@@ -0,0 +1,153 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from abc import ABC
+from pathlib import Path
+from typing import Any, Dict, Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import Datastore as DatastoreData
+from azure.ai.ml._restclient.v2023_04_01_preview.models import DatastoreType
+from azure.ai.ml._restclient.v2023_04_01_preview.models import LakeHouseArtifact as RestLakeHouseArtifact
+from azure.ai.ml._restclient.v2023_04_01_preview.models import NoneDatastoreCredentials as RestNoneDatastoreCredentials
+from azure.ai.ml._restclient.v2023_04_01_preview.models import OneLakeDatastore as RestOneLakeDatastore
+from azure.ai.ml._schema._datastore.one_lake import OneLakeSchema
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE
+from azure.ai.ml.entities._credentials import NoneCredentialConfiguration, ServicePrincipalConfiguration
+from azure.ai.ml.entities._datastore.datastore import Datastore
+from azure.ai.ml.entities._datastore.utils import from_rest_datastore_credentials
+from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin
+from azure.ai.ml.entities._util import load_from_dict
+
+
+@experimental
+class OneLakeArtifact(RestTranslatableMixin, DictMixin, ABC):
+ """OneLake artifact (data source) backing the OneLake workspace.
+
+ :param name: OneLake artifact name/GUID. ex) 01234567-abcd-1234-5678-012345678901
+ :type name: str
+ :param type: OneLake artifact type. Only LakeHouse artifacts are currently supported.
+ :type type: str
+ """
+
+ def __init__(self, *, name: str, type: Optional[str] = None):
+ super().__init__()
+ self.name = name
+ self.type = type
+
+
+@experimental
+class LakeHouseArtifact(OneLakeArtifact):
+ """LakeHouse artifact type for OneLake.
+
+ :param artifact_name: OneLake LakeHouse artifact name/GUID. ex) 01234567-abcd-1234-5678-012345678901
+ :type artifact_name: str
+ """
+
+ def __init__(self, *, name: str):
+ super(LakeHouseArtifact, self).__init__(name=name, type="lake_house")
+
+ def _to_datastore_rest_object(self) -> RestLakeHouseArtifact:
+ return RestLakeHouseArtifact(artifact_name=self.name)
+
+
+@experimental
+class OneLakeDatastore(Datastore):
+ """OneLake datastore that is linked to an Azure ML workspace.
+
+ :param name: Name of the datastore.
+ :type name: str
+ :param artifact: OneLake Artifact. Only LakeHouse artifacts are currently supported.
+ :type artifact: ~azure.ai.ml.entities.OneLakeArtifact
+ :param one_lake_workspace_name: OneLake workspace name/GUID. ex) 01234567-abcd-1234-5678-012345678901
+ :type one_lake_workspace_name: str
+ :param endpoint: OneLake endpoint to use for the datastore. ex) https://onelake.dfs.fabric.microsoft.com
+ :type endpoint: str
+ :param description: Description of the resource.
+ :type description: str
+ :param tags: Tag dictionary. Tags can be added, removed, and updated.
+ :type tags: dict[str, str]
+ :param properties: The asset property dictionary.
+ :type properties: dict[str, str]
+ :param credentials: Credentials to use to authenticate against OneLake.
+ :type credentials: Union[
+ ~azure.ai.ml.entities.ServicePrincipalConfiguration, ~azure.ai.ml.entities.NoneCredentialConfiguration]
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: dict
+ """
+
+ def __init__(
+ self,
+ *,
+ name: str,
+ artifact: OneLakeArtifact,
+ one_lake_workspace_name: str,
+ endpoint: Optional[str] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ properties: Optional[Dict] = None,
+ credentials: Optional[Union[NoneCredentialConfiguration, ServicePrincipalConfiguration]] = None,
+ **kwargs: Any
+ ):
+ kwargs[TYPE] = DatastoreType.ONE_LAKE
+ super().__init__(
+ name=name, description=description, tags=tags, properties=properties, credentials=credentials, **kwargs
+ )
+ self.artifact = artifact
+ self.one_lake_workspace_name = one_lake_workspace_name
+ self.endpoint = endpoint
+
+ def _to_rest_object(self) -> DatastoreData:
+ one_lake_ds = RestOneLakeDatastore(
+ credentials=(
+ RestNoneDatastoreCredentials()
+ if self.credentials is None
+ else self.credentials._to_datastore_rest_object()
+ ),
+ artifact=RestLakeHouseArtifact(artifact_name=self.artifact["name"]),
+ one_lake_workspace_name=self.one_lake_workspace_name,
+ endpoint=self.endpoint,
+ description=self.description,
+ tags=self.tags,
+ )
+ return DatastoreData(properties=one_lake_ds)
+
+ @classmethod
+ def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "OneLakeDatastore":
+ res: OneLakeDatastore = load_from_dict(OneLakeSchema, data, context, additional_message, **kwargs)
+ return res
+
+ @classmethod
+ def _from_rest_object(cls, datastore_resource: DatastoreData) -> "OneLakeDatastore":
+ properties: RestOneLakeDatastore = datastore_resource.properties
+ return OneLakeDatastore(
+ name=datastore_resource.name,
+ id=datastore_resource.id,
+ artifact=LakeHouseArtifact(name=properties.artifact.artifact_name),
+ one_lake_workspace_name=properties.one_lake_workspace_name,
+ endpoint=properties.endpoint,
+ credentials=from_rest_datastore_credentials(properties.credentials), # type: ignore[arg-type]
+ description=properties.description,
+ tags=properties.tags,
+ )
+
+ def __eq__(self, other: Any) -> bool:
+ res: bool = (
+ super().__eq__(other)
+ and self.one_lake_workspace_name == other.one_lake_workspace_name
+ and self.artifact.type == other.artifact["type"]
+ and self.artifact.name == other.artifact["name"]
+ and self.endpoint == other.endpoint
+ )
+ return res
+
+ def __ne__(self, other: Any) -> bool:
+ return not self.__eq__(other)
+
+ def _to_dict(self) -> Dict:
+ context = {BASE_PATH_CONTEXT_KEY: Path(".").parent}
+ res: dict = OneLakeSchema(context=context).dump(self)
+ return res
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/utils.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/utils.py
new file mode 100644
index 00000000..538f9590
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_datastore/utils.py
@@ -0,0 +1,70 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from typing import Any, Optional, Union, cast
+
+from azure.ai.ml._restclient.v2023_04_01_preview import models
+from azure.ai.ml._restclient.v2024_07_01_preview import models as models2024
+from azure.ai.ml.entities._credentials import (
+ AccountKeyConfiguration,
+ CertificateConfiguration,
+ NoneCredentialConfiguration,
+ SasTokenConfiguration,
+ ServicePrincipalConfiguration,
+)
+from azure.ai.ml.entities._datastore._on_prem_credentials import KerberosKeytabCredentials, KerberosPasswordCredentials
+
+
+def from_rest_datastore_credentials(
+ rest_credentials: models.DatastoreCredentials,
+) -> Union[
+ AccountKeyConfiguration,
+ SasTokenConfiguration,
+ ServicePrincipalConfiguration,
+ CertificateConfiguration,
+ NoneCredentialConfiguration,
+]:
+ config_class: Any = NoneCredentialConfiguration
+
+ if isinstance(rest_credentials, (models.AccountKeyDatastoreCredentials, models2024.AccountKeyDatastoreCredentials)):
+ # we are no more using key for key base account.
+ # https://github.com/Azure/azure-sdk-for-python/pull/35716
+ if isinstance(rest_credentials.secrets, models2024.SasDatastoreSecrets):
+ config_class = SasTokenConfiguration
+ else:
+ config_class = AccountKeyConfiguration
+ elif isinstance(rest_credentials, (models.SasDatastoreCredentials, models2024.SasDatastoreCredentials)):
+ config_class = SasTokenConfiguration
+ elif isinstance(
+ rest_credentials, (models.ServicePrincipalDatastoreCredentials, models2024.ServicePrincipalDatastoreCredentials)
+ ):
+ config_class = ServicePrincipalConfiguration
+ elif isinstance(
+ rest_credentials, (models.CertificateDatastoreCredentials, models2024.CertificateDatastoreCredentials)
+ ):
+ config_class = CertificateConfiguration
+
+ return cast(
+ Union[
+ AccountKeyConfiguration,
+ SasTokenConfiguration,
+ ServicePrincipalConfiguration,
+ CertificateConfiguration,
+ NoneCredentialConfiguration,
+ ],
+ config_class._from_datastore_rest_object(rest_credentials),
+ )
+
+
+def _from_rest_datastore_credentials_preview(
+ rest_credentials: models.DatastoreCredentials,
+) -> Optional[Union[KerberosKeytabCredentials, KerberosPasswordCredentials]]:
+ if isinstance(rest_credentials, models.KerberosKeytabCredentials):
+ return KerberosKeytabCredentials._from_rest_object(rest_credentials)
+ if isinstance(rest_credentials, models.KerberosPasswordCredentials):
+ return KerberosPasswordCredentials._from_rest_object(rest_credentials)
+
+ return None
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/__init__.py
new file mode 100644
index 00000000..fdf8caba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/batch_deployment.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/batch_deployment.py
new file mode 100644
index 00000000..59b23eb8
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/batch_deployment.py
@@ -0,0 +1,356 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import logging
+from os import PathLike
+from pathlib import Path
+from typing import Any, Dict, Optional, Union
+
+from azure.ai.ml._restclient.v2024_01_01_preview.models import BatchDeployment as BatchDeploymentData
+from azure.ai.ml._restclient.v2024_01_01_preview.models import BatchDeploymentProperties as RestBatchDeployment
+from azure.ai.ml._restclient.v2024_01_01_preview.models import BatchOutputAction
+from azure.ai.ml._restclient.v2024_01_01_preview.models import CodeConfiguration as RestCodeConfiguration
+from azure.ai.ml._restclient.v2024_01_01_preview.models import IdAssetReference
+from azure.ai.ml._schema._deployment.batch.batch_deployment import BatchDeploymentSchema
+from azure.ai.ml._utils._arm_id_utils import _parse_endpoint_name_from_deployment_id
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY
+from azure.ai.ml.constants._deployment import BatchDeploymentOutputAction
+from azure.ai.ml.entities._assets import Environment, Model
+from azure.ai.ml.entities._deployment.deployment_settings import BatchRetrySettings
+from azure.ai.ml.entities._job.resource_configuration import ResourceConfiguration
+from azure.ai.ml.entities._system_data import SystemData
+from azure.ai.ml.entities._util import load_from_dict
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+
+from .code_configuration import CodeConfiguration
+from .deployment import Deployment
+
+module_logger = logging.getLogger(__name__)
+
+
+class BatchDeployment(Deployment):
+ """Batch endpoint deployment entity.
+
+ :param name: the name of the batch deployment
+ :type name: str
+ :param description: Description of the resource.
+ :type description: str
+ :param tags: Tag dictionary. Tags can be added, removed, and updated.
+ :type tags: dict[str, str]
+ :param properties: The asset property dictionary.
+ :type properties: dict[str, str]
+ :param model: Model entity for the endpoint deployment, defaults to None
+ :type model: Union[str, Model]
+ :param code_configuration: defaults to None
+ :type code_configuration: CodeConfiguration
+ :param environment: Environment entity for the endpoint deployment., defaults to None
+ :type environment: Union[str, Environment]
+ :param compute: Compute target for batch inference operation.
+ :type compute: str
+ :param output_action: Indicates how the output will be organized. Possible values include:
+ "summary_only", "append_row". Defaults to "append_row"
+ :type output_action: str or ~azure.ai.ml.constants._deployment.BatchDeploymentOutputAction
+ :param output_file_name: Customized output file name for append_row output action, defaults to "predictions.csv"
+ :type output_file_name: str
+ :param max_concurrency_per_instance: Indicates maximum number of parallelism per instance, defaults to 1
+ :type max_concurrency_per_instance: int
+ :param error_threshold: Error threshold, if the error count for the entire input goes above
+ this value,
+ the batch inference will be aborted. Range is [-1, int.MaxValue]
+ -1 value indicates, ignore all failures during batch inference
+ For FileDataset count of file failures
+ For TabularDataset, this is the count of record failures, defaults to -1
+ :type error_threshold: int
+ :param retry_settings: Retry settings for a batch inference operation, defaults to None
+ :type retry_settings: BatchRetrySettings
+ :param resources: Indicates compute configuration for the job.
+ :type resources: ~azure.mgmt.machinelearningservices.models.ResourceConfiguration
+ :param logging_level: Logging level for batch inference operation, defaults to "info"
+ :type logging_level: str
+ :param mini_batch_size: Size of the mini-batch passed to each batch invocation, defaults to 10
+ :type mini_batch_size: int
+ :param environment_variables: Environment variables that will be set in deployment.
+ :type environment_variables: dict
+ :param code_path: Folder path to local code assets. Equivalent to code_configuration.code.
+ :type code_path: Union[str, PathLike]
+ :param scoring_script: Scoring script name. Equivalent to code_configuration.code.scoring_script.
+ :type scoring_script: Union[str, PathLike]
+ :param instance_count: Number of instances the interfering will run on. Equivalent to resources.instance_count.
+ :type instance_count: int
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if BatchDeployment cannot be successfully validated.
+ Details will be provided in the error message.
+ """
+
+ def __init__(
+ self,
+ *,
+ name: str,
+ endpoint_name: Optional[str] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict[str, Any]] = None,
+ properties: Optional[Dict[str, str]] = None,
+ model: Optional[Union[str, Model]] = None,
+ code_configuration: Optional[CodeConfiguration] = None,
+ environment: Optional[Union[str, Environment]] = None,
+ compute: Optional[str] = None,
+ resources: Optional[ResourceConfiguration] = None,
+ output_file_name: Optional[str] = None,
+ output_action: Optional[Union[BatchDeploymentOutputAction, str]] = None,
+ error_threshold: Optional[int] = None,
+ retry_settings: Optional[BatchRetrySettings] = None,
+ logging_level: Optional[str] = None,
+ mini_batch_size: Optional[int] = None,
+ max_concurrency_per_instance: Optional[int] = None,
+ environment_variables: Optional[Dict[str, str]] = None,
+ code_path: Optional[Union[str, PathLike]] = None, # promoted property from code_configuration.code
+ scoring_script: Optional[
+ Union[str, PathLike]
+ ] = None, # promoted property from code_configuration.scoring_script
+ instance_count: Optional[int] = None, # promoted property from resources.instance_count
+ **kwargs: Any,
+ ) -> None:
+ self._provisioning_state: Optional[str] = kwargs.pop("provisioning_state", None)
+
+ super(BatchDeployment, self).__init__(
+ name=name,
+ endpoint_name=endpoint_name,
+ properties=properties,
+ tags=tags,
+ description=description,
+ model=model,
+ code_configuration=code_configuration,
+ environment=environment,
+ environment_variables=environment_variables,
+ code_path=code_path,
+ scoring_script=scoring_script,
+ **kwargs,
+ )
+
+ self.compute = compute
+ self.resources = resources
+ self.output_action = output_action
+ self.output_file_name = output_file_name
+ self.error_threshold = error_threshold
+ self.retry_settings = retry_settings
+ self.logging_level = logging_level
+ self.mini_batch_size = mini_batch_size
+ self.max_concurrency_per_instance = max_concurrency_per_instance
+
+ if self.resources and instance_count:
+ msg = "Can't set instance_count when resources is provided."
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.BATCH_DEPLOYMENT,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+ if not self.resources and instance_count:
+ self.resources = ResourceConfiguration(instance_count=instance_count)
+
+ @property
+ def instance_count(self) -> Optional[int]:
+ return self.resources.instance_count if self.resources else None
+
+ @instance_count.setter
+ def instance_count(self, value: int) -> None:
+ if not self.resources:
+ self.resources = ResourceConfiguration()
+
+ self.resources.instance_count = value
+
+ @property
+ def provisioning_state(self) -> Optional[str]:
+ """Batch deployment provisioning state, readonly.
+
+ :return: Batch deployment provisioning state.
+ :rtype: Optional[str]
+ """
+ return self._provisioning_state
+
+ def _to_dict(self) -> Dict:
+ res: dict = BatchDeploymentSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
+
+ @classmethod
+ def _rest_output_action_to_yaml_output_action(cls, rest_output_action: str) -> str:
+ output_switcher = {
+ BatchOutputAction.APPEND_ROW: BatchDeploymentOutputAction.APPEND_ROW,
+ BatchOutputAction.SUMMARY_ONLY: BatchDeploymentOutputAction.SUMMARY_ONLY,
+ }
+
+ return output_switcher.get(rest_output_action, rest_output_action)
+
+ @classmethod
+ def _yaml_output_action_to_rest_output_action(cls, yaml_output_action: Any) -> str:
+ output_switcher = {
+ BatchDeploymentOutputAction.APPEND_ROW: BatchOutputAction.APPEND_ROW,
+ BatchDeploymentOutputAction.SUMMARY_ONLY: BatchOutputAction.SUMMARY_ONLY,
+ }
+
+ return output_switcher.get(yaml_output_action, yaml_output_action)
+
+ # pylint: disable=arguments-differ
+ def _to_rest_object(self, location: str) -> BatchDeploymentData: # type: ignore
+ self._validate()
+ code_config = (
+ RestCodeConfiguration(
+ code_id=self.code_configuration.code,
+ scoring_script=self.code_configuration.scoring_script,
+ )
+ if self.code_configuration
+ else None
+ )
+ model = IdAssetReference(asset_id=self.model) if self.model else None
+ environment = self.environment
+
+ batch_deployment: RestBatchDeployment = None
+ if isinstance(self.output_action, str):
+ batch_deployment = RestBatchDeployment(
+ compute=self.compute,
+ description=self.description,
+ resources=self.resources._to_rest_object() if self.resources else None,
+ code_configuration=code_config,
+ environment_id=environment,
+ model=model,
+ output_file_name=self.output_file_name,
+ output_action=BatchDeployment._yaml_output_action_to_rest_output_action(self.output_action),
+ error_threshold=self.error_threshold,
+ retry_settings=self.retry_settings._to_rest_object() if self.retry_settings else None,
+ logging_level=self.logging_level,
+ mini_batch_size=self.mini_batch_size,
+ max_concurrency_per_instance=self.max_concurrency_per_instance,
+ environment_variables=self.environment_variables,
+ properties=self.properties,
+ )
+ else:
+ batch_deployment = RestBatchDeployment(
+ compute=self.compute,
+ description=self.description,
+ resources=self.resources._to_rest_object() if self.resources else None,
+ code_configuration=code_config,
+ environment_id=environment,
+ model=model,
+ output_file_name=self.output_file_name,
+ output_action=None,
+ error_threshold=self.error_threshold,
+ retry_settings=self.retry_settings._to_rest_object() if self.retry_settings else None,
+ logging_level=self.logging_level,
+ mini_batch_size=self.mini_batch_size,
+ max_concurrency_per_instance=self.max_concurrency_per_instance,
+ environment_variables=self.environment_variables,
+ properties=self.properties,
+ )
+
+ return BatchDeploymentData(location=location, properties=batch_deployment, tags=self.tags)
+
+ @classmethod
+ def _from_rest_object( # pylint: disable=arguments-renamed
+ cls, deployment: BatchDeploymentData
+ ) -> BatchDeploymentData:
+ modelId = deployment.properties.model.asset_id if deployment.properties.model else None
+
+ if (
+ hasattr(deployment.properties, "deployment_configuration")
+ and deployment.properties.deployment_configuration is not None
+ ):
+ settings = deployment.properties.deployment_configuration.settings
+ deployment_comp_settings = {
+ "deployment_configuration_type": deployment.properties.deployment_configuration.deployment_configuration_type, # pylint: disable=line-too-long
+ "componentDeployment.Settings.continue_on_step_failure": settings.get(
+ "ComponentDeployment.Settings.continue_on_step_failure", None
+ ),
+ "default_datastore": settings.get("default_datastore", None),
+ "default_compute": settings.get("default_compute", None),
+ }
+ properties = {}
+ if deployment.properties.properties:
+ properties.update(deployment.properties.properties)
+ properties.update(deployment_comp_settings)
+ else:
+ properties = deployment.properties.properties
+
+ code_configuration = (
+ CodeConfiguration._from_rest_code_configuration(deployment.properties.code_configuration)
+ if deployment.properties.code_configuration
+ else None
+ )
+ deployment = BatchDeployment(
+ name=deployment.name,
+ description=deployment.properties.description,
+ id=deployment.id,
+ tags=deployment.tags,
+ model=modelId,
+ environment=deployment.properties.environment_id,
+ code_configuration=code_configuration,
+ output_file_name=(
+ deployment.properties.output_file_name
+ if cls._rest_output_action_to_yaml_output_action(deployment.properties.output_action)
+ == BatchDeploymentOutputAction.APPEND_ROW
+ else None
+ ),
+ output_action=cls._rest_output_action_to_yaml_output_action(deployment.properties.output_action),
+ error_threshold=deployment.properties.error_threshold,
+ retry_settings=BatchRetrySettings._from_rest_object(deployment.properties.retry_settings),
+ logging_level=deployment.properties.logging_level,
+ mini_batch_size=deployment.properties.mini_batch_size,
+ compute=deployment.properties.compute,
+ resources=ResourceConfiguration._from_rest_object(deployment.properties.resources),
+ environment_variables=deployment.properties.environment_variables,
+ max_concurrency_per_instance=deployment.properties.max_concurrency_per_instance,
+ endpoint_name=_parse_endpoint_name_from_deployment_id(deployment.id),
+ properties=properties,
+ creation_context=SystemData._from_rest_object(deployment.system_data),
+ provisioning_state=deployment.properties.provisioning_state,
+ )
+
+ return deployment
+
+ @classmethod
+ def _load(
+ cls,
+ data: Optional[Dict] = None,
+ yaml_path: Optional[Union[PathLike, str]] = None,
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "BatchDeployment":
+ data = data or {}
+ params_override = params_override or []
+ cls._update_params(params_override)
+
+ context = {
+ BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path.cwd(),
+ PARAMS_OVERRIDE_KEY: params_override,
+ }
+ res: BatchDeployment = load_from_dict(BatchDeploymentSchema, data, context, **kwargs)
+ return res
+
+ def _validate(self) -> None:
+ self._validate_output_action()
+
+ @classmethod
+ def _update_params(cls, params_override: Any) -> None:
+ for param in params_override:
+ endpoint_name = param.get("endpoint_name")
+ if isinstance(endpoint_name, str):
+ param["endpoint_name"] = endpoint_name.lower()
+
+ def _validate_output_action(self) -> None:
+ if (
+ self.output_action
+ and self.output_action == BatchDeploymentOutputAction.SUMMARY_ONLY
+ and self.output_file_name
+ ):
+ msg = "When output_action is set to {}, the output_file_name need not to be specified."
+ msg = msg.format(BatchDeploymentOutputAction.SUMMARY_ONLY)
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.BATCH_DEPLOYMENT,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/batch_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/batch_job.py
new file mode 100644
index 00000000..c078f479
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/batch_job.py
@@ -0,0 +1,38 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Any, Dict
+
+from azure.ai.ml._restclient.v2020_09_01_dataplanepreview.models import BatchJobResource
+
+
+class BatchJob(object):
+ """Batch jobs that are created with batch deployments/endpoints invocation.
+
+ This class shouldn't be instantiated directly. Instead, it is used as the return type of batch deployment/endpoint
+ invocation and job listing.
+ """
+
+ def __init__(self, **kwargs: Any):
+ self.id = kwargs.get("id", None)
+ self.name = kwargs.get("name", None)
+ self.type = kwargs.get("type", None)
+ self.status = kwargs.get("status", None)
+
+ def _to_dict(self) -> Dict:
+ return {
+ "id": self.id,
+ "name": self.name,
+ "type": self.type,
+ "status": self.status,
+ }
+
+ @classmethod
+ def _from_rest_object(cls, obj: BatchJobResource) -> "BatchJob":
+ return cls(
+ id=obj.id,
+ name=obj.name,
+ type=obj.type,
+ status=obj.properties.status,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/code_configuration.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/code_configuration.py
new file mode 100644
index 00000000..cbae647d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/code_configuration.py
@@ -0,0 +1,93 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import logging
+import os
+from typing import Optional, Union
+
+from azure.ai.ml._restclient.v2022_05_01.models import CodeConfiguration as RestCodeConfiguration
+from azure.ai.ml.entities._assets import Code
+from azure.ai.ml.entities._mixins import DictMixin
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+
+module_logger = logging.getLogger(__name__)
+
+
+class CodeConfiguration(DictMixin):
+ """Code configuration for a scoring job.
+
+ :param code: The code directory containing the scoring script. The code can be an Code object, an ARM resource ID
+ of an existing code asset, a local path, or "http:", "https:", or "azureml:" url pointing to a remote location.
+ :type code: Optional[Union[~azure.ai.ml.entities.Code, str]]
+ :param scoring_script: The scoring script file path relative to the code directory.
+ :type scoring_script: Optional[str]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START code_configuration]
+ :end-before: [END code_configuration]
+ :language: python
+ :dedent: 8
+ :caption: Creating a CodeConfiguration for a BatchDeployment.
+ """
+
+ def __init__(
+ self,
+ code: Optional[Union[str, os.PathLike]] = None,
+ scoring_script: Optional[Union[str, os.PathLike]] = None,
+ ) -> None:
+ self.code: Optional[Union[str, os.PathLike]] = code
+ self._scoring_script: Optional[Union[str, os.PathLike]] = scoring_script
+
+ @property
+ def scoring_script(self) -> Optional[Union[str, os.PathLike]]:
+ """The scoring script file path relative to the code directory.
+
+ :rtype: str
+ """
+ return self._scoring_script
+
+ def _to_rest_code_configuration(self) -> RestCodeConfiguration:
+ return RestCodeConfiguration(code_id=self.code, scoring_script=self.scoring_script)
+
+ def _validate(self) -> None:
+ if self.code and not self.scoring_script:
+ msg = "scoring script can't be empty"
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.CODE,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.MISSING_FIELD,
+ )
+
+ @staticmethod
+ def _from_rest_code_configuration(code_configuration: RestCodeConfiguration) -> Optional["CodeConfiguration"]:
+ if code_configuration:
+ return CodeConfiguration(
+ code=code_configuration.code_id,
+ scoring_script=code_configuration.scoring_script,
+ )
+ return None
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, CodeConfiguration):
+ return NotImplemented
+ if not other:
+ return False
+ # only compare mutable fields
+ return (
+ self.scoring_script == other.scoring_script
+ and (
+ isinstance(self.code, Code)
+ and isinstance(other.code, Code)
+ or isinstance(self.code, str)
+ and isinstance(other.code, str)
+ )
+ and self.code == other.code
+ )
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/container_resource_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/container_resource_settings.py
new file mode 100644
index 00000000..0d0bc15d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/container_resource_settings.py
@@ -0,0 +1,74 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=arguments-renamed
+
+import logging
+from typing import Optional
+
+from azure.ai.ml._restclient.v2022_05_01.models import ContainerResourceSettings
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+module_logger = logging.getLogger(__name__)
+
+
+class ResourceSettings(RestTranslatableMixin):
+ """Resource settings for a container.
+
+ This class uses Kubernetes Resource unit formats. For more information, see
+ https://kubernetes.io/docs/concepts/configuration/manage-resources-containers/.
+
+ :param cpu: The CPU resource settings for a container.
+ :type cpu: Optional[str]
+ :param memory: The memory resource settings for a container.
+ :type memory: Optional[str]
+ :param gpu: The GPU resource settings for a container.
+ :type gpu: Optional[str]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START resource_requirements_configuration]
+ :end-before: [END resource_requirements_configuration]
+ :language: python
+ :dedent: 8
+ :caption: Configuring ResourceSettings for a Kubernetes deployment.
+ """
+
+ def __init__(self, cpu: Optional[str] = None, memory: Optional[str] = None, gpu: Optional[str] = None) -> None:
+ self.cpu = cpu
+ self.memory = memory
+ self.gpu = gpu
+
+ def _to_rest_object(self) -> ContainerResourceSettings:
+ return ContainerResourceSettings(cpu=self.cpu, memory=self.memory, gpu=self.gpu)
+
+ @classmethod
+ def _from_rest_object(cls, settings: ContainerResourceSettings) -> Optional["ResourceSettings"]:
+ return (
+ ResourceSettings(
+ cpu=settings.cpu,
+ memory=settings.memory,
+ gpu=settings.gpu,
+ )
+ if settings
+ else None
+ )
+
+ def _merge_with(self, other: Optional["ResourceSettings"]) -> None:
+ if other:
+ self.cpu = other.cpu or self.cpu
+ self.memory = other.memory or self.memory
+ self.gpu = other.gpu or self.gpu
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, ResourceSettings):
+ return NotImplemented
+ if not other:
+ return False
+ # only compare mutable fields
+ return self.cpu == other.cpu and self.memory == other.memory and self.gpu == other.gpu
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/data_asset.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/data_asset.py
new file mode 100644
index 00000000..72d24131
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/data_asset.py
@@ -0,0 +1,38 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Dict, Optional
+
+from azure.ai.ml._schema._deployment.online.data_asset_schema import DataAssetSchema
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+
+
+@experimental
+class DataAsset:
+ """Data Asset entity
+
+ :keyword Optional[str] data_id: Arm id of registered data asset
+ :keyword Optional[str] name: Name of data asset
+ :keyword Optional[str] path: Path where the data asset is stored.
+ :keyword Optional[int] version: Version of data asset.
+ """
+
+ def __init__(
+ self,
+ *,
+ data_id: Optional[str] = None,
+ name: Optional[str] = None,
+ path: Optional[str] = None,
+ version: Optional[int] = None,
+ ):
+ self.data_id = data_id
+ self.name = name
+ self.path = path
+ self.version = version
+
+ def _to_dict(self) -> Dict:
+ # pylint: disable=no-member
+ res: dict = DataAssetSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/data_collector.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/data_collector.py
new file mode 100644
index 00000000..74277c61
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/data_collector.py
@@ -0,0 +1,84 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+# pylint: disable=protected-access
+
+from typing import Any, Dict, Optional
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import DataCollector as RestDataCollector
+from azure.ai.ml._schema._deployment.online.data_collector_schema import DataCollectorSchema
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+from azure.ai.ml.entities._deployment.deployment_collection import DeploymentCollection
+from azure.ai.ml.entities._deployment.request_logging import RequestLogging
+
+
+@experimental
+class DataCollector:
+ """Data Capture deployment entity.
+
+ :param collections: Mapping dictionary of strings mapped to DeploymentCollection entities.
+ :type collections: Mapping[str, DeploymentCollection]
+ :param rolling_rate: The rolling rate of mdc files, possible values: ["minute", "hour", "day"].
+ :type rolling_rate: str
+ :param sampling_rate: The sampling rate of mdc files, possible values: [0.0, 1.0].
+ :type sampling_rate: float
+ :param request_logging: Logging of request payload parameters.
+ :type request_logging: RequestLogging
+ """
+
+ def __init__(
+ self,
+ collections: Dict[str, DeploymentCollection],
+ *,
+ rolling_rate: Optional[str] = None,
+ sampling_rate: Optional[float] = None,
+ request_logging: Optional[RequestLogging] = None,
+ **kwargs: Any,
+ ): # pylint: disable=unused-argument
+ self.collections = collections
+ self.rolling_rate = rolling_rate
+ self.sampling_rate = sampling_rate
+ self.request_logging = request_logging
+
+ if self.sampling_rate:
+ for collection in self.collections.values():
+ collection.sampling_rate = self.sampling_rate
+
+ def _to_dict(self) -> Dict:
+ # pylint: disable=no-member
+ res: dict = DataCollectorSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
+
+ @classmethod
+ def _from_rest_object(cls, rest_obj: RestDataCollector) -> "DataCollector":
+ collections = {}
+ sampling_rate = None
+ for k, v in rest_obj.collections.items():
+ sampling_rate = v.sampling_rate
+ collections[k] = DeploymentCollection._from_rest_object(v)
+ delattr(collections[k], "sampling_rate")
+
+ return DataCollector(
+ collections=collections,
+ rolling_rate=rest_obj.rolling_rate,
+ request_logging=(
+ RequestLogging._from_rest_object(rest_obj.request_logging) if rest_obj.request_logging else None
+ ),
+ sampling_rate=sampling_rate,
+ )
+
+ def _to_rest_object(self) -> RestDataCollector:
+ rest_collections: dict = {}
+ for collection in self.collections.values():
+ collection.sampling_rate = self.sampling_rate
+ delattr(self, "sampling_rate")
+ if self.request_logging:
+ self.request_logging = self.request_logging._to_rest_object()
+ if self.collections:
+ rest_collections = {}
+ for k, v in self.collections.items():
+ rest_collections[k] = v._to_rest_object()
+ return RestDataCollector(
+ collections=rest_collections, rolling_rate=self.rolling_rate, request_logging=self.request_logging
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/deployment.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/deployment.py
new file mode 100644
index 00000000..2f857cfa
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/deployment.py
@@ -0,0 +1,213 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access,arguments-renamed
+
+import logging
+from abc import abstractmethod
+from os import PathLike
+from typing import IO, TYPE_CHECKING, Any, AnyStr, Dict, Optional, Union
+
+from azure.ai.ml._restclient.v2022_02_01_preview.models import BatchDeploymentData
+from azure.ai.ml._restclient.v2022_05_01.models import OnlineDeploymentData
+from azure.ai.ml._utils.utils import dump_yaml_to_file
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+from azure.ai.ml.entities._resource import Resource
+from azure.ai.ml.exceptions import (
+ DeploymentException,
+ ErrorCategory,
+ ErrorTarget,
+ ValidationErrorType,
+ ValidationException,
+)
+
+from .code_configuration import CodeConfiguration
+
+# avoid circular import error
+if TYPE_CHECKING:
+ from azure.ai.ml.entities._assets._artifacts.model import Model
+ from azure.ai.ml.entities._assets.environment import Environment
+
+module_logger = logging.getLogger(__name__)
+
+
+class Deployment(Resource, RestTranslatableMixin):
+ """Endpoint Deployment base class.
+
+ :param name: Name of the deployment resource, defaults to None
+ :type name: typing.Optional[str]
+ :param endpoint_name: Name of the Endpoint resource, defaults to None
+ :type endpoint_name: typing.Optional[str]
+ :param description: Description of the deployment resource, defaults to None
+ :type description: typing.Optional[str]
+ :param tags: Tag dictionary. Tags can be added, removed, and updated, defaults to None
+ :type tags: typing.Optional[typing.Dict[str, typing.Any]]
+ :param properties: The asset property dictionary, defaults to None
+ :type properties: typing.Optional[typing.Dict[str, typing.Any]]
+ :param model: The Model entity, defaults to None
+ :type model: typing.Optional[typing.Union[str, ~azure.ai.ml.entities.Model]]
+ :param code_configuration: Code Configuration, defaults to None
+ :type code_configuration: typing.Optional[CodeConfiguration]
+ :param environment: The Environment entity, defaults to None
+ :type environment: typing.Optional[typing.Union[str, ~azure.ai.ml.entities.Environment]]
+ :param environment_variables: Environment variables that will be set in deployment, defaults to None
+ :type environment_variables: typing.Optional[typing.Dict[str, str]]
+ :param code_path: Folder path to local code assets. Equivalent to code_configuration.code.path
+ , defaults to None
+ :type code_path: typing.Optional[typing.Union[str, PathLike]]
+ :param scoring_script: Scoring script name. Equivalent to code_configuration.code.scoring_script
+ , defaults to None
+ :type scoring_script: typing.Optional[typing.Union[str, PathLike]]
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Deployment cannot be successfully validated.
+ Exception details will be provided in the error message.
+ """
+
+ def __init__(
+ self,
+ name: Optional[str] = None,
+ *,
+ endpoint_name: Optional[str] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict[str, Any]] = None,
+ properties: Optional[Dict[str, Any]] = None,
+ model: Optional[Union[str, "Model"]] = None,
+ code_configuration: Optional[CodeConfiguration] = None,
+ environment: Optional[Union[str, "Environment"]] = None,
+ environment_variables: Optional[Dict[str, str]] = None,
+ code_path: Optional[Union[str, PathLike]] = None,
+ scoring_script: Optional[Union[str, PathLike]] = None,
+ **kwargs: Any,
+ ):
+ # MFE is case-insensitive for Name. So convert the name into lower case here.
+ name = name.lower() if name else None
+ self.endpoint_name = endpoint_name
+ self._type: Optional[str] = kwargs.pop("type", None)
+
+ if code_configuration and (code_path or scoring_script):
+ msg = "code_path and scoring_script are not allowed if code_configuration is provided."
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.DEPLOYMENT,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+ super().__init__(name, description, tags, properties, **kwargs)
+
+ self.model = model
+ self.code_configuration = code_configuration
+ if not self.code_configuration and (code_path or scoring_script):
+ self.code_configuration = CodeConfiguration(code=code_path, scoring_script=scoring_script)
+
+ self.environment = environment
+ self.environment_variables = dict(environment_variables) if environment_variables else {}
+
+ @property
+ def type(self) -> Optional[str]:
+ """
+ Type of deployment.
+
+ :rtype: str
+ """
+ return self._type
+
+ @property
+ def code_path(self) -> Optional[Union[str, PathLike]]:
+ """
+ The code directory containing the scoring script.
+
+ :rtype: Union[str, PathLike]
+ """
+ return self.code_configuration.code if self.code_configuration and self.code_configuration.code else None
+
+ @code_path.setter
+ def code_path(self, value: Union[str, PathLike]) -> None:
+ if not self.code_configuration:
+ self.code_configuration = CodeConfiguration()
+
+ self.code_configuration.code = value
+
+ @property
+ def scoring_script(self) -> Optional[Union[str, PathLike]]:
+ """
+ The scoring script file path relative to the code directory.
+
+ :rtype: Union[str, PathLike]
+ """
+ return self.code_configuration.scoring_script if self.code_configuration else None
+
+ @scoring_script.setter
+ def scoring_script(self, value: Union[str, PathLike]) -> None:
+ if not self.code_configuration:
+ self.code_configuration = CodeConfiguration()
+
+ self.code_configuration.scoring_script = value # type: ignore[misc]
+
+ def dump(self, dest: Union[str, PathLike, IO[AnyStr]], **kwargs: Any) -> None:
+ """Dump the deployment content into a file in yaml format.
+
+ :param dest: The destination to receive this deployment's content.
+ Must be either a path to a local file, or an already-open file stream.
+ If dest is a file path, a new file will be created,
+ and an exception is raised if the file exists.
+ If dest is an open file, the file will be written to directly,
+ and an exception will be raised if the file is not writable.
+ :type dest: typing.Union[os.PathLike, str, typing.IO[typing.AnyStr]]
+ """
+ path = kwargs.pop("path", None)
+ yaml_serialized = self._to_dict()
+ dump_yaml_to_file(dest, yaml_serialized, default_flow_style=False, path=path, **kwargs)
+
+ @abstractmethod
+ def _to_dict(self) -> Dict:
+ pass
+
+ @classmethod
+ def _from_rest_object(
+ cls, deployment_rest_object: Union[OnlineDeploymentData, BatchDeploymentData]
+ ) -> Union[OnlineDeploymentData, BatchDeploymentData]:
+ from azure.ai.ml.entities._deployment.batch_deployment import BatchDeployment
+ from azure.ai.ml.entities._deployment.online_deployment import OnlineDeployment
+
+ if isinstance(deployment_rest_object, OnlineDeploymentData):
+ return OnlineDeployment._from_rest_object(deployment_rest_object)
+ if isinstance(deployment_rest_object, BatchDeploymentData):
+ return BatchDeployment._from_rest_object(deployment_rest_object)
+
+ msg = f"Unsupported deployment type {type(deployment_rest_object)}"
+ raise DeploymentException(
+ message=msg,
+ target=ErrorTarget.DEPLOYMENT,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.SYSTEM_ERROR,
+ )
+
+ def _to_rest_object(self) -> Any:
+ pass
+
+ def _merge_with(self, other: "Deployment") -> None:
+ if other:
+ if self.name != other.name:
+ msg = "The deployment name: {} and {} are not matched when merging."
+ raise ValidationException(
+ message=msg.format(self.name, other.name),
+ target=ErrorTarget.DEPLOYMENT,
+ no_personal_data_message=msg.format("[name1]", "[name2]"),
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ if other.tags:
+ self.tags: dict = {**self.tags, **other.tags}
+ if other.properties:
+ self.properties: dict = {**self.properties, **other.properties}
+ if other.environment_variables:
+ self.environment_variables = {
+ **self.environment_variables,
+ **other.environment_variables,
+ }
+ self.code_configuration = other.code_configuration or self.code_configuration
+ self.model = other.model or self.model
+ self.environment = other.environment or self.environment
+ self.endpoint_name = other.endpoint_name or self.endpoint_name
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/deployment_collection.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/deployment_collection.py
new file mode 100644
index 00000000..c1b1c750
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/deployment_collection.py
@@ -0,0 +1,62 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Any, Dict, Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import Collection as RestCollection
+from azure.ai.ml._schema._deployment.online.deployment_collection_schema import DeploymentCollectionSchema
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+from .data_asset import DataAsset
+
+
+@experimental
+class DeploymentCollection:
+ """Collection entity
+
+ :param enabled: Is logging for this collection enabled. Possible values include: 'true', 'false'.
+ :type enabled: str
+ :param data: Data asset id associated with collection logging.
+ :type data: str
+ :param client_id: Client ID associated with collection logging.
+ :type client_id: str
+
+ """
+
+ def __init__(
+ self,
+ *,
+ enabled: Optional[str] = None,
+ data: Optional[Union[str, DataAsset]] = None,
+ client_id: Optional[str] = None,
+ **kwargs: Any
+ ):
+ self.enabled = enabled # maps to data_collection_mode
+ self.data = data # maps to data_id
+ self.sampling_rate = kwargs.get(
+ "sampling_rate", None
+ ) # maps to sampling_rate, but it has to be passed from the data_collector root
+ self.client_id = client_id
+
+ def _to_dict(self) -> Dict:
+ # pylint: disable=no-member
+ res: dict = DeploymentCollectionSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
+
+ @classmethod
+ def _from_rest_object(cls, rest_obj: RestCollection) -> "DeploymentCollection":
+ return DeploymentCollection(
+ enabled="true" if rest_obj.data_collection_mode == "Enabled" else "false",
+ sampling_rate=rest_obj.sampling_rate,
+ data=rest_obj.data_id,
+ client_id=rest_obj.client_id,
+ )
+
+ def _to_rest_object(self) -> RestCollection:
+ return RestCollection(
+ data_collection_mode="enabled" if str(self.enabled).lower() == "true" else "disabled",
+ sampling_rate=self.sampling_rate,
+ data_id=self.data,
+ client_id=self.client_id,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/deployment_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/deployment_settings.py
new file mode 100644
index 00000000..0dbfc8fc
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/deployment_settings.py
@@ -0,0 +1,200 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=arguments-renamed
+
+import logging
+from typing import Optional
+
+from azure.ai.ml._restclient.v2022_05_01.models import BatchRetrySettings as RestBatchRetrySettings
+from azure.ai.ml._restclient.v2022_05_01.models import OnlineRequestSettings as RestOnlineRequestSettings
+from azure.ai.ml._restclient.v2022_05_01.models import ProbeSettings as RestProbeSettings
+from azure.ai.ml._utils.utils import (
+ from_iso_duration_format,
+ from_iso_duration_format_ms,
+ to_iso_duration_format,
+ to_iso_duration_format_ms,
+)
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+module_logger = logging.getLogger(__name__)
+
+
+class BatchRetrySettings(RestTranslatableMixin):
+ """Retry settings for batch deployment.
+
+ :param max_retries: Number of retries in failure, defaults to 3
+ :type max_retries: int
+ :param timeout: Timeout in seconds, defaults to 30
+ :type timeout: int
+ """
+
+ def __init__(self, *, max_retries: Optional[int] = None, timeout: Optional[int] = None):
+ self.max_retries = max_retries
+ self.timeout = timeout
+
+ def _to_rest_object(self) -> RestBatchRetrySettings:
+ return RestBatchRetrySettings(
+ max_retries=self.max_retries,
+ timeout=to_iso_duration_format(self.timeout),
+ )
+
+ @classmethod
+ def _from_rest_object(cls, settings: RestBatchRetrySettings) -> Optional["BatchRetrySettings"]:
+ return (
+ BatchRetrySettings(
+ max_retries=settings.max_retries,
+ timeout=from_iso_duration_format(settings.timeout),
+ )
+ if settings
+ else None
+ )
+
+ def _merge_with(self, other: "BatchRetrySettings") -> None:
+ if other:
+ self.timeout = other.timeout or self.timeout
+ self.max_retries = other.max_retries or self.max_retries
+
+
+class OnlineRequestSettings(RestTranslatableMixin):
+ """Request Settings entity.
+
+ :param request_timeout_ms: defaults to 5000
+ :type request_timeout_ms: int
+ :param max_concurrent_requests_per_instance: defaults to 1
+ :type max_concurrent_requests_per_instance: int
+ :param max_queue_wait_ms: defaults to 500
+ :type max_queue_wait_ms: int
+ """
+
+ def __init__(
+ self,
+ max_concurrent_requests_per_instance: Optional[int] = None,
+ request_timeout_ms: Optional[int] = None,
+ max_queue_wait_ms: Optional[int] = None,
+ ):
+ self.request_timeout_ms = request_timeout_ms
+ self.max_concurrent_requests_per_instance = max_concurrent_requests_per_instance
+ self.max_queue_wait_ms = max_queue_wait_ms
+
+ def _to_rest_object(self) -> RestOnlineRequestSettings:
+ return RestOnlineRequestSettings(
+ max_queue_wait=to_iso_duration_format_ms(self.max_queue_wait_ms),
+ max_concurrent_requests_per_instance=self.max_concurrent_requests_per_instance,
+ request_timeout=to_iso_duration_format_ms(self.request_timeout_ms),
+ )
+
+ def _merge_with(self, other: Optional["OnlineRequestSettings"]) -> None:
+ if other:
+ self.max_concurrent_requests_per_instance = (
+ other.max_concurrent_requests_per_instance or self.max_concurrent_requests_per_instance
+ )
+ self.request_timeout_ms = other.request_timeout_ms or self.request_timeout_ms
+ self.max_queue_wait_ms = other.max_queue_wait_ms or self.max_queue_wait_ms
+
+ @classmethod
+ def _from_rest_object(cls, settings: RestOnlineRequestSettings) -> Optional["OnlineRequestSettings"]:
+ return (
+ OnlineRequestSettings(
+ request_timeout_ms=from_iso_duration_format_ms(settings.request_timeout),
+ max_concurrent_requests_per_instance=settings.max_concurrent_requests_per_instance,
+ max_queue_wait_ms=from_iso_duration_format_ms(settings.max_queue_wait),
+ )
+ if settings
+ else None
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, OnlineRequestSettings):
+ return NotImplemented
+ if not other:
+ return False
+ # only compare mutable fields
+ return (
+ self.max_concurrent_requests_per_instance == other.max_concurrent_requests_per_instance
+ and self.request_timeout_ms == other.request_timeout_ms
+ and self.max_queue_wait_ms == other.max_queue_wait_ms
+ )
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
+
+
+class ProbeSettings(RestTranslatableMixin):
+ def __init__(
+ self,
+ *,
+ failure_threshold: Optional[int] = None,
+ success_threshold: Optional[int] = None,
+ timeout: Optional[int] = None,
+ period: Optional[int] = None,
+ initial_delay: Optional[int] = None,
+ ):
+ """Settings on how to probe an endpoint.
+
+ :param failure_threshold: Threshold for probe failures, defaults to 30
+ :type failure_threshold: int
+ :param success_threshold: Threshold for probe success, defaults to 1
+ :type success_threshold: int
+ :param timeout: timeout in seconds, defaults to 2
+ :type timeout: int
+ :param period: How often (in seconds) to perform the probe, defaults to 10
+ :type period: int
+ :param initial_delay: How long (in seconds) to wait for the first probe, defaults to 10
+ :type initial_delay: int
+ """
+
+ self.failure_threshold = failure_threshold
+ self.success_threshold = success_threshold
+ self.timeout = timeout
+ self.period = period
+ self.initial_delay = initial_delay
+
+ def _to_rest_object(self) -> RestProbeSettings:
+ return RestProbeSettings(
+ failure_threshold=self.failure_threshold,
+ success_threshold=self.success_threshold,
+ timeout=to_iso_duration_format(self.timeout),
+ period=to_iso_duration_format(self.period),
+ initial_delay=to_iso_duration_format(self.initial_delay),
+ )
+
+ def _merge_with(self, other: Optional["ProbeSettings"]) -> None:
+ if other:
+ self.failure_threshold = other.failure_threshold or self.failure_threshold
+ self.success_threshold = other.success_threshold or self.success_threshold
+ self.timeout = other.timeout or self.timeout
+ self.period = other.period or self.period
+ self.initial_delay = other.initial_delay or self.initial_delay
+
+ @classmethod
+ def _from_rest_object(cls, settings: RestProbeSettings) -> Optional["ProbeSettings"]:
+ return (
+ ProbeSettings(
+ failure_threshold=settings.failure_threshold,
+ success_threshold=settings.success_threshold,
+ timeout=from_iso_duration_format(settings.timeout),
+ period=from_iso_duration_format(settings.period),
+ initial_delay=from_iso_duration_format(settings.initial_delay),
+ )
+ if settings
+ else None
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, ProbeSettings):
+ return NotImplemented
+ if not other:
+ return False
+ # only compare mutable fields
+ return (
+ self.failure_threshold == other.failure_threshold
+ and self.success_threshold == other.success_threshold
+ and self.timeout == other.timeout
+ and self.period == other.period
+ and self.initial_delay == other.initial_delay
+ )
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/event_hub.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/event_hub.py
new file mode 100644
index 00000000..2729fa50
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/event_hub.py
@@ -0,0 +1,32 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Any, Dict, Optional
+
+from azure.ai.ml._schema._deployment.online.event_hub_schema import EventHubSchema
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+from azure.ai.ml.entities._deployment.oversize_data_config import OversizeDataConfig
+
+
+class EventHub:
+ """Event Hub deployment entity
+
+ :param namespace: Name space of eventhub, provided in format of "{namespace}.{name}".
+ :type namespace: str
+ :param oversize_data_config: Oversized payload body configurations.
+ :type oversize_data_config: OversizeDataConfig
+
+ """
+
+ # pylint: disable=unused-argument
+ def __init__(
+ self, namespace: Optional[str] = None, oversize_data_config: Optional[OversizeDataConfig] = None, **kwargs: Any
+ ):
+ self.namespace = namespace
+ self.oversize_data_config = oversize_data_config
+
+ def _to_dict(self) -> Dict:
+ # pylint: disable=no-member
+ res: dict = EventHubSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/job_definition.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/job_definition.py
new file mode 100644
index 00000000..56bebebc
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/job_definition.py
@@ -0,0 +1,58 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Any, Dict, Optional, Union
+
+from azure.ai.ml._schema._deployment.batch.job_definition_schema import JobDefinitionSchema
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+from azure.ai.ml.entities._component.component import Component
+from azure.ai.ml.entities._job.job import Job
+
+
+@experimental
+class JobDefinition:
+ """Job Definition entity.
+
+ :param type: Job definition type. Allowed value is: pipeline
+ :type type: str
+ :param name: Job name
+ :type name: str
+ :param job: Job definition
+ :type job: Union[Job, str]
+ :param component: Component definition
+ :type component: Union[Component, str]
+ :param settings: Job settings
+ :type settings: Dict[str, Any]
+ :param description: Job description.
+ :type description: str
+ :param tags: Job tags
+ :type tags: Dict[str, Any]
+ """
+
+ def __init__(
+ self,
+ # pylint: disable=redefined-builtin
+ type: str,
+ name: Optional[str] = None,
+ job: Optional[Union[Job, str]] = None,
+ component: Optional[Union[Component, str]] = None,
+ settings: Optional[Dict[str, Any]] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict[str, Any]] = None,
+ # pylint: disable=unused-argument
+ **kwargs: Any,
+ ):
+ self.type = type
+ self.name = name
+ self.job = job
+ self.component = component
+ self.settings = settings
+ self.tags = tags
+ self.description = description
+
+ def _to_dict(self) -> Dict:
+ # pylint: disable=no-member
+ res: dict = JobDefinitionSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/model_batch_deployment.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/model_batch_deployment.py
new file mode 100644
index 00000000..0ad4fd6f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/model_batch_deployment.py
@@ -0,0 +1,207 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from os import PathLike
+from pathlib import Path
+from typing import Any, Dict, Optional, Union
+
+from azure.ai.ml._restclient.v2022_05_01.models import BatchDeploymentData
+from azure.ai.ml._restclient.v2022_05_01.models import BatchDeploymentDetails as RestBatchDeployment
+from azure.ai.ml._restclient.v2022_05_01.models import BatchOutputAction
+from azure.ai.ml._restclient.v2022_05_01.models import CodeConfiguration as RestCodeConfiguration
+from azure.ai.ml._restclient.v2022_05_01.models import IdAssetReference
+from azure.ai.ml._schema._deployment.batch.model_batch_deployment import ModelBatchDeploymentSchema
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY
+from azure.ai.ml.constants._deployment import BatchDeploymentOutputAction
+from azure.ai.ml.entities._assets import Environment, Model
+from azure.ai.ml.entities._deployment.batch_deployment import BatchDeployment
+from azure.ai.ml.entities._deployment.deployment import Deployment
+from azure.ai.ml.entities._job.resource_configuration import ResourceConfiguration
+from azure.ai.ml.entities._util import load_from_dict
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+
+from .code_configuration import CodeConfiguration
+from .model_batch_deployment_settings import ModelBatchDeploymentSettings
+
+
+@experimental
+class ModelBatchDeployment(Deployment):
+ """Job Definition entity.
+
+ :param type: Job definition type. Allowed value is: pipeline
+ :type type: str
+ :param name: Job name
+ :type name: str
+ :param job: Job definition
+ :type job: Union[Job, str]
+ :param component: Component definition
+ :type component: Union[Component, str]
+ :param settings: Job settings
+ :type settings: Dict[str, Any]
+ :param description: Job description.
+ :type description: str
+ :param tags: Job tags
+ :type tags: Dict[str, Any]
+ :param properties: The asset property dictionary.
+ :type properties: dict[str, str]
+ """
+
+ def __init__(
+ self,
+ *,
+ name: Optional[str],
+ endpoint_name: Optional[str] = None,
+ environment: Optional[Union[str, Environment]] = None,
+ properties: Optional[Dict[str, str]] = None,
+ model: Optional[Union[str, Model]] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict[str, Any]] = None,
+ settings: Optional[ModelBatchDeploymentSettings] = None,
+ resources: Optional[ResourceConfiguration] = None,
+ compute: Optional[str] = None,
+ code_configuration: Optional[CodeConfiguration] = None,
+ code_path: Optional[Union[str, PathLike]] = None, # promoted property from code_configuration.code
+ scoring_script: Optional[
+ Union[str, PathLike]
+ ] = None, # promoted property from code_configuration.scoring_script
+ **kwargs: Any,
+ ):
+ self._provisioning_state: Optional[str] = kwargs.pop("provisioning_state", None)
+ super().__init__(
+ name=name,
+ endpoint_name=endpoint_name,
+ properties=properties,
+ code_path=code_path,
+ scoring_script=scoring_script,
+ environment=environment,
+ model=model,
+ description=description,
+ tags=tags,
+ code_configuration=code_configuration,
+ **kwargs,
+ )
+ self.compute = compute
+ self.resources = resources
+ if settings is not None:
+ self.settings = ModelBatchDeploymentSettings(
+ mini_batch_size=settings.mini_batch_size,
+ instance_count=settings.instance_count,
+ max_concurrency_per_instance=settings.max_concurrency_per_instance,
+ output_action=settings.output_action,
+ output_file_name=settings.output_file_name,
+ retry_settings=settings.retry_settings,
+ environment_variables=settings.environment_variables,
+ error_threshold=settings.error_threshold,
+ logging_level=settings.logging_level,
+ )
+ if self.resources is not None:
+ if self.resources.instance_count is None and settings.instance_count is not None:
+ self.resources.instance_count = settings.instance_count
+ if self.resources is None and settings.instance_count is not None:
+ self.resources = ResourceConfiguration(instance_count=settings.instance_count)
+
+ # pylint: disable=arguments-differ
+ def _to_rest_object(self, location: str) -> BatchDeploymentData: # type: ignore
+ self._validate()
+ code_config = (
+ RestCodeConfiguration(
+ code_id=self.code_configuration.code,
+ scoring_script=self.code_configuration.scoring_script,
+ )
+ if self.code_configuration
+ else None
+ )
+ deployment_settings = self.settings
+ model = IdAssetReference(asset_id=self.model) if self.model else None
+ batch_deployment = RestBatchDeployment(
+ description=self.description,
+ environment_id=self.environment,
+ model=model,
+ code_configuration=code_config,
+ output_file_name=deployment_settings.output_file_name,
+ output_action=BatchDeployment._yaml_output_action_to_rest_output_action( # pylint: disable=protected-access
+ deployment_settings.output_action
+ ),
+ error_threshold=deployment_settings.error_threshold,
+ resources=self.resources._to_rest_object() if self.resources else None, # pylint: disable=protected-access
+ retry_settings=(
+ deployment_settings.retry_settings._to_rest_object() # pylint: disable=protected-access
+ if deployment_settings.retry_settings
+ else None
+ ),
+ logging_level=deployment_settings.logging_level,
+ mini_batch_size=deployment_settings.mini_batch_size,
+ max_concurrency_per_instance=deployment_settings.max_concurrency_per_instance,
+ environment_variables=deployment_settings.environment_variables,
+ compute=self.compute,
+ properties=self.properties,
+ )
+ return BatchDeploymentData(location=location, properties=batch_deployment, tags=self.tags)
+
+ @classmethod
+ def _load(
+ cls,
+ data: Optional[Dict] = None,
+ yaml_path: Optional[Union[PathLike, str]] = None,
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "ModelBatchDeployment":
+ data = data or {}
+ params_override = params_override or []
+ cls._update_params(params_override)
+
+ context = {
+ BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path.cwd(),
+ PARAMS_OVERRIDE_KEY: params_override,
+ }
+ res: ModelBatchDeployment = load_from_dict(ModelBatchDeploymentSchema, data, context, **kwargs)
+ return res
+
+ @classmethod
+ def _update_params(cls, params_override: Any) -> None:
+ for param in params_override:
+ endpoint_name = param.get("endpoint_name")
+ if isinstance(endpoint_name, str):
+ param["endpoint_name"] = endpoint_name.lower()
+
+ @classmethod
+ def _yaml_output_action_to_rest_output_action(cls, yaml_output_action: str) -> str:
+ output_switcher = {
+ BatchDeploymentOutputAction.APPEND_ROW: BatchOutputAction.APPEND_ROW,
+ BatchDeploymentOutputAction.SUMMARY_ONLY: BatchOutputAction.SUMMARY_ONLY,
+ }
+ return output_switcher.get(yaml_output_action, yaml_output_action)
+
+ @property
+ def provisioning_state(self) -> Optional[str]:
+ """Batch deployment provisioning state, readonly.
+
+ :return: Batch deployment provisioning state.
+ :rtype: Optional[str]
+ """
+ return self._provisioning_state
+
+ def _validate(self) -> None:
+ self._validate_output_action()
+
+ def _validate_output_action(self) -> None:
+ if (
+ self.settings.output_action
+ and self.settings.output_action == BatchDeploymentOutputAction.SUMMARY_ONLY
+ and self.settings.output_file_name
+ ):
+ msg = "When output_action is set to {}, the output_file_name need not to be specified."
+ msg = msg.format(BatchDeploymentOutputAction.SUMMARY_ONLY)
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.BATCH_DEPLOYMENT,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+ def _to_dict(self) -> Dict:
+ res: dict = ModelBatchDeploymentSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/model_batch_deployment_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/model_batch_deployment_settings.py
new file mode 100644
index 00000000..36151019
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/model_batch_deployment_settings.py
@@ -0,0 +1,81 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Any, Dict, Optional
+
+from azure.ai.ml._schema._deployment.batch.model_batch_deployment_settings import ModelBatchDeploymentSettingsSchema
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+from azure.ai.ml.constants._deployment import BatchDeploymentOutputAction
+from azure.ai.ml.entities._deployment.deployment_settings import BatchRetrySettings
+
+
+@experimental
+class ModelBatchDeploymentSettings:
+ """Model Batch Deployment Settings entity.
+
+ :param mini_batch_size: Size of the mini-batch passed to each batch invocation, defaults to 10
+ :type mini_batch_size: int
+ :param instance_count: Number of instances the interfering will run on. Equivalent to resources.instance_count.
+ :type instance_count: int
+ :param output_action: Indicates how the output will be organized. Possible values include:
+ "summary_only", "append_row". Defaults to "append_row"
+ :type output_action: str or ~azure.ai.ml.constants._deployment.BatchDeploymentOutputAction
+ :param output_file_name: Customized output file name for append_row output action, defaults to "predictions.csv"
+ :type output_file_name: str
+ :param max_concurrency_per_instance: Indicates maximum number of parallelism per instance, defaults to 1
+ :type max_concurrency_per_instance: int
+ :param retry_settings: Retry settings for a batch inference operation, defaults to None
+ :type retry_settings: BatchRetrySettings
+ :param environment_variables: Environment variables that will be set in deployment.
+ :type environment_variables: dict
+ :param error_threshold: Error threshold, if the error count for the entire input goes above
+ this value,
+ the batch inference will be aborted. Range is [-1, int.MaxValue]
+ -1 value indicates, ignore all failures during batch inference
+ For FileDataset count of file failures
+ For TabularDataset, this is the count of record failures, defaults to -1
+ :type error_threshold: int
+ :param logging_level: Logging level for batch inference operation, defaults to "info"
+ :type logging_level: str
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START model_batch_deployment_settings_entity_create]
+ :end-before: [END model_batch_deployment_settings_entity_create]
+ :language: python
+ :dedent: 8
+ :caption: Creating a Model Batch Deployment Settings object.
+ """
+
+ def __init__(
+ self,
+ *,
+ mini_batch_size: Optional[int],
+ instance_count: Optional[int] = None,
+ max_concurrency_per_instance: Optional[int] = None,
+ output_action: Optional[BatchDeploymentOutputAction] = None,
+ output_file_name: Optional[str] = None,
+ retry_settings: Optional[BatchRetrySettings] = None,
+ environment_variables: Optional[Dict[str, str]] = None,
+ error_threshold: Optional[int] = None,
+ logging_level: Optional[str] = None,
+ # pylint: disable=unused-argument
+ **kwargs: Any,
+ ):
+ self.mini_batch_size = mini_batch_size
+ self.instance_count = instance_count
+ self.max_concurrency_per_instance = max_concurrency_per_instance
+ self.output_action = output_action
+ self.output_file_name = output_file_name
+ self.retry_settings = retry_settings
+ self.environment_variables = environment_variables
+ self.error_threshold = error_threshold
+ self.logging_level = logging_level
+
+ def _to_dict(self) -> Dict:
+ # pylint: disable=no-member
+ res: dict = ModelBatchDeploymentSettingsSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/online_deployment.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/online_deployment.py
new file mode 100644
index 00000000..131d3293
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/online_deployment.py
@@ -0,0 +1,742 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access,arguments-renamed,unidiomatic-typecheck
+
+import logging
+import os
+import typing
+from abc import abstractmethod
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Tuple, Union, cast
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import CodeConfiguration as RestCodeConfiguration
+from azure.ai.ml._restclient.v2023_04_01_preview.models import EndpointComputeType
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ KubernetesOnlineDeployment as RestKubernetesOnlineDeployment,
+)
+from azure.ai.ml._restclient.v2023_04_01_preview.models import ManagedOnlineDeployment as RestManagedOnlineDeployment
+from azure.ai.ml._restclient.v2023_04_01_preview.models import OnlineDeployment as RestOnlineDeploymentData
+from azure.ai.ml._restclient.v2023_04_01_preview.models import OnlineDeploymentProperties as RestOnlineDeploymentDetails
+from azure.ai.ml._restclient.v2023_04_01_preview.models import Sku as RestSku
+from azure.ai.ml._schema._deployment.online.online_deployment import (
+ KubernetesOnlineDeploymentSchema,
+ ManagedOnlineDeploymentSchema,
+)
+from azure.ai.ml._utils._arm_id_utils import _parse_endpoint_name_from_deployment_id
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY, TYPE, ArmConstants
+from azure.ai.ml.constants._endpoint import EndpointYamlFields
+from azure.ai.ml.entities._assets import Code
+from azure.ai.ml.entities._assets._artifacts.model import Model
+from azure.ai.ml.entities._assets.environment import Environment
+from azure.ai.ml.entities._deployment.code_configuration import CodeConfiguration
+from azure.ai.ml.entities._deployment.data_collector import DataCollector
+from azure.ai.ml.entities._deployment.deployment_settings import OnlineRequestSettings, ProbeSettings
+from azure.ai.ml.entities._deployment.resource_requirements_settings import ResourceRequirementsSettings
+from azure.ai.ml.entities._deployment.scale_settings import (
+ DefaultScaleSettings,
+ OnlineScaleSettings,
+ TargetUtilizationScaleSettings,
+)
+from azure.ai.ml.entities._endpoint._endpoint_helpers import validate_endpoint_or_deployment_name
+from azure.ai.ml.entities._util import load_from_dict
+from azure.ai.ml.exceptions import (
+ DeploymentException,
+ ErrorCategory,
+ ErrorTarget,
+ ValidationErrorType,
+ ValidationException,
+)
+
+from .deployment import Deployment
+
+module_logger = logging.getLogger(__name__)
+
+
+# pylint: disable=too-many-instance-attributes
+class OnlineDeployment(Deployment):
+ """Online endpoint deployment entity.
+
+ :param name: Name of the deployment resource.
+ :type name: str
+ :param endpoint_name: Name of the endpoint resource, defaults to None
+ :type endpoint_name: typing.Optional[str]
+ :param tags: Tag dictionary. Tags can be added, removed, and updated, defaults to None
+ :type tags: typing.Optional[typing.Dict[str, typing.Any]]
+ :param properties: The asset property dictionary, defaults to None
+ :type properties: typing.Optional[typing.Dict[str, typing.Any]]
+ :param description: Description of the resource, defaults to None
+ :type description: typing.Optional[str]
+ :param model: Model entity for the endpoint deployment, defaults to None
+ :type model: typing.Optional[typing.Union[str, ~azure.ai.ml.entities.Model]]
+ :param data_collector: Data Collector entity for the endpoint deployment, defaults to None
+ :type data_collector: typing.Optional[typing.Union[str, ~azure.ai.ml.entities.DataCollector]]
+ :param code_configuration: Code Configuration, defaults to None
+ :type code_configuration: typing.Optional[~azure.ai.ml.entities.CodeConfiguration]
+ :param environment: Environment entity for the endpoint deployment, defaults to None
+ :type environment: typing.Optional[typing.Union[str, ~azure.ai.ml.entities.Environment]]
+ :param app_insights_enabled: Is appinsights enabled, defaults to False
+ :type app_insights_enabled: typing.Optional[bool]
+ :param scale_settings: How the online deployment will scale, defaults to None
+ :type scale_settings: typing.Optional[~azure.ai.ml.entities.OnlineScaleSettings]
+ :param request_settings: Online Request Settings, defaults to None
+ :type request_settings: typing.Optional[~azure.ai.ml.entities.OnlineRequestSettings]
+ :param liveness_probe: Liveness probe settings, defaults to None
+ :type liveness_probe: typing.Optional[~azure.ai.ml.entities.ProbeSettings]
+ :param readiness_probe: Readiness probe settings, defaults to None
+ :type readiness_probe: typing.Optional[~azure.ai.ml.entities.ProbeSettings]
+ :param environment_variables: Environment variables that will be set in deployment, defaults to None
+ :type environment_variables: typing.Optional[typing.Dict[str, str]]
+ :param instance_count: The instance count used for this deployment, defaults to None
+ :type instance_count: typing.Optional[int]
+ :param instance_type: Azure compute sku, defaults to None
+ :type instance_type: typing.Optional[str]
+ :param model_mount_path: The path to mount the model in custom container, defaults to None
+ :type model_mount_path: typing.Optional[str]
+ :param code_path: Equivalent to code_configuration.code, will be ignored if code_configuration is present
+ , defaults to None
+ :type code_path: typing.Optional[typing.Union[str, os.PathLike]]
+ :param scoring_script: Equivalent to code_configuration.code.scoring_script.
+ Will be ignored if code_configuration is present, defaults to None
+ :type scoring_script: typing.Optional[typing.Union[str, os.PathLike]]
+ """
+
+ def __init__(
+ self,
+ name: str,
+ *,
+ endpoint_name: Optional[str] = None,
+ tags: Optional[Dict[str, typing.Any]] = None,
+ properties: Optional[Dict[str, typing.Any]] = None,
+ description: Optional[str] = None,
+ model: Optional[Union[str, "Model"]] = None,
+ data_collector: Optional[DataCollector] = None,
+ code_configuration: Optional[CodeConfiguration] = None,
+ environment: Optional[Union[str, "Environment"]] = None,
+ app_insights_enabled: Optional[bool] = False,
+ scale_settings: Optional[OnlineScaleSettings] = None,
+ request_settings: Optional[OnlineRequestSettings] = None,
+ liveness_probe: Optional[ProbeSettings] = None,
+ readiness_probe: Optional[ProbeSettings] = None,
+ environment_variables: Optional[Dict[str, str]] = None,
+ instance_count: Optional[int] = None,
+ instance_type: Optional[str] = None,
+ model_mount_path: Optional[str] = None,
+ code_path: Optional[Union[str, os.PathLike]] = None, # promoted property from code_configuration.code
+ scoring_script: Optional[Union[str, os.PathLike]] = None, # promoted property code_configuration.scoring_script
+ **kwargs: typing.Any,
+ ):
+ self._provisioning_state: Optional[str] = kwargs.pop("provisioning_state", None)
+
+ super(OnlineDeployment, self).__init__(
+ name=name,
+ endpoint_name=endpoint_name,
+ tags=tags,
+ properties=properties,
+ description=description,
+ model=model,
+ code_configuration=code_configuration,
+ environment=environment,
+ environment_variables=environment_variables,
+ code_path=code_path,
+ scoring_script=scoring_script,
+ **kwargs,
+ )
+
+ self.app_insights_enabled = app_insights_enabled
+ self.scale_settings = scale_settings
+ self.request_settings = request_settings
+ self.liveness_probe = liveness_probe
+ self.readiness_probe = readiness_probe
+ self.instance_count = instance_count
+ self._arm_type = ArmConstants.ONLINE_DEPLOYMENT_TYPE
+ self.model_mount_path = model_mount_path
+ self.instance_type = instance_type
+ self.data_collector: Any = data_collector
+
+ @property
+ def provisioning_state(self) -> Optional[str]:
+ """Deployment provisioning state, readonly.
+
+ :return: Deployment provisioning state.
+ :rtype: typing.Optional[str]
+ """
+ return self._provisioning_state
+
+ def _generate_dependencies(self) -> Tuple:
+ """Convert dependencies into ARM id or REST wrapper.
+
+ :return: A 3-tuple of the code configuration, environment ID, and model ID.
+ :rtype: Tuple[RestCodeConfiguration, str, str]
+ """
+ code = None
+
+ if self.code_configuration:
+ self.code_configuration._validate()
+ if self.code_configuration.code is not None:
+ if isinstance(self.code_configuration.code, str):
+ code_id = self.code_configuration.code
+ elif not isinstance(self.code_configuration.code, os.PathLike):
+ code_id = self.code_configuration.code.id
+
+ code = RestCodeConfiguration(
+ code_id=code_id, # pylint: disable=possibly-used-before-assignment
+ scoring_script=self.code_configuration.scoring_script,
+ )
+
+ model_id = None
+ if self.model:
+ model_id = self.model if isinstance(self.model, str) else self.model.id
+
+ environment_id = None
+ if self.environment:
+ environment_id = self.environment if isinstance(self.environment, str) else self.environment.id
+
+ return code, environment_id, model_id
+
+ @abstractmethod
+ def _to_dict(self) -> Dict:
+ pass
+
+ @abstractmethod
+ def _to_arm_resource_param(self, **kwargs: Any) -> Dict:
+ pass
+
+ @abstractmethod
+ def _to_rest_object(self) -> RestOnlineDeploymentData:
+ pass
+
+ @classmethod
+ def _from_rest_object(cls, deployment: RestOnlineDeploymentData) -> RestOnlineDeploymentDetails:
+ if deployment.properties.endpoint_compute_type == EndpointComputeType.KUBERNETES:
+ return KubernetesOnlineDeployment._from_rest_object(deployment)
+ if deployment.properties.endpoint_compute_type == EndpointComputeType.MANAGED:
+ return ManagedOnlineDeployment._from_rest_object(deployment)
+
+ msg = f"Unsupported online endpoint type {deployment.properties.endpoint_compute_type}."
+ raise DeploymentException(
+ message=msg,
+ target=ErrorTarget.ONLINE_DEPLOYMENT,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.SYSTEM_ERROR,
+ )
+
+ def _get_arm_resource(self, **kwargs: Any) -> Dict:
+ resource: dict = super(OnlineDeployment, self)._get_arm_resource(**kwargs)
+ depends_on = []
+ if self.environment and isinstance(self.environment, Environment):
+ depends_on.append(f"{self.environment._arm_type}Deployment")
+ if self.code_configuration and self.code_configuration.code and isinstance(self.code_configuration.code, Code):
+ depends_on.append(f"{self.code_configuration.code._arm_type}Deployment")
+ if self.model and isinstance(self.model, Model):
+ depends_on.append(f"{self.model._arm_type}Deployment")
+ resource[ArmConstants.DEPENDSON_PARAMETER_NAME] = depends_on
+ return resource
+
+ def _get_arm_resource_and_params(self, **kwargs: Any) -> List:
+ resource_param_tuple_list = [(self._get_arm_resource(**kwargs), self._to_arm_resource_param(**kwargs))]
+ if self.environment and isinstance(self.environment, Environment):
+ resource_param_tuple_list.extend(self.environment._get_arm_resource_and_params())
+ if self.code_configuration and self.code_configuration.code and isinstance(self.code_configuration.code, Code):
+ resource_param_tuple_list.extend(self.code_configuration.code._get_arm_resource_and_params())
+ if self.model and isinstance(self.model, Model):
+ resource_param_tuple_list.extend(self.model._get_arm_resource_and_params())
+ return resource_param_tuple_list
+
+ def _validate_name(self) -> None:
+ if self.name:
+ validate_endpoint_or_deployment_name(self.name, is_deployment=True)
+
+ def _merge_with(self, other: Any) -> None:
+ if other:
+ if self.name != other.name:
+ msg = "The deployment name: {} and {} are not matched when merging."
+ raise ValidationException(
+ message=msg.format(self.name, other.name),
+ target=ErrorTarget.ONLINE_DEPLOYMENT,
+ no_personal_data_message=msg.format("[name1]", "[name2]"),
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ super()._merge_with(other)
+ self.app_insights_enabled = other.app_insights_enabled or self.app_insights_enabled
+ # Adding noqa: Fix E721 do not compare types, use 'isinstance()'
+ # isinstance will include checking for subclasses, which is explicitly undesired by a logic.
+ if self.scale_settings and type(self.scale_settings) == type(other.scale_settings): # noqa
+ self.scale_settings._merge_with(other.scale_settings)
+ else:
+ self.scale_settings = other.scale_settings
+ if self.request_settings:
+ self.request_settings._merge_with(other.request_settings)
+ else:
+ self.request_settings = other.request_settings
+ if self.liveness_probe:
+ self.liveness_probe._merge_with(other.liveness_probe)
+ else:
+ self.liveness_probe = other.liveness_probe
+ if self.readiness_probe:
+ self.readiness_probe._merge_with(other.readiness_probe)
+ else:
+ self.readiness_probe = other.readiness_probe
+ self.instance_count = other.instance_count or self.instance_count
+ self.instance_type = other.instance_type or self.instance_type
+
+ @classmethod
+ def _set_scale_settings(cls, data: dict) -> None:
+ if not hasattr(data, EndpointYamlFields.SCALE_SETTINGS):
+ return
+
+ scale_settings = data[EndpointYamlFields.SCALE_SETTINGS]
+ keyName = TYPE
+ if scale_settings and scale_settings[keyName] == "default":
+ scale_copy = scale_settings.copy()
+ for key in scale_copy:
+ if key != keyName:
+ scale_settings.pop(key, None)
+
+ @classmethod
+ def _load(
+ cls,
+ data: Optional[Dict] = None,
+ yaml_path: Optional[Union[os.PathLike, str]] = None,
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "OnlineDeployment":
+ data = data or {}
+ params_override = params_override or []
+ context = {
+ BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path.cwd(),
+ PARAMS_OVERRIDE_KEY: params_override,
+ }
+
+ deployment_type = data.get("type", None)
+
+ if deployment_type == camel_to_snake(EndpointComputeType.KUBERNETES.value):
+ res_kub: OnlineDeployment = load_from_dict(KubernetesOnlineDeploymentSchema, data, context, **kwargs)
+ return res_kub
+
+ res_manage: OnlineDeployment = load_from_dict(ManagedOnlineDeploymentSchema, data, context, **kwargs)
+ return res_manage
+
+
+class KubernetesOnlineDeployment(OnlineDeployment):
+ """Kubernetes Online endpoint deployment entity.
+
+ :param name: Name of the deployment resource.
+ :type name: str
+ :param endpoint_name: Name of the endpoint resource, defaults to None
+ :type endpoint_name: typing.Optional[str]
+ :param tags: Tag dictionary. Tags can be added, removed, and updated., defaults to None
+ :type tags: typing.Optional[typing.Dict[str, typing.Any]]
+ :param properties: The asset property dictionary, defaults to None
+ :type properties: typing.Optional[typing.Dict[str, typing.Any]]
+ :param description: Description of the resource, defaults to None
+ :type description: typing.Optional[str]
+ :param model: Model entity for the endpoint deployment, defaults to None
+ :type model: typing.Optional[typing.Union[str, ~azure.ai.ml.entities.Model]]
+ :param code_configuration: Code Configuration, defaults to None
+ :type code_configuration: typing.Optional[~azure.ai.ml.entities.CodeConfiguration]
+ :param environment: Environment entity for the endpoint deployment, defaults to None
+ :type environment: typing.Optional[typing.Union[str, ~azure.ai.ml.entities.Environment]]
+ :param app_insights_enabled: Is appinsights enabled, defaults to False
+ :type app_insights_enabled: bool
+ :param scale_settings: How the online deployment will scale, defaults to None
+ :type scale_settings: typing.Optional[typing.Union[~azure.ai.ml.entities.DefaultScaleSettings
+ , ~azure.ai.ml.entities.TargetUtilizationScaleSettings]]
+ :param request_settings: Online Request Settings, defaults to None
+ :type request_settings: typing.Optional[OnlineRequestSettings]
+ :param liveness_probe: Liveness probe settings, defaults to None
+ :type liveness_probe: typing.Optional[~azure.ai.ml.entities.ProbeSettings]
+ :param readiness_probe: Readiness probe settings, defaults to None
+ :type readiness_probe: typing.Optional[~azure.ai.ml.entities.ProbeSettings]
+ :param environment_variables: Environment variables that will be set in deployment, defaults to None
+ :type environment_variables: typing.Optional[typing.Dict[str, str]]
+ :param resources: Resource requirements settings, defaults to None
+ :type resources: typing.Optional[~azure.ai.ml.entities.ResourceRequirementsSettings]
+ :param instance_count: The instance count used for this deployment, defaults to None
+ :type instance_count: typing.Optional[int]
+ :param instance_type: The instance type defined by K8S cluster admin, defaults to None
+ :type instance_type: typing.Optional[str]
+ :param code_path: Equivalent to code_configuration.code, will be ignored if code_configuration is present
+ , defaults to None
+ :type code_path: typing.Optional[typing.Union[str, os.PathLike]]
+ :param scoring_script: Equivalent to code_configuration.code.scoring_script.
+ Will be ignored if code_configuration is present, defaults to None
+ :type scoring_script: typing.Optional[typing.Union[str, os.PathLike]]
+ """
+
+ def __init__(
+ self,
+ *,
+ name: str,
+ endpoint_name: Optional[str] = None,
+ tags: Optional[Dict[str, typing.Any]] = None,
+ properties: Optional[Dict[str, typing.Any]] = None,
+ description: Optional[str] = None,
+ model: Optional[Union[str, "Model"]] = None,
+ code_configuration: Optional[CodeConfiguration] = None,
+ environment: Optional[Union[str, "Environment"]] = None,
+ app_insights_enabled: bool = False,
+ scale_settings: Optional[Union[DefaultScaleSettings, TargetUtilizationScaleSettings]] = None,
+ request_settings: Optional[OnlineRequestSettings] = None,
+ liveness_probe: Optional[ProbeSettings] = None,
+ readiness_probe: Optional[ProbeSettings] = None,
+ environment_variables: Optional[Dict[str, str]] = None,
+ resources: Optional[ResourceRequirementsSettings] = None,
+ instance_count: Optional[int] = None,
+ instance_type: Optional[str] = None,
+ code_path: Optional[Union[str, os.PathLike]] = None, # promoted property from code_configuration.code
+ scoring_script: Optional[
+ Union[str, os.PathLike]
+ ] = None, # promoted property from code_configuration.scoring_script
+ **kwargs: Any,
+ ):
+ kwargs["type"] = EndpointComputeType.KUBERNETES.value
+ super(KubernetesOnlineDeployment, self).__init__(
+ name=name,
+ endpoint_name=endpoint_name,
+ tags=tags,
+ properties=properties,
+ description=description,
+ model=model,
+ code_configuration=code_configuration,
+ environment=environment,
+ environment_variables=environment_variables,
+ instance_count=instance_count,
+ instance_type=instance_type,
+ app_insights_enabled=app_insights_enabled,
+ scale_settings=scale_settings,
+ request_settings=request_settings,
+ liveness_probe=liveness_probe,
+ readiness_probe=readiness_probe,
+ code_path=code_path,
+ scoring_script=scoring_script,
+ **kwargs,
+ )
+
+ self.resources = resources
+
+ def _to_dict(self) -> Dict:
+ res: dict = KubernetesOnlineDeploymentSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
+
+ # pylint: disable=arguments-differ
+ def _to_rest_object(self, location: str) -> RestOnlineDeploymentData: # type: ignore
+ self._validate()
+ code, environment, model = self._generate_dependencies()
+
+ properties = RestKubernetesOnlineDeployment(
+ code_configuration=code,
+ environment_id=environment,
+ model=model,
+ model_mount_path=self.model_mount_path,
+ scale_settings=self.scale_settings._to_rest_object() if self.scale_settings else None,
+ properties=self.properties,
+ description=self.description,
+ environment_variables=self.environment_variables,
+ app_insights_enabled=self.app_insights_enabled,
+ request_settings=self.request_settings._to_rest_object() if self.request_settings else None,
+ liveness_probe=self.liveness_probe._to_rest_object() if self.liveness_probe else None,
+ readiness_probe=self.readiness_probe._to_rest_object() if self.readiness_probe else None,
+ container_resource_requirements=self.resources._to_rest_object() if self.resources else None,
+ instance_type=self.instance_type if self.instance_type else None,
+ data_collector=self.data_collector._to_rest_object() if self.data_collector else None,
+ )
+ sku = RestSku(name="Default", capacity=self.instance_count)
+
+ return RestOnlineDeploymentData(location=location, properties=properties, tags=self.tags, sku=sku)
+
+ def _to_arm_resource_param(self, **kwargs: Any) -> Dict:
+ rest_object = self._to_rest_object(**kwargs)
+ properties = rest_object.properties
+ sku = rest_object.sku
+ tags = rest_object.tags
+
+ return {
+ self._arm_type: {
+ ArmConstants.NAME: self.name,
+ ArmConstants.PROPERTIES_PARAMETER_NAME: self._serialize.body(properties, "K8SOnlineDeployment"),
+ ArmConstants.SKU: self._serialize.body(sku, "Sku"),
+ ArmConstants.TAGS: tags,
+ }
+ }
+
+ def _merge_with(self, other: Any) -> None:
+ if other:
+ super()._merge_with(other)
+ if self.resources:
+ self.resources._merge_with(other.resources)
+ else:
+ self.resources = other.resources
+
+ def _validate(self) -> None:
+ self._validate_name()
+
+ @classmethod
+ def _from_rest_object(cls, resource: RestOnlineDeploymentData) -> "KubernetesOnlineDeployment":
+ deployment = resource.properties
+
+ code_config = (
+ CodeConfiguration(
+ code=deployment.code_configuration.code_id,
+ scoring_script=deployment.code_configuration.scoring_script,
+ )
+ if deployment.code_configuration
+ else None
+ )
+
+ return KubernetesOnlineDeployment(
+ id=resource.id,
+ name=resource.name,
+ tags=resource.tags,
+ properties=deployment.properties,
+ description=deployment.description,
+ request_settings=OnlineRequestSettings._from_rest_object(deployment.request_settings),
+ model=deployment.model,
+ code_configuration=code_config,
+ environment=deployment.environment_id,
+ resources=ResourceRequirementsSettings._from_rest_object(deployment.container_resource_requirements),
+ app_insights_enabled=deployment.app_insights_enabled,
+ scale_settings=cast(
+ Optional[Union[DefaultScaleSettings, TargetUtilizationScaleSettings]],
+ OnlineScaleSettings._from_rest_object(deployment.scale_settings),
+ ),
+ liveness_probe=ProbeSettings._from_rest_object(deployment.liveness_probe),
+ readiness_probe=ProbeSettings._from_rest_object(deployment.readiness_probe),
+ environment_variables=deployment.environment_variables,
+ endpoint_name=_parse_endpoint_name_from_deployment_id(resource.id),
+ instance_count=resource.sku.capacity if resource.sku else None,
+ instance_type=deployment.instance_type,
+ data_collector=(
+ DataCollector._from_rest_object(deployment.data_collector)
+ if hasattr(deployment, "data_collector") and deployment.data_collector
+ else None
+ ),
+ provisioning_state=deployment.provisioning_state if hasattr(deployment, "provisioning_state") else None,
+ )
+
+
+class ManagedOnlineDeployment(OnlineDeployment):
+ """Managed Online endpoint deployment entity.
+
+ :param name: Name of the deployment resource
+ :type name: str
+ :param endpoint_name: Name of the endpoint resource, defaults to None
+ :type endpoint_name: typing.Optional[str]
+ :param tags: Tag dictionary. Tags can be added, removed, and updated., defaults to None
+ :type tags: typing.Optional[typing.Dict[str, typing.Any]]
+ :param properties: The asset property dictionary, defaults to None
+ :type properties: typing.Optional[typing.Dict[str, typing.Any]]
+ :param description: Description of the resource, defaults to None
+ :type description: typing.Optional[str]
+ :param model: Model entity for the endpoint deployment, defaults to None
+ :type model: typing.Optional[typing.Union[str, ~azure.ai.ml.entities.Model]]
+ :param code_configuration: Code Configuration, defaults to None
+ :type code_configuration: typing.Optional[~azure.ai.ml.entities.CodeConfiguration]
+ :param environment: Environment entity for the endpoint deployment, defaults to None
+ :type environment: typing.Optional[typing.Union[str, ~azure.ai.ml.entities.Environment]]
+ :param app_insights_enabled: Is appinsights enabled, defaults to False
+ :type app_insights_enabled: bool
+ :param scale_settings: How the online deployment will scale, defaults to None
+ :type scale_settings: typing.Optional[typing.Union[~azure.ai.ml.entities.DefaultScaleSettings
+ , ~azure.ai.ml.entities.TargetUtilizationScaleSettings]]
+ :param request_settings: Online Request Settings, defaults to None
+ :type request_settings: typing.Optional[OnlineRequestSettings]
+ :param liveness_probe: Liveness probe settings, defaults to None
+ :type liveness_probe: typing.Optional[~azure.ai.ml.entities.ProbeSettings]
+ :param readiness_probe: Readiness probe settings, defaults to None
+ :type readiness_probe: typing.Optional[~azure.ai.ml.entities.ProbeSettings]
+ :param environment_variables: Environment variables that will be set in deployment, defaults to None
+ :type environment_variables: typing.Optional[typing.Dict[str, str]]
+ :param instance_type: Azure compute sku, defaults to None
+ :type instance_type: typing.Optional[str]
+ :param instance_count: The instance count used for this deployment, defaults to None
+ :type instance_count: typing.Optional[int]
+ :param egress_public_network_access: Whether to restrict communication between a deployment and the
+ Azure resources used to by the deployment. Allowed values are: "enabled", "disabled", defaults to None
+ :type egress_public_network_access: typing.Optional[str]
+ :param code_path: Equivalent to code_configuration.code, will be ignored if code_configuration is present
+ , defaults to None
+ :type code_path: typing.Optional[typing.Union[str, os.PathLike]]
+ :param scoring_script_path: Equivalent to code_configuration.scoring_script, will be ignored if
+ code_configuration is present, defaults to None
+ :type scoring_script_path: typing.Optional[typing.Union[str, os.PathLike]]
+ :param data_collector: Data collector, defaults to None
+ :type data_collector: typing.Optional[typing.List[~azure.ai.ml.entities.DataCollector]]
+ """
+
+ def __init__(
+ self,
+ *,
+ name: str,
+ endpoint_name: Optional[str] = None,
+ tags: Optional[Dict[str, typing.Any]] = None,
+ properties: Optional[Dict[str, typing.Any]] = None,
+ description: Optional[str] = None,
+ model: Optional[Union[str, "Model"]] = None,
+ code_configuration: Optional[CodeConfiguration] = None,
+ environment: Optional[Union[str, "Environment"]] = None,
+ app_insights_enabled: bool = False,
+ scale_settings: Optional[Union[DefaultScaleSettings, TargetUtilizationScaleSettings]] = None,
+ request_settings: Optional[OnlineRequestSettings] = None,
+ liveness_probe: Optional[ProbeSettings] = None,
+ readiness_probe: Optional[ProbeSettings] = None,
+ environment_variables: Optional[Dict[str, str]] = None,
+ instance_type: Optional[str] = None,
+ instance_count: Optional[int] = None,
+ egress_public_network_access: Optional[str] = None,
+ code_path: Optional[Union[str, os.PathLike]] = None, # promoted property from code_configuration.code
+ scoring_script: Optional[
+ Union[str, os.PathLike]
+ ] = None, # promoted property from code_configuration.scoring_script
+ data_collector: Optional[DataCollector] = None,
+ **kwargs: Any,
+ ):
+ kwargs["type"] = EndpointComputeType.MANAGED.value
+ self.private_network_connection = kwargs.pop("private_network_connection", None)
+ self.package_model = kwargs.pop("package_model", False)
+
+ super(ManagedOnlineDeployment, self).__init__(
+ name=name,
+ endpoint_name=endpoint_name,
+ tags=tags,
+ properties=properties,
+ description=description,
+ model=model,
+ code_configuration=code_configuration,
+ environment=environment,
+ environment_variables=environment_variables,
+ app_insights_enabled=app_insights_enabled,
+ scale_settings=scale_settings,
+ request_settings=request_settings,
+ liveness_probe=liveness_probe,
+ readiness_probe=readiness_probe,
+ instance_count=instance_count,
+ instance_type=instance_type,
+ code_path=code_path,
+ scoring_script=scoring_script,
+ data_collector=data_collector,
+ **kwargs,
+ )
+
+ self.readiness_probe = readiness_probe
+ self.egress_public_network_access = egress_public_network_access
+
+ def _to_dict(self) -> Dict:
+ res: dict = ManagedOnlineDeploymentSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
+
+ # pylint: disable=arguments-differ
+ def _to_rest_object(self, location: str) -> RestOnlineDeploymentData: # type: ignore
+ self._validate()
+ code, environment, model = self._generate_dependencies()
+ properties = RestManagedOnlineDeployment(
+ code_configuration=code,
+ environment_id=environment,
+ model=model,
+ model_mount_path=self.model_mount_path,
+ scale_settings=self.scale_settings._to_rest_object() if self.scale_settings else None,
+ properties=self.properties,
+ description=self.description,
+ environment_variables=self.environment_variables,
+ app_insights_enabled=self.app_insights_enabled,
+ request_settings=self.request_settings._to_rest_object() if self.request_settings else None,
+ liveness_probe=self.liveness_probe._to_rest_object() if self.liveness_probe else None,
+ instance_type=self.instance_type,
+ readiness_probe=self.readiness_probe._to_rest_object() if self.readiness_probe else None,
+ data_collector=self.data_collector._to_rest_object() if self.data_collector else None,
+ )
+ # TODO: SKU name is defaulted to value "Default" since service side requires it.
+ # Should be removed once service side defaults it.
+ sku = RestSku(name="Default", capacity=self.instance_count)
+
+ # mfe is expecting private network connection to be in both the attribute level
+ # as well as in the properties dictionary.
+ if hasattr(self, "private_network_connection") and self.private_network_connection:
+ properties.private_network_connection = self.private_network_connection
+ properties.properties["private-network-connection"] = self.private_network_connection
+ if hasattr(self, "egress_public_network_access") and self.egress_public_network_access:
+ properties.egress_public_network_access = self.egress_public_network_access
+ return RestOnlineDeploymentData(location=location, properties=properties, tags=self.tags, sku=sku)
+
+ def _to_arm_resource_param(self, **kwargs: Any) -> Dict:
+ rest_object = self._to_rest_object(**kwargs)
+ properties = rest_object.properties
+ sku = rest_object.sku
+ tags = rest_object.tags
+
+ return {
+ self._arm_type: {
+ ArmConstants.NAME: self.name,
+ ArmConstants.PROPERTIES_PARAMETER_NAME: self._serialize.body(properties, "ManagedOnlineDeployment"),
+ ArmConstants.SKU: self._serialize.body(sku, "Sku"),
+ ArmConstants.TAGS: tags,
+ }
+ }
+
+ @classmethod
+ def _from_rest_object(cls, resource: RestOnlineDeploymentData) -> "ManagedOnlineDeployment":
+ deployment = resource.properties
+
+ code_config = (
+ CodeConfiguration(
+ code=deployment.code_configuration.code_id,
+ scoring_script=deployment.code_configuration.scoring_script,
+ )
+ if deployment.code_configuration
+ else None
+ )
+
+ return ManagedOnlineDeployment(
+ id=resource.id,
+ name=resource.name,
+ tags=resource.tags,
+ properties=deployment.properties,
+ description=deployment.description,
+ request_settings=OnlineRequestSettings._from_rest_object(deployment.request_settings),
+ model=(deployment.model if deployment.model else None),
+ code_configuration=code_config,
+ environment=deployment.environment_id,
+ app_insights_enabled=deployment.app_insights_enabled,
+ scale_settings=OnlineScaleSettings._from_rest_object(deployment.scale_settings), # type: ignore
+ liveness_probe=ProbeSettings._from_rest_object(deployment.liveness_probe),
+ environment_variables=deployment.environment_variables,
+ readiness_probe=ProbeSettings._from_rest_object(deployment.readiness_probe),
+ instance_type=deployment.instance_type,
+ endpoint_name=_parse_endpoint_name_from_deployment_id(resource.id),
+ instance_count=resource.sku.capacity,
+ private_network_connection=(
+ deployment.private_network_connection if hasattr(deployment, "private_network_connection") else None
+ ),
+ egress_public_network_access=deployment.egress_public_network_access,
+ data_collector=(
+ DataCollector._from_rest_object(deployment.data_collector)
+ if hasattr(deployment, "data_collector") and deployment.data_collector
+ else None
+ ),
+ provisioning_state=deployment.provisioning_state if hasattr(deployment, "provisioning_state") else None,
+ creation_context=resource.system_data,
+ )
+
+ def _merge_with(self, other: Any) -> None:
+ if other:
+ super()._merge_with(other)
+ self.instance_type = other.instance_type or self.instance_type
+
+ def _validate(self) -> None:
+ self._validate_name()
+ self._validate_scale_settings()
+
+ def _validate_scale_settings(self) -> None:
+ if self.scale_settings:
+ if not isinstance(self.scale_settings, DefaultScaleSettings):
+ msg = "ManagedOnlineEndpoint supports DefaultScaleSettings only."
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.ONLINE_DEPLOYMENT,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/oversize_data_config.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/oversize_data_config.py
new file mode 100644
index 00000000..80338c39
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/oversize_data_config.py
@@ -0,0 +1,25 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Any, Dict, Optional
+
+from azure.ai.ml._schema._deployment.online.oversize_data_config_schema import OversizeDataConfigSchema
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+
+
+class OversizeDataConfig:
+ """Oversize Data Config deployment entity.
+
+ :param path: Blob path for Model Data Collector file.
+ :type path: str
+ """
+
+ # pylint: disable=unused-argument
+ def __init__(self, path: Optional[str] = None, **kwargs: Any):
+ self.path = path
+
+ def _to_dict(self) -> Dict:
+ # pylint: disable=no-member
+ res: dict = OversizeDataConfigSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/payload_response.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/payload_response.py
new file mode 100644
index 00000000..b67d46c7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/payload_response.py
@@ -0,0 +1,26 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Any, Dict, Optional
+
+from azure.ai.ml._schema._deployment.online.payload_response_schema import PayloadResponseSchema
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+
+
+class PayloadResponse:
+ """Response deployment entity
+
+ :param enabled: Is response logging enabled.
+ :type enabled: str
+
+ """
+
+ # pylint: disable=unused-argument
+ def __init__(self, enabled: Optional[str] = None, **kwargs: Any):
+ self.enabled = enabled
+
+ def _to_dict(self) -> Dict:
+ # pylint: disable=no-member
+ res: dict = PayloadResponseSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/pipeline_component_batch_deployment.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/pipeline_component_batch_deployment.py
new file mode 100644
index 00000000..730bc39e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/pipeline_component_batch_deployment.py
@@ -0,0 +1,150 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from os import PathLike
+from pathlib import Path
+from typing import IO, Any, AnyStr, Dict, Optional, Union
+
+from azure.ai.ml._restclient.v2024_01_01_preview.models import BatchDeployment as RestBatchDeployment
+from azure.ai.ml._restclient.v2024_01_01_preview.models import (
+ BatchDeploymentProperties,
+ BatchPipelineComponentDeploymentConfiguration,
+ IdAssetReference,
+)
+from azure.ai.ml._schema._deployment.batch.pipeline_component_batch_deployment_schema import (
+ PipelineComponentBatchDeploymentSchema,
+)
+from azure.ai.ml._utils._arm_id_utils import _parse_endpoint_name_from_deployment_id
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml._utils.utils import dump_yaml_to_file
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY
+from azure.ai.ml.entities import PipelineComponent
+from azure.ai.ml.entities._builders import BaseNode
+from azure.ai.ml.entities._component.component import Component
+from azure.ai.ml.entities._resource import Resource
+from azure.ai.ml.entities._util import load_from_dict
+
+
+@experimental
+class PipelineComponentBatchDeployment(Resource):
+ """Pipeline Component Batch Deployment entity.
+
+ :param type: Job definition type. Allowed value: "pipeline"
+ :type type: Optional[str]
+ :param name: Name of the deployment resource.
+ :type name: Optional[str]
+ :param description: Description of the deployment resource.
+ :type description: Optional[str]
+ :param component: Component definition.
+ :type component: Optional[Union[Component, str]]
+ :param settings: Run-time settings for the pipeline job.
+ :type settings: Optional[Dict[str, Any]]
+ :param tags: A set of tags. The tags which will be applied to the job.
+ :type tags: Optional[Dict[str, Any]]
+ :param job_definition: Arm ID or PipelineJob entity of an existing pipeline job.
+ :type job_definition: Optional[Dict[str, ~azure.ai.ml.entities._builders.BaseNode]]
+ :param endpoint_name: Name of the Endpoint resource, defaults to None.
+ :type endpoint_name: Optional[str]
+ """
+
+ def __init__(
+ self,
+ *,
+ name: Optional[str],
+ endpoint_name: Optional[str] = None,
+ component: Optional[Union[Component, str]] = None,
+ settings: Optional[Dict[str, str]] = None,
+ job_definition: Optional[Dict[str, BaseNode]] = None,
+ tags: Optional[Dict] = None,
+ description: Optional[str] = None,
+ **kwargs: Any,
+ ):
+ self._type = kwargs.pop("type", None)
+ super().__init__(name=name, tags=tags, description=description, **kwargs)
+ self.component = component
+ self.endpoint_name = endpoint_name
+ self.settings = settings
+ self.job_definition = job_definition
+
+ def _to_rest_object(self, location: str) -> "RestBatchDeployment":
+ if isinstance(self.component, PipelineComponent):
+ id_asset_ref = IdAssetReference(asset_id=self.component.id)
+
+ batch_pipeline_config = BatchPipelineComponentDeploymentConfiguration(
+ settings=self.settings,
+ tags=self.component.tags,
+ description=self.component.description,
+ component_id=id_asset_ref,
+ )
+ else:
+ id_asset_ref = IdAssetReference(asset_id=self.component)
+ batch_pipeline_config = BatchPipelineComponentDeploymentConfiguration(
+ settings=self.settings, component_id=id_asset_ref
+ )
+ return RestBatchDeployment(
+ location=location,
+ tags=self.tags,
+ properties=BatchDeploymentProperties(
+ deployment_configuration=batch_pipeline_config,
+ description=self.description,
+ ),
+ )
+
+ @classmethod
+ def _load(
+ cls,
+ data: Optional[Dict] = None,
+ yaml_path: Optional[Union[PathLike, str]] = None,
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "PipelineComponentBatchDeployment":
+ data = data or {}
+ params_override = params_override or []
+ cls._update_params(params_override)
+
+ context = {
+ BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path.cwd(),
+ PARAMS_OVERRIDE_KEY: params_override,
+ }
+ res: PipelineComponentBatchDeployment = load_from_dict(
+ PipelineComponentBatchDeploymentSchema, data, context, **kwargs
+ )
+ return res
+
+ @classmethod
+ def _update_params(cls, params_override: Any) -> None:
+ for param in params_override:
+ endpoint_name = param.get("endpoint_name")
+ if isinstance(endpoint_name, str):
+ param["endpoint_name"] = endpoint_name.lower()
+
+ @classmethod
+ def _from_rest_object(cls, deployment: RestBatchDeployment) -> "PipelineComponentBatchDeployment":
+ return PipelineComponentBatchDeployment(
+ name=deployment.name,
+ tags=deployment.tags,
+ component=deployment.properties.additional_properties["deploymentConfiguration"]["componentId"]["assetId"],
+ settings=deployment.properties.additional_properties["deploymentConfiguration"]["settings"],
+ endpoint_name=_parse_endpoint_name_from_deployment_id(deployment.id),
+ )
+
+ def dump(self, dest: Union[str, PathLike, IO[AnyStr]], **kwargs: Any) -> None:
+ """Dump the deployment content into a file in yaml format.
+
+ :param dest: The destination to receive this deployment's content.
+ Must be either a path to a local file, or an already-open file stream.
+ If dest is a file path, a new file will be created,
+ and an exception is raised if the file exists.
+ If dest is an open file, the file will be written to directly,
+ and an exception will be raised if the file is not writable.
+ :type dest: typing.Union[os.PathLike, str, typing.IO[typing.AnyStr]]
+ """
+ path = kwargs.pop("path", None)
+ yaml_serialized = self._to_dict()
+ dump_yaml_to_file(dest, yaml_serialized, default_flow_style=False, path=path, **kwargs)
+
+ def _to_dict(self) -> Dict:
+ res: dict = PipelineComponentBatchDeploymentSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+
+ return res
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/request_logging.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/request_logging.py
new file mode 100644
index 00000000..20cc83fe
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/request_logging.py
@@ -0,0 +1,39 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Any, Dict, List, Optional
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import RequestLogging as RestRequestLogging
+from azure.ai.ml._schema._deployment.online.request_logging_schema import RequestLoggingSchema
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+
+
+@experimental
+class RequestLogging:
+ """Request Logging deployment entity.
+
+ :param capture_headers: Request payload header.
+ :type capture_headers: list[str]
+ """
+
+ def __init__(
+ self,
+ *,
+ capture_headers: Optional[List[str]] = None,
+ **kwargs: Any,
+ ): # pylint: disable=unused-argument
+ self.capture_headers = capture_headers
+
+ @classmethod
+ def _from_rest_object(cls, rest_obj: RestRequestLogging) -> "RequestLogging":
+ return RequestLogging(capture_headers=rest_obj.capture_headers)
+
+ def _to_dict(self) -> Dict:
+ # pylint: disable=no-member
+ res: dict = RequestLoggingSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
+
+ def _to_rest_object(self) -> RestRequestLogging:
+ return RestRequestLogging(capture_headers=self.capture_headers)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/resource_requirements_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/resource_requirements_settings.py
new file mode 100644
index 00000000..9db61aae
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/resource_requirements_settings.py
@@ -0,0 +1,84 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import logging
+from typing import Optional
+
+from azure.ai.ml._restclient.v2022_05_01.models import ContainerResourceRequirements
+from azure.ai.ml.entities._deployment.container_resource_settings import ResourceSettings
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+module_logger = logging.getLogger(__name__)
+
+
+class ResourceRequirementsSettings(RestTranslatableMixin):
+ """Resource requirements settings for a container.
+
+ :param requests: The minimum resource requests for a container.
+ :type requests: Optional[~azure.ai.ml.entities.ResourceSettings]
+ :param limits: The resource limits for a container.
+ :type limits: Optional[~azure.ai.ml.entities.ResourceSettings]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START resource_requirements_configuration]
+ :end-before: [END resource_requirements_configuration]
+ :language: python
+ :dedent: 8
+ :caption: Configuring ResourceRequirementSettings for a Kubernetes deployment.
+ """
+
+ def __init__(
+ self,
+ requests: Optional[ResourceSettings] = None,
+ limits: Optional[ResourceSettings] = None,
+ ) -> None:
+ self.requests = requests
+ self.limits = limits
+
+ def _to_rest_object(self) -> ContainerResourceRequirements:
+ return ContainerResourceRequirements(
+ container_resource_requests=self.requests._to_rest_object() if self.requests else None,
+ container_resource_limits=self.limits._to_rest_object() if self.limits else None,
+ )
+
+ @classmethod
+ def _from_rest_object( # pylint: disable=arguments-renamed
+ cls, settings: ContainerResourceRequirements
+ ) -> Optional["ResourceRequirementsSettings"]:
+ requests = settings.container_resource_requests
+ limits = settings.container_resource_limits
+ return (
+ ResourceRequirementsSettings(
+ requests=ResourceSettings._from_rest_object(requests),
+ limits=ResourceSettings._from_rest_object(limits),
+ )
+ if settings
+ else None
+ )
+
+ def _merge_with(self, other: Optional["ResourceRequirementsSettings"]) -> None:
+ if other:
+ if self.requests:
+ self.requests._merge_with(other.requests)
+ else:
+ self.requests = other.requests
+ if self.limits:
+ self.limits._merge_with(other.limits)
+ else:
+ self.limits = other.limits
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, ResourceRequirementsSettings):
+ return NotImplemented
+ if not other:
+ return False
+ # only compare mutable fields
+ return self.requests == other.requests and self.limits == other.limits
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/run_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/run_settings.py
new file mode 100644
index 00000000..f1deac83
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/run_settings.py
@@ -0,0 +1,50 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Any, Dict, Optional
+
+from azure.ai.ml._schema._deployment.batch.run_settings_schema import RunSettingsSchema
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+
+
+@experimental
+class RunSettings:
+ """Run Settings entity.
+
+ :param name: Run settings name
+ :type name: str
+ :param display_name: Run settings display name
+ :type display_name: str
+ :param experiment_name: Run settings experiment name
+ :type experiment_name: str
+ :param description: Run settings description
+ :type description: str
+ :param tags: Run settings tags
+ :type tags: Dict[str, Any]
+ :param settings: Run settings - settings
+ :type settings: Dict[str, Any]
+ """
+
+ def __init__(
+ self,
+ name: Optional[str] = None,
+ display_name: Optional[str] = None,
+ experiment_name: Optional[str] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict[str, Any]] = None,
+ settings: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+ ): # pylint: disable=unused-argument
+ self.name = name
+ self.display_name = display_name
+ self.experiment_name = experiment_name
+ self.description = description
+ self.tags = tags
+ self.settings = settings
+
+ def _to_dict(self) -> Dict:
+ # pylint: disable=no-member
+ res: dict = RunSettingsSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/scale_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/scale_settings.py
new file mode 100644
index 00000000..85535ca0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_deployment/scale_settings.py
@@ -0,0 +1,173 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import logging
+from abc import abstractmethod
+from typing import Any, Optional
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import DefaultScaleSettings as RestDefaultScaleSettings
+from azure.ai.ml._restclient.v2023_04_01_preview.models import OnlineScaleSettings as RestOnlineScaleSettings
+from azure.ai.ml._restclient.v2023_04_01_preview.models import ScaleType
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ TargetUtilizationScaleSettings as RestTargetUtilizationScaleSettings,
+)
+from azure.ai.ml._utils.utils import camel_to_snake, from_iso_duration_format, to_iso_duration_format
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+from azure.ai.ml.exceptions import DeploymentException, ErrorCategory, ErrorTarget
+
+module_logger = logging.getLogger(__name__)
+
+
+class OnlineScaleSettings(RestTranslatableMixin):
+ """Scale settings for online deployment.
+
+ :param type: Type of the scale settings, allowed values are "default" and "target_utilization".
+ :type type: str
+ """
+
+ def __init__(
+ self,
+ # pylint: disable=redefined-builtin
+ type: str,
+ # pylint: disable=unused-argument
+ **kwargs: Any,
+ ):
+ self.type = camel_to_snake(type)
+
+ @abstractmethod
+ def _to_rest_object(self) -> RestOnlineScaleSettings:
+ pass
+
+ def _merge_with(self, other: Any) -> None:
+ if other:
+ self.type = other.type or self.type
+
+ @classmethod
+ def _from_rest_object( # pylint: disable=arguments-renamed
+ cls, settings: RestOnlineScaleSettings
+ ) -> "OnlineScaleSettings":
+ if settings.scale_type == "Default":
+ return DefaultScaleSettings._from_rest_object(settings)
+ if settings.scale_type == "TargetUtilization":
+ return TargetUtilizationScaleSettings._from_rest_object(settings)
+
+ msg = f"Unsupported online scale setting type {settings.scale_type}."
+ raise DeploymentException(
+ message=msg,
+ target=ErrorTarget.ONLINE_DEPLOYMENT,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.SYSTEM_ERROR,
+ )
+
+
+class DefaultScaleSettings(OnlineScaleSettings):
+ """Default scale settings.
+
+ :ivar type: Default scale settings type. Set automatically to "default" for this class.
+ :vartype type: str
+ """
+
+ def __init__(self, **kwargs: Any):
+ super(DefaultScaleSettings, self).__init__(
+ type=ScaleType.DEFAULT.value,
+ )
+
+ def _to_rest_object(self) -> RestDefaultScaleSettings:
+ return RestDefaultScaleSettings()
+
+ @classmethod
+ def _from_rest_object(cls, settings: RestDefaultScaleSettings) -> "DefaultScaleSettings":
+ return DefaultScaleSettings()
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, DefaultScaleSettings):
+ return NotImplemented
+ if not other:
+ return False
+ # only compare mutable fields
+ res: bool = self.type.lower() == other.type.lower()
+ return res
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
+
+
+class TargetUtilizationScaleSettings(OnlineScaleSettings):
+ """Auto scale settings.
+
+ :param min_instances: Minimum number of the instances
+ :type min_instances: int
+ :param max_instances: Maximum number of the instances
+ :type max_instances: int
+ :param polling_interval: The polling interval in ISO 8691 format. Only supports duration with
+ precision as low as Seconds.
+ :type polling_interval: str
+ :param target_utilization_percentage:
+ :type target_utilization_percentage: int
+ :ivar type: Target utilization scale settings type. Set automatically to "target_utilization" for this class.
+ :vartype type: str
+ """
+
+ def __init__(
+ self,
+ *,
+ min_instances: Optional[int] = None,
+ max_instances: Optional[int] = None,
+ polling_interval: Optional[int] = None,
+ target_utilization_percentage: Optional[int] = None,
+ **kwargs: Any,
+ ):
+ super(TargetUtilizationScaleSettings, self).__init__(
+ type=ScaleType.TARGET_UTILIZATION.value,
+ )
+ self.min_instances = min_instances
+ self.max_instances = max_instances
+ self.polling_interval = polling_interval
+ self.target_utilization_percentage = target_utilization_percentage
+
+ def _to_rest_object(self) -> RestTargetUtilizationScaleSettings:
+ return RestTargetUtilizationScaleSettings(
+ min_instances=self.min_instances,
+ max_instances=self.max_instances,
+ polling_interval=to_iso_duration_format(self.polling_interval),
+ target_utilization_percentage=self.target_utilization_percentage,
+ )
+
+ def _merge_with(self, other: Optional["TargetUtilizationScaleSettings"]) -> None:
+ if other:
+ super()._merge_with(other)
+ self.min_instances = other.min_instances or self.min_instances
+ self.max_instances = other.max_instances or self.max_instances
+ self.polling_interval = other.polling_interval or self.polling_interval
+ self.target_utilization_percentage = (
+ other.target_utilization_percentage or self.target_utilization_percentage
+ )
+
+ @classmethod
+ def _from_rest_object(cls, settings: RestTargetUtilizationScaleSettings) -> "TargetUtilizationScaleSettings":
+ return cls(
+ min_instances=settings.min_instances,
+ max_instances=settings.max_instances,
+ polling_interval=from_iso_duration_format(settings.polling_interval),
+ target_utilization_percentage=settings.target_utilization_percentage,
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, TargetUtilizationScaleSettings):
+ return NotImplemented
+ if not other:
+ return False
+ # only compare mutable fields
+ return (
+ self.type.lower() == other.type.lower()
+ and self.min_instances == other.min_instances
+ and self.max_instances == other.max_instances
+ and self.polling_interval == other.polling_interval
+ and self.target_utilization_percentage == other.target_utilization_percentage
+ )
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/__init__.py
new file mode 100644
index 00000000..fdf8caba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/_endpoint_helpers.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/_endpoint_helpers.py
new file mode 100644
index 00000000..5d62a229
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/_endpoint_helpers.py
@@ -0,0 +1,62 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import re
+from typing import Any, Optional
+
+from azure.ai.ml.constants._endpoint import EndpointConfigurations
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+
+
+def validate_endpoint_or_deployment_name(name: Optional[str], is_deployment: bool = False) -> None:
+ """Validates the name of an endpoint or a deployment
+
+ A valid name of an endpoint or deployment:
+
+ 1. Is between 3 and 32 characters long (inclusive of both ends of the range)
+ 2. Starts with a letter
+ 3. Is followed by 0 or more alphanumeric characters (`a-zA-Z0-9`) or hyphens (`-`)
+ 3. Ends with an alphanumeric character (`a-zA-Z0-9`)
+
+ :param name: Either an endpoint or deployment name
+ :type name: str
+ :param is_deployment: Whether the name is a deployment name. Defaults to False
+ :type is_deployment: bool
+ """
+ if name is None:
+ return
+
+ type_str = "a deployment" if is_deployment else "an endpoint"
+ target = ErrorTarget.DEPLOYMENT if is_deployment else ErrorTarget.ENDPOINT
+ if len(name) < EndpointConfigurations.MIN_NAME_LENGTH or len(name) > EndpointConfigurations.MAX_NAME_LENGTH:
+ msg = f"The name for {type_str} must be at least 3 and at most 32 characters long (inclusive of both limits)."
+ raise ValidationException(
+ message=msg,
+ target=target,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ if not re.match(EndpointConfigurations.NAME_REGEX_PATTERN, name):
+ msg = f"""The name for {type_str} must start with an upper- or lowercase letter
+ and only consist of '-'s and alphanumeric characters."""
+ raise ValidationException(
+ message=msg,
+ target=target,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+
+def validate_identity_type_defined(identity: Any) -> None:
+ if identity and not identity.type:
+ msg = "Identity type not found in provided yaml file."
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.ENDPOINT,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.MISSING_FIELD,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/batch_endpoint.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/batch_endpoint.py
new file mode 100644
index 00000000..4883c828
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/batch_endpoint.py
@@ -0,0 +1,134 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import logging
+from os import PathLike
+from pathlib import Path
+from typing import IO, Any, AnyStr, Dict, Optional, Union
+
+from azure.ai.ml._restclient.v2023_10_01.models import BatchEndpoint as BatchEndpointData
+from azure.ai.ml._restclient.v2023_10_01.models import BatchEndpointProperties as RestBatchEndpoint
+from azure.ai.ml._schema._endpoint import BatchEndpointSchema
+from azure.ai.ml._utils.utils import camel_to_snake, snake_to_camel
+from azure.ai.ml.constants._common import AAD_TOKEN_YAML, BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY
+from azure.ai.ml.entities._endpoint._endpoint_helpers import validate_endpoint_or_deployment_name
+from azure.ai.ml.entities._util import load_from_dict
+
+from .endpoint import Endpoint
+
+module_logger = logging.getLogger(__name__)
+
+
+class BatchEndpoint(Endpoint):
+ """Batch endpoint entity.
+
+ :param name: Name of the resource.
+ :type name: str
+ :param tags: Tag dictionary. Tags can be added, removed, and updated.
+ :type tags: dict[str, str]
+ :param properties: The asset property dictionary.
+ :type properties: dict[str, str]
+ :param auth_mode: Possible values include: "AMLToken", "Key", "AADToken", defaults to None
+ :type auth_mode: str
+ :param description: Description of the inference endpoint, defaults to None
+ :type description: str
+ :param location: defaults to None
+ :type location: str
+ :param defaults: Traffic rules on how the traffic will be routed across deployments, defaults to {}
+ :type defaults: Dict[str, str]
+ :param default_deployment_name: Equivalent to defaults.default_deployment, will be ignored if defaults is present.
+ :type default_deployment_name: str
+ :param scoring_uri: URI to use to perform a prediction, readonly.
+ :type scoring_uri: str
+ :param openapi_uri: URI to check the open API definition of the endpoint.
+ :type openapi_uri: str
+ """
+
+ def __init__(
+ self,
+ *,
+ name: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ properties: Optional[Dict] = None,
+ auth_mode: str = AAD_TOKEN_YAML,
+ description: Optional[str] = None,
+ location: Optional[str] = None,
+ defaults: Optional[Dict[str, str]] = None,
+ default_deployment_name: Optional[str] = None,
+ scoring_uri: Optional[str] = None,
+ openapi_uri: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ super(BatchEndpoint, self).__init__(
+ name=name,
+ tags=tags,
+ properties=properties,
+ auth_mode=auth_mode,
+ description=description,
+ location=location,
+ scoring_uri=scoring_uri,
+ openapi_uri=openapi_uri,
+ **kwargs,
+ )
+
+ self.defaults = defaults
+
+ if not self.defaults and default_deployment_name:
+ self.defaults = {}
+ self.defaults["deployment_name"] = default_deployment_name
+
+ def _to_rest_batch_endpoint(self, location: str) -> BatchEndpointData:
+ validate_endpoint_or_deployment_name(self.name)
+ batch_endpoint = RestBatchEndpoint(
+ description=self.description,
+ auth_mode=snake_to_camel(self.auth_mode),
+ properties=self.properties,
+ defaults=self.defaults,
+ )
+ return BatchEndpointData(location=location, tags=self.tags, properties=batch_endpoint)
+
+ @classmethod
+ def _from_rest_object(cls, obj: BatchEndpointData) -> "BatchEndpoint":
+ return BatchEndpoint(
+ id=obj.id,
+ name=obj.name,
+ tags=obj.tags,
+ properties=obj.properties.properties,
+ auth_mode=camel_to_snake(obj.properties.auth_mode),
+ description=obj.properties.description,
+ location=obj.location,
+ defaults=obj.properties.defaults,
+ provisioning_state=obj.properties.provisioning_state,
+ scoring_uri=obj.properties.scoring_uri,
+ openapi_uri=obj.properties.swagger_uri,
+ )
+
+ def dump(
+ self,
+ dest: Optional[Union[str, PathLike, IO[AnyStr]]] = None,
+ **kwargs: Any,
+ ) -> Dict[str, Any]:
+ context = {BASE_PATH_CONTEXT_KEY: Path(".").parent}
+ return BatchEndpointSchema(context=context).dump(self) # type: ignore
+
+ @classmethod
+ def _load(
+ cls,
+ data: Optional[Dict] = None,
+ yaml_path: Optional[Union[PathLike, str]] = None,
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "BatchEndpoint":
+ data = data or {}
+ params_override = params_override or []
+ context = {
+ BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path.cwd(),
+ PARAMS_OVERRIDE_KEY: params_override,
+ }
+ res: BatchEndpoint = load_from_dict(BatchEndpointSchema, data, context)
+ return res
+
+ def _to_dict(self) -> Dict:
+ res: dict = BatchEndpointSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/endpoint.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/endpoint.py
new file mode 100644
index 00000000..d878742e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/endpoint.py
@@ -0,0 +1,145 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import logging
+from abc import abstractmethod
+from os import PathLike
+from typing import IO, Any, AnyStr, Dict, Optional, Union
+
+from azure.ai.ml.entities._resource import Resource
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+
+module_logger = logging.getLogger(__name__)
+
+
+class Endpoint(Resource): # pylint: disable=too-many-instance-attributes
+ """Endpoint base class.
+
+ :param auth_mode: The authentication mode, defaults to None
+ :type auth_mode: str
+ :param location: The location of the endpoint, defaults to None
+ :type location: str
+ :param name: Name of the resource.
+ :type name: str
+ :param tags: Tag dictionary. Tags can be added, removed, and updated.
+ :type tags: typing.Optional[typing.Dict[str, str]]
+ :param properties: The asset property dictionary.
+ :type properties: typing.Optional[typing.Dict[str, str]]
+ :param description: Description of the resource.
+ :type description: typing.Optional[str]
+ :keyword traffic: Traffic rules on how the traffic will be routed across deployments, defaults to {}
+ :paramtype traffic: typing.Optional[typing.Dict[str, int]]
+ :keyword scoring_uri: str, Endpoint URI, readonly
+ :paramtype scoring_uri: typing.Optional[str]
+ :keyword openapi_uri: str, Endpoint Open API URI, readonly
+ :paramtype openapi_uri: typing.Optional[str]
+ :keyword provisioning_state: str, provisioning state, readonly
+ :paramtype provisioning_state: typing.Optional[str]
+ """
+
+ def __init__(
+ self,
+ auth_mode: Optional[str] = None,
+ location: Optional[str] = None,
+ name: Optional[str] = None,
+ tags: Optional[Dict[str, str]] = None,
+ properties: Optional[Dict[str, Any]] = None,
+ description: Optional[str] = None,
+ **kwargs: Any,
+ ):
+ """Endpoint base class.
+
+ Constructor for Endpoint base class.
+
+ :param auth_mode: The authentication mode, defaults to None
+ :type auth_mode: str
+ :param location: The location of the endpoint, defaults to None
+ :type location: str
+ :param name: Name of the resource.
+ :type name: str
+ :param tags: Tag dictionary. Tags can be added, removed, and updated.
+ :type tags: typing.Optional[typing.Dict[str, str]]
+ :param properties: The asset property dictionary.
+ :type properties: typing.Optional[typing.Dict[str, str]]
+ :param description: Description of the resource.
+ :type description: typing.Optional[str]
+ :keyword traffic: Traffic rules on how the traffic will be routed across deployments, defaults to {}
+ :paramtype traffic: typing.Optional[typing.Dict[str, int]]
+ :keyword scoring_uri: str, Endpoint URI, readonly
+ :paramtype scoring_uri: typing.Optional[str]
+ :keyword openapi_uri: str, Endpoint Open API URI, readonly
+ :paramtype openapi_uri: typing.Optional[str]
+ :keyword provisioning_state: str, provisioning state, readonly
+ :paramtype provisioning_state: typing.Optional[str]
+ """
+ # MFE is case-insensitive for Name. So convert the name into lower case here.
+ if name:
+ name = name.lower()
+ self._scoring_uri: Optional[str] = kwargs.pop("scoring_uri", None)
+ self._openapi_uri: Optional[str] = kwargs.pop("openapi_uri", None)
+ self._provisioning_state: Optional[str] = kwargs.pop("provisioning_state", None)
+ super().__init__(name, description, tags, properties, **kwargs)
+ self.auth_mode = auth_mode
+ self.location = location
+
+ @property
+ def scoring_uri(self) -> Optional[str]:
+ """URI to use to perform a prediction, readonly.
+
+ :return: The scoring URI
+ :rtype: typing.Optional[str]
+ """
+ return self._scoring_uri
+
+ @property
+ def openapi_uri(self) -> Optional[str]:
+ """URI to check the open api definition of the endpoint.
+
+ :return: The open API URI
+ :rtype: typing.Optional[str]
+ """
+ return self._openapi_uri
+
+ @property
+ def provisioning_state(self) -> Optional[str]:
+ """Endpoint provisioning state, readonly.
+
+ :return: Endpoint provisioning state.
+ :rtype: typing.Optional[str]
+ """
+ return self._provisioning_state
+
+ @abstractmethod
+ def dump(self, dest: Optional[Union[str, PathLike, IO[AnyStr]]] = None, **kwargs: Any) -> Dict:
+ pass
+
+ @classmethod
+ @abstractmethod
+ def _from_rest_object(cls, obj: Any) -> Any:
+ pass
+
+ def _merge_with(self, other: Any) -> None:
+ if other:
+ if self.name != other.name:
+ msg = "The endpoint name: {} and {} are not matched when merging."
+ raise ValidationException(
+ message=msg.format(self.name, other.name),
+ target=ErrorTarget.ENDPOINT,
+ no_personal_data_message=msg.format("[name1]", "[name2]"),
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ self.description = other.description or self.description
+ if other.tags:
+ if self.tags is not None:
+ self.tags = {**self.tags, **other.tags}
+ if other.properties:
+ self.properties = {**self.properties, **other.properties}
+ self.auth_mode = other.auth_mode or self.auth_mode
+ if hasattr(other, "traffic"):
+ self.traffic = other.traffic # pylint: disable=attribute-defined-outside-init
+ if hasattr(other, "mirror_traffic"):
+ self.mirror_traffic = other.mirror_traffic # pylint: disable=attribute-defined-outside-init
+ if hasattr(other, "defaults"):
+ self.defaults = other.defaults # pylint: disable=attribute-defined-outside-init
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/online_endpoint.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/online_endpoint.py
new file mode 100644
index 00000000..cdd72536
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_endpoint/online_endpoint.py
@@ -0,0 +1,647 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=no-member
+
+import logging
+from os import PathLike
+from pathlib import Path
+from typing import IO, Any, AnyStr, Dict, Optional, Union, cast
+
+from azure.ai.ml._restclient.v2022_02_01_preview.models import EndpointAuthKeys as RestEndpointAuthKeys
+from azure.ai.ml._restclient.v2022_02_01_preview.models import EndpointAuthMode
+from azure.ai.ml._restclient.v2022_02_01_preview.models import EndpointAuthToken as RestEndpointAuthToken
+from azure.ai.ml._restclient.v2022_02_01_preview.models import OnlineEndpointData
+from azure.ai.ml._restclient.v2022_02_01_preview.models import OnlineEndpointDetails as RestOnlineEndpoint
+from azure.ai.ml._restclient.v2022_05_01.models import ManagedServiceIdentity as RestManagedServiceIdentityConfiguration
+from azure.ai.ml._schema._endpoint import KubernetesOnlineEndpointSchema, ManagedOnlineEndpointSchema
+from azure.ai.ml._utils.utils import dict_eq
+from azure.ai.ml.constants._common import (
+ AAD_TOKEN_YAML,
+ AML_TOKEN_YAML,
+ BASE_PATH_CONTEXT_KEY,
+ KEY,
+ PARAMS_OVERRIDE_KEY,
+)
+from azure.ai.ml.constants._endpoint import EndpointYamlFields
+from azure.ai.ml.entities._credentials import IdentityConfiguration
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+from azure.ai.ml.entities._util import is_compute_in_override, load_from_dict
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+from azure.core.credentials import AccessToken
+
+from ._endpoint_helpers import validate_endpoint_or_deployment_name, validate_identity_type_defined
+from .endpoint import Endpoint
+
+module_logger = logging.getLogger(__name__)
+
+
+class OnlineEndpoint(Endpoint):
+ """Online endpoint entity.
+
+ :keyword name: Name of the resource, defaults to None
+ :paramtype name: typing.Optional[str]
+ :keyword tags: Tag dictionary. Tags can be added, removed, and updated. defaults to None
+ :paramtype tags: typing.Optional[typing.Dict[str, typing.Any]]
+ :keyword properties: The asset property dictionary, defaults to None
+ :paramtype properties: typing.Optional[typing.Dict[str, typing.Any]]
+ :keyword auth_mode: Possible values include: "aml_token", "key", defaults to KEY
+ :type auth_mode: typing.Optional[str]
+ :keyword description: Description of the inference endpoint, defaults to None
+ :paramtype description: typing.Optional[str]
+ :keyword location: Location of the resource, defaults to None
+ :paramtype location: typing.Optional[str]
+ :keyword traffic: Traffic rules on how the traffic will be routed across deployments, defaults to None
+ :paramtype traffic: typing.Optional[typing.Dict[str, int]]
+ :keyword mirror_traffic: Duplicated live traffic used to inference a single deployment, defaults to None
+ :paramtype mirror_traffic: typing.Optional[typing.Dict[str, int]]
+ :keyword identity: Identity Configuration, defaults to SystemAssigned
+ :paramtype identity: typing.Optional[IdentityConfiguration]
+ :keyword scoring_uri: Scoring URI, defaults to None
+ :paramtype scoring_uri: typing.Optional[str]
+ :keyword openapi_uri: OpenAPI URI, defaults to None
+ :paramtype openapi_uri: typing.Optional[str]
+ :keyword provisioning_state: Provisioning state of an endpoint, defaults to None
+ :paramtype provisioning_state: typing.Optional[str]
+ :keyword kind: Kind of the resource, we have two kinds: K8s and Managed online endpoints, defaults to None
+ :paramtype kind: typing.Optional[str]
+ """
+
+ def __init__(
+ self,
+ *,
+ name: Optional[str] = None,
+ tags: Optional[Dict[str, Any]] = None,
+ properties: Optional[Dict[str, Any]] = None,
+ auth_mode: str = KEY,
+ description: Optional[str] = None,
+ location: Optional[str] = None,
+ traffic: Optional[Dict[str, int]] = None,
+ mirror_traffic: Optional[Dict[str, int]] = None,
+ identity: Optional[IdentityConfiguration] = None,
+ scoring_uri: Optional[str] = None,
+ openapi_uri: Optional[str] = None,
+ provisioning_state: Optional[str] = None,
+ kind: Optional[str] = None,
+ **kwargs: Any,
+ ):
+ """Online endpoint entity.
+
+ Constructor for an Online endpoint entity.
+
+ :keyword name: Name of the resource, defaults to None
+ :paramtype name: typing.Optional[str]
+ :keyword tags: Tag dictionary. Tags can be added, removed, and updated. defaults to None
+ :paramtype tags: typing.Optional[typing.Dict[str, typing.Any]]
+ :keyword properties: The asset property dictionary, defaults to None
+ :paramtype properties: typing.Optional[typing.Dict[str, typing.Any]]
+ :keyword auth_mode: Possible values include: "aml_token", "key", defaults to KEY
+ :type auth_mode: typing.Optional[str]
+ :keyword description: Description of the inference endpoint, defaults to None
+ :paramtype description: typing.Optional[str]
+ :keyword location: Location of the resource, defaults to None
+ :paramtype location: typing.Optional[str]
+ :keyword traffic: Traffic rules on how the traffic will be routed across deployments, defaults to None
+ :paramtype traffic: typing.Optional[typing.Dict[str, int]]
+ :keyword mirror_traffic: Duplicated live traffic used to inference a single deployment, defaults to None
+ :paramtype mirror_traffic: typing.Optional[typing.Dict[str, int]]
+ :keyword identity: Identity Configuration, defaults to SystemAssigned
+ :paramtype identity: typing.Optional[IdentityConfiguration]
+ :keyword scoring_uri: Scoring URI, defaults to None
+ :paramtype scoring_uri: typing.Optional[str]
+ :keyword openapi_uri: OpenAPI URI, defaults to None
+ :paramtype openapi_uri: typing.Optional[str]
+ :keyword provisioning_state: Provisioning state of an endpoint, defaults to None
+ :paramtype provisioning_state: typing.Optional[str]
+ :keyword kind: Kind of the resource, we have two kinds: K8s and Managed online endpoints, defaults to None
+ :type kind: typing.Optional[str]
+ """
+ self._provisioning_state = kwargs.pop("provisioning_state", None)
+
+ super(OnlineEndpoint, self).__init__(
+ name=name,
+ properties=properties,
+ tags=tags,
+ auth_mode=auth_mode,
+ description=description,
+ location=location,
+ scoring_uri=scoring_uri,
+ openapi_uri=openapi_uri,
+ provisioning_state=provisioning_state,
+ **kwargs,
+ )
+
+ self.identity = identity
+ self.traffic: Dict = dict(traffic) if traffic else {}
+ self.mirror_traffic: Dict = dict(mirror_traffic) if mirror_traffic else {}
+ self.kind = kind
+
+ @property
+ def provisioning_state(self) -> Optional[str]:
+ """Endpoint provisioning state, readonly.
+
+ :return: Endpoint provisioning state.
+ :rtype: typing.Optional[str]
+ """
+ return self._provisioning_state
+
+ def _to_rest_online_endpoint(self, location: str) -> OnlineEndpointData:
+ # pylint: disable=protected-access
+ identity = (
+ self.identity._to_online_endpoint_rest_object()
+ if self.identity
+ else RestManagedServiceIdentityConfiguration(type="SystemAssigned")
+ )
+ validate_endpoint_or_deployment_name(self.name)
+ validate_identity_type_defined(self.identity)
+ properties = RestOnlineEndpoint(
+ description=self.description,
+ auth_mode=OnlineEndpoint._yaml_auth_mode_to_rest_auth_mode(self.auth_mode),
+ properties=self.properties,
+ traffic=self.traffic,
+ mirror_traffic=self.mirror_traffic,
+ )
+
+ if hasattr(self, "public_network_access") and self.public_network_access:
+ properties.public_network_access = self.public_network_access
+ return OnlineEndpointData(
+ location=location,
+ properties=properties,
+ identity=identity,
+ tags=self.tags,
+ )
+
+ def _to_rest_online_endpoint_traffic_update(self, location: str, no_validation: bool = False) -> OnlineEndpointData:
+ if not no_validation:
+ # validate_deployment_name_matches_traffic(self.deployments, self.traffic)
+ validate_identity_type_defined(self.identity)
+ # validate_uniqueness_of_deployment_names(self.deployments)
+ properties = RestOnlineEndpoint(
+ description=self.description,
+ auth_mode=OnlineEndpoint._yaml_auth_mode_to_rest_auth_mode(self.auth_mode),
+ endpoint=self.name,
+ traffic=self.traffic,
+ properties=self.properties,
+ )
+ return OnlineEndpointData(
+ location=location,
+ properties=properties,
+ identity=self.identity,
+ tags=self.tags,
+ )
+
+ @classmethod
+ def _rest_auth_mode_to_yaml_auth_mode(cls, rest_auth_mode: str) -> str:
+ switcher = {
+ EndpointAuthMode.AML_TOKEN: AML_TOKEN_YAML,
+ EndpointAuthMode.AAD_TOKEN: AAD_TOKEN_YAML,
+ EndpointAuthMode.KEY: KEY,
+ }
+
+ return switcher.get(rest_auth_mode, rest_auth_mode)
+
+ @classmethod
+ def _yaml_auth_mode_to_rest_auth_mode(cls, yaml_auth_mode: Optional[str]) -> str:
+ if yaml_auth_mode is None:
+ return ""
+
+ yaml_auth_mode = yaml_auth_mode.lower()
+
+ switcher = {
+ AML_TOKEN_YAML: EndpointAuthMode.AML_TOKEN,
+ AAD_TOKEN_YAML: EndpointAuthMode.AAD_TOKEN,
+ KEY: EndpointAuthMode.KEY,
+ }
+
+ return switcher.get(yaml_auth_mode, yaml_auth_mode)
+
+ @classmethod
+ def _from_rest_object(cls, obj: OnlineEndpointData) -> "OnlineEndpoint":
+ auth_mode = cls._rest_auth_mode_to_yaml_auth_mode(obj.properties.auth_mode)
+ # pylint: disable=protected-access
+ identity = IdentityConfiguration._from_online_endpoint_rest_object(obj.identity) if obj.identity else None
+
+ endpoint: Any = KubernetesOnlineEndpoint()
+
+ if obj.system_data:
+ properties_dict = {
+ "createdBy": obj.system_data.created_by,
+ "createdAt": obj.system_data.created_at.strftime("%Y-%m-%dT%H:%M:%S.%f%z"),
+ "lastModifiedAt": obj.system_data.last_modified_at.strftime("%Y-%m-%dT%H:%M:%S.%f%z"),
+ }
+ properties_dict.update(obj.properties.properties)
+ else:
+ properties_dict = obj.properties.properties
+
+ if obj.properties.compute:
+ endpoint = KubernetesOnlineEndpoint(
+ id=obj.id,
+ name=obj.name,
+ tags=obj.tags,
+ properties=properties_dict,
+ compute=obj.properties.compute,
+ auth_mode=auth_mode,
+ description=obj.properties.description,
+ location=obj.location,
+ traffic=obj.properties.traffic,
+ provisioning_state=obj.properties.provisioning_state,
+ scoring_uri=obj.properties.scoring_uri,
+ openapi_uri=obj.properties.swagger_uri,
+ identity=identity,
+ kind=obj.kind,
+ )
+ else:
+ endpoint = ManagedOnlineEndpoint(
+ id=obj.id,
+ name=obj.name,
+ tags=obj.tags,
+ properties=properties_dict,
+ auth_mode=auth_mode,
+ description=obj.properties.description,
+ location=obj.location,
+ traffic=obj.properties.traffic,
+ mirror_traffic=obj.properties.mirror_traffic,
+ provisioning_state=obj.properties.provisioning_state,
+ scoring_uri=obj.properties.scoring_uri,
+ openapi_uri=obj.properties.swagger_uri,
+ identity=identity,
+ kind=obj.kind,
+ public_network_access=obj.properties.public_network_access,
+ )
+
+ return cast(OnlineEndpoint, endpoint)
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, OnlineEndpoint):
+ return NotImplemented
+ if not other:
+ return False
+ if self.auth_mode is None or other.auth_mode is None:
+ return False
+
+ if self.name is None and other.name is None:
+ return (
+ self.auth_mode.lower() == other.auth_mode.lower()
+ and dict_eq(self.tags, other.tags)
+ and self.description == other.description
+ and dict_eq(self.traffic, other.traffic)
+ )
+
+ if self.name is not None and other.name is not None:
+ # only compare mutable fields
+ return (
+ self.name.lower() == other.name.lower()
+ and self.auth_mode.lower() == other.auth_mode.lower()
+ and dict_eq(self.tags, other.tags)
+ and self.description == other.description
+ and dict_eq(self.traffic, other.traffic)
+ )
+
+ return False
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
+
+ @classmethod
+ def _load(
+ cls,
+ data: Optional[Dict] = None,
+ yaml_path: Optional[Union[PathLike, str]] = None,
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "Endpoint":
+ data = data or {}
+ params_override = params_override or []
+ context = {
+ BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path.cwd(),
+ PARAMS_OVERRIDE_KEY: params_override,
+ }
+
+ if data.get(EndpointYamlFields.COMPUTE) or is_compute_in_override(params_override):
+ res_kub: Endpoint = load_from_dict(KubernetesOnlineEndpointSchema, data, context)
+ return res_kub
+
+ res_managed: Endpoint = load_from_dict(ManagedOnlineEndpointSchema, data, context)
+ return res_managed
+
+
+class KubernetesOnlineEndpoint(OnlineEndpoint):
+ """K8s Online endpoint entity.
+
+ :keyword name: Name of the resource, defaults to None
+ :paramtype name: typing.Optional[str]
+ :keyword tags: Tag dictionary. Tags can be added, removed, and updated, defaults to None
+ :paramtype tags: typing.Optional[typing.Dict[str, typing.Any]]
+ :keyword properties: The asset property dictionary, defaults to None
+ :paramtype properties: typing.Optional[typing.Dict[str, typing.Any]]
+ :keyword auth_mode: Possible values include: "aml_token", "key", defaults to KEY
+ :type auth_mode: typing.Optional[str]
+ :keyword description: Description of the inference endpoint, defaults to None
+ :paramtype description: typing.Optional[str]
+ :keyword location: Location of the resource, defaults to None
+ :paramtype location: typing.Optional[str]
+ :keyword traffic: Traffic rules on how the traffic will be routed across deployments, defaults to None
+ :paramtype traffic: typing.Optional[typing.Dict[str, int]]
+ :keyword mirror_traffic: Duplicated live traffic used to inference a single deployment, defaults to None
+ :paramtype mirror_traffic: typing.Optional[typing.Dict[str, int]]
+ :keyword compute: Compute cluster id, defaults to None
+ :paramtype compute: typing.Optional[str]
+ :keyword identity: Identity Configuration, defaults to SystemAssigned
+ :paramtype identity: typing.Optional[IdentityConfiguration]
+ :keyword kind: Kind of the resource, we have two kinds: K8s and Managed online endpoints, defaults to None
+ :paramtype kind: typing.Optional[str]
+ """
+
+ def __init__(
+ self,
+ *,
+ name: Optional[str] = None,
+ tags: Optional[Dict[str, Any]] = None,
+ properties: Optional[Dict[str, Any]] = None,
+ auth_mode: str = KEY,
+ description: Optional[str] = None,
+ location: Optional[str] = None,
+ traffic: Optional[Dict[str, int]] = None,
+ mirror_traffic: Optional[Dict[str, int]] = None,
+ compute: Optional[str] = None,
+ identity: Optional[IdentityConfiguration] = None,
+ kind: Optional[str] = None,
+ **kwargs: Any,
+ ):
+ """K8s Online endpoint entity.
+
+ Constructor for K8s Online endpoint entity.
+
+ :keyword name: Name of the resource, defaults to None
+ :paramtype name: typing.Optional[str]
+ :keyword tags: Tag dictionary. Tags can be added, removed, and updated, defaults to None
+ :paramtype tags: typing.Optional[typing.Dict[str, typing.Any]]
+ :keyword properties: The asset property dictionary, defaults to None
+ :paramtype properties: typing.Optional[typing.Dict[str, typing.Any]]
+ :keyword auth_mode: Possible values include: "aml_token", "key", defaults to KEY
+ :type auth_mode: typing.Optional[str]
+ :keyword description: Description of the inference endpoint, defaults to None
+ :paramtype description: typing.Optional[str]
+ :keyword location: Location of the resource, defaults to None
+ :paramtype location: typing.Optional[str]
+ :keyword traffic: Traffic rules on how the traffic will be routed across deployments, defaults to None
+ :paramtype traffic: typing.Optional[typing.Dict[str, int]]
+ :keyword mirror_traffic: Duplicated live traffic used to inference a single deployment, defaults to None
+ :paramtype mirror_traffic: typing.Optional[typing.Dict[str, int]]
+ :keyword compute: Compute cluster id, defaults to None
+ :paramtype compute: typing.Optional[str]
+ :keyword identity: Identity Configuration, defaults to SystemAssigned
+ :paramtype identity: typing.Optional[IdentityConfiguration]
+ :keyword kind: Kind of the resource, we have two kinds: K8s and Managed online endpoints, defaults to None
+ :type kind: typing.Optional[str]
+ """
+ super(KubernetesOnlineEndpoint, self).__init__(
+ name=name,
+ properties=properties,
+ tags=tags,
+ auth_mode=auth_mode,
+ description=description,
+ location=location,
+ traffic=traffic,
+ mirror_traffic=mirror_traffic,
+ identity=identity,
+ kind=kind,
+ **kwargs,
+ )
+
+ self.compute = compute
+
+ def dump(
+ self,
+ dest: Optional[Union[str, PathLike, IO[AnyStr]]] = None,
+ **kwargs: Any,
+ ) -> Dict[str, Any]:
+ context = {BASE_PATH_CONTEXT_KEY: Path(".").parent}
+ res: dict = KubernetesOnlineEndpointSchema(context=context).dump(self)
+ return res
+
+ def _to_rest_online_endpoint(self, location: str) -> OnlineEndpointData:
+ resource = super()._to_rest_online_endpoint(location)
+ resource.properties.compute = self.compute
+ return resource
+
+ def _to_rest_online_endpoint_traffic_update(self, location: str, no_validation: bool = False) -> OnlineEndpointData:
+ resource = super()._to_rest_online_endpoint_traffic_update(location, no_validation)
+ resource.properties.compute = self.compute
+ return resource
+
+ def _merge_with(self, other: "KubernetesOnlineEndpoint") -> None:
+ if other:
+ if self.name != other.name:
+ msg = "The endpoint name: {} and {} are not matched when merging."
+ raise ValidationException(
+ message=msg.format(self.name, other.name),
+ target=ErrorTarget.ONLINE_ENDPOINT,
+ no_personal_data_message=msg.format("[name1]", "[name2]"),
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ super()._merge_with(other)
+ self.compute = other.compute or self.compute
+
+ def _to_dict(self) -> Dict:
+ res: dict = KubernetesOnlineEndpointSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
+
+
+class ManagedOnlineEndpoint(OnlineEndpoint):
+ """Managed Online endpoint entity.
+
+ :keyword name: Name of the resource, defaults to None
+ :paramtype name: typing.Optional[str]
+ :keyword tags: Tag dictionary. Tags can be added, removed, and updated, defaults to None
+ :paramtype tags: typing.Optional[typing.Dict[str, typing.Any]]
+ :keyword properties: The asset property dictionary, defaults to None
+ :paramtype properties: typing.Optional[typing.Dict[str, typing.Any]]
+ :keyword auth_mode: Possible values include: "aml_token", "key", defaults to KEY
+ :type auth_mode: str
+ :keyword description: Description of the inference endpoint, defaults to None
+ :paramtype description: typing.Optional[str]
+ :keyword location: Location of the resource, defaults to None
+ :paramtype location: typing.Optional[str]
+ :keyword traffic: Traffic rules on how the traffic will be routed across deployments, defaults to None
+ :paramtype traffic: typing.Optional[typing.Dict[str, int]]
+ :keyword mirror_traffic: Duplicated live traffic used to inference a single deployment, defaults to None
+ :paramtype mirror_traffic: typing.Optional[typing.Dict[str, int]]
+ :keyword identity: Identity Configuration, defaults to SystemAssigned
+ :paramtype identity: typing.Optional[IdentityConfiguration]
+ :keyword kind: Kind of the resource, we have two kinds: K8s and Managed online endpoints, defaults to None.
+ :paramtype kind: typing.Optional[str]
+ :keyword public_network_access: Whether to allow public endpoint connectivity, defaults to None
+ Allowed values are: "enabled", "disabled"
+ :type public_network_access: typing.Optional[str]
+ """
+
+ def __init__(
+ self,
+ *,
+ name: Optional[str] = None,
+ tags: Optional[Dict[str, Any]] = None,
+ properties: Optional[Dict[str, Any]] = None,
+ auth_mode: str = KEY,
+ description: Optional[str] = None,
+ location: Optional[str] = None,
+ traffic: Optional[Dict[str, int]] = None,
+ mirror_traffic: Optional[Dict[str, int]] = None,
+ identity: Optional[IdentityConfiguration] = None,
+ kind: Optional[str] = None,
+ public_network_access: Optional[str] = None,
+ **kwargs: Any,
+ ):
+ """Managed Online endpoint entity.
+
+ Constructor for Managed Online endpoint entity.
+
+ :keyword name: Name of the resource, defaults to None
+ :paramtype name: typing.Optional[str]
+ :keyword tags: Tag dictionary. Tags can be added, removed, and updated, defaults to None
+ :paramtype tags: typing.Optional[typing.Dict[str, typing.Any]]
+ :keyword properties: The asset property dictionary, defaults to None
+ :paramtype properties: typing.Optional[typing.Dict[str, typing.Any]]
+ :keyword auth_mode: Possible values include: "aml_token", "key", defaults to KEY
+ :type auth_mode: str
+ :keyword description: Description of the inference endpoint, defaults to None
+ :paramtype description: typing.Optional[str]
+ :keyword location: Location of the resource, defaults to None
+ :paramtype location: typing.Optional[str]
+ :keyword traffic: Traffic rules on how the traffic will be routed across deployments, defaults to None
+ :paramtype traffic: typing.Optional[typing.Dict[str, int]]
+ :keyword mirror_traffic: Duplicated live traffic used to inference a single deployment, defaults to None
+ :paramtype mirror_traffic: typing.Optional[typing.Dict[str, int]]
+ :keyword identity: Identity Configuration, defaults to SystemAssigned
+ :paramtype identity: typing.Optional[IdentityConfiguration]
+ :keyword kind: Kind of the resource, we have two kinds: K8s and Managed online endpoints, defaults to None.
+ :type kind: typing.Optional[str]
+ :keyword public_network_access: Whether to allow public endpoint connectivity, defaults to None
+ Allowed values are: "enabled", "disabled"
+ :type public_network_access: typing.Optional[str]
+ """
+ self.public_network_access = public_network_access
+
+ super(ManagedOnlineEndpoint, self).__init__(
+ name=name,
+ properties=properties,
+ tags=tags,
+ auth_mode=auth_mode,
+ description=description,
+ location=location,
+ traffic=traffic,
+ mirror_traffic=mirror_traffic,
+ identity=identity,
+ kind=kind,
+ **kwargs,
+ )
+
+ def dump(
+ self,
+ dest: Optional[Union[str, PathLike, IO[AnyStr]]] = None,
+ **kwargs: Any,
+ ) -> Dict[str, Any]:
+ context = {BASE_PATH_CONTEXT_KEY: Path(".").parent}
+ res: dict = ManagedOnlineEndpointSchema(context=context).dump(self)
+ return res
+
+ def _to_dict(self) -> Dict:
+ res: dict = ManagedOnlineEndpointSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
+
+
+class EndpointAuthKeys(RestTranslatableMixin):
+ """Keys for endpoint authentication.
+
+ :ivar primary_key: The primary key.
+ :vartype primary_key: str
+ :ivar secondary_key: The secondary key.
+ :vartype secondary_key: str
+ """
+
+ def __init__(self, **kwargs: Any):
+ """Constructor for keys for endpoint authentication.
+
+ :keyword primary_key: The primary key.
+ :paramtype primary_key: str
+ :keyword secondary_key: The secondary key.
+ :paramtype secondary_key: str
+ """
+ self.primary_key = kwargs.get("primary_key", None)
+ self.secondary_key = kwargs.get("secondary_key", None)
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestEndpointAuthKeys) -> "EndpointAuthKeys":
+ return cls(primary_key=obj.primary_key, secondary_key=obj.secondary_key)
+
+ def _to_rest_object(self) -> RestEndpointAuthKeys:
+ return RestEndpointAuthKeys(primary_key=self.primary_key, secondary_key=self.secondary_key)
+
+
+class EndpointAuthToken(RestTranslatableMixin):
+ """Endpoint authentication token.
+
+ :ivar access_token: Access token for endpoint authentication.
+ :vartype access_token: str
+ :ivar expiry_time_utc: Access token expiry time (UTC).
+ :vartype expiry_time_utc: float
+ :ivar refresh_after_time_utc: Refresh access token after time (UTC).
+ :vartype refresh_after_time_utc: float
+ :ivar token_type: Access token type.
+ :vartype token_type: str
+ """
+
+ def __init__(self, **kwargs: Any):
+ """Constuctor for Endpoint authentication token.
+
+ :keyword access_token: Access token for endpoint authentication.
+ :paramtype access_token: str
+ :keyword expiry_time_utc: Access token expiry time (UTC).
+ :paramtype expiry_time_utc: float
+ :keyword refresh_after_time_utc: Refresh access token after time (UTC).
+ :paramtype refresh_after_time_utc: float
+ :keyword token_type: Access token type.
+ :paramtype token_type: str
+ """
+ self.access_token = kwargs.get("access_token", None)
+ self.expiry_time_utc = kwargs.get("expiry_time_utc", 0)
+ self.refresh_after_time_utc = kwargs.get("refresh_after_time_utc", 0)
+ self.token_type = kwargs.get("token_type", None)
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestEndpointAuthToken) -> "EndpointAuthToken":
+ return cls(
+ access_token=obj.access_token,
+ expiry_time_utc=obj.expiry_time_utc,
+ refresh_after_time_utc=obj.refresh_after_time_utc,
+ token_type=obj.token_type,
+ )
+
+ def _to_rest_object(self) -> RestEndpointAuthToken:
+ return RestEndpointAuthToken(
+ access_token=self.access_token,
+ expiry_time_utc=self.expiry_time_utc,
+ refresh_after_time_utc=self.refresh_after_time_utc,
+ token_type=self.token_type,
+ )
+
+
+class EndpointAadToken:
+ """Endpoint aad token.
+
+ :ivar access_token: Access token for aad authentication.
+ :vartype access_token: str
+ :ivar expiry_time_utc: Access token expiry time (UTC).
+ :vartype expiry_time_utc: float
+ """
+
+ def __init__(self, obj: AccessToken):
+ """Constructor for Endpoint aad token.
+
+ :param obj: Access token object
+ :type obj: AccessToken
+ """
+ self.access_token = obj.token
+ self.expiry_time_utc = obj.expires_on
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/__init__.py
new file mode 100644
index 00000000..fdf8caba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/data_availability_status.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/data_availability_status.py
new file mode 100644
index 00000000..aa438f3b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/data_availability_status.py
@@ -0,0 +1,15 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from enum import Enum
+from azure.core import CaseInsensitiveEnumMeta
+
+
+class DataAvailabilityStatus(str, Enum, metaclass=CaseInsensitiveEnumMeta):
+ """DataAvailabilityStatus."""
+
+ NONE = "None"
+ PENDING = "Pending"
+ INCOMPLETE = "Incomplete"
+ COMPLETE = "Complete"
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/delay_metadata.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/delay_metadata.py
new file mode 100644
index 00000000..66599605
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/delay_metadata.py
@@ -0,0 +1,17 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+
+from typing import Any, Optional
+
+
+class DelayMetadata(object):
+ def __init__(
+ self, *, days: Optional[int] = None, hours: Optional[int] = None, minutes: Optional[int] = None, **kwargs: Any
+ ):
+ self.days = days
+ self.hours = hours
+ self.minutes = minutes
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature.py
new file mode 100644
index 00000000..2cc54815
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature.py
@@ -0,0 +1,54 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from typing import Any, Dict, Optional
+
+from azure.ai.ml._restclient.v2023_10_01.models import Feature as RestFeature
+from azure.ai.ml._restclient.v2023_10_01.models import FeatureProperties
+from azure.ai.ml.entities._feature_store_entity.data_column_type import DataColumnType
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+
+class Feature(RestTranslatableMixin):
+ """Feature
+
+ :param name: The name of the feature.
+ :type name: str
+ :param data_type: The data type of the feature.
+ :type data_type: ~azure.ai.ml.entities.DataColumnType
+ :param description: The description of the feature. Defaults to None.
+ :type description: Optional[str]
+ :param tags: Tag dictionary. Tags can be added, removed, and updated. Defaults to None.
+ :type tags: Optional[dict[str, str]]
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: dict
+ """
+
+ def __init__(
+ self,
+ *,
+ name: str,
+ data_type: DataColumnType,
+ description: Optional[str] = None,
+ tags: Optional[Dict[str, str]] = None,
+ **kwargs: Any
+ ):
+ self.name = name
+ self.data_type = data_type
+ self.description = description
+ self.tags = tags
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestFeature) -> Optional["Feature"]:
+ if not obj:
+ return None
+ feature_rest_object_details: FeatureProperties = obj.properties
+ return Feature(
+ name=feature_rest_object_details.feature_name,
+ data_type=feature_rest_object_details.data_type,
+ description=feature_rest_object_details.description,
+ tags=feature_rest_object_details.tags,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_set_backfill_metadata.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_set_backfill_metadata.py
new file mode 100644
index 00000000..652908e9
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_set_backfill_metadata.py
@@ -0,0 +1,39 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Any, List, Optional
+
+from azure.ai.ml._restclient.v2023_10_01.models import (
+ FeaturesetVersionBackfillResponse as RestFeaturesetVersionBackfillResponse,
+)
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+
+class FeatureSetBackfillMetadata(RestTranslatableMixin):
+ """Feature Set Backfill Metadata
+
+ :param job_ids: A list of IDs of the backfill jobs. Defaults to None.
+ :type job_ids: Optional[List[str]]
+ :param type: The type of the backfill job. Defaults to None.
+ :type type: Optional[str]
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: dict
+ """
+
+ def __init__(
+ self,
+ *,
+ job_ids: Optional[List[str]] = None,
+ type: Optional[str] = None, # pylint: disable=redefined-builtin
+ # pylint: disable=unused-argument
+ **kwargs: Any
+ ) -> None:
+ self.type = type if type else "BackfillMaterialization"
+ self.job_ids = job_ids
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestFeaturesetVersionBackfillResponse) -> Optional["FeatureSetBackfillMetadata"]:
+ if not obj:
+ return None
+ return FeatureSetBackfillMetadata(job_ids=obj.job_ids)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_set_backfill_request.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_set_backfill_request.py
new file mode 100644
index 00000000..0baebf4c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_set_backfill_request.py
@@ -0,0 +1,91 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from os import PathLike
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+from azure.ai.ml._schema._feature_set.feature_set_backfill_schema import FeatureSetBackfillSchema
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY
+from azure.ai.ml.entities._feature_set.feature_window import FeatureWindow
+from azure.ai.ml.entities._feature_set.materialization_compute_resource import MaterializationComputeResource
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+from azure.ai.ml.entities._util import load_from_dict
+
+
+class FeatureSetBackfillRequest(RestTranslatableMixin):
+ """Feature Set Backfill Request
+
+ :param name: The name of the backfill job request
+ :type name: str
+ :param version: The version of the backfill job request.
+ :type version: str
+ :param feature_window: The time window for the feature set backfill request.
+ :type feature_window: ~azure.ai.ml._restclient.v2023_04_01_preview.models.FeatureWindow
+ :param description: The description of the backfill job request. Defaults to None.
+ :type description: Optional[str]
+ :param tags: Tag dictionary. Tags can be added, removed, and updated. Defaults to None.
+ :type tags: Optional[dict[str, str]]
+ :keyword resource: The compute resource settings. Defaults to None.
+ :paramtype resource: Optional[~azure.ai.ml.entities.MaterializationComputeResource]
+ :param spark_configuration: Specifies the spark configuration. Defaults to None.
+ :type spark_configuration: Optional[dict[str, str]]
+ """
+
+ def __init__(
+ self,
+ *,
+ name: str,
+ version: str,
+ feature_window: Optional[FeatureWindow] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict[str, str]] = None,
+ resource: Optional[MaterializationComputeResource] = None,
+ spark_configuration: Optional[Dict[str, str]] = None,
+ data_status: Optional[List[str]] = None,
+ job_id: Optional[str] = None,
+ **kwargs: Any,
+ ):
+ self.name = name
+ self.version = version
+ self.feature_window = feature_window
+ self.description = description
+ self.resource = resource
+ self.tags = tags
+ self.spark_configuration = spark_configuration
+ self.data_status = data_status
+ self.job_id = job_id
+
+ @classmethod
+ # pylint: disable=unused-argument
+ def _resolve_cls_and_type(cls, data: Dict, params_override: Tuple) -> Tuple:
+ """Resolve the class to use for deserializing the data. Return current class if no override is provided.
+
+ :param data: Data to deserialize.
+ :type data: dict
+ :param params_override: Parameters to override, defaults to None
+ :type params_override: typing.Optional[list]
+ :return: Class to use for deserializing the data & its "type". Type will be None if no override is provided.
+ :rtype: tuple[class, typing.Optional[str]]
+ """
+ return cls, None
+
+ @classmethod
+ def _load(
+ cls,
+ data: Optional[Dict] = None,
+ yaml_path: Optional[Union[PathLike, str]] = None,
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "FeatureSetBackfillRequest":
+ data = data or {}
+ params_override = params_override or []
+ context = {
+ BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"),
+ PARAMS_OVERRIDE_KEY: params_override,
+ }
+ loaded_schema = load_from_dict(FeatureSetBackfillSchema, data, context, **kwargs)
+ return FeatureSetBackfillRequest(**loaded_schema)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_set_materialization_metadata.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_set_materialization_metadata.py
new file mode 100644
index 00000000..afcf3fd1
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_set_materialization_metadata.py
@@ -0,0 +1,98 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from datetime import datetime, timedelta
+from typing import Any, Dict, Optional
+
+from azure.ai.ml._restclient.v2023_10_01.models import JobBase as RestJobBase
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+from azure.ai.ml.entities._system_data import SystemData
+
+from .materialization_type import MaterializationType
+
+FeaturestoreJobTypeMap: Dict[str, MaterializationType] = {
+ "BackfillMaterialization": MaterializationType.BACKFILL_MATERIALIZATION,
+ "RecurrentMaterialization": MaterializationType.RECURRENT_MATERIALIZATION,
+}
+
+
+class FeatureSetMaterializationMetadata(RestTranslatableMixin):
+ """Feature Set Materialization Metadata
+
+ :param type: The type of the materialization job.
+ :type type: MaterializationType
+ :param feature_window_start_time: The feature window start time for the feature set materialization job.
+ :type feature_window_start_time: Optional[datetime]
+ :param feature_window_end_time: The feature window end time for the feature set materialization job.
+ :type feature_window_end_time: Optional[datetime]
+ :param name: The name of the feature set materialization job.
+ :type name: Optional[str]
+ :param display_name: The display name for the feature set materialization job.
+ :type display_name: Optional[str]
+ :param creation_context: The creation context of the feature set materialization job.
+ :type creation_context: Optional[~azure.ai.ml.entities.SystemData]
+ :param duration: current time elapsed for feature set materialization job.
+ :type duration: Optional[~datetime.timedelta]
+ :param status: The status of the feature set materialization job.
+ :type status: Optional[str]
+ :param tags: Tag dictionary. Tags can be added, removed, and updated.
+ :type tags: Optional[dict[str, str]]
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: dict
+ """
+
+ def __init__(
+ self,
+ *,
+ # pylint: disable=redefined-builtin
+ type: Optional[MaterializationType],
+ feature_window_start_time: Optional[datetime],
+ feature_window_end_time: Optional[datetime],
+ name: Optional[str],
+ display_name: Optional[str],
+ creation_context: Optional[SystemData],
+ duration: Optional[timedelta],
+ status: Optional[str],
+ tags: Optional[Dict[str, str]],
+ # pylint: disable=unused-argument
+ **kwargs: Any,
+ ):
+ self.type = type
+ self.feature_window_start_time = feature_window_start_time
+ self.feature_window_end_time = feature_window_end_time
+ self.name = name
+ self.display_name = display_name
+ self.creation_context = creation_context
+ self.duration = duration
+ self.status = status
+ self.tags = tags
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestJobBase) -> Optional["FeatureSetMaterializationMetadata"]:
+ if not obj:
+ return None
+ job_properties = obj.properties
+ job_type = job_properties.properties.get("azureml.FeatureStoreJobType", None)
+ feature_window_start_time = job_properties.properties.get("azureml.FeatureWindowStart", None)
+ feature_window_end_time = job_properties.properties.get("azureml.FeatureWindowEnd", None)
+
+ time_format = "%Y-%m-%dT%H:%M:%SZ"
+ feature_window_start_time = (
+ datetime.strptime(feature_window_start_time, time_format) if feature_window_start_time else None
+ )
+ feature_window_end_time = (
+ datetime.strptime(feature_window_end_time, time_format) if feature_window_end_time else None
+ )
+
+ return FeatureSetMaterializationMetadata(
+ type=FeaturestoreJobTypeMap.get(job_type),
+ feature_window_start_time=feature_window_start_time,
+ feature_window_end_time=feature_window_end_time,
+ name=obj.name,
+ display_name=job_properties.display_name,
+ creation_context=SystemData(created_at=obj.system_data.created_at),
+ status=job_properties.status,
+ tags=job_properties.tags,
+ duration=None,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_set_specification.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_set_specification.py
new file mode 100644
index 00000000..88ed093f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_set_specification.py
@@ -0,0 +1,46 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from os import PathLike
+from typing import Any, Optional, Union
+
+from azure.ai.ml._restclient.v2023_10_01.models import FeaturesetSpecification as RestFeaturesetSpecification
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+
+class FeatureSetSpecification(RestTranslatableMixin):
+ """Feature Set Specification
+
+ :param path: Specifies the feature set spec path to file. Defaults to None.
+ :type path: Optional[str]
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: dict
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_featurestore.py
+ :start-after: [START configure_feature_set]
+ :end-before: [END configure_feature_set]
+ :language: python
+ :dedent: 8
+ :caption: Using Feature Set Spec to create Feature Set
+ """
+
+ def __init__(
+ self, *, path: Optional[Union[PathLike, str]] = None, **kwargs: Any
+ ): # pylint: disable=unused-argument
+ """
+ :param path: Specifies the spec path.
+ :type path: str
+ """
+ self.path = path
+
+ def _to_rest_object(self) -> RestFeaturesetSpecification:
+ return RestFeaturesetSpecification(path=self.path)
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestFeaturesetSpecification) -> Optional["FeatureSetSpecification"]:
+ if not obj:
+ return None
+ return FeatureSetSpecification(path=obj.path)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_transformation_code_metadata.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_transformation_code_metadata.py
new file mode 100644
index 00000000..5fd8544e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_transformation_code_metadata.py
@@ -0,0 +1,13 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from typing import Any, Optional
+
+
+class FeatureTransformationCodeMetadata(object):
+ def __init__(self, *, path: str, transformer_class: Optional[str] = None, **kwargs: Any):
+ self.path = path
+ self.transformer_class = transformer_class
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_window.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_window.py
new file mode 100644
index 00000000..758d1ecf
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/feature_window.py
@@ -0,0 +1,34 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from datetime import datetime
+from typing import Any, Optional
+
+from azure.ai.ml._restclient.v2023_10_01.models import FeatureWindow as RestFeatureWindow
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+
+class FeatureWindow(RestTranslatableMixin):
+ """Feature window
+ :keyword feature_window_end: Specifies the feature window end time.
+ :paramtype feature_window_end: ~datetime.datetime
+ :keyword feature_window_start: Specifies the feature window start time.
+ :paramtype feature_window_start: ~datetime.datetime
+ """
+
+ # pylint: disable=unused-argument
+ def __init__(self, *, feature_window_start: datetime, feature_window_end: datetime, **kwargs: Any) -> None:
+ self.feature_window_start = feature_window_start
+ self.feature_window_end = feature_window_end
+
+ def _to_rest_object(self) -> RestFeatureWindow:
+ return RestFeatureWindow(
+ feature_window_start=self.feature_window_start, feature_window_end=self.feature_window_end
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestFeatureWindow) -> Optional["FeatureWindow"]:
+ if not obj:
+ return None
+ return FeatureWindow(feature_window_start=obj.feature_window_start, feature_window_end=obj.feature_window_end)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/featureset_spec_metadata.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/featureset_spec_metadata.py
new file mode 100644
index 00000000..4178b074
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/featureset_spec_metadata.py
@@ -0,0 +1,101 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+
+from os import PathLike
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Union
+
+from marshmallow import INCLUDE
+
+from azure.ai.ml._schema._feature_set.featureset_spec_metadata_schema import FeaturesetSpecMetadataSchema
+from azure.ai.ml._schema._feature_set.featureset_spec_properties_schema import FeaturesetSpecPropertiesSchema
+from azure.ai.ml._utils.utils import load_yaml
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+from azure.ai.ml.entities._feature_store_entity.data_column import DataColumn
+from azure.ai.ml.entities._util import load_from_dict
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+
+from .delay_metadata import DelayMetadata
+from .feature import Feature
+from .feature_transformation_code_metadata import FeatureTransformationCodeMetadata
+from .source_metadata import SourceMetadata
+
+
+class FeaturesetSpecMetadata(object):
+ """FeaturesetSpecMetadata for feature-set."""
+
+ def __init__(
+ self,
+ *,
+ source: SourceMetadata,
+ feature_transformation_code: Optional[FeatureTransformationCodeMetadata] = None,
+ features: List[Feature],
+ index_columns: Optional[List[DataColumn]] = None,
+ source_lookback: Optional[DelayMetadata] = None,
+ temporal_join_lookback: Optional[DelayMetadata] = None,
+ **_kwargs: Any,
+ ):
+ if source.type == "featureset" and index_columns:
+ msg = f"You cannot provide index_columns for {source.type} feature source."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ target=ErrorTarget.FEATURE_SET,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ if not index_columns and source.type != "featureset":
+ msg = f"You need to provide index_columns for {source.type} feature source."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ target=ErrorTarget.FEATURE_SET,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ self.source = source
+ self.feature_transformation_code = feature_transformation_code
+ self.features = features
+ self.index_columns = index_columns
+ self.source_lookback = source_lookback
+ self.temporal_join_lookback = temporal_join_lookback
+
+ @classmethod
+ def load(
+ cls,
+ yaml_path: Union[PathLike, str],
+ **kwargs: Any,
+ ) -> "FeaturesetSpecMetadata":
+ """Construct an FeaturesetSpecMetadata object from yaml file.
+
+ :param yaml_path: Path to a local file as the source.
+ :type yaml_path: PathLike | str
+
+ :return: Constructed FeaturesetSpecMetadata object.
+ :rtype: FeaturesetSpecMetadata
+ """
+ yaml_dict = load_yaml(yaml_path)
+ return cls._load(yaml_data=yaml_dict, yaml_path=yaml_path, **kwargs)
+
+ @classmethod
+ def _load(
+ cls,
+ yaml_data: Optional[Dict],
+ yaml_path: Optional[Union[PathLike, str]],
+ **kwargs: Any,
+ ) -> "FeaturesetSpecMetadata":
+ yaml_data = yaml_data or {}
+ context = {
+ BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"),
+ }
+ res: FeaturesetSpecMetadata = load_from_dict(
+ FeaturesetSpecMetadataSchema, yaml_data, context, "", unknown=INCLUDE, **kwargs
+ )
+
+ return res
+
+ def _to_dict(self) -> Dict:
+ res: dict = FeaturesetSpecPropertiesSchema(context={BASE_PATH_CONTEXT_KEY: "./"}, unknown=INCLUDE).dump(self)
+ return res
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/materialization_compute_resource.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/materialization_compute_resource.py
new file mode 100644
index 00000000..5bcff24b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/materialization_compute_resource.py
@@ -0,0 +1,41 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Any, Optional
+
+from azure.ai.ml._restclient.v2023_10_01.models import (
+ MaterializationComputeResource as RestMaterializationComputeResource,
+)
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+
+class MaterializationComputeResource(RestTranslatableMixin):
+ """Materialization Compute resource
+
+ :keyword instance_type: The compute instance type.
+ :paramtype instance_type: str
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: dict
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_compute.py
+ :start-after: [START materialization_compute_resource]
+ :end-before: [END materialization_compute_resource]
+ :language: python
+ :dedent: 8
+ :caption: Creating a MaterializationComputeResource object.
+ """
+
+ def __init__(self, *, instance_type: str, **kwargs: Any): # pylint: disable=unused-argument
+ self.instance_type = instance_type
+
+ def _to_rest_object(self) -> RestMaterializationComputeResource:
+ return RestMaterializationComputeResource(instance_type=self.instance_type)
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestMaterializationComputeResource) -> Optional["MaterializationComputeResource"]:
+ if not obj:
+ return None
+ return MaterializationComputeResource(instance_type=obj.instance_type)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/materialization_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/materialization_settings.py
new file mode 100644
index 00000000..cf6f12e0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/materialization_settings.py
@@ -0,0 +1,100 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Any, Dict, Optional
+
+from azure.ai.ml._restclient.v2023_10_01.models import MaterializationSettings as RestMaterializationSettings
+from azure.ai.ml._restclient.v2023_10_01.models import MaterializationStoreType
+from azure.ai.ml.entities._feature_set.materialization_compute_resource import MaterializationComputeResource
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+from azure.ai.ml.entities._notification.notification import Notification
+from azure.ai.ml.entities._schedule.trigger import RecurrenceTrigger
+
+
+class MaterializationSettings(RestTranslatableMixin):
+ """Defines materialization settings.
+
+ :keyword schedule: The schedule details. Defaults to None.
+ :paramtype schedule: Optional[~azure.ai.ml.entities.RecurrenceTrigger]
+ :keyword offline_enabled: Boolean that specifies if offline store is enabled. Defaults to None.
+ :paramtype offline_enabled: Optional[bool]
+ :keyword online_enabled: Boolean that specifies if online store is enabled. Defaults to None.
+ :paramtype online_enabled: Optional[bool]
+ :keyword notification: The notification details. Defaults to None.
+ :paramtype notification: Optional[~azure.ai.ml.entities.Notification]
+ :keyword resource: The compute resource settings. Defaults to None.
+ :paramtype resource: Optional[~azure.ai.ml.entities.MaterializationComputeResource]
+ :keyword spark_configuration: The spark compute settings. Defaults to None.
+ :paramtype spark_configuration: Optional[dict[str, str]]
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: dict
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_spark_configurations.py
+ :start-after: [START materialization_setting_configuration]
+ :end-before: [END materialization_setting_configuration]
+ :language: python
+ :dedent: 8
+ :caption: Configuring MaterializationSettings.
+ """
+
+ def __init__(
+ self,
+ *,
+ schedule: Optional[RecurrenceTrigger] = None,
+ offline_enabled: Optional[bool] = None,
+ online_enabled: Optional[bool] = None,
+ notification: Optional[Notification] = None,
+ resource: Optional[MaterializationComputeResource] = None,
+ spark_configuration: Optional[Dict[str, str]] = None,
+ # pylint: disable=unused-argument
+ **kwargs: Any,
+ ) -> None:
+ self.schedule = schedule
+ self.offline_enabled = offline_enabled
+ self.online_enabled = online_enabled
+ self.notification = notification
+ self.resource = resource
+ self.spark_configuration = spark_configuration
+
+ def _to_rest_object(self) -> RestMaterializationSettings:
+ store_type = None
+ if self.offline_enabled and self.online_enabled:
+ store_type = MaterializationStoreType.ONLINE_AND_OFFLINE
+ elif self.offline_enabled:
+ store_type = MaterializationStoreType.OFFLINE
+ elif self.online_enabled:
+ store_type = MaterializationStoreType.ONLINE
+ else:
+ store_type = MaterializationStoreType.NONE
+
+ return RestMaterializationSettings(
+ schedule=self.schedule._to_rest_object() if self.schedule else None, # pylint: disable=protected-access
+ notification=(
+ self.notification._to_rest_object() if self.notification else None # pylint: disable=protected-access
+ ),
+ resource=self.resource._to_rest_object() if self.resource else None, # pylint: disable=protected-access
+ spark_configuration=self.spark_configuration,
+ store_type=store_type,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestMaterializationSettings) -> Optional["MaterializationSettings"]:
+ if not obj:
+ return None
+ return MaterializationSettings(
+ schedule=(
+ RecurrenceTrigger._from_rest_object(obj.schedule) # pylint: disable=protected-access
+ if obj.schedule
+ else None
+ ),
+ notification=Notification._from_rest_object(obj.notification), # pylint: disable=protected-access
+ resource=MaterializationComputeResource._from_rest_object(obj.resource), # pylint: disable=protected-access
+ spark_configuration=obj.spark_configuration,
+ offline_enabled=obj.store_type
+ in {MaterializationStoreType.OFFLINE, MaterializationStoreType.ONLINE_AND_OFFLINE},
+ online_enabled=obj.store_type
+ in {MaterializationStoreType.ONLINE, MaterializationStoreType.ONLINE_AND_OFFLINE},
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/materialization_type.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/materialization_type.py
new file mode 100644
index 00000000..912d69fc
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/materialization_type.py
@@ -0,0 +1,14 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from enum import Enum
+
+from azure.core import CaseInsensitiveEnumMeta
+
+
+class MaterializationType(str, Enum, metaclass=CaseInsensitiveEnumMeta):
+ """Materialization Type Enum"""
+
+ RECURRENT_MATERIALIZATION = 1
+ BACKFILL_MATERIALIZATION = 2
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/source_metadata.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/source_metadata.py
new file mode 100644
index 00000000..1c9e55fe
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/source_metadata.py
@@ -0,0 +1,69 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=redefined-builtin,disable=unused-argument
+
+from typing import Any, Dict, Optional
+
+from azure.ai.ml.entities._feature_set.source_process_code_metadata import SourceProcessCodeMetadata
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+
+from .delay_metadata import DelayMetadata
+from .timestamp_column_metadata import TimestampColumnMetadata
+
+
+class SourceMetadata(object):
+ def __init__(
+ self,
+ *,
+ type: str,
+ timestamp_column: Optional[TimestampColumnMetadata] = None,
+ path: Optional[str] = None,
+ source_delay: Optional[DelayMetadata] = None,
+ source_process_code: Optional[SourceProcessCodeMetadata] = None,
+ dict: Optional[Dict] = None,
+ **kwargs: Any,
+ ):
+ if type == "custom":
+ # For custom feature source
+ # Required: timestamp_column, dict and source_process_code.
+ # Not support: path.
+ if path:
+ self.throw_exception("path", type, should_provide=False)
+ if not (timestamp_column and dict and source_process_code):
+ self.throw_exception("timestamp_column/dict/source_process_code", type, should_provide=True)
+ elif type == "featureset":
+ # For featureset feature source
+ # Required: path.
+ # Not support: timestamp_column, source_delay and source_process_code.
+ if timestamp_column or source_delay or source_process_code:
+ self.throw_exception("timestamp_column/source_delay/source_process_code", type, should_provide=False)
+ if not path:
+ self.throw_exception("path", type, should_provide=True)
+ else:
+ # For other type feature source
+ # Required: timestamp_column, path.
+ # Not support: source_process_code, dict
+ if dict or source_process_code:
+ self.throw_exception("dict/source_process_code", type, should_provide=False)
+ if not (timestamp_column and path):
+ self.throw_exception("timestamp_column/path", type, should_provide=True)
+ self.type = type
+ self.path = path
+ self.timestamp_column = timestamp_column
+ self.source_delay = source_delay
+ self.source_process_code = source_process_code
+ self.kwargs = dict
+
+ @staticmethod
+ def throw_exception(property_names: str, type: str, should_provide: bool):
+ should_or_not = "need to" if should_provide else "cannot"
+ msg = f"You {should_or_not} provide {property_names} for {type} feature source."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ target=ErrorTarget.FEATURE_SET,
+ error_category=ErrorCategory.USER_ERROR,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/source_process_code_metadata.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/source_process_code_metadata.py
new file mode 100644
index 00000000..415785da
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/source_process_code_metadata.py
@@ -0,0 +1,13 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from typing import Any, Optional
+
+
+class SourceProcessCodeMetadata(object):
+ def __init__(self, *, path: str, process_class: Optional[str] = None, **kwargs: Any):
+ self.path = path
+ self.process_class = process_class
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/timestamp_column_metadata.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/timestamp_column_metadata.py
new file mode 100644
index 00000000..833088af
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_set/timestamp_column_metadata.py
@@ -0,0 +1,14 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=redefined-builtin,disable=unused-argument
+
+
+from typing import Any, Optional
+
+
+class TimestampColumnMetadata(object):
+ def __init__(self, *, name: str, format: Optional[str] = None, **kwargs: Any):
+ self.name = name
+ self.format = format
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store/__init__.py
new file mode 100644
index 00000000..fdf8caba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store/_constants.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store/_constants.py
new file mode 100644
index 00000000..d6466401
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store/_constants.py
@@ -0,0 +1,15 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+OFFLINE_STORE_CONNECTION_NAME = "OfflineStoreConnectionName"
+OFFLINE_MATERIALIZATION_STORE_TYPE = "azure_data_lake_gen2"
+OFFLINE_STORE_CONNECTION_CATEGORY = "ADLSGen2"
+ONLINE_STORE_CONNECTION_NAME = "OnlineStoreConnectionName"
+ONLINE_MATERIALIZATION_STORE_TYPE = "redis"
+ONLINE_STORE_CONNECTION_CATEGORY = "Redis"
+DEFAULT_SPARK_RUNTIME_VERSION = "3.4.0"
+STORE_REGEX_PATTERN = (
+ "^/?subscriptions/([^/]+)/resourceGroups/([^/]+)/providers/Microsoft.Storage"
+ "/storageAccounts/([^/]+)/blobServices/default/containers/([^/]+)"
+)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store/feature_store.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store/feature_store.py
new file mode 100644
index 00000000..0c41f1a3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store/feature_store.py
@@ -0,0 +1,226 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+
+from os import PathLike
+from pathlib import Path
+from typing import Any, Dict, Optional, Union
+
+from azure.ai.ml._restclient.v2024_10_01_preview.models import Workspace as RestWorkspace
+from azure.ai.ml._schema._feature_store.feature_store_schema import FeatureStoreSchema
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY, WorkspaceKind
+from azure.ai.ml.entities._credentials import IdentityConfiguration, ManagedIdentityConfiguration
+from azure.ai.ml.entities._util import load_from_dict
+from azure.ai.ml.entities._workspace.compute_runtime import ComputeRuntime
+from azure.ai.ml.entities._workspace.customer_managed_key import CustomerManagedKey
+from azure.ai.ml.entities._workspace.feature_store_settings import FeatureStoreSettings
+from azure.ai.ml.entities._workspace.networking import ManagedNetwork
+from azure.ai.ml.entities._workspace.workspace import Workspace
+
+from ._constants import DEFAULT_SPARK_RUNTIME_VERSION
+from .materialization_store import MaterializationStore
+
+
+class FeatureStore(Workspace):
+ """Feature Store
+
+ :param name: The name of the feature store.
+ :type name: str
+ :param compute_runtime: The compute runtime of the feature store. Defaults to None.
+ :type compute_runtime: Optional[~azure.ai.ml.entities.ComputeRuntime]
+ :param offline_store: The offline store for feature store.
+ materialization_identity is required when offline_store is passed. Defaults to None.
+ :type offline_store: Optional[~azure.ai.ml.entities.MaterializationStore]
+ :param online_store: The online store for feature store.
+ materialization_identity is required when online_store is passed. Defaults to None.
+ :type online_store: Optional[~azure.ai.ml.entities.MaterializationStore]
+ :param materialization_identity: The identity used for materialization. Defaults to None.
+ :type materialization_identity: Optional[~azure.ai.ml.entities.ManagedIdentityConfiguration]
+ :param description: The description of the feature store. Defaults to None.
+ :type description: Optional[str]
+ :param tags: Tags of the feature store.
+ :type tags: dict
+ :param display_name: The display name for the feature store. This is non-unique within the resource group.
+ Defaults to None.
+ :type display_name: Optional[str]
+ :param location: The location to create the feature store in.
+ If not specified, the same location as the resource group will be used. Defaults to None.
+ :type location: Optional[str]
+ :param resource_group: The name of the resource group to create the feature store in. Defaults to None.
+ :type resource_group: Optional[str]
+ :param hbi_workspace: Boolean for whether the customer data is of high business impact (HBI),
+ containing sensitive business information. Defaults to False.
+ For more information, see
+ https://learn.microsoft.com/azure/machine-learning/concept-data-encryption#encryption-at-rest.
+ :type hbi_workspace: Optional[bool]
+ :param storage_account: The resource ID of an existing storage account to use instead of creating a new one.
+ Defaults to None.
+ :type storage_account: Optional[str]
+ :param container_registry: The resource ID of an existing container registry
+ to use instead of creating a new one. Defaults to None.
+ :type container_registry: Optional[str]
+ :param key_vault: The resource ID of an existing key vault to use instead of creating a new one. Defaults to None.
+ :type key_vault: Optional[str]
+ :param application_insights: The resource ID of an existing application insights
+ to use instead of creating a new one. Defaults to None.
+ :type application_insights: Optional[str]
+ :param customer_managed_key: The key vault details for encrypting data with customer-managed keys.
+ If not specified, Microsoft-managed keys will be used by default. Defaults to None.
+ :type customer_managed_key: Optional[CustomerManagedKey]
+ :param image_build_compute: The name of the compute target to use for building environment
+ Docker images with the container registry is behind a VNet. Defaults to None.
+ :type image_build_compute: Optional[str]
+ :param public_network_access: Whether to allow public endpoint connectivity
+ when a workspace is private link enabled. Defaults to None.
+ :type public_network_access: Optional[str]
+ :param identity: The workspace's Managed Identity (user assigned, or system assigned). Defaults to None.
+ :type identity: Optional[IdentityConfiguration]
+ :param primary_user_assigned_identity: The workspace's primary user assigned identity. Defaults to None.
+ :type primary_user_assigned_identity: Optional[str]
+ :param managed_network: The workspace's Managed Network configuration. Defaults to None.
+ :type managed_network: Optional[ManagedNetwork]
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: dict
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_featurestore.py
+ :start-after: [START create_feature_store]
+ :end-before: [END create_feature_store]
+ :language: Python
+ :dedent: 8
+ :caption: Instantiating a Feature Store object
+ """
+
+ def __init__(
+ self,
+ *,
+ name: str,
+ compute_runtime: Optional[ComputeRuntime] = None,
+ offline_store: Optional[MaterializationStore] = None,
+ online_store: Optional[MaterializationStore] = None,
+ materialization_identity: Optional[ManagedIdentityConfiguration] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict[str, str]] = None,
+ display_name: Optional[str] = None,
+ location: Optional[str] = None,
+ resource_group: Optional[str] = None,
+ hbi_workspace: bool = False,
+ storage_account: Optional[str] = None,
+ container_registry: Optional[str] = None,
+ key_vault: Optional[str] = None,
+ application_insights: Optional[str] = None,
+ customer_managed_key: Optional[CustomerManagedKey] = None,
+ image_build_compute: Optional[str] = None,
+ public_network_access: Optional[str] = None,
+ identity: Optional[IdentityConfiguration] = None,
+ primary_user_assigned_identity: Optional[str] = None,
+ managed_network: Optional[ManagedNetwork] = None,
+ **kwargs: Any,
+ ) -> None:
+ feature_store_settings = kwargs.pop(
+ "feature_store_settings",
+ FeatureStoreSettings(
+ compute_runtime=(
+ compute_runtime
+ if compute_runtime
+ else ComputeRuntime(spark_runtime_version=DEFAULT_SPARK_RUNTIME_VERSION)
+ ),
+ ),
+ )
+ # TODO: Refactor this so that super().__init__() is not called twice coming from _from_rest_object()
+ super().__init__(
+ name=name,
+ description=description,
+ tags=tags,
+ kind=WorkspaceKind.FEATURE_STORE,
+ display_name=display_name,
+ location=location,
+ resource_group=resource_group,
+ hbi_workspace=hbi_workspace,
+ storage_account=storage_account,
+ container_registry=container_registry,
+ key_vault=key_vault,
+ application_insights=application_insights,
+ customer_managed_key=customer_managed_key,
+ image_build_compute=image_build_compute,
+ public_network_access=public_network_access,
+ managed_network=managed_network,
+ identity=identity,
+ primary_user_assigned_identity=primary_user_assigned_identity,
+ feature_store_settings=feature_store_settings,
+ **kwargs,
+ )
+ self.offline_store = offline_store
+ self.online_store = online_store
+ self.materialization_identity = materialization_identity
+ self.identity = identity
+ self.public_network_access = public_network_access
+ self.managed_network = managed_network
+ # here, compute_runtime is used instead of feature_store_settings because
+ # it uses default spark version if no compute_runtime is specified during update
+ self.compute_runtime = compute_runtime
+
+ @classmethod
+ def _from_rest_object(
+ cls, rest_obj: RestWorkspace, v2_service_context: Optional[object] = None
+ ) -> Optional["FeatureStore"]:
+ if not rest_obj:
+ return None
+
+ workspace_object = Workspace._from_rest_object(rest_obj, v2_service_context)
+ if workspace_object is not None:
+ return FeatureStore(
+ name=str(workspace_object.name),
+ id=workspace_object.id,
+ description=workspace_object.description,
+ tags=workspace_object.tags,
+ compute_runtime=ComputeRuntime._from_rest_object(
+ workspace_object._feature_store_settings.compute_runtime
+ if workspace_object._feature_store_settings
+ else None
+ ),
+ display_name=workspace_object.display_name,
+ discovery_url=workspace_object.discovery_url,
+ location=workspace_object.location,
+ resource_group=workspace_object.resource_group,
+ hbi_workspace=workspace_object.hbi_workspace,
+ storage_account=workspace_object.storage_account,
+ container_registry=workspace_object.container_registry,
+ key_vault=workspace_object.key_vault,
+ application_insights=workspace_object.application_insights,
+ customer_managed_key=workspace_object.customer_managed_key,
+ image_build_compute=workspace_object.image_build_compute,
+ public_network_access=workspace_object.public_network_access,
+ identity=workspace_object.identity,
+ primary_user_assigned_identity=workspace_object.primary_user_assigned_identity,
+ managed_network=workspace_object.managed_network,
+ workspace_id=rest_obj.workspace_id,
+ feature_store_settings=workspace_object._feature_store_settings,
+ )
+
+ return None
+
+ @classmethod
+ def _load(
+ cls,
+ data: Optional[Dict] = None,
+ yaml_path: Optional[Union[PathLike, str]] = None,
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "FeatureStore":
+ data = data or {}
+ params_override = params_override or []
+ context = {
+ BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"),
+ PARAMS_OVERRIDE_KEY: params_override,
+ }
+ loaded_schema = load_from_dict(FeatureStoreSchema, data, context, **kwargs)
+ return FeatureStore(**loaded_schema)
+
+ def _to_dict(self) -> Dict:
+ res: dict = FeatureStoreSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store/materialization_store.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store/materialization_store.py
new file mode 100644
index 00000000..c6a7e6a7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store/materialization_store.py
@@ -0,0 +1,49 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from azure.ai.ml._utils._arm_id_utils import AzureResourceId
+
+
+class MaterializationStore:
+ """Materialization Store
+
+ :param type: The type of the materialization store.
+ :type type: str
+ :param target: The ARM ID of the materialization store target.
+ :type target: str
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_featurestore.py
+ :start-after: [START configure_materialization_store]
+ :end-before: [END configure_materialization_store]
+ :language: Python
+ :dedent: 8
+ :caption: Configuring a Materialization Store
+ """
+
+ def __init__(self, type: str, target: str) -> None: # pylint: disable=redefined-builtin
+ self.type = type
+ _ = AzureResourceId(target)
+ self.__target = target
+
+ @property
+ def target(self) -> str:
+ """Get target value
+
+ :return: returns the ID of the target
+ :rtype: str
+ """
+ return self.__target
+
+ @target.setter
+ def target(self, value: str) -> None:
+ """Set target value
+
+ :param value: the ID of the target
+ :type value: str
+ :raises ~azure.ai.ml.exceptions.ValidationException~: Raised if the value is an invalid ARM ID.
+ """
+ _ = AzureResourceId(value)
+ self.__target = value
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store_entity/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store_entity/__init__.py
new file mode 100644
index 00000000..fdf8caba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store_entity/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store_entity/data_column.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store_entity/data_column.py
new file mode 100644
index 00000000..a4446ad4
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store_entity/data_column.py
@@ -0,0 +1,80 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=redefined-builtin,disable=unused-argument
+
+from typing import Any, Dict, Optional, Union
+
+from azure.ai.ml._restclient.v2023_10_01.models import FeatureDataType, IndexColumn
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+
+from .data_column_type import DataColumnType
+
+DataColumnTypeMap: Dict[DataColumnType, FeatureDataType] = {
+ DataColumnType.STRING: FeatureDataType.STRING,
+ DataColumnType.INTEGER: FeatureDataType.INTEGER,
+ DataColumnType.LONG: FeatureDataType.LONG,
+ DataColumnType.FLOAT: FeatureDataType.FLOAT,
+ DataColumnType.DOUBLE: FeatureDataType.DOUBLE,
+ DataColumnType.BINARY: FeatureDataType.BINARY,
+ DataColumnType.DATETIME: FeatureDataType.DATETIME,
+ DataColumnType.BOOLEAN: FeatureDataType.BOOLEAN,
+}
+
+FeatureDataTypeMap: Dict[str, DataColumnType] = {
+ "String": DataColumnType.STRING,
+ "Integer": DataColumnType.INTEGER,
+ "Long": DataColumnType.LONG,
+ "Float": DataColumnType.FLOAT,
+ "Double": DataColumnType.DOUBLE,
+ "Binary": DataColumnType.BINARY,
+ "Datetime": DataColumnType.DATETIME,
+ "Boolean": DataColumnType.BOOLEAN,
+}
+
+
+class DataColumn(RestTranslatableMixin):
+ """A dataframe column
+
+ :param name: The column name
+ :type name: str
+ :param type: The column data type. Defaults to None.
+ :type type: Optional[union[str, ~azure.ai.ml.entities.DataColumnType]]
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: dict
+ :raises ValidationException: Raised if type is specified and is not a valid DataColumnType or str.
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_featurestore.py
+ :start-after: [START configure_feature_store_entity]
+ :end-before: [END configure_feature_store_entity]
+ :language: Python
+ :dedent: 8
+ :caption: Using DataColumn when creating an index column for a feature store entity
+ """
+
+ def __init__(self, *, name: str, type: Optional[Union[str, DataColumnType]] = None, **kwargs: Any):
+ if isinstance(type, str):
+ type = DataColumnType[type]
+ elif not isinstance(type, DataColumnType):
+ msg = f"Type should be DataColumnType enum string or enum type, found {type}"
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ target=ErrorTarget.DATA,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ self.name = name
+ self.type = type
+
+ def _to_rest_object(self) -> IndexColumn:
+ return IndexColumn(column_name=self.name, data_type=DataColumnTypeMap.get(self.type, None))
+
+ @classmethod
+ def _from_rest_object(cls, obj: IndexColumn) -> "DataColumn":
+ return DataColumn(name=obj.column_name, type=FeatureDataTypeMap.get(obj.data_type, None))
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store_entity/data_column_type.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store_entity/data_column_type.py
new file mode 100644
index 00000000..0bdfa002
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store_entity/data_column_type.py
@@ -0,0 +1,34 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from enum import Enum
+from typing import Any
+
+from azure.core import CaseInsensitiveEnumMeta
+
+
+class DataColumnType(str, Enum, metaclass=CaseInsensitiveEnumMeta):
+ """Dataframe Column Type Enum
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_featurestore.py
+ :start-after: [START configure_feature_store_entity]
+ :end-before: [END configure_feature_store_entity]
+ :language: Python
+ :dedent: 8
+ :caption: Using DataColumnType when instantiating a DataColumn
+ """
+
+ STRING = "string"
+ INTEGER = "integer"
+ LONG = "long"
+ FLOAT = "float"
+ DOUBLE = "double"
+ BINARY = "binary"
+ DATETIME = "datetime"
+ BOOLEAN = "boolean"
+
+ def __str__(self) -> Any:
+ return self.value
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store_entity/feature_store_entity.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store_entity/feature_store_entity.py
new file mode 100644
index 00000000..6a04bc13
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_feature_store_entity/feature_store_entity.py
@@ -0,0 +1,146 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from os import PathLike
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Union
+
+from azure.ai.ml._restclient.v2023_10_01.models import (
+ FeaturestoreEntityContainer,
+ FeaturestoreEntityContainerProperties,
+ FeaturestoreEntityVersion,
+ FeaturestoreEntityVersionProperties,
+)
+from azure.ai.ml._schema._feature_store_entity.feature_store_entity_schema import FeatureStoreEntitySchema
+from azure.ai.ml._utils._arm_id_utils import get_arm_id_object_from_id
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY
+from azure.ai.ml.entities._assets.asset import Asset
+from azure.ai.ml.entities._util import load_from_dict
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+
+from .data_column import DataColumn
+
+
+class FeatureStoreEntity(Asset):
+ """Feature Store Entity
+
+ :param name: The name of the feature store entity resource.
+ :type name: str
+ :param version: The version of the feature store entity resource.
+ :type version: str
+ :param index_columns: Specifies index columns of the feature-store entity resource.
+ :type index_columns: list[~azure.ai.ml.entities.DataColumn]
+ :param stage: The feature store entity stage. Allowed values: Development, Production, Archived.
+ Defaults to "Development".
+ :type stage: Optional[str]
+ :param description: The description of the feature store entity resource. Defaults to None.
+ :type description: Optional[str]
+ :param tags: Tag dictionary. Tags can be added, removed, and updated. Defaults to None.
+ :type tags: Optional[dict[str, str]]
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: dict
+ :raises ValidationException: Raised if stage is specified and is not valid.
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_featurestore.py
+ :start-after: [START configure_feature_store_entity]
+ :end-before: [END configure_feature_store_entity]
+ :language: Python
+ :dedent: 8
+ :caption: Configuring a Feature Store Entity
+ """
+
+ def __init__(
+ self,
+ *,
+ name: str,
+ version: str,
+ index_columns: List[DataColumn],
+ stage: Optional[str] = "Development",
+ description: Optional[str] = None,
+ tags: Optional[Dict[str, str]] = None,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ name=name,
+ version=version,
+ description=description,
+ tags=tags,
+ **kwargs,
+ )
+ if stage and stage not in ["Development", "Production", "Archived"]:
+ msg = f"Stage must be Development, Production, or Archived, found {stage}"
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ target=ErrorTarget.FEATURE_STORE_ENTITY,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ self.index_columns = index_columns
+ self.version = version
+ self.latest_version = None
+ self.stage = stage
+
+ def _to_rest_object(self) -> FeaturestoreEntityVersion:
+ feature_store_entity_version_properties = FeaturestoreEntityVersionProperties(
+ description=self.description,
+ index_columns=[column._to_rest_object() for column in self.index_columns],
+ tags=self.tags,
+ properties=self.properties,
+ stage=self.stage,
+ )
+ return FeaturestoreEntityVersion(properties=feature_store_entity_version_properties)
+
+ @classmethod
+ def _from_rest_object(cls, rest_obj: FeaturestoreEntityVersion) -> "FeatureStoreEntity":
+ rest_object_details: FeaturestoreEntityVersionProperties = rest_obj.properties
+ arm_id_object = get_arm_id_object_from_id(rest_obj.id)
+ featurestoreEntity = FeatureStoreEntity(
+ name=arm_id_object.asset_name,
+ version=arm_id_object.asset_version,
+ index_columns=[DataColumn._from_rest_object(column) for column in rest_object_details.index_columns],
+ stage=rest_object_details.stage,
+ description=rest_object_details.description,
+ tags=rest_object_details.tags,
+ )
+ return featurestoreEntity
+
+ @classmethod
+ def _from_container_rest_object(cls, rest_obj: FeaturestoreEntityContainer) -> "FeatureStoreEntity":
+ rest_object_details: FeaturestoreEntityContainerProperties = rest_obj.properties
+ arm_id_object = get_arm_id_object_from_id(rest_obj.id)
+ featurestoreEntity = FeatureStoreEntity(
+ name=arm_id_object.asset_name,
+ description=rest_object_details.description,
+ tags=rest_object_details.tags,
+ index_columns=[],
+ version="",
+ )
+ featurestoreEntity.latest_version = rest_object_details.latest_version
+ return featurestoreEntity
+
+ @classmethod
+ def _load(
+ cls,
+ data: Optional[Dict] = None,
+ yaml_path: Optional[Union[PathLike, str]] = None,
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "FeatureStoreEntity":
+ data = data or {}
+ params_override = params_override or []
+ context = {
+ BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"),
+ PARAMS_OVERRIDE_KEY: params_override,
+ }
+ loaded_schema = load_from_dict(FeatureStoreEntitySchema, data, context, **kwargs)
+ return FeatureStoreEntity(**loaded_schema)
+
+ def _to_dict(self) -> Dict:
+ res: dict = FeatureStoreEntitySchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/__init__.py
new file mode 100644
index 00000000..43f615c3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/__init__.py
@@ -0,0 +1,16 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+"""AzureML Retrieval Augmented Generation (RAG) utilities."""
+
+from .input._ai_search_config import AzureAISearchConfig
+from .input._index_data_source import IndexDataSource, GitSource, LocalSource
+from .model_config import ModelConfiguration
+
+__all__ = [
+ "ModelConfiguration",
+ "AzureAISearchConfig",
+ "IndexDataSource",
+ "GitSource",
+ "LocalSource",
+]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/data_index_func.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/data_index_func.py
new file mode 100644
index 00000000..884faf82
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/data_index_func.py
@@ -0,0 +1,748 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+# pylint: disable=protected-access
+# pylint: disable=no-member
+
+import json
+import re
+from typing import Any, Dict, Optional, Union
+
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants._common import AssetTypes, LegacyAssetTypes
+from azure.ai.ml.entities import PipelineJob
+from azure.ai.ml.entities._builders.base_node import pipeline_node_decorator
+from azure.ai.ml.entities._credentials import ManagedIdentityConfiguration, UserIdentityConfiguration
+from azure.ai.ml.entities._inputs_outputs import Input, Output
+from azure.ai.ml.entities._job.pipeline._io import PipelineInput
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+from azure.ai.ml.constants._common import DataIndexTypes
+from azure.ai.ml.constants._component import LLMRAGComponentUri
+from azure.ai.ml.entities._indexes.entities.data_index import DataIndex
+
+SUPPORTED_INPUTS = [
+ LegacyAssetTypes.PATH,
+ AssetTypes.URI_FILE,
+ AssetTypes.URI_FOLDER,
+ AssetTypes.MLTABLE,
+]
+
+
+def _build_data_index(io_dict: Union[Dict, DataIndex]):
+ if io_dict is None:
+ return io_dict
+ if isinstance(io_dict, DataIndex):
+ component_io = io_dict
+ else:
+ if isinstance(io_dict, dict):
+ component_io = DataIndex(**io_dict)
+ else:
+ msg = "data_index only support dict and DataIndex"
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.DATA,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+ return component_io
+
+
+@experimental
+@pipeline_node_decorator
+def index_data(
+ *,
+ data_index: DataIndex,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ name: Optional[str] = None,
+ display_name: Optional[str] = None,
+ experiment_name: Optional[str] = None,
+ compute: Optional[str] = None,
+ serverless_instance_type: Optional[str] = None,
+ ml_client: Optional[Any] = None,
+ identity: Optional[Union[ManagedIdentityConfiguration, UserIdentityConfiguration]] = None,
+ input_data_override: Optional[Input] = None,
+ **kwargs,
+) -> PipelineJob:
+ """
+ Create a PipelineJob object which can be used inside dsl.pipeline.
+
+ :keyword data_index: The data index configuration.
+ :type data_index: DataIndex
+ :keyword description: Description of the job.
+ :type description: str
+ :keyword tags: Tag dictionary. Tags can be added, removed, and updated.
+ :type tags: dict[str, str]
+ :keyword name: Name of the job.
+ :type name: str
+ :keyword display_name: Display name of the job.
+ :type display_name: str
+ :keyword experiment_name: Name of the experiment the job will be created under.
+ :type experiment_name: str
+ :keyword compute: The compute resource the job runs on.
+ :type compute: str
+ :keyword serverless_instance_type: The instance type to use for serverless compute.
+ :type serverless_instance_type: Optional[str]
+ :keyword ml_client: The ml client to use for the job.
+ :type ml_client: Any
+ :keyword identity: Identity configuration for the job.
+ :type identity: Optional[Union[ManagedIdentityConfiguration, UserIdentityConfiguration]]
+ :keyword input_data_override: Input data override for the job.
+ Used to pipe output of step into DataIndex Job in a pipeline.
+ :type input_data_override: Optional[Input]
+ :return: A PipelineJob object.
+ :rtype: ~azure.ai.ml.entities.PipelineJob.
+ """
+ data_index = _build_data_index(data_index)
+
+ if data_index.index.type == DataIndexTypes.FAISS:
+ configured_component = data_index_faiss(
+ ml_client,
+ data_index,
+ description,
+ tags,
+ name,
+ display_name,
+ experiment_name,
+ compute,
+ serverless_instance_type,
+ identity,
+ input_data_override,
+ )
+ elif data_index.index.type in (DataIndexTypes.ACS, DataIndexTypes.PINECONE):
+ if kwargs.get("incremental_update", False):
+ configured_component = data_index_incremental_update_hosted(
+ ml_client,
+ data_index,
+ description,
+ tags,
+ name,
+ display_name,
+ experiment_name,
+ compute,
+ serverless_instance_type,
+ identity,
+ input_data_override,
+ )
+ else:
+ configured_component = data_index_hosted(
+ ml_client,
+ data_index,
+ description,
+ tags,
+ name,
+ display_name,
+ experiment_name,
+ compute,
+ serverless_instance_type,
+ identity,
+ input_data_override,
+ )
+ else:
+ raise ValueError(f"Unsupported index type: {data_index.index.type}")
+
+ configured_component.properties["azureml.mlIndexAssetName"] = data_index.name
+ configured_component.properties["azureml.mlIndexAssetKind"] = data_index.index.type
+ configured_component.properties["azureml.mlIndexAssetSource"] = "Data Asset"
+
+ return configured_component
+
+
+# pylint: disable=too-many-statements
+def data_index_incremental_update_hosted(
+ ml_client: Any,
+ data_index: DataIndex,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ name: Optional[str] = None,
+ display_name: Optional[str] = None,
+ experiment_name: Optional[str] = None,
+ compute: Optional[str] = None,
+ serverless_instance_type: Optional[str] = None,
+ identity: Optional[Union[ManagedIdentityConfiguration, UserIdentityConfiguration]] = None,
+ input_data_override: Optional[Input] = None,
+):
+ from azure.ai.ml.entities._indexes.utils import build_model_protocol, pipeline
+
+ crack_and_chunk_and_embed_component = get_component_obj(
+ ml_client, LLMRAGComponentUri.LLM_RAG_CRACK_AND_CHUNK_AND_EMBED
+ )
+
+ if data_index.index.type == DataIndexTypes.ACS:
+ update_index_component = get_component_obj(ml_client, LLMRAGComponentUri.LLM_RAG_UPDATE_ACS_INDEX)
+ elif data_index.index.type == DataIndexTypes.PINECONE:
+ update_index_component = get_component_obj(ml_client, LLMRAGComponentUri.LLM_RAG_UPDATE_PINECONE_INDEX)
+ else:
+ raise ValueError(f"Unsupported hosted index type: {data_index.index.type}")
+
+ register_mlindex_asset_component = get_component_obj(ml_client, LLMRAGComponentUri.LLM_RAG_REGISTER_MLINDEX_ASSET)
+
+ @pipeline( # type: ignore [call-overload]
+ name=name if name else f"data_index_incremental_update_{data_index.index.type}",
+ description=description,
+ tags=tags,
+ display_name=(
+ display_name if display_name else f"LLM - Data to {data_index.index.type.upper()} (Incremental Update)"
+ ),
+ experiment_name=experiment_name,
+ compute=compute,
+ get_component=True,
+ )
+ def data_index_pipeline(
+ input_data: Input,
+ embeddings_model: str,
+ index_config: str,
+ index_connection_id: str,
+ chunk_size: int = 768,
+ chunk_overlap: int = 0,
+ input_glob: str = "**/*",
+ citation_url: Optional[str] = None,
+ citation_replacement_regex: Optional[str] = None,
+ aoai_connection_id: Optional[str] = None,
+ embeddings_container: Optional[Input] = None,
+ ):
+ """
+ Generate embeddings for a `input_data` source and
+ push them into a hosted index (such as Azure Cognitive Search and Pinecone).
+
+ :param input_data: The input data to be indexed.
+ :type input_data: Input
+ :param embeddings_model: The embedding model to use when processing source data chunks.
+ :type embeddings_model: str
+ :param index_config: The configuration for the hosted index.
+ :type index_config: str
+ :param index_connection_id: The connection ID for the hosted index.
+ :type index_connection_id: str
+ :param chunk_size: The size of the chunks to break the input data into.
+ :type chunk_size: int
+ :param chunk_overlap: The number of tokens to overlap between chunks.
+ :type chunk_overlap: int
+ :param input_glob: The glob pattern to use when searching for input data.
+ :type input_glob: str
+ :param citation_url: The URL to use when generating citations for the input data.
+ :type citation_url: str
+ :param citation_replacement_regex: The regex to use when generating citations for the input data.
+ :type citation_replacement_regex: str
+ :param aoai_connection_id: The connection ID for the Azure Open AI service.
+ :type aoai_connection_id: str
+ :param embeddings_container: The container to use when caching embeddings.
+ :type embeddings_container: Input
+ :return: The URI of the generated Azure Cognitive Search index.
+ :rtype: str.
+ """
+ crack_and_chunk_and_embed = crack_and_chunk_and_embed_component(
+ input_data=input_data,
+ input_glob=input_glob,
+ chunk_size=chunk_size,
+ chunk_overlap=chunk_overlap,
+ citation_url=citation_url,
+ citation_replacement_regex=citation_replacement_regex,
+ embeddings_container=embeddings_container,
+ embeddings_model=embeddings_model,
+ embeddings_connection_id=aoai_connection_id,
+ )
+ if compute is None or compute == "serverless":
+ use_automatic_compute(crack_and_chunk_and_embed, instance_type=serverless_instance_type)
+ if optional_pipeline_input_provided(embeddings_container): # type: ignore [arg-type]
+ crack_and_chunk_and_embed.outputs.embeddings = Output(
+ type="uri_folder", path=f"{embeddings_container.path}/{{name}}" # type: ignore [union-attr]
+ )
+ if identity:
+ crack_and_chunk_and_embed.identity = identity
+
+ if data_index.index.type == DataIndexTypes.ACS:
+ update_index = update_index_component(
+ embeddings=crack_and_chunk_and_embed.outputs.embeddings, acs_config=index_config
+ )
+ update_index.environment_variables["AZUREML_WORKSPACE_CONNECTION_ID_ACS"] = index_connection_id
+ elif data_index.index.type == DataIndexTypes.PINECONE:
+ update_index = update_index_component(
+ embeddings=crack_and_chunk_and_embed.outputs.embeddings, pinecone_config=index_config
+ )
+ update_index.environment_variables["AZUREML_WORKSPACE_CONNECTION_ID_PINECONE"] = index_connection_id
+ else:
+ raise ValueError(f"Unsupported hosted index type: {data_index.index.type}")
+ if compute is None or compute == "serverless":
+ use_automatic_compute(update_index, instance_type=serverless_instance_type)
+ if identity:
+ update_index.identity = identity
+
+ register_mlindex_asset = register_mlindex_asset_component(
+ storage_uri=update_index.outputs.index,
+ asset_name=data_index.name,
+ )
+ if compute is None or compute == "serverless":
+ use_automatic_compute(register_mlindex_asset, instance_type=serverless_instance_type)
+ if identity:
+ register_mlindex_asset.identity = identity
+ return {
+ "mlindex_asset_uri": update_index.outputs.index,
+ "mlindex_asset_id": register_mlindex_asset.outputs.asset_id,
+ }
+
+ if input_data_override is not None:
+ input_data = input_data_override
+ else:
+ input_data = Input(
+ type=data_index.source.input_data.type, path=data_index.source.input_data.path # type: ignore [arg-type]
+ )
+
+ index_config = {
+ "index_name": data_index.index.name if data_index.index.name is not None else data_index.name,
+ "full_sync": True,
+ }
+ if data_index.index.config is not None:
+ index_config.update(data_index.index.config)
+
+ component = data_index_pipeline(
+ input_data=input_data,
+ input_glob=data_index.source.input_glob, # type: ignore [arg-type]
+ chunk_size=data_index.source.chunk_size, # type: ignore [arg-type]
+ chunk_overlap=data_index.source.chunk_overlap, # type: ignore [arg-type]
+ citation_url=data_index.source.citation_url,
+ citation_replacement_regex=(
+ json.dumps(data_index.source.citation_url_replacement_regex._to_dict())
+ if data_index.source.citation_url_replacement_regex
+ else None
+ ),
+ embeddings_model=build_model_protocol(data_index.embedding.model),
+ aoai_connection_id=_resolve_connection_id(ml_client, data_index.embedding.connection),
+ embeddings_container=(
+ Input(type=AssetTypes.URI_FOLDER, path=data_index.embedding.cache_path)
+ if data_index.embedding.cache_path
+ else None
+ ),
+ index_config=json.dumps(index_config),
+ index_connection_id=_resolve_connection_id(ml_client, data_index.index.connection), # type: ignore [arg-type]
+ )
+ # Hack until full Component classes are implemented that can annotate the optional parameters properly
+ component.inputs["input_glob"]._meta.optional = True
+ component.inputs["chunk_size"]._meta.optional = True
+ component.inputs["chunk_overlap"]._meta.optional = True
+ component.inputs["citation_url"]._meta.optional = True
+ component.inputs["citation_replacement_regex"]._meta.optional = True
+ component.inputs["aoai_connection_id"]._meta.optional = True
+ component.inputs["embeddings_container"]._meta.optional = True
+
+ if data_index.path:
+ component.outputs.mlindex_asset_uri = Output( # type: ignore [attr-defined]
+ type=AssetTypes.URI_FOLDER, path=data_index.path # type: ignore [arg-type]
+ )
+
+ return component
+
+
+def data_index_faiss(
+ ml_client: Any,
+ data_index: DataIndex,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ name: Optional[str] = None,
+ display_name: Optional[str] = None,
+ experiment_name: Optional[str] = None,
+ compute: Optional[str] = None,
+ serverless_instance_type: Optional[str] = None,
+ identity: Optional[Union[ManagedIdentityConfiguration, UserIdentityConfiguration]] = None,
+ input_data_override: Optional[Input] = None,
+):
+ from azure.ai.ml.entities._indexes.utils import build_model_protocol, pipeline
+
+ crack_and_chunk_component = get_component_obj(ml_client, LLMRAGComponentUri.LLM_RAG_CRACK_AND_CHUNK)
+ generate_embeddings_component = get_component_obj(ml_client, LLMRAGComponentUri.LLM_RAG_GENERATE_EMBEDDINGS)
+ create_faiss_index_component = get_component_obj(ml_client, LLMRAGComponentUri.LLM_RAG_CREATE_FAISS_INDEX)
+ register_mlindex_asset_component = get_component_obj(ml_client, LLMRAGComponentUri.LLM_RAG_REGISTER_MLINDEX_ASSET)
+
+ @pipeline( # type: ignore [call-overload]
+ name=name if name else "data_index_faiss",
+ description=description,
+ tags=tags,
+ display_name=display_name if display_name else "LLM - Data to Faiss",
+ experiment_name=experiment_name,
+ compute=compute,
+ get_component=True,
+ )
+ def data_index_faiss_pipeline(
+ input_data: Input,
+ embeddings_model: str,
+ chunk_size: int = 1024,
+ data_source_glob: str = None, # type: ignore [assignment]
+ data_source_url: str = None, # type: ignore [assignment]
+ document_path_replacement_regex: str = None, # type: ignore [assignment]
+ aoai_connection_id: str = None, # type: ignore [assignment]
+ embeddings_container: Input = None, # type: ignore [assignment]
+ ):
+ """
+ Generate embeddings for a `input_data` source and create a Faiss index from them.
+
+ :param input_data: The input data to be indexed.
+ :type input_data: Input
+ :param embeddings_model: The embedding model to use when processing source data chunks.
+ :type embeddings_model: str
+ :param chunk_size: The size of the chunks to break the input data into.
+ :type chunk_size: int
+ :param data_source_glob: The glob pattern to use when searching for input data.
+ :type data_source_glob: str
+ :param data_source_url: The URL to use when generating citations for the input data.
+ :type data_source_url: str
+ :param document_path_replacement_regex: The regex to use when generating citations for the input data.
+ :type document_path_replacement_regex: str
+ :param aoai_connection_id: The connection ID for the Azure Open AI service.
+ :type aoai_connection_id: str
+ :param embeddings_container: The container to use when caching embeddings.
+ :type embeddings_container: Input
+ :return: The URI of the generated Faiss index.
+ :rtype: str.
+ """
+ crack_and_chunk = crack_and_chunk_component(
+ input_data=input_data,
+ input_glob=data_source_glob,
+ chunk_size=chunk_size,
+ data_source_url=data_source_url,
+ document_path_replacement_regex=document_path_replacement_regex,
+ )
+ if compute is None or compute == "serverless":
+ use_automatic_compute(crack_and_chunk, instance_type=serverless_instance_type)
+ if identity:
+ crack_and_chunk.identity = identity
+
+ generate_embeddings = generate_embeddings_component(
+ chunks_source=crack_and_chunk.outputs.output_chunks,
+ embeddings_container=embeddings_container,
+ embeddings_model=embeddings_model,
+ )
+ if compute is None or compute == "serverless":
+ use_automatic_compute(generate_embeddings, instance_type=serverless_instance_type)
+ if optional_pipeline_input_provided(aoai_connection_id): # type: ignore [arg-type]
+ generate_embeddings.environment_variables["AZUREML_WORKSPACE_CONNECTION_ID_AOAI"] = aoai_connection_id
+ if optional_pipeline_input_provided(embeddings_container): # type: ignore [arg-type]
+ generate_embeddings.outputs.embeddings = Output(
+ type="uri_folder", path=f"{embeddings_container.path}/{{name}}"
+ )
+ if identity:
+ generate_embeddings.identity = identity
+
+ create_faiss_index = create_faiss_index_component(embeddings=generate_embeddings.outputs.embeddings)
+ if compute is None or compute == "serverless":
+ use_automatic_compute(create_faiss_index, instance_type=serverless_instance_type)
+ if identity:
+ create_faiss_index.identity = identity
+
+ register_mlindex_asset = register_mlindex_asset_component(
+ storage_uri=create_faiss_index.outputs.index,
+ asset_name=data_index.name,
+ )
+ if compute is None or compute == "serverless":
+ use_automatic_compute(register_mlindex_asset, instance_type=serverless_instance_type)
+ if identity:
+ register_mlindex_asset.identity = identity
+ return {
+ "mlindex_asset_uri": create_faiss_index.outputs.index,
+ "mlindex_asset_id": register_mlindex_asset.outputs.asset_id,
+ }
+
+ if input_data_override is not None:
+ input_data = input_data_override
+ else:
+ input_data = Input(
+ type=data_index.source.input_data.type, path=data_index.source.input_data.path # type: ignore [arg-type]
+ )
+
+ component = data_index_faiss_pipeline(
+ input_data=input_data,
+ embeddings_model=build_model_protocol(data_index.embedding.model),
+ chunk_size=data_index.source.chunk_size, # type: ignore [arg-type]
+ data_source_glob=data_index.source.input_glob, # type: ignore [arg-type]
+ data_source_url=data_index.source.citation_url, # type: ignore [arg-type]
+ document_path_replacement_regex=(
+ json.dumps(data_index.source.citation_url_replacement_regex._to_dict()) # type: ignore [arg-type]
+ if data_index.source.citation_url_replacement_regex
+ else None
+ ),
+ aoai_connection_id=_resolve_connection_id(
+ ml_client, data_index.embedding.connection
+ ), # type: ignore [arg-type]
+ embeddings_container=(
+ Input(type=AssetTypes.URI_FOLDER, path=data_index.embedding.cache_path) # type: ignore [arg-type]
+ if data_index.embedding.cache_path
+ else None
+ ),
+ )
+ # Hack until full Component classes are implemented that can annotate the optional parameters properly
+ component.inputs["data_source_glob"]._meta.optional = True
+ component.inputs["data_source_url"]._meta.optional = True
+ component.inputs["document_path_replacement_regex"]._meta.optional = True
+ component.inputs["aoai_connection_id"]._meta.optional = True
+ component.inputs["embeddings_container"]._meta.optional = True
+ if data_index.path:
+ component.outputs.mlindex_asset_uri = Output(
+ type=AssetTypes.URI_FOLDER, path=data_index.path # type: ignore [arg-type]
+ )
+
+ return component
+
+
+# pylint: disable=too-many-statements
+def data_index_hosted(
+ ml_client: Any,
+ data_index: DataIndex,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ name: Optional[str] = None,
+ display_name: Optional[str] = None,
+ experiment_name: Optional[str] = None,
+ compute: Optional[str] = None,
+ serverless_instance_type: Optional[str] = None,
+ identity: Optional[Union[ManagedIdentityConfiguration, UserIdentityConfiguration]] = None,
+ input_data_override: Optional[Input] = None,
+):
+ from azure.ai.ml.entities._indexes.utils import build_model_protocol, pipeline
+
+ crack_and_chunk_component = get_component_obj(ml_client, LLMRAGComponentUri.LLM_RAG_CRACK_AND_CHUNK)
+ generate_embeddings_component = get_component_obj(ml_client, LLMRAGComponentUri.LLM_RAG_GENERATE_EMBEDDINGS)
+
+ if data_index.index.type == DataIndexTypes.ACS:
+ update_index_component = get_component_obj(ml_client, LLMRAGComponentUri.LLM_RAG_UPDATE_ACS_INDEX)
+ elif data_index.index.type == DataIndexTypes.PINECONE:
+ update_index_component = get_component_obj(ml_client, LLMRAGComponentUri.LLM_RAG_UPDATE_PINECONE_INDEX)
+ else:
+ raise ValueError(f"Unsupported hosted index type: {data_index.index.type}")
+
+ register_mlindex_asset_component = get_component_obj(ml_client, LLMRAGComponentUri.LLM_RAG_REGISTER_MLINDEX_ASSET)
+
+ @pipeline( # type: ignore [call-overload]
+ name=name if name else f"data_index_{data_index.index.type}",
+ description=description,
+ tags=tags,
+ display_name=display_name if display_name else f"LLM - Data to {data_index.index.type.upper()}",
+ experiment_name=experiment_name,
+ compute=compute,
+ get_component=True,
+ )
+ def data_index_pipeline(
+ input_data: Input,
+ embeddings_model: str,
+ index_config: str,
+ index_connection_id: str,
+ chunk_size: int = 1024,
+ data_source_glob: str = None, # type: ignore [assignment]
+ data_source_url: str = None, # type: ignore [assignment]
+ document_path_replacement_regex: str = None, # type: ignore [assignment]
+ aoai_connection_id: str = None, # type: ignore [assignment]
+ embeddings_container: Input = None, # type: ignore [assignment]
+ ):
+ """
+ Generate embeddings for a `input_data` source
+ and push them into a hosted index (such as Azure Cognitive Search and Pinecone).
+
+ :param input_data: The input data to be indexed.
+ :type input_data: Input
+ :param embeddings_model: The embedding model to use when processing source data chunks.
+ :type embeddings_model: str
+ :param index_config: The configuration for the hosted index.
+ :type index_config: str
+ :param index_connection_id: The connection ID for the hosted index.
+ :type index_connection_id: str
+ :param chunk_size: The size of the chunks to break the input data into.
+ :type chunk_size: int
+ :param data_source_glob: The glob pattern to use when searching for input data.
+ :type data_source_glob: str
+ :param data_source_url: The URL to use when generating citations for the input data.
+ :type data_source_url: str
+ :param document_path_replacement_regex: The regex to use when generating citations for the input data.
+ :type document_path_replacement_regex: str
+ :param aoai_connection_id: The connection ID for the Azure Open AI service.
+ :type aoai_connection_id: str
+ :param embeddings_container: The container to use when caching embeddings.
+ :type embeddings_container: Input
+ :return: The URI of the generated Azure Cognitive Search index.
+ :rtype: str.
+ """
+ crack_and_chunk = crack_and_chunk_component(
+ input_data=input_data,
+ input_glob=data_source_glob,
+ chunk_size=chunk_size,
+ data_source_url=data_source_url,
+ document_path_replacement_regex=document_path_replacement_regex,
+ )
+ if compute is None or compute == "serverless":
+ use_automatic_compute(crack_and_chunk, instance_type=serverless_instance_type)
+ if identity:
+ crack_and_chunk.identity = identity
+
+ generate_embeddings = generate_embeddings_component(
+ chunks_source=crack_and_chunk.outputs.output_chunks,
+ embeddings_container=embeddings_container,
+ embeddings_model=embeddings_model,
+ )
+ if compute is None or compute == "serverless":
+ use_automatic_compute(generate_embeddings, instance_type=serverless_instance_type)
+ if optional_pipeline_input_provided(aoai_connection_id): # type: ignore [arg-type]
+ generate_embeddings.environment_variables["AZUREML_WORKSPACE_CONNECTION_ID_AOAI"] = aoai_connection_id
+ if optional_pipeline_input_provided(embeddings_container): # type: ignore [arg-type]
+ generate_embeddings.outputs.embeddings = Output(
+ type="uri_folder", path=f"{embeddings_container.path}/{{name}}"
+ )
+ if identity:
+ generate_embeddings.identity = identity
+
+ if data_index.index.type == DataIndexTypes.ACS:
+ update_index = update_index_component(
+ embeddings=generate_embeddings.outputs.embeddings, acs_config=index_config
+ )
+ update_index.environment_variables["AZUREML_WORKSPACE_CONNECTION_ID_ACS"] = index_connection_id
+ elif data_index.index.type == DataIndexTypes.PINECONE:
+ update_index = update_index_component(
+ embeddings=generate_embeddings.outputs.embeddings, pinecone_config=index_config
+ )
+ update_index.environment_variables["AZUREML_WORKSPACE_CONNECTION_ID_PINECONE"] = index_connection_id
+ else:
+ raise ValueError(f"Unsupported hosted index type: {data_index.index.type}")
+ if compute is None or compute == "serverless":
+ use_automatic_compute(update_index, instance_type=serverless_instance_type)
+ if identity:
+ update_index.identity = identity
+
+ register_mlindex_asset = register_mlindex_asset_component(
+ storage_uri=update_index.outputs.index,
+ asset_name=data_index.name,
+ )
+ if compute is None or compute == "serverless":
+ use_automatic_compute(register_mlindex_asset, instance_type=serverless_instance_type)
+ if identity:
+ register_mlindex_asset.identity = identity
+ return {
+ "mlindex_asset_uri": update_index.outputs.index,
+ "mlindex_asset_id": register_mlindex_asset.outputs.asset_id,
+ }
+
+ if input_data_override is not None:
+ input_data = input_data_override
+ else:
+ input_data = Input(
+ type=data_index.source.input_data.type, path=data_index.source.input_data.path # type: ignore [arg-type]
+ )
+
+ index_config = {
+ "index_name": data_index.index.name if data_index.index.name is not None else data_index.name,
+ }
+ if data_index.index.config is not None:
+ index_config.update(data_index.index.config)
+
+ component = data_index_pipeline(
+ input_data=input_data,
+ embeddings_model=build_model_protocol(data_index.embedding.model),
+ index_config=json.dumps(index_config),
+ index_connection_id=_resolve_connection_id(ml_client, data_index.index.connection), # type: ignore [arg-type]
+ chunk_size=data_index.source.chunk_size, # type: ignore [arg-type]
+ data_source_glob=data_index.source.input_glob, # type: ignore [arg-type]
+ data_source_url=data_index.source.citation_url, # type: ignore [arg-type]
+ document_path_replacement_regex=(
+ json.dumps(data_index.source.citation_url_replacement_regex._to_dict()) # type: ignore [arg-type]
+ if data_index.source.citation_url_replacement_regex
+ else None
+ ),
+ aoai_connection_id=_resolve_connection_id(
+ ml_client, data_index.embedding.connection # type: ignore [arg-type]
+ ),
+ embeddings_container=(
+ Input(type=AssetTypes.URI_FOLDER, path=data_index.embedding.cache_path) # type: ignore [arg-type]
+ if data_index.embedding.cache_path
+ else None
+ ),
+ )
+ # Hack until full Component classes are implemented that can annotate the optional parameters properly
+ component.inputs["data_source_glob"]._meta.optional = True
+ component.inputs["data_source_url"]._meta.optional = True
+ component.inputs["document_path_replacement_regex"]._meta.optional = True
+ component.inputs["aoai_connection_id"]._meta.optional = True
+ component.inputs["embeddings_container"]._meta.optional = True
+
+ if data_index.path:
+ component.outputs.mlindex_asset_uri = Output(
+ type=AssetTypes.URI_FOLDER, path=data_index.path # type: ignore [arg-type]
+ )
+
+ return component
+
+
+def optional_pipeline_input_provided(input: Optional[PipelineInput]):
+ """
+ Checks if optional pipeline inputs are provided.
+
+ :param input: The pipeline input to check.
+ :type input: Optional[PipelineInput]
+ :return: True if the input is not None and has a value, False otherwise.
+ :rtype: bool.
+ """
+ return input is not None and input._data is not None
+
+
+def use_automatic_compute(component, instance_count=1, instance_type=None):
+ """
+ Configure input `component` to use automatic compute with `instance_count` and `instance_type`.
+
+ This avoids the need to provision a compute cluster to run the component.
+ :param component: The component to configure.
+ :type component: Any
+ :param instance_count: The number of instances to use.
+ :type instance_count: int
+ :param instance_type: The type of instance to use.
+ :type instance_type: str
+ :return: The configured component.
+ :rtype: Any.
+ """
+ component.set_resources(
+ instance_count=instance_count,
+ instance_type=instance_type,
+ properties={"compute_specification": {"automatic": True}},
+ )
+ return component
+
+
+def get_component_obj(ml_client, component_uri):
+ from azure.ai.ml import MLClient
+
+ if not isinstance(component_uri, str):
+ # Assume Component object
+ return component_uri
+
+ matches = re.match(
+ r"azureml://registries/(?P<registry_name>.*)/components/(?P<component_name>.*)"
+ r"/(?P<identifier_type>.*)/(?P<identifier_name>.*)",
+ component_uri,
+ )
+ if matches is None:
+ from azure.ai.ml import load_component
+
+ # Assume local path to component
+ return load_component(source=component_uri)
+
+ registry_name = matches.group("registry_name")
+ registry_client = MLClient(
+ subscription_id=ml_client.subscription_id,
+ resource_group_name=ml_client.resource_group_name,
+ credential=ml_client._credential,
+ registry_name=registry_name,
+ )
+ component_obj = registry_client.components.get(
+ matches.group("component_name"),
+ **{matches.group("identifier_type").rstrip("s"): matches.group("identifier_name")},
+ )
+ return component_obj
+
+
+def _resolve_connection_id(ml_client, connection: Optional[str] = None) -> Optional[str]:
+ if connection is None:
+ return None
+
+ if isinstance(connection, str):
+ from azure.ai.ml._utils._arm_id_utils import AMLNamedArmId
+
+ connection_name = AMLNamedArmId(connection).asset_name
+
+ connection = ml_client.connections.get(connection_name)
+ if connection is None:
+ return None
+ return connection.id # type: ignore [attr-defined]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/entities/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/entities/__init__.py
new file mode 100644
index 00000000..fdf8caba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/entities/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/entities/data_index.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/entities/data_index.py
new file mode 100644
index 00000000..094d19aa
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/entities/data_index.py
@@ -0,0 +1,243 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+"""DataIndex entities."""
+
+from typing import Dict, Optional
+
+from azure.ai.ml.constants._common import DataIndexTypes
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.entities._assets import Data
+from azure.ai.ml.entities._inputs_outputs.utils import _remove_empty_values
+from azure.ai.ml.entities._mixins import DictMixin
+
+
+@experimental
+class CitationRegex(DictMixin):
+ """
+ :keyword match_pattern: Regex to match citation in the citation_url + input file path.
+ e.g. '(.*)/articles/(.*)(\\.[^.]+)$'.
+ :type match_pattern: str
+ :keyword replacement_pattern: Replacement string for citation. e.g. '\\1/\\2'.
+ :type replacement_pattern: str
+ """
+
+ def __init__(
+ self,
+ *,
+ match_pattern: str,
+ replacement_pattern: str,
+ ):
+ """Initialize a CitationRegex object."""
+ self.match_pattern = match_pattern
+ self.replacement_pattern = replacement_pattern
+
+ def _to_dict(self) -> Dict:
+ """Convert the Source object to a dict.
+ :return: The dictionary representation of the class
+ :rtype: Dict
+ """
+ keys = [
+ "match_pattern",
+ "replacement_pattern",
+ ]
+ result = {key: getattr(self, key) for key in keys}
+ return _remove_empty_values(result)
+
+
+@experimental
+class IndexSource(DictMixin):
+ """Congifuration for the destination index to write processed data to.
+ :keyword input_data: Input Data to index files from. MLTable type inputs will use `mode: eval_mount`.
+ :type input_data: Data
+ :keyword input_glob: Connection reference to use for embedding model information,
+ only needed for hosted embeddings models (such as Azure OpenAI).
+ :type input_glob: str, optional
+ :keyword chunk_size: Maximum number of tokens to put in each chunk.
+ :type chunk_size: int, optional
+ :keyword chunk_overlap: Number of tokens to overlap between chunks.
+ :type chunk_overlap: int, optional
+ :keyword citation_url: Base URL to join with file paths to create full source file URL for chunk metadata.
+ :type citation_url: str, optional
+ :keyword citation_url_replacement_regex: Regex match and replacement patterns for citation url. Useful if the paths
+ in `input_data` don't match the desired citation format.
+ :type citation_url_replacement_regex: CitationRegex, optional
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if the IndexSource object cannot be validated.
+ Details will be provided in the error message.
+ """
+
+ def __init__(
+ self,
+ *,
+ input_data: Data,
+ input_glob: Optional[str] = None,
+ chunk_size: Optional[int] = None,
+ chunk_overlap: Optional[int] = None,
+ citation_url: Optional[str] = None,
+ citation_url_replacement_regex: Optional[CitationRegex] = None,
+ ):
+ """Initialize a IndexSource object."""
+ self.input_data = input_data
+ self.input_glob = input_glob
+ self.chunk_size = chunk_size
+ self.chunk_overlap = chunk_overlap
+ self.citation_url = citation_url
+ self.citation_url_replacement_regex = citation_url_replacement_regex
+
+ def _to_dict(self) -> Dict:
+ """Convert the Source object to a dict.
+ :return: The dictionary representation of the class
+ :rtype: Dict
+ """
+ keys = [
+ "input_data",
+ "input_glob",
+ "chunk_size",
+ "chunk_overlap",
+ "citation_url",
+ "citation_url_replacement_regex",
+ ]
+ result = {key: getattr(self, key) for key in keys}
+ return _remove_empty_values(result)
+
+
+@experimental
+class Embedding(DictMixin):
+ """Congifuration for the destination index to write processed data to.
+ :keyword model: The model to use to embed data. E.g. 'hugging_face://model/sentence-transformers/all-mpnet-base-v2'
+ or 'azure_open_ai://deployment/{deployment_name}/model/{model_name}'
+ :type model: str
+ :keyword connection: Connection reference to use for embedding model information,
+ only needed for hosted embeddings models (such as Azure OpenAI).
+ :type connection: str, optional
+ :keyword cache_path: Folder containing previously generated embeddings.
+ Should be parent folder of the 'embeddings' output path used for for this component.
+ Will compare input data to existing embeddings and only embed changed/new data, reusing existing chunks.
+ :type cache_path: str, optional
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if the Embedding object cannot be validated.
+ Details will be provided in the error message.
+ """
+
+ def __init__(
+ self,
+ *,
+ model: str,
+ connection: Optional[str] = None,
+ cache_path: Optional[str] = None,
+ ):
+ """Initialize a Embedding object."""
+ self.model = model
+ self.connection = connection
+ self.cache_path = cache_path
+
+ def _to_dict(self) -> Dict:
+ """Convert the Source object to a dict.
+ :return: The dictionary representation of the class
+ :rtype: Dict
+ """
+ keys = [
+ "model",
+ "connection",
+ "cache_path",
+ ]
+ result = {key: getattr(self, key) for key in keys}
+ return _remove_empty_values(result)
+
+
+@experimental
+class IndexStore(DictMixin):
+ """Congifuration for the destination index to write processed data to.
+ :keyword type: The type of index to write to. Currently supported types are 'acs', 'pinecone', and 'faiss'.
+ :type type: str
+ :keyword name: Name of index to update/create, only needed for hosted indexes
+ (such as Azure Cognitive Search and Pinecone).
+ :type name: str, optional
+ :keyword connection: Connection reference to use for index information,
+ only needed for hosted indexes (such as Azure Cognitive Search and Pinecone).
+ :type connection: str, optional
+ :keyword config: Configuration for the index. Configuration for the index.
+ Primary use is to configure AI Search and Pinecone specific settings.
+ Such as custom `field_mapping` for known field types.
+ :type config: dict, optional
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if the IndexStore object cannot be validated.
+ Details will be provided in the error message.
+ """
+
+ def __init__(
+ self,
+ *,
+ type: str = DataIndexTypes.FAISS,
+ name: Optional[str] = None,
+ connection: Optional[str] = None,
+ config: Optional[Dict] = None,
+ ):
+ """Initialize a IndexStore object."""
+ self.type = type
+ self.name = name
+ self.connection = connection
+ self.config = config
+
+ def _to_dict(self) -> Dict:
+ """Convert the Source object to a dict.
+ :return: The dictionary representation of the class
+ :rtype: Dict
+ """
+ keys = ["type", "name", "connection", "config"]
+ result = {key: getattr(self, key) for key in keys}
+ return _remove_empty_values(result)
+
+
+@experimental
+class DataIndex(Data):
+ """Data asset with a creating data index job.
+ :param name: Name of the asset.
+ :type name: str
+ :param path: The path to the asset being created by data index job.
+ :type path: str
+ :param source: The source data to be indexed.
+ :type source: IndexSource
+ :param embedding: The embedding model to use when processing source data chunks.
+ :type embedding: Embedding
+ :param index: The destination index to write processed data to.
+ :type index: IndexStore
+ :param version: Version of the asset created by running this DataIndex Job.
+ :type version: str
+ :param description: Description of the resource.
+ :type description: str
+ :param tags: Tag dictionary. Tags can be added, removed, and updated.
+ :type tags: dict[str, str]
+ :param properties: The asset property dictionary.
+ :type properties: dict[str, str]
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: dict
+ """
+
+ def __init__(
+ self,
+ *,
+ name: str,
+ source: IndexSource,
+ embedding: Embedding,
+ index: IndexStore,
+ incremental_update: bool = False,
+ path: Optional[str] = None,
+ version: Optional[str] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ properties: Optional[Dict] = None,
+ **kwargs,
+ ):
+ """Initialize a DataIndex object."""
+ super().__init__(
+ name=name,
+ version=version,
+ description=description,
+ tags=tags,
+ properties=properties,
+ path=path,
+ **kwargs,
+ )
+ self.source = source
+ self.embedding = embedding
+ self.index = index
+ self.incremental_update = incremental_update
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/__init__.py
new file mode 100644
index 00000000..fdf8caba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/_ai_search_config.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/_ai_search_config.py
new file mode 100644
index 00000000..b2163c40
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/_ai_search_config.py
@@ -0,0 +1,31 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# General todo: need to determine which args are required or optional when parsed out into groups like this.
+# General todo: move these to more permanent locations?
+
+# Defines stuff related to the resulting created index, like the index type.
+
+from typing import Optional
+from azure.ai.ml._utils._experimental import experimental
+
+
+@experimental
+class AzureAISearchConfig:
+ """Config class for creating an Azure AI Search index.
+
+ :param index_name: The name of the Azure AI Search index.
+ :type index_name: Optional[str]
+ :param connection_id: The Azure AI Search connection ID.
+ :type connection_id: Optional[str]
+ """
+
+ def __init__(
+ self,
+ *,
+ index_name: Optional[str] = None,
+ connection_id: Optional[str] = None,
+ ) -> None:
+ self.index_name = index_name
+ self.connection_id = connection_id
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/_index_config.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/_index_config.py
new file mode 100644
index 00000000..0eec691a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/_index_config.py
@@ -0,0 +1,47 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Optional
+
+
+class IndexConfig: # pylint: disable=too-many-instance-attributes
+ """Convenience class that contains all config values that for index creation that are
+ NOT specific to the index source data or the created index type. Meant for internal use only
+ to simplify function headers. The user-entry point is a function that
+ should still contain all the fields in this class as individual function parameters.
+
+ Params omitted for brevity and to avoid maintaining duplicate docs. See index creation function
+ for actual parameter descriptions.
+ """
+
+ def __init__(
+ self,
+ *,
+ output_index_name: str,
+ vector_store: str,
+ data_source_url: Optional[str] = None,
+ chunk_size: Optional[int] = None,
+ chunk_overlap: Optional[int] = None,
+ input_glob: Optional[str] = None,
+ max_sample_files: Optional[int] = None,
+ chunk_prepend_summary: Optional[bool] = None,
+ document_path_replacement_regex: Optional[str] = None,
+ embeddings_container: Optional[str] = None,
+ embeddings_model: str,
+ aoai_connection_id: str,
+ _dry_run: bool = False
+ ):
+ self.output_index_name = output_index_name
+ self.vector_store = vector_store
+ self.data_source_url = data_source_url
+ self.chunk_size = chunk_size
+ self.chunk_overlap = chunk_overlap
+ self.input_glob = input_glob
+ self.max_sample_files = max_sample_files
+ self.chunk_prepend_summary = chunk_prepend_summary
+ self.document_path_replacement_regex = document_path_replacement_regex
+ self.embeddings_container = embeddings_container
+ self.embeddings_model = embeddings_model
+ self.aoai_connection_id = aoai_connection_id
+ self._dry_run = _dry_run
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/_index_data_source.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/_index_data_source.py
new file mode 100644
index 00000000..92b62b6b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/input/_index_data_source.py
@@ -0,0 +1,62 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from typing import Union
+
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.entities._inputs_outputs import Input
+from azure.ai.ml.constants._common import IndexInputType
+
+
+# General todo: need to determine which args are required or optional when parsed out into groups like this.
+# General todo: move these to more permanent locations?
+
+
+# Defines stuff related to supplying inputs for an index AKA the base data.
+@experimental
+class IndexDataSource:
+ """Base class for configs that define data that will be processed into an ML index.
+ This class should not be instantiated directly. Use one of its child classes instead.
+
+ :param input_type: A type enum describing the source of the index. Used to avoid
+ direct type checking.
+ :type input_type: Union[str, ~azure.ai.ml.constants._common.IndexInputType]
+ """
+
+ def __init__(self, *, input_type: Union[str, IndexInputType]):
+ self.input_type = input_type
+
+
+# Field bundle for creating an index from files located in a Git repo.
+# TODO Does git_url need to specifically be an SSH or HTTPS style link?
+# TODO What is git connection id?
+@experimental
+class GitSource(IndexDataSource):
+ """Config class for creating an ML index from files located in a git repository.
+
+ :param url: A link to the repository to use.
+ :type url: str
+ :param branch_name: The name of the branch to use from the target repository.
+ :type branch_name: str
+ :param connection_id: The connection ID for GitHub
+ :type connection_id: str
+ """
+
+ def __init__(self, *, url: str, branch_name: str, connection_id: str):
+ self.url = url
+ self.branch_name = branch_name
+ self.connection_id = connection_id
+ super().__init__(input_type=IndexInputType.GIT)
+
+
+@experimental
+class LocalSource(IndexDataSource):
+ """Config class for creating an ML index from a collection of local files.
+
+ :param input_data: An input object describing the local location of index source files.
+ :type input_data: ~azure.ai.ml.Input
+ """
+
+ def __init__(self, *, input_data: str): # todo Make sure type of input_data is correct
+ self.input_data = Input(type="uri_folder", path=input_data)
+ super().__init__(input_type=IndexInputType.LOCAL)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/model_config.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/model_config.py
new file mode 100644
index 00000000..c9e54da4
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/model_config.py
@@ -0,0 +1,122 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from dataclasses import dataclass
+from typing import Any, Dict, Optional
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.entities._workspace.connections.workspace_connection import WorkspaceConnection
+from azure.ai.ml.entities._workspace.connections.connection_subtypes import (
+ AzureOpenAIConnection,
+ AadCredentialConfiguration,
+)
+
+
+@experimental
+@dataclass
+class ModelConfiguration:
+ """Configuration for a embedding model.
+
+ :param api_base: The base URL for the API.
+ :type api_base: Optional[str]
+ :param api_key: The API key.
+ :type api_key: Optional[str]
+ :param api_version: The API version.
+ :type api_version: Optional[str]
+ :param model_name: The name of the model.
+ :type model_name: Optional[str]
+ :param model_name: The deployment name of the model.
+ :type model_name: Optional[str]
+ :param connection_name: The name of the workspace connection of this model.
+ :type connection_name: Optional[str]
+ :param connection_type: The type of the workspace connection of this model.
+ :type connection_type: Optional[str]
+ :param model_kwargs: Additional keyword arguments for the model.
+ :type model_kwargs: Dict[str, Any]
+ """
+
+ api_base: Optional[str]
+ api_key: Optional[str]
+ api_version: Optional[str]
+ connection_name: Optional[str]
+ connection_type: Optional[str]
+ model_name: Optional[str]
+ deployment_name: Optional[str]
+ model_kwargs: Dict[str, Any]
+
+ def __init__(
+ self,
+ *,
+ api_base: Optional[str],
+ api_key: Optional[str],
+ api_version: Optional[str],
+ connection_name: Optional[str],
+ connection_type: Optional[str],
+ model_name: Optional[str],
+ deployment_name: Optional[str],
+ model_kwargs: Dict[str, Any]
+ ):
+ self.api_base = api_base
+ self.api_key = api_key
+ self.api_version = api_version
+ self.connection_name = connection_name
+ self.connection_type = connection_type
+ self.model_name = model_name
+ self.deployment_name = deployment_name
+ self.model_kwargs = model_kwargs
+
+ @staticmethod
+ def from_connection(
+ connection: WorkspaceConnection,
+ model_name: Optional[str] = None,
+ deployment_name: Optional[str] = None,
+ **kwargs
+ ) -> "ModelConfiguration":
+ """Create an model configuration from a Connection.
+
+ :param connection: The WorkspaceConnection object.
+ :type connection: ~azure.ai.ml.entities.WorkspaceConnection
+ :param model_name: The name of the model.
+ :type model_name: Optional[str]
+ :param deployment_name: The name of the deployment.
+ :type deployment_name: Optional[str]
+ :return: The model configuration.
+ :rtype: ~azure.ai.ml.entities._indexes.entities.ModelConfiguration
+ :raises TypeError: If the connection is not an AzureOpenAIConnection.
+ :raises ValueError: If the connection does not contain an OpenAI key.
+ """
+ if isinstance(connection, AzureOpenAIConnection) or camel_to_snake(connection.type) == "azure_open_ai":
+ connection_type = "azure_open_ai"
+ api_version = connection.api_version # type: ignore[attr-defined]
+ if not model_name or not deployment_name:
+ raise ValueError("Please specify model_name and deployment_name.")
+ elif connection.type and connection.type.lower() == "serverless":
+ connection_type = "serverless"
+ api_version = None
+ if not connection.id:
+ raise TypeError("The connection id is missing from the serverless connection object.")
+ else:
+ raise TypeError("Connection object is not supported.")
+
+ if isinstance(connection.credentials, AadCredentialConfiguration):
+ key = None
+ else:
+ key = connection.credentials.get("key") # type: ignore[union-attr]
+ if key is None and connection_type == "azure_open_ai":
+ import os
+
+ if "AZURE_OPENAI_API_KEY" in os.environ:
+ key = os.getenv("AZURE_OPENAI_API_KEY")
+ else:
+ raise ValueError("Unable to retrieve openai key from connection object or env variable.")
+
+ return ModelConfiguration(
+ api_base=connection.target,
+ api_key=key,
+ api_version=api_version,
+ connection_name=connection.name,
+ connection_type=connection_type,
+ model_name=model_name,
+ deployment_name=deployment_name,
+ model_kwargs=kwargs,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/__init__.py
new file mode 100644
index 00000000..f65f5505
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/__init__.py
@@ -0,0 +1,10 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+"""AzureML Retrieval Augmented Generation (RAG) utilities."""
+
+from ._models import build_model_protocol
+from ._open_ai_utils import build_open_ai_protocol, build_connection_id
+from ._pipeline_decorator import pipeline
+
+__all__ = ["build_model_protocol", "build_open_ai_protocol", "build_connection_id", "pipeline"]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/_models.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/_models.py
new file mode 100644
index 00000000..d3e8c952
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/_models.py
@@ -0,0 +1,25 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+"""DataIndex embedding model helpers."""
+import re
+from typing import Optional
+
+OPEN_AI_PROTOCOL_TEMPLATE = "azure_open_ai://deployment/{}/model/{}"
+OPEN_AI_PROTOCOL_REGEX_PATTERN = OPEN_AI_PROTOCOL_TEMPLATE.format(".*", ".*")
+OPEN_AI_SHORT_FORM_PROTOCOL_TEMPLATE = "azure_open_ai://deployments?/{}"
+OPEN_AI_PROTOCOL_REGEX_PATTERN = OPEN_AI_SHORT_FORM_PROTOCOL_TEMPLATE.format(".*")
+
+HUGGINGFACE_PROTOCOL_TEMPLATE = "hugging_face://model/{}"
+HUGGINGFACE_PROTOCOL_REGEX_PATTERN = HUGGINGFACE_PROTOCOL_TEMPLATE.format(".*")
+
+
+def build_model_protocol(model: Optional[str] = None):
+ if not model or re.match(OPEN_AI_PROTOCOL_REGEX_PATTERN, model, re.IGNORECASE):
+ return model
+ if re.match(OPEN_AI_SHORT_FORM_PROTOCOL_TEMPLATE, model, re.IGNORECASE):
+ return model
+ if re.match(HUGGINGFACE_PROTOCOL_REGEX_PATTERN, model, re.IGNORECASE):
+ return model
+
+ return OPEN_AI_PROTOCOL_TEMPLATE.format(model, model)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/_open_ai_utils.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/_open_ai_utils.py
new file mode 100644
index 00000000..d38a447f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/_open_ai_utils.py
@@ -0,0 +1,36 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Optional
+
+from azure.ai.ml._utils._arm_id_utils import is_ARM_id_for_resource
+from azure.ai.ml._scope_dependent_operations import OperationScope
+
+OPEN_AI_PROTOCOL_TEMPLATE = "azure_open_ai://deployment/{}/model/{}"
+
+
+def build_open_ai_protocol(
+ model: Optional[str] = None,
+ deployment: Optional[str] = None,
+):
+ if not deployment or not model:
+ return None
+ return OPEN_AI_PROTOCOL_TEMPLATE.format(deployment, model)
+
+
+def build_connection_id(id: Optional[str], scope: OperationScope):
+ if not id or not scope.subscription_id or not scope.resource_group_name or not scope.workspace_name:
+ return id
+
+ if is_ARM_id_for_resource(id, "connections", True):
+ return id
+
+ # pylint: disable=line-too-long
+ template = "/subscriptions/{subscription_id}/resourceGroups/{resource_group_name}/providers/Microsoft.MachineLearningServices/workspaces/{workspace_name}/connections/{id}"
+ return template.format(
+ subscription_id=scope.subscription_id,
+ resource_group_name=scope.resource_group_name,
+ workspace_name=scope.workspace_name,
+ id=id,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/_pipeline_decorator.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/_pipeline_decorator.py
new file mode 100644
index 00000000..e70f97f2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_indexes/utils/_pipeline_decorator.py
@@ -0,0 +1,248 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import inspect
+import logging
+from functools import wraps
+from pathlib import Path
+from typing import Any, Callable, Dict, Optional, TypeVar, Union, overload
+
+from typing_extensions import ParamSpec
+
+from azure.ai.ml.entities import Data, Model, PipelineJob, PipelineJobSettings
+from azure.ai.ml.entities._builders.pipeline import Pipeline
+from azure.ai.ml.entities._inputs_outputs import Input
+from azure.ai.ml.entities._job.pipeline._io import NodeOutput, PipelineInput, _GroupAttrDict
+from azure.ai.ml.entities._job.pipeline._pipeline_expression import PipelineExpression
+from azure.ai.ml.exceptions import UserErrorException
+
+from azure.ai.ml.dsl._pipeline_component_builder import PipelineComponentBuilder, _is_inside_dsl_pipeline_func
+from azure.ai.ml.dsl._pipeline_decorator import _validate_args
+from azure.ai.ml.dsl._settings import _dsl_settings_stack
+from azure.ai.ml.dsl._utils import _resolve_source_file
+
+SUPPORTED_INPUT_TYPES = (
+ PipelineInput,
+ NodeOutput,
+ Input,
+ Model,
+ Data, # For the case use a Data object as an input, we will convert it to Input object
+ Pipeline, # For the case use a pipeline node as the input, we use its only one output as the real input.
+ str,
+ bool,
+ int,
+ float,
+ PipelineExpression,
+ _GroupAttrDict,
+)
+module_logger = logging.getLogger(__name__)
+
+T = TypeVar("T")
+P = ParamSpec("P")
+
+
+# Overload the returns a decorator when func is None
+@overload
+def pipeline(
+ func: None,
+ *,
+ name: Optional[str] = None,
+ version: Optional[str] = None,
+ display_name: Optional[str] = None,
+ description: Optional[str] = None,
+ experiment_name: Optional[str] = None,
+ tags: Optional[Dict[str, str]] = None,
+ **kwargs: Any,
+) -> Callable[[Callable[P, T]], Callable[P, PipelineJob]]: ...
+
+
+# Overload the returns a decorated function when func isn't None
+@overload
+def pipeline(
+ func: Callable[P, T],
+ *,
+ name: Optional[str] = None,
+ version: Optional[str] = None,
+ display_name: Optional[str] = None,
+ description: Optional[str] = None,
+ experiment_name: Optional[str] = None,
+ tags: Optional[Dict[str, str]] = None,
+ **kwargs: Any,
+) -> Callable[P, PipelineJob]: ...
+
+
+def pipeline(
+ func: Optional[Callable[P, T]] = None,
+ *,
+ name: Optional[str] = None,
+ version: Optional[str] = None,
+ display_name: Optional[str] = None,
+ description: Optional[str] = None,
+ experiment_name: Optional[str] = None,
+ tags: Optional[Union[Dict[str, str], str]] = None,
+ **kwargs: Any,
+) -> Union[Callable[[Callable[P, T]], Callable[P, PipelineJob]], Callable[P, PipelineJob]]:
+ """Build a pipeline which contains all component nodes defined in this function.
+
+ :param func: The user pipeline function to be decorated.
+ :type func: types.FunctionType
+ :keyword name: The name of pipeline component, defaults to function name.
+ :paramtype name: str
+ :keyword version: The version of pipeline component, defaults to "1".
+ :paramtype version: str
+ :keyword display_name: The display name of pipeline component, defaults to function name.
+ :paramtype display_name: str
+ :keyword description: The description of the built pipeline.
+ :paramtype description: str
+ :keyword experiment_name: Name of the experiment the job will be created under, \
+ if None is provided, experiment will be set to current directory.
+ :paramtype experiment_name: str
+ :keyword tags: The tags of pipeline component.
+ :paramtype tags: dict[str, str]
+ :return: Either
+ * A decorator, if `func` is None
+ * The decorated `func`
+
+ :rtype: Union[
+ Callable[[Callable], Callable[..., ~azure.ai.ml.entities.PipelineJob]],
+ Callable[P, ~azure.ai.ml.entities.PipelineJob]
+
+ ]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../../../../samples/ml_samples_pipeline_job_configurations.py
+ :start-after: [START configure_pipeline]
+ :end-before: [END configure_pipeline]
+ :language: python
+ :dedent: 8
+ :caption: Shows how to create a pipeline using this decorator.
+ """
+
+ # get_component force pipeline to return Pipeline instead of PipelineJob so we can set optional argument
+ # need to remove get_component and rely on azure.ai.ml.dsl.pipeline
+ get_component = kwargs.get("get_component", False)
+
+ def pipeline_decorator(func: Callable[P, T]) -> Callable:
+ if not isinstance(func, Callable): # type: ignore
+ raise UserErrorException(f"Dsl pipeline decorator accept only function type, got {type(func)}.")
+
+ non_pipeline_inputs = kwargs.get("non_pipeline_inputs", []) or kwargs.get("non_pipeline_parameters", [])
+ # compute variable names changed from default_compute_targe -> compute -> default_compute -> none
+ # to support legacy usage, we support them with priority.
+ compute = kwargs.get("compute", None)
+ default_compute_target = kwargs.get("default_compute_target", None)
+ default_compute_target = kwargs.get("default_compute", None) or default_compute_target
+ continue_on_step_failure = kwargs.get("continue_on_step_failure", None)
+ on_init = kwargs.get("on_init", None)
+ on_finalize = kwargs.get("on_finalize", None)
+
+ default_datastore = kwargs.get("default_datastore", None)
+ force_rerun = kwargs.get("force_rerun", None)
+ job_settings = {
+ "default_datastore": default_datastore,
+ "continue_on_step_failure": continue_on_step_failure,
+ "force_rerun": force_rerun,
+ "default_compute": default_compute_target,
+ "on_init": on_init,
+ "on_finalize": on_finalize,
+ }
+ func_entry_path = _resolve_source_file()
+ if not func_entry_path:
+ func_path = Path(inspect.getfile(func))
+ # in notebook, func_path may be a fake path and will raise error when trying to resolve this fake path
+ if func_path.exists():
+ func_entry_path = func_path.resolve().absolute()
+
+ job_settings = {k: v for k, v in job_settings.items() if v is not None}
+ pipeline_builder = PipelineComponentBuilder(
+ func=func,
+ name=name,
+ version=version,
+ display_name=display_name,
+ description=description,
+ default_datastore=default_datastore,
+ tags=tags,
+ source_path=str(func_entry_path),
+ non_pipeline_inputs=non_pipeline_inputs,
+ )
+
+ @wraps(func)
+ def wrapper(*args: P.args, **kwargs: P.kwargs) -> Union[Pipeline, PipelineJob]:
+ # Default args will be added here.
+ # Node: push/pop stack here instead of put it inside build()
+ # Because we only want to enable dsl settings on top level pipeline
+ _dsl_settings_stack.push() # use this stack to track on_init/on_finalize settings
+ try:
+ # Convert args to kwargs
+ provided_positional_kwargs = _validate_args(func, args, kwargs, non_pipeline_inputs)
+
+ # When pipeline supports variable params, update pipeline component to support the inputs in **kwargs.
+ pipeline_parameters = {
+ k: v for k, v in provided_positional_kwargs.items() if k not in non_pipeline_inputs
+ }
+ pipeline_builder._update_inputs(pipeline_parameters)
+
+ non_pipeline_params_dict = {
+ k: v for k, v in provided_positional_kwargs.items() if k in non_pipeline_inputs
+ }
+
+ # TODO: cache built pipeline component
+ pipeline_component = pipeline_builder.build(
+ user_provided_kwargs=provided_positional_kwargs,
+ non_pipeline_inputs_dict=non_pipeline_params_dict,
+ non_pipeline_inputs=non_pipeline_inputs,
+ )
+ finally:
+ # use `finally` to ensure pop operation from the stack
+ dsl_settings = _dsl_settings_stack.pop()
+
+ # update on_init/on_finalize settings if init/finalize job is set
+ if dsl_settings.init_job_set:
+ job_settings["on_init"] = dsl_settings.init_job_name(pipeline_component.jobs)
+ if dsl_settings.finalize_job_set:
+ job_settings["on_finalize"] = dsl_settings.finalize_job_name(pipeline_component.jobs)
+
+ # TODO: pass compute & default_compute separately?
+ common_init_args: Any = {
+ "experiment_name": experiment_name,
+ "component": pipeline_component,
+ "inputs": pipeline_parameters,
+ "tags": tags,
+ }
+ built_pipeline: Any = None
+ if _is_inside_dsl_pipeline_func() or get_component:
+ # on_init/on_finalize is not supported for pipeline component
+ if job_settings.get("on_init") is not None or job_settings.get("on_finalize") is not None:
+ raise UserErrorException("On_init/on_finalize is not supported for pipeline component.")
+ # Build pipeline node instead of pipeline job if inside dsl.
+ built_pipeline = Pipeline(_from_component_func=True, **common_init_args)
+ if job_settings:
+ module_logger.warning(
+ ("Job settings %s on pipeline function %r are ignored when using inside PipelineJob."),
+ job_settings,
+ func.__name__,
+ )
+ else:
+ built_pipeline = PipelineJob(
+ jobs=pipeline_component.jobs,
+ compute=compute,
+ settings=PipelineJobSettings(**job_settings),
+ **common_init_args,
+ )
+
+ return built_pipeline
+
+ # Bug Item number: 2883169
+ wrapper._is_dsl_func = True # type: ignore
+ wrapper._job_settings = job_settings # type: ignore
+ wrapper._pipeline_builder = pipeline_builder # type: ignore
+ return wrapper
+
+ # enable use decorator without "()" if all arguments are default values
+ if func is not None:
+ return pipeline_decorator(func)
+ return pipeline_decorator
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/__init__.py
new file mode 100644
index 00000000..90affdda
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/__init__.py
@@ -0,0 +1,73 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+"""This package includes the type classes which could be used in dsl.pipeline,
+command function, or any other place that requires job inputs/outputs.
+
+.. note::
+
+ The following pseudo-code shows how to create a pipeline with such classes.
+
+ .. code-block:: python
+
+ @pipeline()
+ def some_pipeline(
+ input_param: Input(type="uri_folder", path="xxx", mode="ro_mount"),
+ int_param0: Input(type="integer", default=0, min=-3, max=10),
+ int_param1 = 2
+ str_param = 'abc',
+ ):
+ pass
+
+
+ The following pseudo-code shows how to create a command with such classes.
+
+ .. code-block:: python
+
+ my_command = command(
+ name="my_command",
+ display_name="my_command",
+ description="This is a command",
+ tags=dict(),
+ command="python train.py --input-data ${{inputs.input_data}} --lr ${{inputs.learning_rate}}",
+ code="./src",
+ compute="cpu-cluster",
+ environment="my-env:1",
+ distribution=MpiDistribution(process_count_per_instance=4),
+ environment_variables=dict(foo="bar"),
+ # Customers can still do this:
+ # resources=Resources(instance_count=2, instance_type="STANDARD_D2"),
+ # limits=Limits(timeout=300),
+ inputs={
+ "float": Input(type="number", default=1.1, min=0, max=5),
+ "integer": Input(type="integer", default=2, min=-1, max=4),
+ "integer1": 2,
+ "string0": Input(type="string", default="default_str0"),
+ "string1": "default_str1",
+ "boolean": Input(type="boolean", default=False),
+ "uri_folder": Input(type="uri_folder", path="https://my-blob/path/to/data", mode="ro_mount"),
+ "uri_file": Input(type="uri_file", path="https://my-blob/path/to/data", mode="download"),
+ },
+ outputs={"my_model": Output(type="mlflow_model")},
+ )
+ node = my_command()
+"""
+
+from .enum_input import EnumInput
+from .external_data import Database, FileSystem
+from .group_input import GroupInput
+from .input import Input
+from .output import Output
+from .utils import _get_param_with_standard_annotation, is_group
+
+__all__ = [
+ "Input",
+ "Output",
+ "EnumInput",
+ "GroupInput",
+ "is_group",
+ "_get_param_with_standard_annotation",
+ "Database",
+ "FileSystem",
+]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/base.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/base.py
new file mode 100644
index 00000000..3a726b38
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/base.py
@@ -0,0 +1,34 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from typing import Any
+
+from azure.ai.ml._schema.component.input_output import SUPPORTED_PARAM_TYPES
+from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin
+
+
+class _InputOutputBase(DictMixin, RestTranslatableMixin):
+ def __init__(
+ self,
+ *,
+ # pylint: disable=redefined-builtin
+ type: Any,
+ # pylint: disable=unused-argument
+ **kwargs: Any,
+ ) -> None:
+ """Base class for Input & Output class.
+
+ This class is introduced to support literal output in the future.
+
+ :param type: The type of the Input/Output.
+ :type type: str
+ """
+ self.type = type
+
+ def _is_literal(self) -> bool:
+ """Check whether input is a literal
+
+ :return: True if this input is literal input.
+ :rtype: bool
+ """
+ return self.type in SUPPORTED_PARAM_TYPES
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/enum_input.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/enum_input.py
new file mode 100644
index 00000000..d6c88eef
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/enum_input.py
@@ -0,0 +1,133 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from enum import EnumMeta
+from typing import Any, Iterable, List, Optional, Sequence, Tuple, Union
+
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+
+from .input import Input
+
+
+class EnumInput(Input):
+ """Enum parameter parse the value according to its enum values."""
+
+ def __init__(
+ self,
+ *,
+ enum: Optional[Union[EnumMeta, Sequence[str]]] = None,
+ default: Any = None,
+ description: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ """Enum parameter parse the value according to its enum values.
+
+ :param enum: Enum values.
+ :type enum: Union[EnumMeta, Sequence[str]]
+ :param default: Default value of the parameter
+ :type default: Any
+ :param description: Description of the parameter
+ :type description: str
+ """
+ enum_values = self._assert_enum_valid(enum)
+ self._enum_class: Optional[EnumMeta] = None
+ # This is used to parse enum class instead of enum str value if a enum class is provided.
+ if isinstance(enum, EnumMeta):
+ self._enum_class = enum
+ self._str2enum = dict(zip(enum_values, enum))
+ else:
+ self._str2enum = {v: v for v in enum_values}
+ super().__init__(type="string", default=default, enum=enum_values, description=description)
+
+ @property
+ def _allowed_types(self) -> Tuple:
+ return (
+ (str,)
+ if not self._enum_class
+ else (
+ self._enum_class,
+ str,
+ )
+ )
+
+ @classmethod
+ def _assert_enum_valid(cls, enum: Optional[Union[EnumMeta, Sequence[str]]]) -> List:
+ """Check whether the enum is valid and return the values of the enum.
+
+ :param enum: The enum to validate
+ :type enum: Type
+ :return: The enum values
+ :rtype: List[Any]
+ """
+ if isinstance(enum, EnumMeta):
+ enum_values = [str(option.value) for option in enum] # type: ignore
+ elif isinstance(enum, Iterable):
+ enum_values = list(enum)
+ else:
+ msg = "enum must be a subclass of Enum or an iterable."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ target=ErrorTarget.PIPELINE,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+ if len(enum_values) <= 0:
+ msg = "enum must have enum values."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ target=ErrorTarget.PIPELINE,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+ if any(not isinstance(v, str) for v in enum_values):
+ msg = "enum values must be str type."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ target=ErrorTarget.PIPELINE,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+ return enum_values
+
+ def _parse(self, val: str) -> Any:
+ """Parse the enum value from a string value or the enum value.
+
+ :param val: The string to parse
+ :type val: str
+ :return: The enum value
+ :rtype: Any
+ """
+ if val is None:
+ return val
+
+ if self._enum_class and isinstance(val, self._enum_class):
+ return val # Directly return the enum value if it is the enum.
+
+ if val not in self._str2enum:
+ msg = "Not a valid enum value: '{}', valid values: {}"
+ raise ValidationException(
+ message=msg.format(val, ", ".join(self.enum)),
+ no_personal_data_message=msg.format("[val]", "[enum]"),
+ error_category=ErrorCategory.USER_ERROR,
+ target=ErrorTarget.PIPELINE,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ return self._str2enum[val]
+
+ def _update_default(self, default_value: Any) -> None:
+ """Enum parameter support updating values with a string value.
+
+ :param default_value: The default value for the input
+ :type default_value: Any
+ """
+ enum_val = self._parse(default_value)
+ if self._enum_class and isinstance(enum_val, self._enum_class):
+ enum_val = enum_val.value
+ self.default = enum_val
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/external_data.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/external_data.py
new file mode 100644
index 00000000..8a4fe21f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/external_data.py
@@ -0,0 +1,207 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from inspect import Parameter
+from typing import Dict, List, Optional, Union
+
+from azure.ai.ml.constants._component import ExternalDataType
+from azure.ai.ml.entities._inputs_outputs.utils import _remove_empty_values
+from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin
+
+
+class StoredProcedureParameter(DictMixin, RestTranslatableMixin):
+ """Define a stored procedure parameter class for DataTransfer import database task.
+
+ :keyword name: The name of the database stored procedure.
+ :paramtype name: str
+ :keyword value: The value of the database stored procedure.
+ :paramtype value: str
+ :keyword type: The type of the database stored procedure.
+ :paramtype type: str
+ """
+
+ def __init__(
+ self,
+ *,
+ name: Optional[str] = None,
+ value: Optional[str] = None,
+ type: Optional[str] = None, # pylint: disable=redefined-builtin
+ ) -> None:
+ self.type = type
+ self.name = name
+ self.value = value
+
+
+class Database(DictMixin, RestTranslatableMixin):
+ """Define a database class for a DataTransfer Component or Job.
+
+ :keyword query: The SQL query to retrieve data from the database.
+ :paramtype query: str
+ :keyword table_name: The name of the database table.
+ :paramtype table_name: str
+ :keyword stored_procedure: The name of the stored procedure.
+ :paramtype stored_procedure: str
+ :keyword stored_procedure_params: The parameters for the stored procedure.
+ :paramtype stored_procedure_params: List
+ :keyword connection: The connection string for the database.
+ The credential information should be stored in the connection.
+ :paramtype connection: str
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if the Database object cannot be successfully validated.
+ Details will be provided in the error message.
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_input_output_configurations.py
+ :start-after: [START configure_database]
+ :end-before: [END configure_database]
+ :language: python
+ :dedent: 8
+ :caption: Create a database and querying a database table.
+ """
+
+ _EMPTY = Parameter.empty
+
+ def __init__(
+ self,
+ *,
+ query: Optional[str] = None,
+ table_name: Optional[str] = None,
+ stored_procedure: Optional[str] = None,
+ stored_procedure_params: Optional[List[Dict]] = None,
+ connection: Optional[str] = None,
+ ) -> None:
+ # As an annotation, it is not allowed to initialize the name.
+ # The name will be updated by the annotated variable name.
+ self.name = None
+ self.type = ExternalDataType.DATABASE
+ self.connection = connection
+ self.query = query
+ self.table_name = table_name
+ self.stored_procedure = stored_procedure
+ self.stored_procedure_params = stored_procedure_params
+
+ def _to_dict(self, remove_name: bool = True) -> Dict:
+ """Convert the Source object to a dict.
+
+ :param remove_name: Whether to remove the `name` key from the dict representation. Defaults to True.
+ :type remove_name: bool
+ :return: The dictionary representation of the class
+ :rtype: Dict
+ """
+ keys = [
+ "name",
+ "type",
+ "query",
+ "stored_procedure",
+ "stored_procedure_params",
+ "connection",
+ "table_name",
+ ]
+ if remove_name:
+ keys.remove("name")
+ result = {key: getattr(self, key) for key in keys}
+ res: dict = _remove_empty_values(result)
+ return res
+
+ def _to_rest_object(self) -> Dict:
+ # this is for component rest object when using Source as component inputs, as for job input usage,
+ # rest object is generated by extracting Source's properties, see details in to_rest_dataset_literal_inputs()
+ result = self._to_dict()
+ return result
+
+ def _update_name(self, name: str) -> None:
+ self.name = name
+
+ @classmethod
+ def _from_rest_object(cls, obj: Dict) -> "Database":
+ return Database(**obj)
+
+ @property
+ def stored_procedure_params(self) -> Optional[List]:
+ """Get or set the parameters for the stored procedure.
+
+ :return: The parameters for the stored procedure.
+ :rtype: List[StoredProcedureParameter]
+ """
+
+ return self._stored_procedure_params
+
+ @stored_procedure_params.setter
+ def stored_procedure_params(self, value: Union[Dict[str, str], List, None]) -> None:
+ """Set the parameters for the stored procedure.
+
+ :param value: The parameters for the stored procedure.
+ :type value: Union[Dict[str, str], StoredProcedureParameter, None]
+ """
+ if value is None:
+ self._stored_procedure_params = value
+ else:
+ if not isinstance(value, list):
+ value = [value]
+ for index, item in enumerate(value):
+ if isinstance(item, dict):
+ value[index] = StoredProcedureParameter(**item)
+ self._stored_procedure_params = value
+
+
+class FileSystem(DictMixin, RestTranslatableMixin):
+ """Define a file system class of a DataTransfer Component or Job.
+
+ e.g. source_s3 = FileSystem(path='s3://my_bucket/my_folder', connection='azureml:my_s3_connection')
+
+ :param path: The path to which the input is pointing. Could be pointing to the path of file system. Default is None.
+ :type path: str
+ :param connection: Connection is workspace, we didn't support storage connection here, need leverage workspace
+ connection to store these credential info. Default is None.
+ :type connection: str
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Source cannot be successfully validated.
+ Details will be provided in the error message.
+ """
+
+ _EMPTY = Parameter.empty
+
+ def __init__(
+ self,
+ *,
+ path: Optional[str] = None,
+ connection: Optional[str] = None,
+ ) -> None:
+ self.type = ExternalDataType.FILE_SYSTEM
+ self.name: Optional[str] = None
+ self.connection = connection
+ self.path: Optional[str] = None
+
+ if path is not None and not isinstance(path, str):
+ # this logic will make dsl data binding expression working in the same way as yaml
+ # it's written to handle InputOutputBase, but there will be loop import if we import InputOutputBase here
+ self.path = str(path)
+ else:
+ self.path = path
+
+ def _to_dict(self, remove_name: bool = True) -> Dict:
+ """Convert the Source object to a dict.
+
+ :param remove_name: Whether to remove the `name` key from the dict representation. Defaults to True.
+ :type remove_name: bool
+ :return: The dictionary representation of the object
+ :rtype: Dict
+ """
+ keys = ["name", "path", "type", "connection"]
+ if remove_name:
+ keys.remove("name")
+ result = {key: getattr(self, key) for key in keys}
+ res: dict = _remove_empty_values(result)
+ return res
+
+ def _to_rest_object(self) -> Dict:
+ # this is for component rest object when using Source as component inputs, as for job input usage,
+ # rest object is generated by extracting Source's properties, see details in to_rest_dataset_literal_inputs()
+ result = self._to_dict()
+ return result
+
+ def _update_name(self, name: str) -> None:
+ self.name = name
+
+ @classmethod
+ def _from_rest_object(cls, obj: Dict) -> "FileSystem":
+ return FileSystem(**obj)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/group_input.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/group_input.py
new file mode 100644
index 00000000..e7fc565c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/group_input.py
@@ -0,0 +1,251 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+import copy
+from enum import Enum as PyEnum
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
+
+from azure.ai.ml.constants._component import IOConstants
+from azure.ai.ml.exceptions import ErrorTarget, UserErrorException, ValidationException
+
+from .input import Input
+from .output import Output
+from .utils import is_group
+
+# avoid circular import error
+if TYPE_CHECKING:
+ from azure.ai.ml.entities._job.pipeline._io import _GroupAttrDict
+
+
+class GroupInput(Input):
+ """Define a group input object.
+
+ :param values: The values of the group input.
+ :type values: dict
+ :param _group_class: The class representing the group.
+ :type _group_class: Any
+ """
+
+ def __init__(self, values: dict, _group_class: Any) -> None:
+ super().__init__(type=IOConstants.GROUP_TYPE_NAME)
+ self.assert_group_value_valid(values)
+ self.values: Any = values
+ # Create empty default by values
+ # Note Output do not have default so just set a None
+ self.default = self._create_default()
+ # Save group class for init function generation
+ self._group_class = _group_class
+
+ @classmethod
+ def _create_group_attr_dict(cls, dct: dict) -> "_GroupAttrDict":
+ from .._job.pipeline._io import _GroupAttrDict
+
+ return _GroupAttrDict(dct)
+
+ @classmethod
+ def _is_group_attr_dict(cls, obj: object) -> bool:
+ from .._job.pipeline._io import _GroupAttrDict
+
+ return isinstance(obj, _GroupAttrDict)
+
+ def __getattr__(self, item: Any) -> Any:
+ try:
+ # TODO: Bug Item number: 2883363
+ return super().__getattr__(item) # type: ignore
+ except AttributeError:
+ # TODO: why values is not a dict in some cases?
+ if isinstance(self.values, dict) and item in self.values:
+ return self.values[item]
+ raise
+
+ def _create_default(self) -> "_GroupAttrDict":
+ from .._job.pipeline._io import PipelineInput
+
+ default_dict: dict = {}
+ # Note: no top-level group names at this time.
+ for k, v in self.values.items():
+ # skip create default for outputs or port inputs
+ if isinstance(v, Output):
+ continue
+
+ # Create PipelineInput object if not subgroup
+ if not isinstance(v, GroupInput):
+ default_dict[k] = PipelineInput(name=k, data=v.default, meta=v)
+ continue
+ # Copy and insert k into group names for subgroup
+ default_dict[k] = copy.deepcopy(v.default)
+ default_dict[k].insert_group_name_for_items(k)
+ return self._create_group_attr_dict(default_dict)
+
+ @classmethod
+ def assert_group_value_valid(cls, values: Dict) -> None:
+ """Check if all values in the group are supported types.
+
+ :param values: The values of the group.
+ :type values: dict
+ :raises ValueError: If a value in the group is not a supported type or if a parameter name is duplicated.
+ :raises UserErrorException: If a value in the group has an unsupported type.
+ """
+ names = set()
+ msg = (
+ f"Parameter {{!r}} with type {{!r}} is not supported in group. "
+ f"Supported types are: {list(IOConstants.INPUT_TYPE_COMBINATION.keys())}"
+ )
+ for key, value in values.items():
+ if not isinstance(value, (Input, Output)):
+ raise ValueError(msg.format(key, type(value).__name__))
+ if value.type is None:
+ # Skip check for parameter translated from pipeline job (lost type)
+ continue
+ if value.type not in IOConstants.INPUT_TYPE_COMBINATION and not isinstance(value, GroupInput):
+ raise UserErrorException(msg.format(key, value.type))
+ if key in names:
+ if not isinstance(value, Input):
+ raise ValueError(f"Duplicate parameter name {value.name!r} found in Group values.")
+ names.add(key)
+
+ def flatten(self, group_parameter_name: str) -> Dict:
+ """Flatten the group and return all parameters.
+
+ :param group_parameter_name: The name of the group parameter.
+ :type group_parameter_name: str
+ :return: A dictionary of flattened parameters.
+ :rtype: dict
+ """
+ all_parameters = {}
+ group_parameter_name = group_parameter_name if group_parameter_name else ""
+ for key, value in self.values.items():
+ flattened_name = ".".join([group_parameter_name, key])
+ if isinstance(value, GroupInput):
+ all_parameters.update(value.flatten(flattened_name))
+ else:
+ all_parameters[flattened_name] = value
+ return all_parameters
+
+ def _to_dict(self) -> dict:
+ attr_dict = super()._to_dict()
+ attr_dict["values"] = {k: v._to_dict() for k, v in self.values.items()} # pylint: disable=protected-access
+ return attr_dict
+
+ @staticmethod
+ def custom_class_value_to_attr_dict(value: Any, group_names: Optional[List] = None) -> Any:
+ """Convert a custom parameter group class object to GroupAttrDict.
+
+ :param value: The value to convert.
+ :type value: any
+ :param group_names: The names of the parent groups.
+ :type group_names: list
+ :return: The converted value as a GroupAttrDict.
+ :rtype: GroupAttrDict or any
+ """
+ if not is_group(value):
+ return value
+ group_definition = getattr(value, IOConstants.GROUP_ATTR_NAME)
+ group_names = [*group_names] if group_names else []
+ attr_dict = {}
+ from .._job.pipeline._io import PipelineInput
+
+ for k, v in value.__dict__.items():
+ if is_group(v):
+ attr_dict[k] = GroupInput.custom_class_value_to_attr_dict(v, [*group_names, k])
+ continue
+ data = v.value if isinstance(v, PyEnum) else v
+ if GroupInput._is_group_attr_dict(data):
+ attr_dict[k] = data
+ continue
+ attr_dict[k] = PipelineInput(name=k, meta=group_definition.get(k), data=data, group_names=group_names)
+ return GroupInput._create_group_attr_dict(attr_dict)
+
+ @staticmethod
+ def validate_conflict_keys(keys: List) -> None:
+ """Validate conflicting keys in a flattened input dictionary, like {'a.b.c': 1, 'a.b': 1}.
+
+ :param keys: The keys to validate.
+ :type keys: list
+ :raises ValidationException: If conflicting keys are found.
+ """
+ conflict_msg = "Conflict parameter key '%s' and '%s'."
+
+ def _group_count(s: str) -> int:
+ return len(s.split(".")) - 1
+
+ # Sort order by group numbers
+ keys = sorted(list(keys), key=_group_count)
+ for idx, key1 in enumerate(keys[:-1]):
+ for key2 in keys[idx + 1 :]:
+ if _group_count(key2) == 0:
+ continue
+ # Skip case a.b.c and a.b.c1
+ if _group_count(key1) == _group_count(key2):
+ continue
+ if not key2.startswith(key1):
+ continue
+ # Invalid case 'a.b' in 'a.b.c'
+ raise ValidationException(
+ message=conflict_msg % (key1, key2),
+ no_personal_data_message=conflict_msg % ("[key1]", "[key2]"),
+ target=ErrorTarget.PIPELINE,
+ )
+
+ @staticmethod
+ def restore_flattened_inputs(inputs: Dict) -> Dict:
+ """Restore flattened inputs to structured groups.
+
+ :param inputs: The flattened input dictionary.
+ :type inputs: dict
+ :return: The restored structured inputs.
+ :rtype: dict
+ """
+ GroupInput.validate_conflict_keys(list(inputs.keys()))
+ restored_inputs = {}
+ group_inputs: Dict = {}
+ # 1. Build all group parameters dict
+ for name, data in inputs.items():
+ # for a.b.c, group names is [a, b]
+ name_splits = name.split(".")
+ group_names, param_name = name_splits[:-1], name_splits[-1]
+ if not group_names:
+ restored_inputs[name] = data
+ continue
+ # change {'a.b.c': data} -> {'a': {'b': {'c': data}}}
+ target_dict = group_inputs
+ for group_name in group_names:
+ if group_name not in target_dict:
+ target_dict[group_name] = {}
+ target_dict = target_dict[group_name]
+ target_dict[param_name] = data
+
+ def restore_from_dict_recursively(_data: dict) -> Union[GroupInput, "_GroupAttrDict"]:
+ for key, val in _data.items():
+ if type(val) == dict: # pylint: disable=unidiomatic-typecheck
+ _data[key] = restore_from_dict_recursively(val)
+ # Create GroupInput for definition and _GroupAttrDict for PipelineInput
+ # Regard all Input class as parameter definition, as data will not appear in group now.
+ if all(isinstance(val, Input) for val in _data.values()):
+ return GroupInput(values=_data, _group_class=None)
+ return GroupInput._create_group_attr_dict(dct=_data)
+
+ # 2. Rehydrate dict to GroupInput(definition) or GroupAttrDict.
+ for name, data in group_inputs.items():
+ restored_inputs[name] = restore_from_dict_recursively(data)
+ return restored_inputs
+
+ def _update_default(self, default_value: object = None) -> None:
+ default_cls = type(default_value)
+
+ # Assert '__dsl_group__' must in the class of default value
+ if self._is_group_attr_dict(default_value):
+ self.default = default_value
+ self.optional = False
+ return
+ if default_value and not is_group(default_cls):
+ raise ValueError(f"Default value must be instance of parameter group, got {default_cls}.")
+ if hasattr(default_value, "__dict__"):
+ # Convert default value with customer type to _AttrDict
+ self.default = GroupInput.custom_class_value_to_attr_dict(default_value)
+ # Update item annotation
+ for key, annotation in self.values.items():
+ if not hasattr(default_value, key):
+ continue
+ annotation._update_default(getattr(default_value, key)) # pylint: disable=protected-access
+ self.optional = default_value is None
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/input.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/input.py
new file mode 100644
index 00000000..4a945108
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/input.py
@@ -0,0 +1,547 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=redefined-builtin
+# disable redefined-builtin to use type/min/max as argument name
+
+import math
+from inspect import Parameter
+from typing import Any, Dict, List, Optional, Union, overload
+
+from typing_extensions import Literal
+
+from azure.ai.ml.constants._component import ComponentParameterTypes, IOConstants
+from azure.ai.ml.entities._assets.intellectual_property import IntellectualProperty
+from azure.ai.ml.exceptions import (
+ ErrorCategory,
+ ErrorTarget,
+ UserErrorException,
+ ValidationErrorType,
+ ValidationException,
+)
+
+from .base import _InputOutputBase
+from .utils import _get_param_with_standard_annotation, _remove_empty_values
+
+
+class Input(_InputOutputBase): # pylint: disable=too-many-instance-attributes
+ """Initialize an Input object.
+
+ :keyword type: The type of the data input. Accepted values are
+ 'uri_folder', 'uri_file', 'mltable', 'mlflow_model', 'custom_model', 'integer', 'number', 'string', and
+ 'boolean'. Defaults to 'uri_folder'.
+ :paramtype type: str
+ :keyword path: The path to the input data. Paths can be local paths, remote data uris, or a registered AzureML asset
+ ID.
+ :paramtype path: Optional[str]
+ :keyword mode: The access mode of the data input. Accepted values are:
+ * 'ro_mount': Mount the data to the compute target as read-only,
+ * 'download': Download the data to the compute target,
+ * 'direct': Pass in the URI as a string to be accessed at runtime
+ :paramtype mode: Optional[str]
+ :keyword path_on_compute: The access path of the data input for compute
+ :paramtype path_on_compute: Optional[str]
+ :keyword default: The default value of the input. If a default is set, the input data will be optional.
+ :paramtype default: Union[str, int, float, bool]
+ :keyword min: The minimum value for the input. If a value smaller than the minimum is passed to the job, the job
+ execution will fail.
+ :paramtype min: Union[int, float]
+ :keyword max: The maximum value for the input. If a value larger than the maximum is passed to a job, the job
+ execution will fail.
+ :paramtype max: Union[int, float]
+ :keyword optional: Specifies if the input is optional.
+ :paramtype optional: Optional[bool]
+ :keyword description: Description of the input
+ :paramtype description: Optional[str]
+ :keyword datastore: The datastore to upload local files to.
+ :paramtype datastore: str
+ :keyword intellectual_property: Intellectual property for the input.
+ :paramtype intellectual_property: IntellectualProperty
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Input cannot be successfully validated.
+ Details will be provided in the error message.
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START create_inputs_outputs]
+ :end-before: [END create_inputs_outputs]
+ :language: python
+ :dedent: 8
+ :caption: Creating a CommandJob with two inputs.
+ """
+
+ _EMPTY = Parameter.empty
+ _IO_KEYS = [
+ "path",
+ "type",
+ "mode",
+ "path_on_compute",
+ "description",
+ "default",
+ "min",
+ "max",
+ "enum",
+ "optional",
+ "datastore",
+ ]
+
+ @overload
+ def __init__(
+ self,
+ *,
+ type: str,
+ path: Optional[str] = None,
+ mode: Optional[str] = None,
+ optional: Optional[bool] = None,
+ description: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ """"""
+
+ @overload
+ def __init__(
+ self,
+ *,
+ type: Literal["number"] = "number",
+ default: Optional[float] = None,
+ min: Optional[float] = None,
+ max: Optional[float] = None,
+ optional: Optional[bool] = None,
+ description: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ """Initialize a number input.
+
+ :keyword type: The type of the data input. Can only be set to "number".
+ :paramtype type: str
+ :keyword default: The default value of the input. If a default is set, the input data will be optional.
+ :paramtype default: Union[str, int, float, bool]
+ :keyword min: The minimum value for the input. If a value smaller than the minimum is passed to the job, the job
+ execution will fail.
+ :paramtype min: Optional[float]
+ :keyword max: The maximum value for the input. If a value larger than the maximum is passed to a job, the job
+ execution will fail.
+ :paramtype max: Optional[float]
+ :keyword optional: Specifies if the input is optional.
+ :paramtype optional: bool
+ :keyword description: Description of the input
+ :paramtype description: str
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Input cannot be successfully validated.
+ Details will be provided in the error message.
+ """
+
+ @overload
+ def __init__(
+ self,
+ *,
+ type: Literal["integer"] = "integer",
+ default: Optional[int] = None,
+ min: Optional[int] = None,
+ max: Optional[int] = None,
+ optional: Optional[bool] = None,
+ description: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ """Initialize an integer input.
+
+ :keyword type: The type of the data input. Can only be set to "integer".
+ :paramtype type: str
+ :keyword default: The default value of the input. If a default is set, the input data will be optional.
+ :paramtype default: Union[str, int, float, bool]
+ :keyword min: The minimum value for the input. If a value smaller than the minimum is passed to the job, the job
+ execution will fail.
+ :paramtype min: Optional[int]
+ :keyword max: The maximum value for the input. If a value larger than the maximum is passed to a job, the job
+ execution will fail.
+ :paramtype max: Optional[int]
+ :keyword optional: Specifies if the input is optional.
+ :paramtype optional: bool
+ :keyword description: Description of the input
+ :paramtype description: str
+ """
+
+ @overload
+ def __init__(
+ self,
+ *,
+ type: Literal["string"] = "string",
+ default: Optional[str] = None,
+ optional: Optional[bool] = None,
+ description: Optional[str] = None,
+ path: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ """Initialize a string input.
+
+ :keyword type: The type of the data input. Can only be set to "string".
+ :paramtype type: str
+ :keyword default: The default value of this input. When a `default` is set, the input will be optional.
+ :paramtype default: str
+ :keyword optional: Determine if this input is optional.
+ :paramtype optional: bool
+ :keyword description: Description of the input.
+ :paramtype description: str
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Input cannot be successfully validated.
+ Details will be provided in the error message.
+ """
+
+ @overload
+ def __init__(
+ self,
+ *,
+ type: Literal["boolean"] = "boolean",
+ default: Optional[bool] = None,
+ optional: Optional[bool] = None,
+ description: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ """Initialize a bool input.
+
+ :keyword type: The type of the data input. Can only be set to "boolean".
+ :paramtype type: str
+ :keyword path: The path to the input data. Paths can be local paths, remote data uris, or a registered AzureML
+ asset id.
+ :paramtype path: str
+ :keyword default: The default value of the input. If a default is set, the input data will be optional.
+ :paramtype default: Union[str, int, float, bool]
+ :keyword optional: Specifies if the input is optional.
+ :paramtype optional: bool
+ :keyword description: Description of the input
+ :paramtype description: str
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Input cannot be successfully validated.
+ Details will be provided in the error message.
+ """
+
+ def __init__(
+ self,
+ *,
+ type: str = "uri_folder",
+ path: Optional[str] = None,
+ mode: Optional[str] = None,
+ path_on_compute: Optional[str] = None,
+ default: Optional[Union[str, int, float, bool]] = None,
+ optional: Optional[bool] = None,
+ min: Optional[Union[int, float]] = None,
+ max: Optional[Union[int, float]] = None,
+ enum: Any = None,
+ description: Optional[str] = None,
+ datastore: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ super(Input, self).__init__(type=type)
+ # As an annotation, it is not allowed to initialize the _port_name.
+ self._port_name = None
+ self.description = description
+ self.path: Any = None
+
+ if path is not None and not isinstance(path, str):
+ # this logic will make dsl data binding expression working in the same way as yaml
+ # it's written to handle InputOutputBase, but there will be loop import if we import InputOutputBase here
+ self.path = str(path)
+ else:
+ self.path = path
+ self.path_on_compute = path_on_compute
+ self.mode = None if self._is_primitive_type else mode
+ self._update_default(default)
+ self.optional = optional
+ # set the flag to mark if the optional=True is inferred by us.
+ self._is_inferred_optional = False
+ self.min = min
+ self.max = max
+ self.enum = enum
+ self.datastore = datastore
+ intellectual_property = kwargs.pop("intellectual_property", None)
+ if intellectual_property:
+ self._intellectual_property = (
+ intellectual_property
+ if isinstance(intellectual_property, IntellectualProperty)
+ else IntellectualProperty(**intellectual_property)
+ )
+ # normalize properties like ["default", "min", "max", "optional"]
+ self._normalize_self_properties()
+
+ self._validate_parameter_combinations()
+
+ @property
+ def _allowed_types(self) -> Any:
+ if self._multiple_types:
+ return None
+ return IOConstants.PRIMITIVE_STR_2_TYPE.get(self.type)
+
+ @property
+ def _is_primitive_type(self) -> bool:
+ if self._multiple_types:
+ # note: we suppose that no primitive type will be included when there are multiple types
+ return False
+ return self.type in IOConstants.PRIMITIVE_STR_2_TYPE
+
+ @property
+ def _multiple_types(self) -> bool:
+ """Returns True if this input has multiple types.
+
+ Currently, there are two scenarios that need to check this property:
+ 1. before `in` as it may throw exception; there will be `in` operation for validation/transformation.
+ 2. `str()` of list is not ideal, so we need to manually create its string result.
+
+ :return: Whether this input has multiple types
+ :rtype: bool
+ """
+ return isinstance(self.type, list)
+
+ def _is_literal(self) -> bool:
+ """Whether this input is a literal
+
+ Override this function as `self.type` can be list and not hashable for operation `in`.
+
+ :return: Whether is a literal
+ :rtype: bool
+ """
+ return not self._multiple_types and super(Input, self)._is_literal()
+
+ def _is_enum(self) -> bool:
+ """Whether input is an enum
+
+ :return: True if the input is enum.
+ :rtype: bool
+ """
+ res: bool = self.type == ComponentParameterTypes.STRING and self.enum
+ return res
+
+ def _to_dict(self) -> Dict:
+ """Convert the Input object to a dict.
+
+ :return: Dictionary representation of Input
+ :rtype: Dict
+ """
+ keys = self._IO_KEYS
+ result = {key: getattr(self, key) for key in keys}
+ res: dict = _remove_empty_values(result)
+ return res
+
+ def _parse(self, val: Any) -> Union[int, float, bool, str, Any]:
+ """Parse value passed from command line.
+
+ :param val: The input value
+ :type val: T
+ :return: The parsed value.
+ :rtype: Union[int, float, bool, str, T]
+ """
+ if self.type == "integer":
+ return int(float(val)) # backend returns 10.0,for integer, parse it to float before int
+ if self.type == "number":
+ return float(val)
+ if self.type == "boolean":
+ lower_val = str(val).lower()
+ if lower_val not in {"true", "false"}:
+ msg = "Boolean parameter '{}' only accept True/False, got {}."
+ raise ValidationException(
+ message=msg.format(self._port_name, val),
+ no_personal_data_message=msg.format("[self._port_name]", "[val]"),
+ error_category=ErrorCategory.USER_ERROR,
+ target=ErrorTarget.PIPELINE,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ return lower_val == "true"
+ if self.type == "string":
+ return val if isinstance(val, str) else str(val)
+ return val
+
+ def _parse_and_validate(self, val: Any) -> Union[int, float, bool, str, Any]:
+ """Parse the val passed from the command line and validate the value.
+
+ :param val: The input string value from the command line.
+ :type val: T
+ :return: The parsed value, an exception will be raised if the value is invalid.
+ :rtype: Union[int, float, bool, str, T]
+ """
+ if self._is_primitive_type:
+ val = self._parse(val) if isinstance(val, str) else val
+ self._validate_or_throw(val)
+ return val
+
+ def _update_name(self, name: Any) -> None:
+ self._port_name = name
+
+ def _update_default(self, default_value: Any) -> None:
+ """Update provided default values.
+
+ :param default_value: The default value of the Input
+ :type default_value: Any
+ """
+ name = "" if not self._port_name else f"{self._port_name!r} "
+ msg_prefix = f"Default value of Input {name}"
+
+ if not self._is_primitive_type and default_value is not None:
+ msg = f"{msg_prefix}cannot be set: Non-primitive type Input has no default value."
+ raise UserErrorException(msg)
+ if isinstance(default_value, float) and not math.isfinite(default_value):
+ # Since nan/inf cannot be stored in the backend, just ignore them.
+ # logger.warning("Float default value %r is not allowed, ignored." % default_value)
+ return
+ # pylint: disable=pointless-string-statement
+ """Update provided default values.
+ Here we need to make sure the type of default value is allowed or it could be parsed..
+ """
+ if default_value is not None:
+ if type(default_value) not in IOConstants.PRIMITIVE_TYPE_2_STR:
+ msg = (
+ f"{msg_prefix}cannot be set: type must be one of "
+ f"{list(IOConstants.PRIMITIVE_TYPE_2_STR.values())}, got '{type(default_value)}'."
+ )
+ raise UserErrorException(msg)
+
+ if not isinstance(default_value, self._allowed_types):
+ try:
+ default_value = self._parse(default_value)
+ # return original validation exception which is custom defined if raised by self._parse
+ except ValidationException as e:
+ raise e
+ except Exception as e:
+ msg = f"{msg_prefix}cannot be parsed, got '{default_value}', type = {type(default_value)!r}."
+ raise UserErrorException(msg) from e
+ self.default = default_value
+
+ def _validate_or_throw(self, value: Any) -> None:
+ """Validate input parameter value, throw exception if not as expected.
+
+ It will throw exception if validate failed, otherwise do nothing.
+
+ :param value: A value to validate
+ :type value: Any
+ """
+ if not self.optional and value is None:
+ msg = "Parameter {} cannot be None since it is not optional."
+ raise ValidationException(
+ message=msg.format(self._port_name),
+ no_personal_data_message=msg.format("[self._port_name]"),
+ error_category=ErrorCategory.USER_ERROR,
+ target=ErrorTarget.PIPELINE,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ if self._allowed_types and value is not None:
+ if not isinstance(value, self._allowed_types):
+ msg = "Unexpected data type for parameter '{}'. Expected {} but got {}."
+ raise ValidationException(
+ message=msg.format(self._port_name, self._allowed_types, type(value)),
+ no_personal_data_message=msg.format("[_port_name]", self._allowed_types, type(value)),
+ error_category=ErrorCategory.USER_ERROR,
+ target=ErrorTarget.PIPELINE,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ # for numeric values, need extra check for min max value
+ if not self._multiple_types and self.type in ("integer", "number"):
+ if self.min is not None and value < self.min:
+ msg = "Parameter '{}' should not be less than {}."
+ raise ValidationException(
+ message=msg.format(self._port_name, self.min),
+ no_personal_data_message=msg.format("[_port_name]", self.min),
+ error_category=ErrorCategory.USER_ERROR,
+ target=ErrorTarget.PIPELINE,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ if self.max is not None and value > self.max:
+ msg = "Parameter '{}' should not be greater than {}."
+ raise ValidationException(
+ message=msg.format(self._port_name, self.max),
+ no_personal_data_message=msg.format("[_port_name]", self.max),
+ error_category=ErrorCategory.USER_ERROR,
+ target=ErrorTarget.PIPELINE,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+ def _get_python_builtin_type_str(self) -> str:
+ """Get python builtin type for current input in string, eg: str.
+
+ Return yaml type if not available.
+
+ :return: The name of the input type
+ :rtype: str
+ """
+ if self._multiple_types:
+ return "[" + ", ".join(self.type) + "]"
+ if self._is_primitive_type:
+ res_primitive_type: str = IOConstants.PRIMITIVE_STR_2_TYPE[self.type].__name__
+ return res_primitive_type
+ res: str = self.type
+ return res
+
+ def _validate_parameter_combinations(self) -> None:
+ """Validate different parameter combinations according to type."""
+ parameters = ["type", "path", "mode", "default", "min", "max"]
+ parameters_dict: dict = {key: getattr(self, key, None) for key in parameters}
+ type = parameters_dict.pop("type")
+
+ # validate parameter combination
+ if not self._multiple_types and type in IOConstants.INPUT_TYPE_COMBINATION:
+ valid_parameters = IOConstants.INPUT_TYPE_COMBINATION[type]
+ for key, value in parameters_dict.items():
+ if key not in valid_parameters and value is not None:
+ msg = "Invalid parameter for '{}' Input, parameter '{}' should be None but got '{}'"
+ raise ValidationException(
+ message=msg.format(type, key, value),
+ no_personal_data_message=msg.format("[type]", "[parameter]", "[parameter_value]"),
+ error_category=ErrorCategory.USER_ERROR,
+ target=ErrorTarget.PIPELINE,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+ def _simple_parse(self, value: Any, _type: Any = None) -> Any:
+ if self._multiple_types:
+ return value
+ if _type is None:
+ _type = self.type
+ if _type in IOConstants.PARAM_PARSERS:
+ return IOConstants.PARAM_PARSERS[_type](value)
+ return value
+
+ def _normalize_self_properties(self) -> None:
+ # parse value from string to its original type. eg: "false" -> False
+ for key in ["min", "max"]:
+ if getattr(self, key) is not None:
+ origin_value = getattr(self, key)
+ new_value = self._simple_parse(origin_value)
+ setattr(self, key, new_value)
+ if self.optional:
+ self.optional = self._simple_parse(getattr(self, "optional", "false"), _type="boolean")
+
+ @classmethod
+ def _get_input_by_type(cls, t: type, optional: Any = None) -> Optional["Input"]:
+ if t in IOConstants.PRIMITIVE_TYPE_2_STR:
+ return cls(type=IOConstants.PRIMITIVE_TYPE_2_STR[t], optional=optional)
+ return None
+
+ @classmethod
+ def _get_default_unknown_input(cls, optional: Optional[bool] = None) -> "Input":
+ # Set type as None here to avoid schema validation failed
+ res: Input = cls(type=None, optional=optional) # type: ignore
+ return res
+
+ @classmethod
+ def _get_param_with_standard_annotation(cls, func: Any) -> Dict:
+ return _get_param_with_standard_annotation(func, is_func=True)
+
+ def _to_rest_object(self) -> Dict:
+ # this is for component rest object when using Input as component inputs, as for job input usage,
+ # rest object is generated by extracting Input's properties, see details in to_rest_dataset_literal_inputs()
+ result = self._to_dict()
+ # parse string -> String, integer -> Integer, etc.
+ if result["type"] in IOConstants.TYPE_MAPPING_YAML_2_REST:
+ result["type"] = IOConstants.TYPE_MAPPING_YAML_2_REST[result["type"]]
+ return result
+
+ @classmethod
+ def _map_from_rest_type(cls, _type: Union[str, List]) -> Union[str, List]:
+ # this is for component rest object when using Input as component inputs
+ reversed_data_type_mapping = {v: k for k, v in IOConstants.TYPE_MAPPING_YAML_2_REST.items()}
+ # parse String -> string, Integer -> integer, etc
+ if not isinstance(_type, list) and _type in reversed_data_type_mapping:
+ res: str = reversed_data_type_mapping[_type]
+ return res
+ return _type
+
+ @classmethod
+ def _from_rest_object(cls, obj: Dict) -> "Input":
+ obj["type"] = cls._map_from_rest_type(obj["type"])
+
+ return cls(**obj)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/output.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/output.py
new file mode 100644
index 00000000..1c4dcd06
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/output.py
@@ -0,0 +1,180 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=redefined-builtin
+import re
+from typing import Any, Dict, Optional, overload
+
+from typing_extensions import Literal
+
+from azure.ai.ml.constants import AssetTypes
+from azure.ai.ml.constants._component import IOConstants
+from azure.ai.ml.entities._assets.intellectual_property import IntellectualProperty
+from azure.ai.ml.exceptions import UserErrorException
+
+from .base import _InputOutputBase
+from .utils import _remove_empty_values
+
+
+class Output(_InputOutputBase):
+ _IO_KEYS = ["name", "version", "path", "path_on_compute", "type", "mode", "description", "early_available"]
+
+ @overload
+ def __init__(
+ self,
+ *,
+ type: str,
+ path: Optional[str] = None,
+ mode: Optional[str] = None,
+ description: Optional[str] = None,
+ **kwargs: Any,
+ ): ...
+
+ @overload
+ def __init__(
+ self,
+ type: Literal["uri_file"] = "uri_file",
+ path: Optional[str] = None,
+ mode: Optional[str] = None,
+ description: Optional[str] = None,
+ ):
+ """Define a URI file output.
+
+ :keyword type: The type of the data output. Can only be set to 'uri_file'.
+ :paramtype type: str
+ :keyword path: The remote path where the output should be stored.
+ :paramtype path: str
+ :keyword mode: The access mode of the data output. Accepted values are
+ * 'rw_mount': Read-write mount the data,
+ * 'upload': Upload the data from the compute target,
+ * 'direct': Pass in the URI as a string
+ :paramtype mode: str
+ :keyword description: The description of the output.
+ :paramtype description: str
+ :keyword name: The name to be used to register the output as a Data or Model asset. A name can be set without
+ setting a version.
+ :paramtype name: str
+ :keyword version: The version used to register the output as a Data or Model asset. A version can be set only
+ when name is set.
+ :paramtype version: str
+ """
+
+ def __init__( # type: ignore[misc]
+ self,
+ *,
+ type: str = AssetTypes.URI_FOLDER,
+ path: Optional[str] = None,
+ mode: Optional[str] = None,
+ description: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ """Define an output.
+
+ :keyword type: The type of the data output. Accepted values are 'uri_folder', 'uri_file', 'mltable',
+ 'mlflow_model', 'custom_model', and user-defined types. Defaults to 'uri_folder'.
+ :paramtype type: str
+ :keyword path: The remote path where the output should be stored.
+ :paramtype path: Optional[str]
+ :keyword mode: The access mode of the data output. Accepted values are
+ * 'rw_mount': Read-write mount the data
+ * 'upload': Upload the data from the compute target
+ * 'direct': Pass in the URI as a string
+ :paramtype mode: Optional[str]
+ :keyword path_on_compute: The access path of the data output for compute
+ :paramtype path_on_compute: Optional[str]
+ :keyword description: The description of the output.
+ :paramtype description: Optional[str]
+ :keyword name: The name to be used to register the output as a Data or Model asset. A name can be set without
+ setting a version.
+ :paramtype name: str
+ :keyword version: The version used to register the output as a Data or Model asset. A version can be set only
+ when name is set.
+ :paramtype version: str
+ :keyword is_control: Determine if the output is a control output.
+ :paramtype is_control: bool
+ :keyword early_available: Mark the output for early node orchestration.
+ :paramtype early_available: bool
+ :keyword intellectual_property: Intellectual property associated with the output.
+ It can be an instance of `IntellectualProperty` or a dictionary that will be used to create an instance.
+ :paramtype intellectual_property: Union[
+ ~azure.ai.ml.entities._assets.intellectual_property.IntellectualProperty, dict]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START create_inputs_outputs]
+ :end-before: [END create_inputs_outputs]
+ :language: python
+ :dedent: 8
+ :caption: Creating a CommandJob with a folder output.
+ """
+ super(Output, self).__init__(type=type)
+ # As an annotation, it is not allowed to initialize the _port_name.
+ self._port_name = None
+ self.name = kwargs.pop("name", None)
+ self.version = kwargs.pop("version", None)
+ self._is_primitive_type = self.type in IOConstants.PRIMITIVE_STR_2_TYPE
+ self.description = description
+ self.path = path
+ self.path_on_compute = kwargs.pop("path_on_compute", None)
+ self.mode = mode
+ # use this field to mark Output for early node orchestrate, currently hide in kwargs
+ self.early_available = kwargs.pop("early_available", None)
+ self._intellectual_property = None
+ intellectual_property = kwargs.pop("intellectual_property", None)
+ if intellectual_property:
+ self._intellectual_property = (
+ intellectual_property
+ if isinstance(intellectual_property, IntellectualProperty)
+ else IntellectualProperty(**intellectual_property)
+ )
+ self._assert_name_and_version()
+ # normalize properties
+ self._normalize_self_properties()
+
+ def _get_hint(self, new_line_style: bool = False) -> Optional[str]:
+ comment_str = self.description.replace('"', '\\"') if self.description else self.type
+ return '"""%s"""' % comment_str if comment_str and new_line_style else comment_str
+
+ def _to_dict(self) -> Dict:
+ """Convert the Output object to a dict.
+
+ :return: The dictionary representation of Output
+ :rtype: Dict
+ """
+ keys = self._IO_KEYS
+ result = {key: getattr(self, key) for key in keys}
+ res: dict = _remove_empty_values(result)
+ return res
+
+ def _to_rest_object(self) -> Dict:
+ # this is for component rest object when using Output as component outputs, as for job output usage,
+ # rest object is generated by extracting Output's properties, see details in to_rest_data_outputs()
+ return self._to_dict()
+
+ def _simple_parse(self, value: Any, _type: Any = None) -> Any:
+ if _type is None:
+ _type = self.type
+ if _type in IOConstants.PARAM_PARSERS:
+ return IOConstants.PARAM_PARSERS[_type](value)
+ return value
+
+ def _normalize_self_properties(self) -> None:
+ # parse value from string to its original type. eg: "false" -> False
+ if self.early_available:
+ self.early_available = self._simple_parse(getattr(self, "early_available", "false"), _type="boolean")
+
+ @classmethod
+ def _from_rest_object(cls, obj: Dict) -> "Output":
+ # this is for component rest object when using Output as component outputs
+ return Output(**obj)
+
+ def _assert_name_and_version(self) -> None:
+ if self.name and not (re.match("^[A-Za-z0-9_-]*$", self.name) and len(self.name) <= 255):
+ raise UserErrorException(
+ f"The output name {self.name} can only contain alphanumeric characters, dashes and underscores, "
+ f"with a limit of 255 characters."
+ )
+ if self.version and not self.name:
+ raise UserErrorException("Output name is required when output version is specified.")
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/utils.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/utils.py
new file mode 100644
index 00000000..bd752571
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/utils.py
@@ -0,0 +1,479 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+# enable protected access for protected helper functions
+
+import copy
+from collections import OrderedDict
+from enum import Enum as PyEnum
+from enum import EnumMeta
+from inspect import Parameter, getmro, signature
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast
+
+from typing_extensions import Annotated, Literal, TypeAlias
+
+from azure.ai.ml.constants._component import IOConstants
+from azure.ai.ml.exceptions import UserErrorException
+
+# avoid circular import error
+if TYPE_CHECKING:
+ from .input import Input
+ from .output import Output
+
+SUPPORTED_RETURN_TYPES_PRIMITIVE = list(IOConstants.PRIMITIVE_TYPE_2_STR.keys())
+
+Annotation: TypeAlias = Union[str, Type, Annotated[Any, Any], None] # type: ignore
+
+
+def is_group(obj: object) -> bool:
+ """Return True if obj is a group or an instance of a parameter group class.
+
+ :param obj: The object to check.
+ :type obj: Any
+ :return: True if obj is a group or an instance, False otherwise.
+ :rtype: bool
+ """
+ return hasattr(obj, IOConstants.GROUP_ATTR_NAME)
+
+
+def _get_annotation_by_value(val: Any) -> Union["Input", Type["Input"]]:
+ # TODO: we'd better remove this potential recursive import
+ from .enum_input import EnumInput
+ from .input import Input
+
+ annotation: Any = None
+
+ def _is_dataset(data: Any) -> bool:
+ from azure.ai.ml.entities._job.job_io_mixin import JobIOMixin
+
+ DATASET_TYPES = JobIOMixin
+ return isinstance(data, DATASET_TYPES)
+
+ if _is_dataset(val):
+ annotation = Input
+ elif val is Parameter.empty or val is None:
+ # If no default value or default is None, create val as the basic parameter type,
+ # it could be replaced using component parameter definition.
+ annotation = Input._get_default_unknown_input()
+ elif isinstance(val, PyEnum):
+ # Handle enum values
+ annotation = EnumInput(enum=val.__class__)
+ else:
+ _new_annotation = _get_annotation_cls_by_type(type(val), raise_error=False)
+ if not _new_annotation:
+ # Fall back to default
+ annotation = Input._get_default_unknown_input()
+ else:
+ return _new_annotation
+ return cast(Union["Input", Type["Input"]], annotation)
+
+
+def _get_annotation_cls_by_type(
+ t: type, raise_error: bool = False, optional: Optional[bool] = None
+) -> Optional["Input"]:
+ # TODO: we'd better remove this potential recursive import
+ from .input import Input
+
+ cls = Input._get_input_by_type(t, optional=optional)
+ if cls is None and raise_error:
+ raise UserErrorException(f"Can't convert type {t} to azure.ai.ml.Input")
+ return cls
+
+
+# pylint: disable=too-many-statements
+def _get_param_with_standard_annotation(
+ cls_or_func: Union[Callable, Type], is_func: bool = False, skip_params: Optional[List[str]] = None
+) -> Dict[str, Union[Annotation, "Input", "Output"]]:
+ """Standardize function parameters or class fields with dsl.types annotation.
+
+ :param cls_or_func: Either a class or a function
+ :type cls_or_func: Union[Callable, Type]
+ :param is_func: Whether `cls_or_func` is a function. Defaults to False.
+ :type is_func: bool
+ :param skip_params:
+ :type skip_params: Optional[List[str]]
+ :return: A dictionary of field annotations
+ :rtype: Dict[str, Union[Annotation, "Input", "Output"]]
+ """
+ # TODO: we'd better remove this potential recursive import
+ from .group_input import GroupInput
+ from .input import Input
+ from .output import Output
+
+ def _is_dsl_type_cls(t: Any) -> bool:
+ if type(t) is not type: # pylint: disable=unidiomatic-typecheck
+ return False
+ return issubclass(t, (Input, Output))
+
+ def _is_dsl_types(o: object) -> bool:
+ return _is_dsl_type_cls(type(o))
+
+ def _get_fields(annotations: Dict) -> Dict:
+ """Return field names to annotations mapping in class.
+
+ :param annotations: The annotations
+ :type annotations: Dict[str, Union[Annotation, Input, Output]]
+ :return: The field dict
+ :rtype: Dict[str, Union[Annotation, Input, Output]]
+ """
+ annotation_fields = OrderedDict()
+ for name, annotation in annotations.items():
+ # Skip return type
+ if name == "return":
+ continue
+ # Handle EnumMeta annotation
+ if isinstance(annotation, EnumMeta):
+ from .enum_input import EnumInput
+
+ annotation = EnumInput(type="string", enum=annotation)
+ # Handle Group annotation
+ if is_group(annotation):
+ _deep_copy: GroupInput = copy.deepcopy(getattr(annotation, IOConstants.GROUP_ATTR_NAME))
+ annotation = _deep_copy
+ # Try creating annotation by type when got like 'param: int'
+ if not _is_dsl_type_cls(annotation) and not _is_dsl_types(annotation):
+ origin_annotation = annotation
+ annotation = cast(Input, _get_annotation_cls_by_type(annotation, raise_error=False))
+ if not annotation:
+ msg = f"Unsupported annotation type {origin_annotation!r} for parameter {name!r}."
+ raise UserErrorException(msg)
+ annotation_fields[name] = annotation
+ return annotation_fields
+
+ def _merge_field_keys(
+ annotation_fields: Dict[str, Union[Annotation, Input, Output]], defaults_dict: Dict[str, Any]
+ ) -> List[str]:
+ """Merge field keys from annotations and cls dict to get all fields in class.
+
+ :param annotation_fields: The field annotations
+ :type annotation_fields: Dict[str, Union[Annotation, Input, Output]]
+ :param defaults_dict: The map of variable name to default value
+ :type defaults_dict: Dict[str, Any]
+ :return: A list of field keys
+ :rtype: List[str]
+ """
+ anno_keys = list(annotation_fields.keys())
+ dict_keys = defaults_dict.keys()
+ if not dict_keys:
+ return anno_keys
+ return [*anno_keys, *[key for key in dict_keys if key not in anno_keys]]
+
+ def _update_annotation_with_default(
+ anno: Union[Annotation, Input, Output], name: str, default: Any
+ ) -> Union[Annotation, Input, Output]:
+ """Create annotation if is type class and update the default.
+
+ :param anno: The annotation
+ :type anno: Union[Annotation, Input, Output]
+ :param name: The port name
+ :type name: str
+ :param default: The default value
+ :type default: Any
+ :return: The updated annotation
+ :rtype: Union[Annotation, Input, Output]
+ """
+ # Create instance if is type class
+ complete_annotation = anno
+ if _is_dsl_type_cls(anno):
+ if anno is not None and not isinstance(anno, (str, Input, Output)):
+ complete_annotation = anno()
+ if complete_annotation is not None and not isinstance(complete_annotation, str):
+ complete_annotation._port_name = name
+ if default is Input._EMPTY:
+ return complete_annotation
+ if isinstance(complete_annotation, Input):
+ # Non-parameter Input has no default attribute
+ if complete_annotation._is_primitive_type and complete_annotation.default is not None:
+ # logger.warning(
+ # f"Warning: Default value of f{complete_annotation.name!r} is set twice: "
+ # f"{complete_annotation.default!r} and {default!r}, will use {default!r}"
+ # )
+ pass
+ complete_annotation._update_default(default)
+ if isinstance(complete_annotation, Output) and default is not None:
+ msg = (
+ f"Default value of Output {complete_annotation._port_name!r} cannot be set:"
+ f"Output has no default value."
+ )
+ raise UserErrorException(msg)
+ return complete_annotation
+
+ def _update_fields_with_default(
+ annotation_fields: Dict[str, Union[Annotation, Input, Output]], defaults_dict: Dict[str, Any]
+ ) -> Dict[str, Union[Annotation, Input, Output]]:
+ """Use public values in class dict to update annotations.
+
+ :param annotation_fields: The field annotations
+ :type annotation_fields: Dict[str, Union[Annotation, Input, Output]]
+ :param defaults_dict: A map of variable name to default value
+ :type defaults_dict: Dict[str, Any]
+ :return: List of field names
+ :rtype: List[str]
+ """
+ all_fields = OrderedDict()
+ all_filed_keys = _merge_field_keys(annotation_fields, defaults_dict)
+ for name in all_filed_keys:
+ # Get or create annotation
+ annotation = (
+ annotation_fields[name]
+ if name in annotation_fields
+ else _get_annotation_by_value(defaults_dict.get(name, Input._EMPTY))
+ )
+ # Create annotation if is class type and update default
+ annotation = _update_annotation_with_default(annotation, name, defaults_dict.get(name, Input._EMPTY))
+ all_fields[name] = annotation
+ return all_fields
+
+ def _merge_and_reorder(
+ inherited_fields: Dict[str, Union[Annotation, Input, Output]],
+ cls_fields: Dict[str, Union[Annotation, Input, Output]],
+ ) -> Dict[str, Union[Annotation, Input, Output]]:
+ """Merge inherited fields with cls fields.
+
+ The order inside each part will not be changed. Order will be:
+
+ {inherited_no_default_fields} + {cls_no_default_fields} + {inherited_default_fields} + {cls_default_fields}.
+
+
+ :param inherited_fields: The inherited fields
+ :type inherited_fields: Dict[str, Union[Annotation, Input, Output]]
+ :param cls_fields: The class fields
+ :type cls_fields: Dict[str, Union[Annotation, Input, Output]]
+ :return: The merged fields
+ :rtype: Dict[str, Union[Annotation, Input, Output]]
+
+ .. admonition:: Additional Note
+
+ :class: note
+
+ If cls overwrite an inherited no default field with default, it will be put in the
+ cls_default_fields part and deleted from inherited_no_default_fields:
+
+ .. code-block:: python
+
+ @dsl.group
+ class SubGroup:
+ int_param0: Integer
+ int_param1: int
+
+ @dsl.group
+ class Group(SubGroup):
+ int_param3: Integer
+ int_param1: int = 1
+
+ The init function of Group will be `def __init__(self, *, int_param0, int_param3, int_param1=1)`.
+ """
+
+ def _split(
+ _fields: Dict[str, Union[Annotation, Input, Output]]
+ ) -> Tuple[Dict[str, Union[Annotation, Input, Output]], Dict[str, Union[Annotation, Input, Output]]]:
+ """Split fields to two parts from the first default field.
+
+ :param _fields: The fields
+ :type _fields: Dict[str, Union[Annotation, Input, Output]]
+ :return: A 2-tuple of (fields with no defaults, fields with defaults)
+ :rtype: Tuple[Dict[str, Union[Annotation, Input, Output]], Dict[str, Union[Annotation, Input, Output]]]
+ """
+ _no_defaults_fields, _defaults_fields = {}, {}
+ seen_default = False
+ for key, val in _fields.items():
+ if val is not None and not isinstance(val, str):
+ if val.get("default", None) or seen_default:
+ seen_default = True
+ _defaults_fields[key] = val
+ else:
+ _no_defaults_fields[key] = val
+ return _no_defaults_fields, _defaults_fields
+
+ inherited_no_default, inherited_default = _split(inherited_fields)
+ cls_no_default, cls_default = _split(cls_fields)
+ # Cross comparison and delete from inherited_fields if same key appeared in cls_fields
+ # pylint: disable=consider-iterating-dictionary
+ for key in cls_default.keys():
+ if key in inherited_no_default.keys():
+ del inherited_no_default[key]
+ for key in cls_no_default.keys():
+ if key in inherited_default.keys():
+ del inherited_default[key]
+ return OrderedDict(
+ {
+ **inherited_no_default,
+ **cls_no_default,
+ **inherited_default,
+ **cls_default,
+ }
+ )
+
+ def _get_inherited_fields() -> Dict[str, Union[Annotation, Input, Output]]:
+ """Get all fields inherited from @group decorated base classes.
+
+ :return: The field dict
+ :rtype: Dict[str, Union[Annotation, Input, Output]]
+ """
+ # Return value of _get_param_with_standard_annotation
+ _fields: Dict[str, Union[Annotation, Input, Output]] = OrderedDict({})
+ if is_func:
+ return _fields
+ # In reversed order so that more derived classes
+ # override earlier field definitions in base classes.
+ if isinstance(cls_or_func, type):
+ for base in cls_or_func.__mro__[-1:0:-1]:
+ if is_group(base):
+ # merge and reorder fields from current base with previous
+ _fields = _merge_and_reorder(
+ _fields, copy.deepcopy(getattr(base, IOConstants.GROUP_ATTR_NAME).values)
+ )
+ return _fields
+
+ skip_params = skip_params or []
+ inherited_fields = _get_inherited_fields()
+ # From annotations get field with type
+ annotations: Dict[str, Annotation] = getattr(cls_or_func, "__annotations__", {})
+ annotations = {k: v for k, v in annotations.items() if k not in skip_params}
+ annotations = _update_io_from_mldesigner(annotations)
+ annotation_fields = _get_fields(annotations)
+ defaults_dict: Dict[str, Any] = {}
+ # Update fields use class field with defaults from class dict or signature(func).paramters
+ if not is_func:
+ # Only consider public fields in class dict
+ defaults_dict = {
+ key: val for key, val in cls_or_func.__dict__.items() if not key.startswith("_") and key not in skip_params
+ }
+ else:
+ # Infer parameter type from value if is function
+ defaults_dict = {
+ key: val.default
+ for key, val in signature(cls_or_func).parameters.items()
+ if key not in skip_params and val.kind != val.VAR_KEYWORD
+ }
+ fields = _update_fields_with_default(annotation_fields, defaults_dict)
+ all_fields = _merge_and_reorder(inherited_fields, fields)
+ return all_fields
+
+
+def _update_io_from_mldesigner(annotations: Dict[str, Annotation]) -> Dict[str, Union[Annotation, "Input", "Output"]]:
+ """Translates IOBase from mldesigner package to azure.ml.entities.Input/Output.
+
+ This function depends on:
+
+ * `mldesigner._input_output._IOBase._to_io_entity_args_dict` to translate Input/Output instance annotations
+ to IO entities.
+ * class names of `mldesigner._input_output` to translate Input/Output class annotations
+ to IO entities.
+
+ :param annotations: A map of variable names to annotations
+ :type annotations: Dict[str, Annotation]
+ :return: Dict with mldesigner IO types converted to azure-ai-ml Input/Output
+ :rtype: Dict[str, Union[Annotation, Input, Output]]
+ """
+ from typing_extensions import get_args, get_origin
+
+ from azure.ai.ml import Input, Output
+
+ from .enum_input import EnumInput
+
+ mldesigner_pkg = "mldesigner"
+ param_name = "_Param"
+ return_annotation_key = "return"
+
+ def _is_primitive_type(io: type) -> bool:
+ """Checks whether type is a primitive type
+
+ :param io: A type
+ :type io: type
+ :return: Return true if type is subclass of mldesigner._input_output._Param
+ :rtype: bool
+ """
+ return any(io.__module__.startswith(mldesigner_pkg) and item.__name__ == param_name for item in getmro(io))
+
+ def _is_input_or_output_type(io: type, type_str: Literal["Input", "Output", "Meta"]) -> bool:
+ """Checks whether a type is an Input or Output type
+
+ :param io: A type
+ :type io: type
+ :param type_str: The kind of type to check for
+ :type type_str: Literal["Input", "Output", "Meta"]
+ :return: Return true if type name contains type_str
+ :rtype: bool
+ """
+ if isinstance(io, type) and io.__module__.startswith(mldesigner_pkg):
+ if type_str in io.__name__:
+ return True
+ return False
+
+ result = {}
+ for key, io in annotations.items(): # pylint: disable=too-many-nested-blocks
+ if isinstance(io, type):
+ if _is_input_or_output_type(io, "Input"):
+ # mldesigner.Input -> entities.Input
+ io = Input
+ elif _is_input_or_output_type(io, "Output"):
+ # mldesigner.Output -> entities.Output
+ io = Output
+ elif _is_primitive_type(io):
+ io = (
+ Output(type=io.TYPE_NAME) # type: ignore
+ if key == return_annotation_key
+ else Input(type=io.TYPE_NAME) # type: ignore
+ )
+ elif hasattr(io, "_to_io_entity_args_dict"):
+ try:
+ if _is_input_or_output_type(type(io), "Input"):
+ # mldesigner.Input() -> entities.Input()
+ if io is not None:
+ io = Input(**io._to_io_entity_args_dict())
+ elif _is_input_or_output_type(type(io), "Output"):
+ # mldesigner.Output() -> entities.Output()
+ if io is not None:
+ io = Output(**io._to_io_entity_args_dict())
+ elif _is_primitive_type(type(io)):
+ if io is not None and not isinstance(io, str):
+ if io._is_enum():
+ io = EnumInput(**io._to_io_entity_args_dict())
+ else:
+ io = (
+ Output(**io._to_io_entity_args_dict())
+ if key == return_annotation_key
+ else Input(**io._to_io_entity_args_dict())
+ )
+ except BaseException as e:
+ raise UserErrorException(f"Failed to parse {io} to azure-ai-ml Input/Output: {str(e)}") from e
+ # Handle Annotated annotation
+ elif get_origin(io) is Annotated:
+ hint_type, arg, *hint_args = get_args(io) # pylint: disable=unused-variable
+ if hint_type in SUPPORTED_RETURN_TYPES_PRIMITIVE:
+ if not _is_input_or_output_type(type(arg), "Meta"):
+ raise UserErrorException(
+ f"Annotated Metadata class only support "
+ f"mldesigner._input_output.Meta, "
+ f"it is {type(arg)} now."
+ )
+ if arg.type is not None and arg.type != hint_type:
+ raise UserErrorException(
+ f"Meta class type {arg.type} should be same as Annotated type: " f"{hint_type}"
+ )
+ arg.type = hint_type
+ io = (
+ Output(**arg._to_io_entity_args_dict())
+ if key == return_annotation_key
+ else Input(**arg._to_io_entity_args_dict())
+ )
+ result[key] = io
+ return result
+
+
+def _remove_empty_values(data: Any) -> Any:
+ """Recursively removes None values from a dict
+
+ :param data: The value to remove None from
+ :type data: T
+ :return:
+ * `data` if `data` is not a dict
+ * `data` with None values recursively filtered out if data is a dict
+ :rtype: T
+ """
+ if not isinstance(data, dict):
+ return data
+ return {k: _remove_empty_values(v) for k, v in data.items() if v is not None}
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/__init__.py
new file mode 100644
index 00000000..fdf8caba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/_input_output_helpers.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/_input_output_helpers.py
new file mode 100644
index 00000000..1a13ab41
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/_input_output_helpers.py
@@ -0,0 +1,427 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import collections.abc
+import re
+from typing import Any, Dict, Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ CustomModelJobInput as RestCustomModelJobInput,
+)
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ CustomModelJobOutput as RestCustomModelJobOutput,
+)
+from azure.ai.ml._restclient.v2023_04_01_preview.models import InputDeliveryMode
+from azure.ai.ml._restclient.v2023_04_01_preview.models import JobInput as RestJobInput
+from azure.ai.ml._restclient.v2023_04_01_preview.models import JobInputType
+from azure.ai.ml._restclient.v2023_04_01_preview.models import JobOutput as RestJobOutput
+from azure.ai.ml._restclient.v2023_04_01_preview.models import JobOutputType, LiteralJobInput
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ MLFlowModelJobInput as RestMLFlowModelJobInput,
+)
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ MLFlowModelJobOutput as RestMLFlowModelJobOutput,
+)
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ MLTableJobInput as RestMLTableJobInput,
+)
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ MLTableJobOutput as RestMLTableJobOutput,
+)
+from azure.ai.ml._restclient.v2023_04_01_preview.models import OutputDeliveryMode
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ TritonModelJobInput as RestTritonModelJobInput,
+)
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ TritonModelJobOutput as RestTritonModelJobOutput,
+)
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ UriFileJobInput as RestUriFileJobInput,
+)
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ UriFileJobOutput as RestUriFileJobOutput,
+)
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ UriFolderJobInput as RestUriFolderJobInput,
+)
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ UriFolderJobOutput as RestUriFolderJobOutput,
+)
+from azure.ai.ml._utils.utils import is_data_binding_expression
+from azure.ai.ml.constants import AssetTypes, InputOutputModes, JobType
+from azure.ai.ml.constants._component import IOConstants
+from azure.ai.ml.entities._inputs_outputs import Input, Output
+from azure.ai.ml.entities._job.input_output_entry import InputOutputEntry
+from azure.ai.ml.entities._util import normalize_job_input_output_type
+from azure.ai.ml.exceptions import (
+ ErrorCategory,
+ ErrorTarget,
+ JobException,
+ ValidationErrorType,
+ ValidationException,
+)
+
+INPUT_MOUNT_MAPPING_FROM_REST = {
+ InputDeliveryMode.READ_WRITE_MOUNT: InputOutputModes.RW_MOUNT,
+ InputDeliveryMode.READ_ONLY_MOUNT: InputOutputModes.RO_MOUNT,
+ InputDeliveryMode.DOWNLOAD: InputOutputModes.DOWNLOAD,
+ InputDeliveryMode.DIRECT: InputOutputModes.DIRECT,
+ InputDeliveryMode.EVAL_MOUNT: InputOutputModes.EVAL_MOUNT,
+ InputDeliveryMode.EVAL_DOWNLOAD: InputOutputModes.EVAL_DOWNLOAD,
+}
+
+INPUT_MOUNT_MAPPING_TO_REST = {
+ InputOutputModes.MOUNT: InputDeliveryMode.READ_ONLY_MOUNT,
+ InputOutputModes.RW_MOUNT: InputDeliveryMode.READ_WRITE_MOUNT,
+ InputOutputModes.RO_MOUNT: InputDeliveryMode.READ_ONLY_MOUNT,
+ InputOutputModes.DOWNLOAD: InputDeliveryMode.DOWNLOAD,
+ InputOutputModes.EVAL_MOUNT: InputDeliveryMode.EVAL_MOUNT,
+ InputOutputModes.EVAL_DOWNLOAD: InputDeliveryMode.EVAL_DOWNLOAD,
+ InputOutputModes.DIRECT: InputDeliveryMode.DIRECT,
+}
+
+
+OUTPUT_MOUNT_MAPPING_FROM_REST = {
+ OutputDeliveryMode.READ_WRITE_MOUNT: InputOutputModes.RW_MOUNT,
+ OutputDeliveryMode.UPLOAD: InputOutputModes.UPLOAD,
+ OutputDeliveryMode.DIRECT: InputOutputModes.DIRECT,
+}
+
+OUTPUT_MOUNT_MAPPING_TO_REST = {
+ InputOutputModes.MOUNT: OutputDeliveryMode.READ_WRITE_MOUNT,
+ InputOutputModes.UPLOAD: OutputDeliveryMode.UPLOAD,
+ InputOutputModes.RW_MOUNT: OutputDeliveryMode.READ_WRITE_MOUNT,
+ InputOutputModes.DIRECT: OutputDeliveryMode.DIRECT,
+}
+
+
+# TODO: Remove this as both rest type and sdk type are snake case now.
+def get_output_type_mapping_from_rest() -> Dict[str, str]:
+ """Gets the mapping of JobOutputType to AssetType
+
+ :return: Mapping of JobOutputType to AssetType
+ :rtype: Dict[str, str]
+ """
+ return {
+ JobOutputType.URI_FILE: AssetTypes.URI_FILE,
+ JobOutputType.URI_FOLDER: AssetTypes.URI_FOLDER,
+ JobOutputType.MLTABLE: AssetTypes.MLTABLE,
+ JobOutputType.MLFLOW_MODEL: AssetTypes.MLFLOW_MODEL,
+ JobOutputType.CUSTOM_MODEL: AssetTypes.CUSTOM_MODEL,
+ JobOutputType.TRITON_MODEL: AssetTypes.TRITON_MODEL,
+ }
+
+
+def get_input_rest_cls_dict() -> Dict[str, RestJobInput]:
+ """Gets the mapping of AssetType to RestJobInput
+
+ :return: Map of AssetType to RestJobInput
+ :rtype: Dict[str, RestJobInput]
+ """
+ return {
+ AssetTypes.URI_FILE: RestUriFileJobInput,
+ AssetTypes.URI_FOLDER: RestUriFolderJobInput,
+ AssetTypes.MLTABLE: RestMLTableJobInput,
+ AssetTypes.MLFLOW_MODEL: RestMLFlowModelJobInput,
+ AssetTypes.CUSTOM_MODEL: RestCustomModelJobInput,
+ AssetTypes.TRITON_MODEL: RestTritonModelJobInput,
+ }
+
+
+def get_output_rest_cls_dict() -> Dict[str, RestJobOutput]:
+ """Get output rest init cls dict.
+
+ :return: Map of AssetType to RestJobOutput
+ :rtype: Dict[str, RestJobOutput]
+ """
+ return {
+ AssetTypes.URI_FILE: RestUriFileJobOutput,
+ AssetTypes.URI_FOLDER: RestUriFolderJobOutput,
+ AssetTypes.MLTABLE: RestMLTableJobOutput,
+ AssetTypes.MLFLOW_MODEL: RestMLFlowModelJobOutput,
+ AssetTypes.CUSTOM_MODEL: RestCustomModelJobOutput,
+ AssetTypes.TRITON_MODEL: RestTritonModelJobOutput,
+ }
+
+
+def build_input_output(
+ item: Union[InputOutputEntry, Input, Output, str, bool, int, float],
+ inputs: bool = True,
+) -> Union[InputOutputEntry, Input, Output, str, bool, int, float]:
+ if isinstance(item, (Input, InputOutputEntry, Output)):
+ # return objects constructed at yaml load or specified in sdk
+ return item
+ # parse dictionary into supported class
+ if isinstance(item, collections.abc.Mapping):
+ if item.get("data"):
+ return InputOutputEntry(**item)
+ # else default to JobInput
+ return Input(**item) if inputs else Output(**item)
+ # return literal inputs as-is
+ return item
+
+
+def _validate_inputs_for(input_consumer_name: str, input_consumer: str, inputs: Optional[Dict]) -> None:
+ implicit_inputs = re.findall(r"\${{inputs\.([\w\.-]+)}}", input_consumer)
+ # optional inputs no need to validate whether they're in inputs
+ optional_inputs = re.findall(r"\[[\w\.\s-]*\${{inputs\.([\w\.-]+)}}]", input_consumer)
+ for key in implicit_inputs:
+ if inputs is not None and inputs.get(key, None) is None and key not in optional_inputs:
+ msg = "Inputs to job does not contain '{}' referenced in " + input_consumer_name
+ raise ValidationException(
+ message=msg.format(key),
+ no_personal_data_message=msg.format("[key]"),
+ target=ErrorTarget.JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+
+def validate_inputs_for_command(command: Optional[str], inputs: Optional[Dict]) -> None:
+ if command is not None:
+ _validate_inputs_for("command", command, inputs)
+
+
+def validate_inputs_for_args(args: str, inputs: Optional[Dict[str, Any]]) -> None:
+ _validate_inputs_for("args", args, inputs)
+
+
+def validate_key_contains_allowed_characters(key: str) -> None:
+ if re.match(r"^[a-zA-Z_]+[a-zA-Z0-9_]*$", key) is None:
+ msg = "Key name {} must be composed letters, numbers, and underscore"
+ raise ValidationException(
+ message=msg.format(key),
+ no_personal_data_message=msg.format("[key]"),
+ target=ErrorTarget.JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+
+def validate_pipeline_input_key_characters(key: str) -> None:
+ # Pipeline input allow '.' to support parameter group in key.
+ # Note: ([a-zA-Z_]+[a-zA-Z0-9_]*) is a valid single key,
+ # so a valid pipeline key is: ^{single_key}([.]{single_key})*$
+ if re.match(IOConstants.VALID_KEY_PATTERN, key) is None:
+ msg = (
+ "Pipeline input key name {} must be composed letters, numbers, and underscores with optional split by dots."
+ )
+ raise ValidationException(
+ message=msg.format(key),
+ no_personal_data_message=msg.format("[key]"),
+ target=ErrorTarget.JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+
+def to_rest_dataset_literal_inputs(
+ inputs: Optional[Dict],
+ *,
+ job_type: Optional[str],
+) -> Dict[str, RestJobInput]:
+ """Turns dataset and literal inputs into dictionary of REST JobInput.
+
+ :param inputs: Dictionary of dataset and literal inputs to job
+ :type inputs: Dict[str, Union[int, str, float, bool, JobInput]]
+ :return: A dictionary mapping input name to a ComponentJobInput or PipelineInput
+ :rtype: Dict[str, Union[ComponentJobInput, PipelineInput]]
+ :keyword job_type: When job_type is pipeline, enable dot('.') in parameter keys to support parameter group.
+ TODO: Remove this after move name validation to Job's customized validate.
+ :paramtype job_type: str
+ """
+ rest_inputs = {}
+
+ if inputs is not None:
+ # Pack up the inputs into REST format
+ for input_name, input_value in inputs.items():
+ if job_type == JobType.PIPELINE:
+ validate_pipeline_input_key_characters(input_name)
+ elif job_type:
+ # We pass job_type=None for pipeline node, and want skip this check for nodes.
+ validate_key_contains_allowed_characters(input_name)
+ if isinstance(input_value, Input):
+ if (
+ input_value.path
+ and isinstance(input_value.path, str)
+ and is_data_binding_expression(input_value.path)
+ ):
+ input_data = LiteralJobInput(value=input_value.path)
+ # set mode attribute manually for binding job input
+ if input_value.mode:
+ input_data.mode = INPUT_MOUNT_MAPPING_TO_REST[input_value.mode]
+ if getattr(input_value, "path_on_compute", None) is not None:
+ input_data.pathOnCompute = input_value.path_on_compute
+ input_data.job_input_type = JobInputType.LITERAL
+ else:
+ target_cls_dict = get_input_rest_cls_dict()
+
+ if input_value.type in target_cls_dict:
+ input_data = target_cls_dict[input_value.type](
+ uri=input_value.path,
+ mode=(INPUT_MOUNT_MAPPING_TO_REST[input_value.mode.lower()] if input_value.mode else None),
+ )
+ else:
+ msg = f"Job input type {input_value.type} is not supported as job input."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ elif input_value is None:
+ # If the input is None, we need to pass the origin None to the REST API
+ input_data = LiteralJobInput(value=None)
+ else:
+ # otherwise, the input is a literal input
+ if isinstance(input_value, dict):
+ input_data = LiteralJobInput(value=str(input_value["value"]))
+ # set mode attribute manually for binding job input
+ if "mode" in input_value:
+ input_data.mode = input_value["mode"]
+ else:
+ input_data = LiteralJobInput(value=str(input_value))
+ input_data.job_input_type = JobInputType.LITERAL
+ # Pack up inputs into PipelineInputs or ComponentJobInputs depending on caller
+ rest_inputs[input_name] = input_data
+ return rest_inputs
+
+
+def from_rest_inputs_to_dataset_literal(inputs: Dict[str, RestJobInput]) -> Dict:
+ """Turns REST dataset and literal inputs into the SDK format.
+
+ :param inputs: Dictionary mapping input name to ComponentJobInput or PipelineInput
+ :type inputs: Dict[str, Union[ComponentJobInput, PipelineInput]]
+ :return: A dictionary mapping input name to a literal value or JobInput
+ :rtype: Dict[str, Union[int, str, float, bool, JobInput]]
+ """
+ if inputs is None:
+ return {}
+ from_rest_inputs = {}
+ # Unpack the inputs
+ for input_name, input_value in inputs.items():
+ # TODO:Brandon Clarify with PMs if user should be able to define null input objects
+ if input_value is None:
+ continue
+
+ # TODO: Remove this as both rest type and sdk type are snake case now.
+ type_transfer_dict = get_output_type_mapping_from_rest()
+ # deal with invalid input type submitted by feb api
+ # todo: backend help convert node level input/output type
+ normalize_job_input_output_type(input_value)
+
+ if input_value.job_input_type in type_transfer_dict:
+ if input_value.uri:
+ path = input_value.uri
+ if getattr(input_value, "pathOnCompute", None) is not None:
+ sourcePathOnCompute = input_value.pathOnCompute
+ else:
+ sourcePathOnCompute = None
+ input_data = Input(
+ type=type_transfer_dict[input_value.job_input_type],
+ path=path,
+ mode=(INPUT_MOUNT_MAPPING_FROM_REST[input_value.mode] if input_value.mode else None),
+ path_on_compute=sourcePathOnCompute,
+ )
+ elif input_value.job_input_type in (JobInputType.LITERAL, JobInputType.LITERAL):
+ # otherwise, the input is a literal, so just unpack the InputData value field
+ input_data = input_value.value
+ else:
+ msg = f"Job input type {input_value.job_input_type} is not supported as job input."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+ from_rest_inputs[input_name] = input_data # pylint: disable=possibly-used-before-assignment
+ return from_rest_inputs
+
+
+def to_rest_data_outputs(outputs: Optional[Dict]) -> Dict[str, RestJobOutput]:
+ """Turns job outputs into REST format.
+
+ :param outputs: Dictionary of dataset outputs from job
+ :type outputs: Dict[str, JobOutput]
+ :return: A dictionary mapping output name to a RestJobOutput
+ :rtype: Dict[str, RestJobOutput]
+ """
+ rest_outputs = {}
+ if outputs is not None:
+ for output_name, output_value in outputs.items():
+ validate_key_contains_allowed_characters(output_name)
+ if output_value is None:
+ # pipeline output could be none, default to URI folder with None mode
+ output_cls = RestUriFolderJobOutput
+ rest_outputs[output_name] = output_cls(mode=None)
+ else:
+ target_cls_dict = get_output_rest_cls_dict()
+
+ output_value_type = output_value.type if output_value.type else AssetTypes.URI_FOLDER
+ if output_value_type in target_cls_dict:
+ output = target_cls_dict[output_value_type](
+ asset_name=output_value.name,
+ asset_version=output_value.version,
+ uri=output_value.path,
+ mode=(OUTPUT_MOUNT_MAPPING_TO_REST[output_value.mode.lower()] if output_value.mode else None),
+ pathOnCompute=getattr(output_value, "path_on_compute", None),
+ description=output_value.description,
+ )
+ else:
+ msg = "unsupported JobOutput type: {}".format(output_value.type)
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ rest_outputs[output_name] = output
+ return rest_outputs
+
+
+def from_rest_data_outputs(outputs: Dict[str, RestJobOutput]) -> Dict[str, Output]:
+ """Turns REST outputs into the SDK format.
+
+ :param outputs: Dictionary of dataset and literal inputs to job
+ :type outputs: Dict[str, RestJobOutput]
+ :return: A dictionary mapping input name to a InputOutputEntry
+ :rtype: Dict[str, JobOutput]
+ """
+ output_type_mapping = get_output_type_mapping_from_rest()
+ from_rest_outputs = {}
+ if outputs is None:
+ return {}
+ for output_name, output_value in outputs.items():
+ if output_value is None:
+ continue
+ # deal with invalid output type submitted by feb api
+ # todo: backend help convert node level input/output type
+ normalize_job_input_output_type(output_value)
+ if getattr(output_value, "pathOnCompute", None) is not None:
+ sourcePathOnCompute = output_value.pathOnCompute
+ else:
+ sourcePathOnCompute = None
+ if output_value.job_output_type in output_type_mapping:
+ from_rest_outputs[output_name] = Output(
+ type=output_type_mapping[output_value.job_output_type],
+ path=output_value.uri,
+ mode=(OUTPUT_MOUNT_MAPPING_FROM_REST[output_value.mode] if output_value.mode else None),
+ path_on_compute=sourcePathOnCompute,
+ description=output_value.description,
+ name=output_value.asset_name,
+ version=(output_value.asset_version if hasattr(output_value, "asset_version") else None),
+ )
+ else:
+ msg = "unsupported JobOutput type: {}".format(output_value.job_output_type)
+ raise JobException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.JOB,
+ error_category=ErrorCategory.SYSTEM_ERROR,
+ )
+
+ return from_rest_outputs
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/_studio_url_from_job_id.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/_studio_url_from_job_id.py
new file mode 100644
index 00000000..63ad6f06
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/_studio_url_from_job_id.py
@@ -0,0 +1,26 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import re
+from typing import Optional
+
+from azure.ai.ml._azure_environments import _get_aml_resource_id_from_metadata, _get_default_cloud_name
+
+JOB_ID_RE_PATTERN = re.compile(
+ (
+ r"\/subscriptions\/(?P<subscription>[\w,-]+)\/resourceGroups\/(?P<resource_group>[\w,-]+)\/providers"
+ r"\/Microsoft\.MachineLearningServices\/workspaces\/(?P<workspace>[\w,-]+)\/jobs\/(?P<run_id>[\w,-]+)"
+ ) # fmt: skip
+)
+
+
+def studio_url_from_job_id(job_id: str) -> Optional[str]:
+ resource_id = _get_aml_resource_id_from_metadata(_get_default_cloud_name())
+ m = JOB_ID_RE_PATTERN.match(job_id)
+ if m:
+ return (
+ f"{resource_id}/runs/{m.group('run_id')}?wsid=/subscriptions/{m.group('subscription')}"
+ f"/resourcegroups/{m.group('resource_group')}/workspaces/{m.group('workspace')}"
+ ) # fmt: skip
+ return None
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/__init__.py
new file mode 100644
index 00000000..e99e9321
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/__init__.py
@@ -0,0 +1,16 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
+
+from .search_space import SearchSpace
+from .stack_ensemble_settings import StackEnsembleSettings
+from .training_settings import ClassificationTrainingSettings, TrainingSettings
+
+__all__ = [
+ "ClassificationTrainingSettings",
+ "TrainingSettings",
+ "SearchSpace",
+ "StackEnsembleSettings",
+]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/automl_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/automl_job.py
new file mode 100644
index 00000000..9e1b4d05
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/automl_job.py
@@ -0,0 +1,283 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import logging
+from abc import ABC, abstractmethod
+from typing import Any, Dict, Optional, Union
+
+from azure.ai.ml._restclient.v2024_01_01_preview.models import (
+ JobBase,
+ MLTableJobInput,
+ QueueSettings,
+ ResourceConfiguration,
+ TaskType,
+)
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.constants import JobType
+from azure.ai.ml.constants._common import TYPE, AssetTypes
+from azure.ai.ml.constants._job.automl import AutoMLConstants
+from azure.ai.ml.entities._credentials import (
+ AmlTokenConfiguration,
+ ManagedIdentityConfiguration,
+ UserIdentityConfiguration,
+)
+from azure.ai.ml.entities._inputs_outputs import Input
+from azure.ai.ml.entities._job.job import Job
+from azure.ai.ml.entities._job.job_io_mixin import JobIOMixin
+from azure.ai.ml.entities._job.pipeline._io import AutoMLNodeIOMixin
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
+
+module_logger = logging.getLogger(__name__)
+
+
+class AutoMLJob(Job, JobIOMixin, AutoMLNodeIOMixin, ABC):
+ """Initialize an AutoML job entity.
+
+ Constructor for an AutoMLJob.
+
+ :keyword resources: Resource configuration for the AutoML job, defaults to None
+ :paramtype resources: typing.Optional[ResourceConfiguration]
+ :keyword identity: Identity that training job will use while running on compute, defaults to None
+ :paramtype identity: typing.Optional[ typing.Union[ManagedIdentityConfiguration, AmlTokenConfiguration
+ , UserIdentityConfiguration] ]
+ :keyword environment_id: The environment id for the AutoML job, defaults to None
+ :paramtype environment_id: typing.Optional[str]
+ :keyword environment_variables: The environment variables for the AutoML job, defaults to None
+ :paramtype environment_variables: typing.Optional[Dict[str, str]]
+ :keyword outputs: The outputs for the AutoML job, defaults to None
+ :paramtype outputs: typing.Optional[Dict[str, str]]
+ :keyword queue_settings: The queue settings for the AutoML job, defaults to None
+ :paramtype queue_settings: typing.Optional[QueueSettings]
+ :raises ValidationException: task type validation error
+ :raises NotImplementedError: Raises NotImplementedError
+ :return: An AutoML Job
+ :rtype: AutoMLJob
+ """
+
+ def __init__(
+ self,
+ *,
+ resources: Optional[ResourceConfiguration] = None,
+ identity: Optional[
+ Union[ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration]
+ ] = None,
+ queue_settings: Optional[QueueSettings] = None,
+ **kwargs: Any,
+ ) -> None:
+ """Initialize an AutoML job entity.
+
+ Constructor for an AutoMLJob.
+
+ :keyword resources: Resource configuration for the AutoML job, defaults to None
+ :paramtype resources: typing.Optional[ResourceConfiguration]
+ :keyword identity: Identity that training job will use while running on compute, defaults to None
+ :paramtype identity: typing.Optional[ typing.Union[ManagedIdentityConfiguration, AmlTokenConfiguration
+ , UserIdentityConfiguration] ]
+ :keyword environment_id: The environment id for the AutoML job, defaults to None
+ :paramtype environment_id: typing.Optional[str]
+ :keyword environment_variables: The environment variables for the AutoML job, defaults to None
+ :paramtype environment_variables: typing.Optional[Dict[str, str]]
+ :keyword outputs: The outputs for the AutoML job, defaults to None
+ :paramtype outputs: typing.Optional[Dict[str, str]]
+ :keyword queue_settings: The queue settings for the AutoML job, defaults to None
+ :paramtype queue_settings: typing.Optional[QueueSettings]
+ :raises ValidationException: task type validation error
+ :raises NotImplementedError: Raises NotImplementedError
+ """
+ kwargs[TYPE] = JobType.AUTOML
+ self.environment_id = kwargs.pop("environment_id", None)
+ self.environment_variables = kwargs.pop("environment_variables", None)
+ self.outputs = kwargs.pop("outputs", None)
+
+ super().__init__(**kwargs)
+
+ self.resources = resources
+ self.identity = identity
+ self.queue_settings = queue_settings
+
+ @property
+ @abstractmethod
+ def training_data(self) -> Input:
+ """The training data for the AutoML job.
+
+ :raises NotImplementedError: Raises NotImplementedError
+ :return: Returns the training data for the AutoML job.
+ :rtype: Input
+ """
+ raise NotImplementedError()
+
+ @training_data.setter
+ def training_data(self, value: Any) -> None:
+ self.training_data = value
+
+ @property
+ @abstractmethod
+ def validation_data(self) -> Input:
+ """The validation data for the AutoML job.
+
+ :raises NotImplementedError: Raises NotImplementedError
+ :return: Returns the validation data for the AutoML job.
+ :rtype: Input
+ """
+ raise NotImplementedError()
+
+ @validation_data.setter
+ def validation_data(self, value: Any) -> None:
+ self.validation_data = value
+
+ @property
+ @abstractmethod
+ def test_data(self) -> Input:
+ """The test data for the AutoML job.
+
+ :raises NotImplementedError: Raises NotImplementedError
+ :return: Returns the test data for the AutoML job.
+ :rtype: Input
+ """
+ raise NotImplementedError()
+
+ @test_data.setter
+ def test_data(self, value: Any) -> None:
+ self.test_data = value
+
+ @classmethod
+ def _load_from_rest(cls, obj: JobBase) -> "AutoMLJob":
+ """Loads the rest object to a dict containing items to init the AutoMLJob objects.
+
+ :param obj: Azure Resource Manager resource envelope.
+ :type obj: JobBase
+ :raises ValidationException: task type validation error
+ :return: An AutoML Job
+ :rtype: AutoMLJob
+ """
+ task_type = (
+ camel_to_snake(obj.properties.task_details.task_type) if obj.properties.task_details.task_type else None
+ )
+ class_type = cls._get_task_mapping().get(task_type, None)
+ if class_type:
+ res: AutoMLJob = class_type._from_rest_object(obj)
+ return res
+ msg = f"Unsupported task type: {obj.properties.task_details.task_type}"
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.AUTOML,
+ error_category=ErrorCategory.SYSTEM_ERROR,
+ )
+
+ @classmethod
+ def _load_from_dict(
+ cls,
+ data: Dict,
+ context: Dict,
+ additional_message: str,
+ **kwargs: Any,
+ ) -> "AutoMLJob":
+ """Loads the dictionary objects to an AutoMLJob object.
+
+ :param data: A data dictionary.
+ :type data: typing.Dict
+ :param context: A context dictionary.
+ :type context: typing.Dict
+ :param additional_message: An additional message to be logged in the ValidationException.
+ :type additional_message: str
+
+ :raises ValidationException: task type validation error
+ :return: An AutoML Job
+ :rtype: AutoMLJob
+ """
+ task_type = data.get(AutoMLConstants.TASK_TYPE_YAML)
+ class_type = cls._get_task_mapping().get(task_type, None)
+ if class_type:
+ res: AutoMLJob = class_type._load_from_dict(
+ data,
+ context,
+ additional_message,
+ **kwargs,
+ )
+ return res
+ msg = f"Unsupported task type: {task_type}"
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.AUTOML,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ @classmethod
+ def _create_instance_from_schema_dict(cls, loaded_data: Dict) -> "AutoMLJob":
+ """Create an automl job instance from schema parsed dict.
+
+ :param loaded_data: A loaded_data dictionary.
+ :type loaded_data: typing.Dict
+ :raises ValidationException: task type validation error
+ :return: An AutoML Job
+ :rtype: AutoMLJob
+ """
+ task_type = loaded_data.pop(AutoMLConstants.TASK_TYPE_YAML)
+ class_type = cls._get_task_mapping().get(task_type, None)
+ if class_type:
+ res: AutoMLJob = class_type._create_instance_from_schema_dict(loaded_data=loaded_data)
+ return res
+ msg = f"Unsupported task type: {task_type}"
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.AUTOML,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ @classmethod
+ def _get_task_mapping(cls) -> Dict:
+ """Create a mapping of task type to job class.
+
+ :return: An AutoMLVertical object containing the task type to job class mapping.
+ :rtype: AutoMLVertical
+ """
+ from .image import (
+ ImageClassificationJob,
+ ImageClassificationMultilabelJob,
+ ImageInstanceSegmentationJob,
+ ImageObjectDetectionJob,
+ )
+ from .nlp import TextClassificationJob, TextClassificationMultilabelJob, TextNerJob
+ from .tabular import ClassificationJob, ForecastingJob, RegressionJob
+
+ # create a mapping of task type to job class
+ return {
+ camel_to_snake(TaskType.CLASSIFICATION): ClassificationJob,
+ camel_to_snake(TaskType.REGRESSION): RegressionJob,
+ camel_to_snake(TaskType.FORECASTING): ForecastingJob,
+ camel_to_snake(TaskType.IMAGE_CLASSIFICATION): ImageClassificationJob,
+ camel_to_snake(TaskType.IMAGE_CLASSIFICATION_MULTILABEL): ImageClassificationMultilabelJob,
+ camel_to_snake(TaskType.IMAGE_OBJECT_DETECTION): ImageObjectDetectionJob,
+ camel_to_snake(TaskType.IMAGE_INSTANCE_SEGMENTATION): ImageInstanceSegmentationJob,
+ camel_to_snake(TaskType.TEXT_NER): TextNerJob,
+ camel_to_snake(TaskType.TEXT_CLASSIFICATION): TextClassificationJob,
+ camel_to_snake(TaskType.TEXT_CLASSIFICATION_MULTILABEL): TextClassificationMultilabelJob,
+ }
+
+ def _resolve_data_inputs(self, rest_job: "AutoMLJob") -> None:
+ """Resolve JobInputs to MLTableJobInputs within data_settings.
+
+ :param rest_job: The rest job object.
+ :type rest_job: AutoMLJob
+ """
+ if isinstance(rest_job.training_data, Input):
+ rest_job.training_data = MLTableJobInput(uri=rest_job.training_data.path)
+ if isinstance(rest_job.validation_data, Input):
+ rest_job.validation_data = MLTableJobInput(uri=rest_job.validation_data.path)
+ if hasattr(rest_job, "test_data") and isinstance(rest_job.test_data, Input):
+ rest_job.test_data = MLTableJobInput(uri=rest_job.test_data.path)
+
+ def _restore_data_inputs(self) -> None:
+ """Restore MLTableJobInputs to JobInputs within data_settings."""
+ if isinstance(self.training_data, MLTableJobInput):
+ self.training_data = Input(type=AssetTypes.MLTABLE, path=self.training_data.uri)
+ if isinstance(self.validation_data, MLTableJobInput):
+ self.validation_data = Input(type=AssetTypes.MLTABLE, path=self.validation_data.uri)
+ if hasattr(self, "test_data") and isinstance(self.test_data, MLTableJobInput):
+ self.test_data = Input(type=AssetTypes.MLTABLE, path=self.test_data.uri)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/automl_vertical.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/automl_vertical.py
new file mode 100644
index 00000000..f11be81c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/automl_vertical.py
@@ -0,0 +1,134 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from abc import abstractmethod
+from typing import Any, Optional
+
+from azure.ai.ml import Input
+
+from .automl_job import AutoMLJob
+
+
+class AutoMLVertical(AutoMLJob):
+ """Abstract class for AutoML verticals.
+
+ :param task_type: The type of task to run. Possible values include: "classification", "regression", "forecasting".
+ :type task_type: str
+ :param training_data: Training data input
+ :type training_data: Input
+ :param validation_data: Validation data input
+ :type validation_data: Input
+ :param test_data: Test data input, defaults to None
+ :type test_data: typing.Optional[Input]
+ :raises ValueError: If task_type is not one of "classification", "regression", "forecasting".
+ :raises ValueError: If training_data is not of type Input.
+ :raises ValueError: If validation_data is not of type Input.
+ :raises ValueError: If test_data is not of type Input.
+ """
+
+ @abstractmethod
+ def __init__(
+ self,
+ task_type: str,
+ training_data: Input,
+ validation_data: Input,
+ test_data: Optional[Input] = None,
+ **kwargs: Any
+ ) -> None:
+ """Initialize AutoMLVertical.
+
+ Constructor for AutoMLVertical.
+
+ :param task_type: The type of task to run. Possible values include: "classification", "regression"
+ , "forecasting".
+ :type task_type: str
+ :param training_data: Training data input
+ :type training_data: Input
+ :param validation_data: Validation data input
+ :type validation_data: Input
+ :param test_data: Test data input, defaults to None
+ :type test_data: typing.Optional[Input]
+ :raises ValueError: If task_type is not one of "classification", "regression", "forecasting".
+ :raises ValueError: If training_data is not of type Input.
+ :raises ValueError: If validation_data is not of type Input.
+ :raises ValueError: If test_data is not of type Input.
+ """
+ self._task_type = task_type
+ self.training_data = training_data
+ self.validation_data = validation_data
+ self.test_data = test_data # type: ignore
+ super().__init__(**kwargs)
+
+ @property
+ def task_type(self) -> str:
+ """Get task type.
+
+ :return: The type of task to run. Possible values include: "classification", "regression", "forecasting".
+ :rtype: str
+ """
+ return self._task_type
+
+ @task_type.setter
+ def task_type(self, task_type: str) -> None:
+ """Set task type.
+
+ :param task_type: The type of task to run. Possible values include: "classification", "regression"
+ , "forecasting".
+ :type task_type: str
+ """
+ self._task_type = task_type
+
+ @property
+ def training_data(self) -> Input:
+ """Get training data.
+
+ :return: Training data input
+ :rtype: Input
+ """
+ return self._training_data
+
+ @training_data.setter
+ def training_data(self, training_data: Input) -> None:
+ """Set training data.
+
+ :param training_data: Training data input
+ :type training_data: Input
+ """
+ self._training_data = training_data
+
+ @property
+ def validation_data(self) -> Input:
+ """Get validation data.
+
+ :return: Validation data input
+ :rtype: Input
+ """
+ return self._validation_data
+
+ @validation_data.setter
+ def validation_data(self, validation_data: Input) -> None:
+ """Set validation data.
+
+ :param validation_data: Validation data input
+ :type validation_data: Input
+ """
+ self._validation_data = validation_data
+
+ @property
+ def test_data(self) -> Input:
+ """Get test data.
+
+ :return: Test data input
+ :rtype: Input
+ """
+ return self._test_data
+
+ @test_data.setter
+ def test_data(self, test_data: Input) -> None:
+ """Set test data.
+
+ :param test_data: Test data input
+ :type test_data: Input
+ """
+ self._test_data = test_data
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/featurization_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/featurization_settings.py
new file mode 100644
index 00000000..c9e73d21
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/featurization_settings.py
@@ -0,0 +1,32 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Optional
+
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+
+class FeaturizationSettings(RestTranslatableMixin):
+ """Base Featurization settings."""
+
+ def __init__(
+ self,
+ *,
+ dataset_language: Optional[str] = None,
+ ):
+ self.dataset_language = dataset_language
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, FeaturizationSettings):
+ return NotImplemented
+
+ return self.dataset_language == other.dataset_language
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
+
+
+class FeaturizationSettingsType:
+ NLP = "nlp"
+ TABULAR = "tabular"
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/__init__.py
new file mode 100644
index 00000000..46964086
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/__init__.py
@@ -0,0 +1,35 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from .automl_image import AutoMLImage
+from .image_classification_job import ImageClassificationJob
+from .image_classification_multilabel_job import ImageClassificationMultilabelJob
+from .image_classification_search_space import ImageClassificationSearchSpace
+from .image_instance_segmentation_job import ImageInstanceSegmentationJob
+from .image_limit_settings import ImageLimitSettings
+from .image_model_settings import (
+ ImageModelSettingsClassification,
+ ImageModelSettingsObjectDetection,
+ LogTrainingMetrics,
+ LogValidationLoss,
+)
+from .image_object_detection_job import ImageObjectDetectionJob
+from .image_object_detection_search_space import ImageObjectDetectionSearchSpace
+from .image_sweep_settings import ImageSweepSettings
+
+__all__ = [
+ "AutoMLImage",
+ "LogTrainingMetrics",
+ "LogValidationLoss",
+ "ImageClassificationJob",
+ "ImageClassificationMultilabelJob",
+ "ImageClassificationSearchSpace",
+ "ImageInstanceSegmentationJob",
+ "ImageLimitSettings",
+ "ImageObjectDetectionJob",
+ "ImageObjectDetectionSearchSpace",
+ "ImageSweepSettings",
+ "ImageModelSettingsClassification",
+ "ImageModelSettingsObjectDetection",
+]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/automl_image.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/automl_image.py
new file mode 100644
index 00000000..a07bba4a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/automl_image.py
@@ -0,0 +1,244 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from abc import ABC
+from typing import Any, Dict, Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import LogVerbosity, SamplingAlgorithmType
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.entities._inputs_outputs import Input
+from azure.ai.ml.entities._job.automl.automl_vertical import AutoMLVertical
+from azure.ai.ml.entities._job.automl.image.image_limit_settings import ImageLimitSettings
+from azure.ai.ml.entities._job.automl.image.image_sweep_settings import ImageSweepSettings
+from azure.ai.ml.entities._job.sweep.early_termination_policy import (
+ BanditPolicy,
+ MedianStoppingPolicy,
+ TruncationSelectionPolicy,
+)
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
+
+
+class AutoMLImage(AutoMLVertical, ABC):
+ """Base class for all AutoML Image jobs.
+ You should not instantiate this class directly.
+ Instead you should create classes for specific AutoML Image tasks.
+
+ :keyword task_type: Required. Type of task to run.
+ Possible values include: "ImageClassification", "ImageClassificationMultilabel",
+ "ImageObjectDetection", "ImageInstanceSegmentation"
+ :paramtype task_type: str
+ :keyword limits: Limit settings for all AutoML Image jobs. Defaults to None.
+ :paramtype limits: Optional[~azure.ai.ml.automl.ImageLimitSettings]
+ :keyword sweep: Sweep settings for all AutoML Image jobs. Defaults to None.
+ :paramtype sweep: Optional[~azure.ai.ml.automl.ImageSweepSettings]
+ :keyword kwargs: Additional keyword arguments for AutoMLImage.
+ :paramtype kwargs: Dict[str, Any]
+ """
+
+ def __init__(
+ self,
+ *,
+ task_type: str,
+ limits: Optional[ImageLimitSettings] = None,
+ sweep: Optional[ImageSweepSettings] = None,
+ **kwargs: Any,
+ ) -> None:
+ self.log_verbosity = kwargs.pop("log_verbosity", LogVerbosity.INFO)
+ self.target_column_name = kwargs.pop("target_column_name", None)
+ self.validation_data_size = kwargs.pop("validation_data_size", None)
+
+ super().__init__(
+ task_type=task_type,
+ training_data=kwargs.pop("training_data", None),
+ validation_data=kwargs.pop("validation_data", None),
+ **kwargs,
+ )
+
+ # Set default value for self._limits as it is a required property in rest object.
+ self._limits = limits or ImageLimitSettings()
+ self._sweep = sweep
+
+ @property
+ def log_verbosity(self) -> LogVerbosity:
+ """Returns the verbosity of the logger.
+
+ :return: The log verbosity.
+ :rtype: ~azure.ai.ml._restclient.v2023_04_01_preview.models.LogVerbosity
+ """
+ return self._log_verbosity
+
+ @log_verbosity.setter
+ def log_verbosity(self, value: Union[str, LogVerbosity]) -> None:
+ """Sets the verbosity of the logger.
+
+ :param value: The value to set the log verbosity to.
+ Possible values include: "NotSet", "Debug", "Info", "Warning", "Error", "Critical".
+ :type value: Union[str, ~azure.ai.ml._restclient.v2023_04_01_preview.models.LogVerbosity]
+ """
+ self._log_verbosity = None if value is None else LogVerbosity[camel_to_snake(value).upper()]
+
+ @property
+ def limits(self) -> ImageLimitSettings:
+ """Returns the limit settings for all AutoML Image jobs.
+
+ :return: The limit settings.
+ :rtype: ~azure.ai.ml.automl.ImageLimitSettings
+ """
+ return self._limits
+
+ @limits.setter
+ def limits(self, value: Union[Dict, ImageLimitSettings]) -> None:
+ if isinstance(value, ImageLimitSettings):
+ self._limits = value
+ else:
+ if not isinstance(value, dict):
+ msg = "Expected a dictionary for limit settings."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.AUTOML,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ self.set_limits(**value)
+
+ @property
+ def sweep(self) -> Optional[ImageSweepSettings]:
+ """Returns the sweep settings for all AutoML Image jobs.
+
+ :return: The sweep settings.
+ :rtype: ~azure.ai.ml.automl.ImageSweepSettings
+ """
+ return self._sweep
+
+ @sweep.setter
+ def sweep(self, value: Union[Dict, ImageSweepSettings]) -> None:
+ """Sets the sweep settings for all AutoML Image jobs.
+
+ :param value: The value to set the sweep settings to.
+ :type value: Union[Dict, ~azure.ai.ml.automl.ImageSweepSettings]
+ :raises ~azure.ai.ml.exceptions.ValidationException: If value is not a dictionary.
+ :return: None
+ """
+ if isinstance(value, ImageSweepSettings):
+ self._sweep = value
+ else:
+ if not isinstance(value, dict):
+ msg = "Expected a dictionary for sweep settings."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.AUTOML,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ self.set_sweep(**value)
+
+ def set_data(
+ self,
+ *,
+ training_data: Input,
+ target_column_name: str,
+ validation_data: Optional[Input] = None,
+ validation_data_size: Optional[float] = None,
+ ) -> None:
+ """Data settings for all AutoML Image jobs.
+
+ :keyword training_data: Required. Training data.
+ :type training_data: ~azure.ai.ml.entities.Input
+ :keyword target_column_name: Required. Target column name.
+ :type target_column_name: str
+ :keyword validation_data: Optional. Validation data.
+ :type validation_data: Optional[~azure.ai.ml.entities.Input]
+ :keyword validation_data_size: Optional. The fraction of training dataset that needs to be set aside for
+ validation purpose. Values should be in range (0.0 , 1.0).
+ Applied only when validation dataset is not provided.
+ :type validation_data_size: Optional[float]
+ :return: None
+ """
+ self.target_column_name = self.target_column_name if target_column_name is None else target_column_name
+ self.training_data = self.training_data if training_data is None else training_data
+ self.validation_data = self.validation_data if validation_data is None else validation_data
+ self.validation_data_size = self.validation_data_size if validation_data_size is None else validation_data_size
+
+ def set_limits(
+ self,
+ *,
+ max_concurrent_trials: Optional[int] = None,
+ max_trials: Optional[int] = None,
+ timeout_minutes: Optional[int] = None,
+ ) -> None:
+ """Limit settings for all AutoML Image Jobs.
+
+ :keyword max_concurrent_trials: Maximum number of trials to run concurrently.
+ :type max_concurrent_trials: Optional[int]. Defaults to None.
+ :keyword max_trials: Maximum number of trials to run. Defaults to None.
+ :type max_trials: Optional[int]
+ :keyword timeout_minutes: AutoML job timeout.
+ :type timeout_minutes: ~datetime.timedelta
+ :return: None
+ """
+ self._limits = self._limits or ImageLimitSettings()
+ self._limits.max_concurrent_trials = (
+ max_concurrent_trials if max_concurrent_trials is not None else self._limits.max_concurrent_trials
+ )
+ self._limits.max_trials = max_trials if max_trials is not None else self._limits.max_trials
+ self._limits.timeout_minutes = timeout_minutes if timeout_minutes is not None else self._limits.timeout_minutes
+
+ def set_sweep(
+ self,
+ *,
+ sampling_algorithm: Union[
+ str, SamplingAlgorithmType.RANDOM, SamplingAlgorithmType.GRID, SamplingAlgorithmType.BAYESIAN
+ ],
+ early_termination: Optional[Union[BanditPolicy, MedianStoppingPolicy, TruncationSelectionPolicy]] = None,
+ ) -> None:
+ """Sweep settings for all AutoML Image jobs.
+
+ :keyword sampling_algorithm: Required. Type of the hyperparameter sampling
+ algorithms. Possible values include: "Grid", "Random", "Bayesian".
+ :type sampling_algorithm: Union[str, ~azure.mgmt.machinelearningservices.models.SamplingAlgorithmType.RANDOM,
+ ~azure.mgmt.machinelearningservices.models.SamplingAlgorithmType.GRID,
+ ~azure.mgmt.machinelearningservices.models.SamplingAlgorithmType.BAYESIAN]
+ :keyword early_termination: Type of early termination policy.
+ :type early_termination: Union[
+ ~azure.mgmt.machinelearningservices.models.BanditPolicy,
+ ~azure.mgmt.machinelearningservices.models.MedianStoppingPolicy,
+ ~azure.mgmt.machinelearningservices.models.TruncationSelectionPolicy]
+ :return: None
+ """
+ if self._sweep:
+ self._sweep.sampling_algorithm = sampling_algorithm
+ else:
+ self._sweep = ImageSweepSettings(sampling_algorithm=sampling_algorithm)
+
+ self._sweep.early_termination = early_termination or self._sweep.early_termination
+
+ def __eq__(self, other: object) -> bool:
+ """Compares two AutoMLImage objects for equality.
+
+ :param other: The other AutoMLImage object to compare to.
+ :type other: ~azure.ai.ml.automl.AutoMLImage
+ :return: True if the two AutoMLImage objects are equal; False otherwise.
+ :rtype: bool
+ """
+ if not isinstance(other, AutoMLImage):
+ return NotImplemented
+
+ return (
+ self.target_column_name == other.target_column_name
+ and self.training_data == other.training_data
+ and self.validation_data == other.validation_data
+ and self.validation_data_size == other.validation_data_size
+ and self._limits == other._limits
+ and self._sweep == other._sweep
+ )
+
+ def __ne__(self, other: object) -> bool:
+ """Compares two AutoMLImage objects for inequality.
+
+ :param other: The other AutoMLImage object to compare to.
+ :type other: ~azure.ai.ml.automl.AutoMLImage
+ :return: True if the two AutoMLImage objects are not equal; False otherwise.
+ :rtype: bool
+ """
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/automl_image_classification_base.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/automl_image_classification_base.py
new file mode 100644
index 00000000..ef0c8a2d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/automl_image_classification_base.py
@@ -0,0 +1,439 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from typing import Any, Dict, List, Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import LearningRateScheduler, StochasticOptimizer
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.entities._job.automl.image.automl_image import AutoMLImage
+from azure.ai.ml.entities._job.automl.image.image_classification_search_space import ImageClassificationSearchSpace
+from azure.ai.ml.entities._job.automl.image.image_limit_settings import ImageLimitSettings
+from azure.ai.ml.entities._job.automl.image.image_model_settings import ImageModelSettingsClassification
+from azure.ai.ml.entities._job.automl.image.image_sweep_settings import ImageSweepSettings
+from azure.ai.ml.entities._job.automl.search_space import SearchSpace
+from azure.ai.ml.entities._job.automl.utils import cast_to_specific_search_space
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
+
+
+class AutoMLImageClassificationBase(AutoMLImage):
+ """Base class for AutoML Image Classification and Image Classification Multilabel tasks.
+ Please do not instantiate this class directly. Instantiate one of the child classes instead.
+
+ :keyword task_type: Type of task to run.
+ Possible values include: "ImageClassification", "ImageClassificationMultilabel".
+ :paramtype task_type: str
+ :keyword limits: Limits for Automl image classification jobs. Defaults to None.
+ :paramtype limits: Optional[~azure.ai.ml.automl.ImageLimitSettings]
+ :keyword sweep: Sweep settings for Automl image classification jobs. Defaults to None.
+ :paramtype sweep: Optional[~azure.ai.ml.automl.ImageSweepSettings]
+ :keyword training_parameters: Training parameters for Automl image classification jobs. Defaults to None.
+ :paramtype training_parameters: Optional[~azure.ai.ml.automl.ImageModelSettingsClassification]
+ :keyword search_space: Search space for Automl image classification jobs. Defaults to None.
+ :paramtype search_space: Optional[List[~azure.ai.ml.automl.ImageClassificationSearchSpace]]
+ :keyword kwargs: Other Keyword arguments for AutoMLImageClassificationBase class.
+ :paramtype kwargs: Dict[str, Any]
+ """
+
+ def __init__(
+ self,
+ *,
+ task_type: str,
+ limits: Optional[ImageLimitSettings] = None,
+ sweep: Optional[ImageSweepSettings] = None,
+ training_parameters: Optional[ImageModelSettingsClassification] = None,
+ search_space: Optional[List[ImageClassificationSearchSpace]] = None,
+ **kwargs: Any,
+ ) -> None:
+ self._training_parameters: Optional[ImageModelSettingsClassification] = None
+
+ super().__init__(task_type=task_type, limits=limits, sweep=sweep, **kwargs)
+ self.training_parameters = training_parameters # Assigning training_parameters through setter method.
+ self._search_space = search_space
+
+ @property
+ def training_parameters(self) -> Optional[ImageModelSettingsClassification]:
+ """
+ :rtype: ~azure.ai.ml.automl.ImageModelSettingsClassification
+ :return: Training parameters for AutoML Image Classification and Image Classification Multilabel tasks.
+ """
+ return self._training_parameters
+
+ @training_parameters.setter
+ def training_parameters(self, value: Union[Dict, ImageModelSettingsClassification]) -> None:
+ """Setting Image training parameters for AutoML Image Classification and Image Classification Multilabel tasks.
+
+ :param value: Training parameters for AutoML Image Classification and Image Classification Multilabel tasks.
+ :type value: Union[Dict, ~azure.ai.ml.automl.ImageModelSettingsClassification]
+ :raises ~azure.ml.exceptions.ValidationException if value is not a dictionary or
+ ImageModelSettingsClassification.
+ :return: None
+ """
+ if value is None:
+ self._training_parameters = None
+ elif isinstance(value, ImageModelSettingsClassification):
+ self._training_parameters = value
+ # set_training_parameters convert parameter values from snake case str to enum.
+ # We need to add any future enum parameters in this call to support snake case str.
+ self.set_training_parameters(
+ optimizer=value.optimizer,
+ learning_rate_scheduler=value.learning_rate_scheduler,
+ )
+ else:
+ if not isinstance(value, dict):
+ msg = "Expected a dictionary for model settings."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.AUTOML,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ self.set_training_parameters(**value)
+
+ @property
+ def search_space(self) -> Optional[List[ImageClassificationSearchSpace]]:
+ """
+ :rtype: List[~azure.ai.ml.automl.ImageClassificationSearchSpace]
+ :return: Search space for AutoML Image Classification and Image Classification Multilabel tasks.
+ """
+ return self._search_space
+
+ @search_space.setter
+ def search_space(self, value: Union[List[Dict], List[SearchSpace]]) -> None:
+ """Setting Image search space for AutoML Image Classification and Image Classification Multilabel tasks.
+
+ :param value: Search space for AutoML Image Classification and Image Classification Multilabel tasks.
+ :type value: Union[List[Dict], List[~azure.ai.ml.automl.ImageClassificationSearchSpace]]
+ :raises ~azure.ml.exceptions.ValidationException if value is not a list of dictionaries or
+ ImageClassificationSearchSpace.
+ """
+ if not isinstance(value, list):
+ msg = "Expected a list for search space."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.AUTOML,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ all_dict_type = all(isinstance(item, dict) for item in value)
+ all_search_space_type = all(isinstance(item, SearchSpace) for item in value)
+
+ if all_search_space_type or all_dict_type:
+ self._search_space = [
+ cast_to_specific_search_space(item, ImageClassificationSearchSpace, self.task_type) # type: ignore
+ for item in value
+ ]
+ else:
+ msg = "Expected all items in the list to be either dictionaries or ImageClassificationSearchSpace objects."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.AUTOML,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ # pylint: disable=too-many-locals
+ def set_training_parameters(
+ self,
+ *,
+ advanced_settings: Optional[str] = None,
+ ams_gradient: Optional[bool] = None,
+ beta1: Optional[float] = None,
+ beta2: Optional[float] = None,
+ checkpoint_frequency: Optional[int] = None,
+ checkpoint_run_id: Optional[str] = None,
+ distributed: Optional[bool] = None,
+ early_stopping: Optional[bool] = None,
+ early_stopping_delay: Optional[int] = None,
+ early_stopping_patience: Optional[int] = None,
+ enable_onnx_normalization: Optional[bool] = None,
+ evaluation_frequency: Optional[int] = None,
+ gradient_accumulation_step: Optional[int] = None,
+ layers_to_freeze: Optional[int] = None,
+ learning_rate: Optional[float] = None,
+ learning_rate_scheduler: Optional[Union[str, LearningRateScheduler]] = None,
+ model_name: Optional[str] = None,
+ momentum: Optional[float] = None,
+ nesterov: Optional[bool] = None,
+ number_of_epochs: Optional[int] = None,
+ number_of_workers: Optional[int] = None,
+ optimizer: Optional[Union[str, StochasticOptimizer]] = None,
+ random_seed: Optional[int] = None,
+ step_lr_gamma: Optional[float] = None,
+ step_lr_step_size: Optional[int] = None,
+ training_batch_size: Optional[int] = None,
+ validation_batch_size: Optional[int] = None,
+ warmup_cosine_lr_cycles: Optional[float] = None,
+ warmup_cosine_lr_warmup_epochs: Optional[int] = None,
+ weight_decay: Optional[float] = None,
+ training_crop_size: Optional[int] = None,
+ validation_crop_size: Optional[int] = None,
+ validation_resize_size: Optional[int] = None,
+ weighted_loss: Optional[int] = None,
+ ) -> None:
+ """Setting Image training parameters for AutoML Image Classification and Image Classification Multilabel tasks.
+
+ :keyword advanced_settings: Settings for advanced scenarios.
+ :paramtype advanced_settings: str
+ :keyword ams_gradient: Enable AMSGrad when optimizer is 'adam' or 'adamw'.
+ :paramtype ams_gradient: bool
+ :keyword beta1: Value of 'beta1' when optimizer is 'adam' or 'adamw'. Must be a float in the
+ range [0, 1].
+ :paramtype beta1: float
+ :keyword beta2: Value of 'beta2' when optimizer is 'adam' or 'adamw'. Must be a float in the
+ range [0, 1].
+ :paramtype beta2: float
+ :keyword checkpoint_frequency: Frequency to store model checkpoints. Must be a positive
+ integer.
+ :paramtype checkpoint_frequency: int
+ :keyword checkpoint_run_id: The id of a previous run that has a pretrained checkpoint for
+ incremental training.
+ :paramtype checkpoint_run_id: str
+ :keyword distributed: Whether to use distributed training.
+ :paramtype distributed: bool
+ :keyword early_stopping: Enable early stopping logic during training.
+ :paramtype early_stopping: bool
+ :keyword early_stopping_delay: Minimum number of epochs or validation evaluations to wait
+ before primary metric improvement
+ is tracked for early stopping. Must be a positive integer.
+ :paramtype early_stopping_delay: int
+ :keyword early_stopping_patience: Minimum number of epochs or validation evaluations with no
+ primary metric improvement before
+ the run is stopped. Must be a positive integer.
+ :paramtype early_stopping_patience: int
+ :keyword enable_onnx_normalization: Enable normalization when exporting ONNX model.
+ :paramtype enable_onnx_normalization: bool
+ :keyword evaluation_frequency: Frequency to evaluate validation dataset to get metric scores.
+ Must be a positive integer.
+ :paramtype evaluation_frequency: int
+ :keyword gradient_accumulation_step: Gradient accumulation means running a configured number of
+ "GradAccumulationStep" steps without
+ updating the model weights while accumulating the gradients of those steps, and then using
+ the accumulated gradients to compute the weight updates. Must be a positive integer.
+ :paramtype gradient_accumulation_step: int
+ :keyword layers_to_freeze: Number of layers to freeze for the model. Must be a positive
+ integer.
+ For instance, passing 2 as value for 'seresnext' means
+ freezing layer0 and layer1. For a full list of models supported and details on layer freeze,
+ please
+ see: https://learn.microsoft.com/azure/machine-learning/reference-automl-images-hyperparameters#model-agnostic-hyperparameters. # pylint: disable=line-too-long
+ :type layers_to_freeze: int
+ :keyword learning_rate: Initial learning rate. Must be a float in the range [0, 1].
+ :paramtype learning_rate: float
+ :keyword learning_rate_scheduler: Type of learning rate scheduler. Must be 'warmup_cosine' or
+ 'step'. Possible values include: "None", "WarmupCosine", "Step".
+ :type learning_rate_scheduler: str or
+ ~azure.mgmt.machinelearningservices.models.LearningRateScheduler
+ :keyword model_name: Name of the model to use for training.
+ For more information on the available models please visit the official documentation:
+ https://learn.microsoft.com/azure/machine-learning/how-to-auto-train-image-models.
+ :type model_name: str
+ :keyword momentum: Value of momentum when optimizer is 'sgd'. Must be a float in the range [0,
+ 1].
+ :paramtype momentum: float
+ :keyword nesterov: Enable nesterov when optimizer is 'sgd'.
+ :paramtype nesterov: bool
+ :keyword number_of_epochs: Number of training epochs. Must be a positive integer.
+ :paramtype number_of_epochs: int
+ :keyword number_of_workers: Number of data loader workers. Must be a non-negative integer.
+ :paramtype number_of_workers: int
+ :keyword optimizer: Type of optimizer. Possible values include: "None", "Sgd", "Adam", "Adamw".
+ :type optimizer: str or ~azure.mgmt.machinelearningservices.models.StochasticOptimizer
+ :keyword random_seed: Random seed to be used when using deterministic training.
+ :paramtype random_seed: int
+ :keyword step_lr_gamma: Value of gamma when learning rate scheduler is 'step'. Must be a float
+ in the range [0, 1].
+ :paramtype step_lr_gamma: float
+ :keyword step_lr_step_size: Value of step size when learning rate scheduler is 'step'. Must be
+ a positive integer.
+ :paramtype step_lr_step_size: int
+ :keyword training_batch_size: Training batch size. Must be a positive integer.
+ :paramtype training_batch_size: int
+ :keyword validation_batch_size: Validation batch size. Must be a positive integer.
+ :paramtype validation_batch_size: int
+ :keyword warmup_cosine_lr_cycles: Value of cosine cycle when learning rate scheduler is
+ 'warmup_cosine'. Must be a float in the range [0, 1].
+ :paramtype warmup_cosine_lr_cycles: float
+ :keyword warmup_cosine_lr_warmup_epochs: Value of warmup epochs when learning rate scheduler is
+ 'warmup_cosine'. Must be a positive integer.
+ :paramtype warmup_cosine_lr_warmup_epochs: int
+ :keyword weight_decay: Value of weight decay when optimizer is 'sgd', 'adam', or 'adamw'. Must
+ be a float in the range[0, 1].
+ :paramtype weight_decay: float
+ :keyword training_crop_size: Image crop size that is input to the neural network for the
+ training dataset. Must be a positive integer.
+ :paramtype training_crop_size: int
+ :keyword validation_crop_size: Image crop size that is input to the neural network for the
+ validation dataset. Must be a positive integer.
+ :paramtype validation_crop_size: int
+ :keyword validation_resize_size: Image size to which to resize before cropping for validation
+ dataset. Must be a positive integer.
+ :paramtype validation_resize_size: int
+ :keyword weighted_loss: Weighted loss. The accepted values are 0 for no weighted loss.
+ 1 for weighted loss with sqrt.(class_weights). 2 for weighted loss with class_weights. Must be
+ 0 or 1 or 2.
+ :paramtype weighted_loss: int
+ """
+ self._training_parameters = self._training_parameters or ImageModelSettingsClassification()
+
+ self._training_parameters.advanced_settings = (
+ advanced_settings if advanced_settings is not None else self._training_parameters.advanced_settings
+ )
+ self._training_parameters.ams_gradient = (
+ ams_gradient if ams_gradient is not None else self._training_parameters.ams_gradient
+ )
+ self._training_parameters.beta1 = beta1 if beta1 is not None else self._training_parameters.beta1
+ self._training_parameters.beta2 = beta2 if beta2 is not None else self._training_parameters.beta2
+ self._training_parameters.checkpoint_frequency = (
+ checkpoint_frequency if checkpoint_frequency is not None else self._training_parameters.checkpoint_frequency
+ )
+ self._training_parameters.checkpoint_run_id = (
+ checkpoint_run_id if checkpoint_run_id is not None else self._training_parameters.checkpoint_run_id
+ )
+ self._training_parameters.distributed = (
+ distributed if distributed is not None else self._training_parameters.distributed
+ )
+ self._training_parameters.early_stopping = (
+ early_stopping if early_stopping is not None else self._training_parameters.early_stopping
+ )
+ self._training_parameters.early_stopping_delay = (
+ early_stopping_delay if early_stopping_delay is not None else self._training_parameters.early_stopping_delay
+ )
+ self._training_parameters.early_stopping_patience = (
+ early_stopping_patience
+ if early_stopping_patience is not None
+ else self._training_parameters.early_stopping_patience
+ )
+ self._training_parameters.enable_onnx_normalization = (
+ enable_onnx_normalization
+ if enable_onnx_normalization is not None
+ else self._training_parameters.enable_onnx_normalization
+ )
+ self._training_parameters.evaluation_frequency = (
+ evaluation_frequency if evaluation_frequency is not None else self._training_parameters.evaluation_frequency
+ )
+ self._training_parameters.gradient_accumulation_step = (
+ gradient_accumulation_step
+ if gradient_accumulation_step is not None
+ else self._training_parameters.gradient_accumulation_step
+ )
+ self._training_parameters.layers_to_freeze = (
+ layers_to_freeze if layers_to_freeze is not None else self._training_parameters.layers_to_freeze
+ )
+ self._training_parameters.learning_rate = (
+ learning_rate if learning_rate is not None else self._training_parameters.learning_rate
+ )
+ self._training_parameters.learning_rate_scheduler = (
+ LearningRateScheduler[camel_to_snake(learning_rate_scheduler).upper()]
+ if learning_rate_scheduler is not None
+ else self._training_parameters.learning_rate_scheduler
+ )
+ self._training_parameters.model_name = (
+ model_name if model_name is not None else self._training_parameters.model_name
+ )
+ self._training_parameters.momentum = momentum if momentum is not None else self._training_parameters.momentum
+ self._training_parameters.nesterov = nesterov if nesterov is not None else self._training_parameters.nesterov
+ self._training_parameters.number_of_epochs = (
+ number_of_epochs if number_of_epochs is not None else self._training_parameters.number_of_epochs
+ )
+ self._training_parameters.number_of_workers = (
+ number_of_workers if number_of_workers is not None else self._training_parameters.number_of_workers
+ )
+ self._training_parameters.optimizer = (
+ StochasticOptimizer[camel_to_snake(optimizer).upper()]
+ if optimizer is not None
+ else self._training_parameters.optimizer
+ )
+ self._training_parameters.random_seed = (
+ random_seed if random_seed is not None else self._training_parameters.random_seed
+ )
+ self._training_parameters.step_lr_gamma = (
+ step_lr_gamma if step_lr_gamma is not None else self._training_parameters.step_lr_gamma
+ )
+ self._training_parameters.step_lr_step_size = (
+ step_lr_step_size if step_lr_step_size is not None else self._training_parameters.step_lr_step_size
+ )
+ self._training_parameters.training_batch_size = (
+ training_batch_size if training_batch_size is not None else self._training_parameters.training_batch_size
+ )
+ self._training_parameters.validation_batch_size = (
+ validation_batch_size
+ if validation_batch_size is not None
+ else self._training_parameters.validation_batch_size
+ )
+ self._training_parameters.warmup_cosine_lr_cycles = (
+ warmup_cosine_lr_cycles
+ if warmup_cosine_lr_cycles is not None
+ else self._training_parameters.warmup_cosine_lr_cycles
+ )
+ self._training_parameters.warmup_cosine_lr_warmup_epochs = (
+ warmup_cosine_lr_warmup_epochs
+ if warmup_cosine_lr_warmup_epochs is not None
+ else self._training_parameters.warmup_cosine_lr_warmup_epochs
+ )
+ self._training_parameters.weight_decay = (
+ weight_decay if weight_decay is not None else self._training_parameters.weight_decay
+ )
+ self._training_parameters.training_crop_size = (
+ training_crop_size if training_crop_size is not None else self._training_parameters.training_crop_size
+ )
+ self._training_parameters.validation_crop_size = (
+ validation_crop_size if validation_crop_size is not None else self._training_parameters.validation_crop_size
+ )
+ self._training_parameters.validation_resize_size = (
+ validation_resize_size
+ if validation_resize_size is not None
+ else self._training_parameters.validation_resize_size
+ )
+ self._training_parameters.weighted_loss = (
+ weighted_loss if weighted_loss is not None else self._training_parameters.weighted_loss
+ )
+
+ # pylint: enable=too-many-locals
+
+ def extend_search_space(
+ self,
+ value: Union[SearchSpace, List[SearchSpace]],
+ ) -> None:
+ """Add Search space for AutoML Image Classification and Image Classification Multilabel tasks.
+
+ :param value: specify either an instance of ImageClassificationSearchSpace or list of
+ ImageClassificationSearchSpace for searching through the parameter space
+ :type value: Union[ImageClassificationSearchSpace, List[ImageClassificationSearchSpace]]
+ """
+ self._search_space = self._search_space or []
+
+ if isinstance(value, list):
+ self._search_space.extend(
+ [
+ cast_to_specific_search_space(item, ImageClassificationSearchSpace, self.task_type) # type: ignore
+ for item in value
+ ]
+ )
+ else:
+ self._search_space.append(
+ cast_to_specific_search_space(value, ImageClassificationSearchSpace, self.task_type) # type: ignore
+ )
+
+ @classmethod
+ def _get_search_space_from_str(cls, search_space_str: str) -> Optional[List[ImageClassificationSearchSpace]]:
+ return (
+ [ImageClassificationSearchSpace._from_rest_object(entry) for entry in search_space_str if entry is not None]
+ if search_space_str is not None
+ else None
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, AutoMLImageClassificationBase):
+ return NotImplemented
+
+ if not super().__eq__(other):
+ return False
+
+ return self._training_parameters == other._training_parameters and self._search_space == other._search_space
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/automl_image_object_detection_base.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/automl_image_object_detection_base.py
new file mode 100644
index 00000000..db0c7bc6
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/automl_image_object_detection_base.py
@@ -0,0 +1,524 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from typing import Any, Dict, List, Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ LearningRateScheduler,
+ LogTrainingMetrics,
+ LogValidationLoss,
+ ModelSize,
+ StochasticOptimizer,
+ ValidationMetricType,
+)
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.entities._job.automl import SearchSpace
+from azure.ai.ml.entities._job.automl.image.automl_image import AutoMLImage
+from azure.ai.ml.entities._job.automl.image.image_limit_settings import ImageLimitSettings
+from azure.ai.ml.entities._job.automl.image.image_model_settings import ImageModelSettingsObjectDetection
+from azure.ai.ml.entities._job.automl.image.image_object_detection_search_space import ImageObjectDetectionSearchSpace
+from azure.ai.ml.entities._job.automl.image.image_sweep_settings import ImageSweepSettings
+from azure.ai.ml.entities._job.automl.utils import cast_to_specific_search_space
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
+
+
+class AutoMLImageObjectDetectionBase(AutoMLImage):
+ """Base class for AutoML Image Object Detection and Image Instance Segmentation tasks.
+
+ :keyword task_type: Type of task to run. Possible values include: "ImageObjectDetection",
+ "ImageInstanceSegmentation".
+ :paramtype task_type: str
+ :keyword limits: The resource limits for the job.
+ :paramtype limits: Optional[~azure.ai.ml.entities._job.automl.image.image_limit_settings.ImageLimitSettings]
+ :keyword sweep: The sweep settings for the job.
+ :paramtype sweep: Optional[~azure.ai.ml.entities._job.automl.image.image_sweep_settings.ImageSweepSettings]
+ :keyword training_parameters: The training parameters for the job.
+ :paramtype training_parameters: Optional[~azure.ai.ml.automl.ImageModelSettingsObjectDetection]
+ :keyword search_space: The search space for the job.
+ :paramtype search_space: Optional[List[~azure.ai.ml.automl.ImageObjectDetectionSearchSpace]]
+ """
+
+ def __init__(
+ self,
+ *,
+ task_type: str,
+ limits: Optional[ImageLimitSettings] = None,
+ sweep: Optional[ImageSweepSettings] = None,
+ training_parameters: Optional[ImageModelSettingsObjectDetection] = None,
+ search_space: Optional[List[ImageObjectDetectionSearchSpace]] = None,
+ **kwargs: Any,
+ ) -> None:
+ self._training_parameters: Optional[ImageModelSettingsObjectDetection] = None
+
+ super().__init__(task_type=task_type, limits=limits, sweep=sweep, **kwargs)
+
+ self.training_parameters = training_parameters # Assigning training_parameters through setter method.
+
+ self._search_space = search_space
+
+ @property
+ def training_parameters(self) -> Optional[ImageModelSettingsObjectDetection]:
+ return self._training_parameters
+
+ @training_parameters.setter
+ def training_parameters(self, value: Union[Dict, ImageModelSettingsObjectDetection]) -> None:
+ if value is None:
+ self._training_parameters = None
+ elif isinstance(value, ImageModelSettingsObjectDetection):
+ self._training_parameters = value
+ # set_training_parameters convert parameter values from snake case str to enum.
+ # We need to add any future enum parameters in this call to support snake case str.
+ self.set_training_parameters(
+ optimizer=value.optimizer,
+ learning_rate_scheduler=value.learning_rate_scheduler,
+ model_size=value.model_size,
+ validation_metric_type=value.validation_metric_type,
+ log_training_metrics=value.log_training_metrics,
+ log_validation_loss=value.log_validation_loss,
+ )
+ elif value is None:
+ self._training_parameters = value
+ else:
+ if not isinstance(value, dict):
+ msg = "Expected a dictionary for model settings."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.AUTOML,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ self.set_training_parameters(**value)
+
+ @property
+ def search_space(self) -> Optional[List[ImageObjectDetectionSearchSpace]]:
+ return self._search_space
+
+ @search_space.setter
+ def search_space(self, value: Union[List[Dict], List[SearchSpace]]) -> None:
+ if not isinstance(value, list):
+ msg = "Expected a list for search space."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.AUTOML,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ all_dict_type = all(isinstance(item, dict) for item in value)
+ all_search_space_type = all(isinstance(item, SearchSpace) for item in value)
+
+ if all_search_space_type or all_dict_type:
+ self._search_space = [
+ cast_to_specific_search_space(item, ImageObjectDetectionSearchSpace, self.task_type) # type: ignore
+ for item in value
+ ]
+ else:
+ msg = "Expected all items in the list to be either dictionaries or SearchSpace objects."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.AUTOML,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ # pylint: disable=too-many-locals
+ def set_training_parameters(
+ self,
+ *,
+ advanced_settings: Optional[str] = None,
+ ams_gradient: Optional[bool] = None,
+ beta1: Optional[float] = None,
+ beta2: Optional[float] = None,
+ checkpoint_frequency: Optional[int] = None,
+ checkpoint_run_id: Optional[str] = None,
+ distributed: Optional[bool] = None,
+ early_stopping: Optional[bool] = None,
+ early_stopping_delay: Optional[int] = None,
+ early_stopping_patience: Optional[int] = None,
+ enable_onnx_normalization: Optional[bool] = None,
+ evaluation_frequency: Optional[int] = None,
+ gradient_accumulation_step: Optional[int] = None,
+ layers_to_freeze: Optional[int] = None,
+ learning_rate: Optional[float] = None,
+ learning_rate_scheduler: Optional[Union[str, LearningRateScheduler]] = None,
+ model_name: Optional[str] = None,
+ momentum: Optional[float] = None,
+ nesterov: Optional[bool] = None,
+ number_of_epochs: Optional[int] = None,
+ number_of_workers: Optional[int] = None,
+ optimizer: Optional[Union[str, StochasticOptimizer]] = None,
+ random_seed: Optional[int] = None,
+ step_lr_gamma: Optional[float] = None,
+ step_lr_step_size: Optional[int] = None,
+ training_batch_size: Optional[int] = None,
+ validation_batch_size: Optional[int] = None,
+ warmup_cosine_lr_cycles: Optional[float] = None,
+ warmup_cosine_lr_warmup_epochs: Optional[int] = None,
+ weight_decay: Optional[float] = None,
+ box_detections_per_image: Optional[int] = None,
+ box_score_threshold: Optional[float] = None,
+ image_size: Optional[int] = None,
+ max_size: Optional[int] = None,
+ min_size: Optional[int] = None,
+ model_size: Optional[Union[str, ModelSize]] = None,
+ multi_scale: Optional[bool] = None,
+ nms_iou_threshold: Optional[float] = None,
+ tile_grid_size: Optional[str] = None,
+ tile_overlap_ratio: Optional[float] = None,
+ tile_predictions_nms_threshold: Optional[float] = None,
+ validation_iou_threshold: Optional[float] = None,
+ validation_metric_type: Optional[Union[str, ValidationMetricType]] = None,
+ log_training_metrics: Optional[Union[str, LogTrainingMetrics]] = None,
+ log_validation_loss: Optional[Union[str, LogValidationLoss]] = None,
+ ) -> None:
+ """Setting Image training parameters for for AutoML Image Object Detection and Image Instance Segmentation
+ tasks.
+
+ :keyword advanced_settings: Settings for advanced scenarios.
+ :paramtype advanced_settings: str
+ :keyword ams_gradient: Enable AMSGrad when optimizer is 'adam' or 'adamw'.
+ :paramtype ams_gradient: bool
+ :keyword beta1: Value of 'beta1' when optimizer is 'adam' or 'adamw'. Must be a float in the
+ range [0, 1].
+ :paramtype beta1: float
+ :keyword beta2: Value of 'beta2' when optimizer is 'adam' or 'adamw'. Must be a float in the
+ range [0, 1].
+ :paramtype beta2: float
+ :keyword checkpoint_frequency: Frequency to store model checkpoints. Must be a positive
+ integer.
+ :paramtype checkpoint_frequency: int
+ :keyword checkpoint_run_id: The id of a previous run that has a pretrained checkpoint for
+ incremental training.
+ :paramtype checkpoint_run_id: str
+ :keyword distributed: Whether to use distributed training.
+ :paramtype distributed: bool
+ :keyword early_stopping: Enable early stopping logic during training.
+ :paramtype early_stopping: bool
+ :keyword early_stopping_delay: Minimum number of epochs or validation evaluations to wait
+ before primary metric improvement
+ is tracked for early stopping. Must be a positive integer.
+ :paramtype early_stopping_delay: int
+ :keyword early_stopping_patience: Minimum number of epochs or validation evaluations with no
+ primary metric improvement before
+ the run is stopped. Must be a positive integer.
+ :paramtype early_stopping_patience: int
+ :keyword enable_onnx_normalization: Enable normalization when exporting ONNX model.
+ :paramtype enable_onnx_normalization: bool
+ :keyword evaluation_frequency: Frequency to evaluate validation dataset to get metric scores.
+ Must be a positive integer.
+ :paramtype evaluation_frequency: int
+ :keyword gradient_accumulation_step: Gradient accumulation means running a configured number of
+ "GradAccumulationStep" steps without
+ updating the model weights while accumulating the gradients of those steps, and then using
+ the accumulated gradients to compute the weight updates. Must be a positive integer.
+ :paramtype gradient_accumulation_step: int
+ :keyword layers_to_freeze: Number of layers to freeze for the model. Must be a positive
+ integer.
+ For instance, passing 2 as value for 'seresnext' means
+ freezing layer0 and layer1. For a full list of models supported and details on layer freeze,
+ please
+ see: https://learn.microsoft.com/azure/machine-learning/reference-automl-images-hyperparameters#model-agnostic-hyperparameters. # pylint: disable=line-too-long
+ :type layers_to_freeze: int
+ :keyword learning_rate: Initial learning rate. Must be a float in the range [0, 1].
+ :paramtype learning_rate: float
+ :keyword learning_rate_scheduler: Type of learning rate scheduler. Must be 'warmup_cosine' or
+ 'step'. Possible values include: "None", "WarmupCosine", "Step".
+ :type learning_rate_scheduler: str or
+ ~azure.mgmt.machinelearningservices.models.LearningRateScheduler
+ :keyword model_name: Name of the model to use for training.
+ For more information on the available models please visit the official documentation:
+ https://learn.microsoft.com/azure/machine-learning/how-to-auto-train-image-models.
+ :type model_name: str
+ :keyword momentum: Value of momentum when optimizer is 'sgd'. Must be a float in the range [0,
+ 1].
+ :paramtype momentum: float
+ :keyword nesterov: Enable nesterov when optimizer is 'sgd'.
+ :paramtype nesterov: bool
+ :keyword number_of_epochs: Number of training epochs. Must be a positive integer.
+ :paramtype number_of_epochs: int
+ :keyword number_of_workers: Number of data loader workers. Must be a non-negative integer.
+ :paramtype number_of_workers: int
+ :keyword optimizer: Type of optimizer. Possible values include: "None", "Sgd", "Adam", "Adamw".
+ :type optimizer: str or ~azure.mgmt.machinelearningservices.models.StochasticOptimizer
+ :keyword random_seed: Random seed to be used when using deterministic training.
+ :paramtype random_seed: int
+ :keyword step_lr_gamma: Value of gamma when learning rate scheduler is 'step'. Must be a float
+ in the range [0, 1].
+ :paramtype step_lr_gamma: float
+ :keyword step_lr_step_size: Value of step size when learning rate scheduler is 'step'. Must be
+ a positive integer.
+ :paramtype step_lr_step_size: int
+ :keyword training_batch_size: Training batch size. Must be a positive integer.
+ :paramtype training_batch_size: int
+ :keyword validation_batch_size: Validation batch size. Must be a positive integer.
+ :paramtype validation_batch_size: int
+ :keyword warmup_cosine_lr_cycles: Value of cosine cycle when learning rate scheduler is
+ 'warmup_cosine'. Must be a float in the range [0, 1].
+ :paramtype warmup_cosine_lr_cycles: float
+ :keyword warmup_cosine_lr_warmup_epochs: Value of warmup epochs when learning rate scheduler is
+ 'warmup_cosine'. Must be a positive integer.
+ :paramtype warmup_cosine_lr_warmup_epochs: int
+ :keyword weight_decay: Value of weight decay when optimizer is 'sgd', 'adam', or 'adamw'. Must
+ be a float in the range[0, 1].
+ :paramtype weight_decay: float
+ :keyword box_detections_per_image: Maximum number of detections per image, for all classes.
+ Must be a positive integer.
+ Note: This settings is not supported for the 'yolov5' algorithm.
+ :type box_detections_per_image: int
+ :keyword box_score_threshold: During inference, only return proposals with a classification
+ score greater than
+ BoxScoreThreshold. Must be a float in the range[0, 1].
+ :paramtype box_score_threshold: float
+ :keyword image_size: Image size for training and validation. Must be a positive integer.
+ Note: The training run may get into CUDA OOM if the size is too big.
+ Note: This settings is only supported for the 'yolov5' algorithm.
+ :type image_size: int
+ :keyword max_size: Maximum size of the image to be rescaled before feeding it to the backbone.
+ Must be a positive integer. Note: training run may get into CUDA OOM if the size is too big.
+ Note: This settings is not supported for the 'yolov5' algorithm.
+ :type max_size: int
+ :keyword min_size: Minimum size of the image to be rescaled before feeding it to the backbone.
+ Must be a positive integer. Note: training run may get into CUDA OOM if the size is too big.
+ Note: This settings is not supported for the 'yolov5' algorithm.
+ :type min_size: int
+ :keyword model_size: Model size. Must be 'small', 'medium', 'large', or 'extra_large'.
+ Note: training run may get into CUDA OOM if the model size is too big.
+ Note: This settings is only supported for the 'yolov5' algorithm.
+ :type model_size: str or ~azure.mgmt.machinelearningservices.models.ModelSize
+ :keyword multi_scale: Enable multi-scale image by varying image size by +/- 50%.
+ Note: training run may get into CUDA OOM if no sufficient GPU memory.
+ Note: This settings is only supported for the 'yolov5' algorithm.
+ :type multi_scale: bool
+ :keyword nms_iou_threshold: IOU threshold used during inference in NMS post processing. Must be
+ float in the range [0, 1].
+ :paramtype nms_iou_threshold: float
+ :keyword tile_grid_size: The grid size to use for tiling each image. Note: TileGridSize must
+ not be
+ None to enable small object detection logic. A string containing two integers in mxn format.
+ :type tile_grid_size: str
+ :keyword tile_overlap_ratio: Overlap ratio between adjacent tiles in each dimension. Must be
+ float in the range [0, 1).
+ :paramtype tile_overlap_ratio: float
+ :keyword tile_predictions_nms_threshold: The IOU threshold to use to perform NMS while merging
+ predictions from tiles and image.
+ Used in validation/ inference. Must be float in the range [0, 1].
+ NMS: Non-maximum suppression.
+ :type tile_predictions_nms_threshold: str
+ :keyword validation_iou_threshold: IOU threshold to use when computing validation metric. Must
+ be float in the range [0, 1].
+ :paramtype validation_iou_threshold: float
+ :keyword validation_metric_type: Metric computation method to use for validation metrics. Must
+ be 'none', 'coco', 'voc', or 'coco_voc'.
+ :paramtype validation_metric_type: str or ~azure.mgmt.machinelearningservices.models.ValidationMetricType
+ :keyword log_training_metrics: indicates whether or not to log training metrics. Must
+ be 'Enable' or 'Disable'
+ :paramtype log_training_metrics: str or ~azure.mgmt.machinelearningservices.models.LogTrainingMetrics
+ :keyword log_validation_loss: indicates whether or not to log validation loss. Must
+ be 'Enable' or 'Disable'
+ :paramtype log_validation_loss: str or ~azure.mgmt.machinelearningservices.models.LogValidationLoss
+ """
+ self._training_parameters = self._training_parameters or ImageModelSettingsObjectDetection()
+
+ self._training_parameters.advanced_settings = (
+ advanced_settings if advanced_settings is not None else self._training_parameters.advanced_settings
+ )
+ self._training_parameters.ams_gradient = (
+ ams_gradient if ams_gradient is not None else self._training_parameters.ams_gradient
+ )
+ self._training_parameters.beta1 = beta1 if beta1 is not None else self._training_parameters.beta1
+ self._training_parameters.beta2 = beta2 if beta2 is not None else self._training_parameters.beta2
+ self._training_parameters.checkpoint_frequency = (
+ checkpoint_frequency if checkpoint_frequency is not None else self._training_parameters.checkpoint_frequency
+ )
+ self._training_parameters.checkpoint_run_id = (
+ checkpoint_run_id if checkpoint_run_id is not None else self._training_parameters.checkpoint_run_id
+ )
+ self._training_parameters.distributed = (
+ distributed if distributed is not None else self._training_parameters.distributed
+ )
+ self._training_parameters.early_stopping = (
+ early_stopping if early_stopping is not None else self._training_parameters.early_stopping
+ )
+ self._training_parameters.early_stopping_delay = (
+ early_stopping_delay if early_stopping_delay is not None else self._training_parameters.early_stopping_delay
+ )
+ self._training_parameters.early_stopping_patience = (
+ early_stopping_patience
+ if early_stopping_patience is not None
+ else self._training_parameters.early_stopping_patience
+ )
+ self._training_parameters.enable_onnx_normalization = (
+ enable_onnx_normalization
+ if enable_onnx_normalization is not None
+ else self._training_parameters.enable_onnx_normalization
+ )
+ self._training_parameters.evaluation_frequency = (
+ evaluation_frequency if evaluation_frequency is not None else self._training_parameters.evaluation_frequency
+ )
+ self._training_parameters.gradient_accumulation_step = (
+ gradient_accumulation_step
+ if gradient_accumulation_step is not None
+ else self._training_parameters.gradient_accumulation_step
+ )
+ self._training_parameters.layers_to_freeze = (
+ layers_to_freeze if layers_to_freeze is not None else self._training_parameters.layers_to_freeze
+ )
+ self._training_parameters.learning_rate = (
+ learning_rate if learning_rate is not None else self._training_parameters.learning_rate
+ )
+ self._training_parameters.learning_rate_scheduler = (
+ LearningRateScheduler[camel_to_snake(learning_rate_scheduler)]
+ if learning_rate_scheduler is not None
+ else self._training_parameters.learning_rate_scheduler
+ )
+ self._training_parameters.model_name = (
+ model_name if model_name is not None else self._training_parameters.model_name
+ )
+ self._training_parameters.momentum = momentum if momentum is not None else self._training_parameters.momentum
+ self._training_parameters.nesterov = nesterov if nesterov is not None else self._training_parameters.nesterov
+ self._training_parameters.number_of_epochs = (
+ number_of_epochs if number_of_epochs is not None else self._training_parameters.number_of_epochs
+ )
+ self._training_parameters.number_of_workers = (
+ number_of_workers if number_of_workers is not None else self._training_parameters.number_of_workers
+ )
+ self._training_parameters.optimizer = (
+ StochasticOptimizer[camel_to_snake(optimizer)]
+ if optimizer is not None
+ else self._training_parameters.optimizer
+ )
+ self._training_parameters.random_seed = (
+ random_seed if random_seed is not None else self._training_parameters.random_seed
+ )
+ self._training_parameters.step_lr_gamma = (
+ step_lr_gamma if step_lr_gamma is not None else self._training_parameters.step_lr_gamma
+ )
+ self._training_parameters.step_lr_step_size = (
+ step_lr_step_size if step_lr_step_size is not None else self._training_parameters.step_lr_step_size
+ )
+ self._training_parameters.training_batch_size = (
+ training_batch_size if training_batch_size is not None else self._training_parameters.training_batch_size
+ )
+ self._training_parameters.validation_batch_size = (
+ validation_batch_size
+ if validation_batch_size is not None
+ else self._training_parameters.validation_batch_size
+ )
+ self._training_parameters.warmup_cosine_lr_cycles = (
+ warmup_cosine_lr_cycles
+ if warmup_cosine_lr_cycles is not None
+ else self._training_parameters.warmup_cosine_lr_cycles
+ )
+ self._training_parameters.warmup_cosine_lr_warmup_epochs = (
+ warmup_cosine_lr_warmup_epochs
+ if warmup_cosine_lr_warmup_epochs is not None
+ else self._training_parameters.warmup_cosine_lr_warmup_epochs
+ )
+ self._training_parameters.weight_decay = (
+ weight_decay if weight_decay is not None else self._training_parameters.weight_decay
+ )
+ self._training_parameters.box_detections_per_image = (
+ box_detections_per_image
+ if box_detections_per_image is not None
+ else self._training_parameters.box_detections_per_image
+ )
+ self._training_parameters.box_score_threshold = (
+ box_score_threshold if box_score_threshold is not None else self._training_parameters.box_score_threshold
+ )
+ self._training_parameters.image_size = (
+ image_size if image_size is not None else self._training_parameters.image_size
+ )
+ self._training_parameters.max_size = max_size if max_size is not None else self._training_parameters.max_size
+ self._training_parameters.min_size = min_size if min_size is not None else self._training_parameters.min_size
+ self._training_parameters.model_size = (
+ ModelSize[camel_to_snake(model_size)] if model_size is not None else self._training_parameters.model_size
+ )
+ self._training_parameters.multi_scale = (
+ multi_scale if multi_scale is not None else self._training_parameters.multi_scale
+ )
+ self._training_parameters.nms_iou_threshold = (
+ nms_iou_threshold if nms_iou_threshold is not None else self._training_parameters.nms_iou_threshold
+ )
+ self._training_parameters.tile_grid_size = (
+ tile_grid_size if tile_grid_size is not None else self._training_parameters.tile_grid_size
+ )
+ self._training_parameters.tile_overlap_ratio = (
+ tile_overlap_ratio if tile_overlap_ratio is not None else self._training_parameters.tile_overlap_ratio
+ )
+ self._training_parameters.tile_predictions_nms_threshold = (
+ tile_predictions_nms_threshold
+ if tile_predictions_nms_threshold is not None
+ else self._training_parameters.tile_predictions_nms_threshold
+ )
+ self._training_parameters.validation_iou_threshold = (
+ validation_iou_threshold
+ if validation_iou_threshold is not None
+ else self._training_parameters.validation_iou_threshold
+ )
+ self._training_parameters.validation_metric_type = (
+ ValidationMetricType[camel_to_snake(validation_metric_type)]
+ if validation_metric_type is not None
+ else self._training_parameters.validation_metric_type
+ )
+ self._training_parameters.log_training_metrics = (
+ LogTrainingMetrics[camel_to_snake(log_training_metrics)]
+ if log_training_metrics is not None
+ else self._training_parameters.log_training_metrics
+ )
+ self._training_parameters.log_validation_loss = (
+ LogValidationLoss[camel_to_snake(log_validation_loss)]
+ if log_validation_loss is not None
+ else self._training_parameters.log_validation_loss
+ )
+
+ # pylint: enable=too-many-locals
+
+ def extend_search_space(
+ self,
+ value: Union[SearchSpace, List[SearchSpace]],
+ ) -> None:
+ """Add search space for AutoML Image Object Detection and Image Instance Segmentation tasks.
+
+ :param value: Search through the parameter space
+ :type value: Union[SearchSpace, List[SearchSpace]]
+ """
+ self._search_space = self._search_space or []
+
+ if isinstance(value, list):
+ self._search_space.extend(
+ [
+ cast_to_specific_search_space(item, ImageObjectDetectionSearchSpace, self.task_type) # type: ignore
+ for item in value
+ ]
+ )
+ else:
+ self._search_space.append(
+ cast_to_specific_search_space(value, ImageObjectDetectionSearchSpace, self.task_type) # type: ignore
+ )
+
+ @classmethod
+ def _get_search_space_from_str(cls, search_space_str: str) -> Optional[List[ImageObjectDetectionSearchSpace]]:
+ return (
+ [
+ ImageObjectDetectionSearchSpace._from_rest_object(entry)
+ for entry in search_space_str
+ if entry is not None
+ ]
+ if search_space_str is not None
+ else None
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, AutoMLImageObjectDetectionBase):
+ return NotImplemented
+
+ if not super().__eq__(other):
+ return False
+
+ return self._training_parameters == other._training_parameters and self._search_space == other._search_space
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_classification_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_classification_job.py
new file mode 100644
index 00000000..a1b9dbc3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_classification_job.py
@@ -0,0 +1,244 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from typing import Any, Dict, Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import AutoMLJob as RestAutoMLJob
+from azure.ai.ml._restclient.v2023_04_01_preview.models import ClassificationPrimaryMetrics
+from azure.ai.ml._restclient.v2023_04_01_preview.models import ImageClassification as RestImageClassification
+from azure.ai.ml._restclient.v2023_04_01_preview.models import JobBase, TaskType
+from azure.ai.ml._utils.utils import camel_to_snake, is_data_binding_expression
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+from azure.ai.ml.constants._job.automl import AutoMLConstants
+from azure.ai.ml.entities._credentials import _BaseJobIdentityConfiguration
+from azure.ai.ml.entities._job._input_output_helpers import from_rest_data_outputs, to_rest_data_outputs
+from azure.ai.ml.entities._job.automl.image.automl_image_classification_base import AutoMLImageClassificationBase
+from azure.ai.ml.entities._job.automl.image.image_limit_settings import ImageLimitSettings
+from azure.ai.ml.entities._job.automl.image.image_model_settings import ImageModelSettingsClassification
+from azure.ai.ml.entities._job.automl.image.image_sweep_settings import ImageSweepSettings
+from azure.ai.ml.entities._util import load_from_dict
+
+
+class ImageClassificationJob(AutoMLImageClassificationBase):
+ """Configuration for AutoML multi-class Image Classification job.
+
+ :param primary_metric: The primary metric to use for optimization.
+ :type primary_metric: Optional[str, ~azure.ai.ml.automl.ClassificationMultilabelPrimaryMetrics]
+ :param kwargs: Job-specific arguments.
+ :type kwargs: Dict[str, Any]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_automl_image.py
+ :start-after: [START automl.automl_image_job.image_classification_job]
+ :end-before: [END automl.automl_image_job.image_classification_job]
+ :language: python
+ :dedent: 8
+ :caption: creating an automl image classification job
+ """
+
+ _DEFAULT_PRIMARY_METRIC = ClassificationPrimaryMetrics.ACCURACY
+
+ def __init__(
+ self,
+ *,
+ primary_metric: Optional[Union[str, ClassificationPrimaryMetrics]] = None,
+ **kwargs: Any,
+ ) -> None:
+
+ # Extract any super class init settings
+ limits = kwargs.pop("limits", None)
+ sweep = kwargs.pop("sweep", None)
+ training_parameters = kwargs.pop("training_parameters", None)
+ search_space = kwargs.pop("search_space", None)
+
+ super().__init__(
+ task_type=TaskType.IMAGE_CLASSIFICATION,
+ limits=limits,
+ sweep=sweep,
+ training_parameters=training_parameters,
+ search_space=search_space,
+ **kwargs,
+ )
+
+ self.primary_metric = primary_metric or ImageClassificationJob._DEFAULT_PRIMARY_METRIC
+
+ @property
+ def primary_metric(self) -> Optional[Union[str, ClassificationPrimaryMetrics]]:
+ return self._primary_metric
+
+ @primary_metric.setter
+ def primary_metric(self, value: Union[str, ClassificationPrimaryMetrics]) -> None:
+ if is_data_binding_expression(str(value), ["parent"]):
+ self._primary_metric = value
+ return
+ self._primary_metric = (
+ ImageClassificationJob._DEFAULT_PRIMARY_METRIC
+ if value is None
+ else ClassificationPrimaryMetrics[camel_to_snake(value).upper()]
+ )
+
+ def _to_rest_object(self) -> JobBase:
+ image_classification_task = RestImageClassification(
+ target_column_name=self.target_column_name,
+ training_data=self.training_data,
+ validation_data=self.validation_data,
+ validation_data_size=self.validation_data_size,
+ limit_settings=self._limits._to_rest_object() if self._limits else None,
+ sweep_settings=self._sweep._to_rest_object() if self._sweep else None,
+ model_settings=self._training_parameters._to_rest_object() if self._training_parameters else None,
+ search_space=(
+ [entry._to_rest_object() for entry in self._search_space if entry is not None]
+ if self._search_space is not None
+ else None
+ ),
+ primary_metric=self.primary_metric,
+ log_verbosity=self.log_verbosity,
+ )
+ # resolve data inputs in rest obj
+ self._resolve_data_inputs(image_classification_task)
+
+ properties = RestAutoMLJob(
+ display_name=self.display_name,
+ description=self.description,
+ experiment_name=self.experiment_name,
+ tags=self.tags,
+ compute_id=self.compute,
+ properties=self.properties,
+ environment_id=self.environment_id,
+ environment_variables=self.environment_variables,
+ services=self.services,
+ outputs=to_rest_data_outputs(self.outputs),
+ resources=self.resources,
+ task_details=image_classification_task,
+ identity=self.identity._to_job_rest_object() if self.identity else None,
+ queue_settings=self.queue_settings,
+ )
+
+ result = JobBase(properties=properties)
+ result.name = self.name
+ return result
+
+ @classmethod
+ def _from_rest_object(cls, obj: JobBase) -> "ImageClassificationJob":
+ properties: RestAutoMLJob = obj.properties
+ task_details: RestImageClassification = properties.task_details
+
+ job_args_dict = {
+ "id": obj.id,
+ "name": obj.name,
+ "description": properties.description,
+ "tags": properties.tags,
+ "properties": properties.properties,
+ "experiment_name": properties.experiment_name,
+ "services": properties.services,
+ "status": properties.status,
+ "creation_context": obj.system_data,
+ "display_name": properties.display_name,
+ "compute": properties.compute_id,
+ "outputs": from_rest_data_outputs(properties.outputs),
+ "resources": properties.resources,
+ "identity": (
+ _BaseJobIdentityConfiguration._from_rest_object(properties.identity) if properties.identity else None
+ ),
+ "queue_settings": properties.queue_settings,
+ }
+
+ image_classification_job = cls(
+ target_column_name=task_details.target_column_name,
+ training_data=task_details.training_data,
+ validation_data=task_details.validation_data,
+ validation_data_size=task_details.validation_data_size,
+ limits=(
+ ImageLimitSettings._from_rest_object(task_details.limit_settings)
+ if task_details.limit_settings
+ else None
+ ),
+ sweep=(
+ ImageSweepSettings._from_rest_object(task_details.sweep_settings)
+ if task_details.sweep_settings
+ else None
+ ),
+ training_parameters=(
+ ImageModelSettingsClassification._from_rest_object(task_details.model_settings)
+ if task_details.model_settings
+ else None
+ ),
+ search_space=cls._get_search_space_from_str(task_details.search_space),
+ primary_metric=task_details.primary_metric,
+ log_verbosity=task_details.log_verbosity,
+ **job_args_dict,
+ )
+
+ image_classification_job._restore_data_inputs()
+
+ return image_classification_job
+
+ @classmethod
+ def _load_from_dict(
+ cls,
+ data: Dict,
+ context: Dict,
+ additional_message: str,
+ **kwargs: Any,
+ ) -> "ImageClassificationJob":
+ from azure.ai.ml._schema.automl.image_vertical.image_classification import ImageClassificationSchema
+ from azure.ai.ml._schema.pipeline.automl_node import ImageClassificationMulticlassNodeSchema
+
+ inside_pipeline = kwargs.pop("inside_pipeline", False)
+ if inside_pipeline:
+ if context.get("inside_pipeline", None) is None:
+ context["inside_pipeline"] = True
+ loaded_data = load_from_dict(
+ ImageClassificationMulticlassNodeSchema,
+ data,
+ context,
+ additional_message,
+ **kwargs,
+ )
+ else:
+ loaded_data = load_from_dict(ImageClassificationSchema, data, context, additional_message, **kwargs)
+ job_instance = cls._create_instance_from_schema_dict(loaded_data)
+ return job_instance
+
+ @classmethod
+ def _create_instance_from_schema_dict(cls, loaded_data: Dict) -> "ImageClassificationJob":
+ loaded_data.pop(AutoMLConstants.TASK_TYPE_YAML, None)
+ data_settings = {
+ "training_data": loaded_data.pop("training_data"),
+ "target_column_name": loaded_data.pop("target_column_name"),
+ "validation_data": loaded_data.pop("validation_data", None),
+ "validation_data_size": loaded_data.pop("validation_data_size", None),
+ }
+ job = ImageClassificationJob(**loaded_data)
+ job.set_data(**data_settings)
+ return job
+
+ def _to_dict(self, inside_pipeline: bool = False) -> Dict:
+ from azure.ai.ml._schema.automl.image_vertical.image_classification import ImageClassificationSchema
+ from azure.ai.ml._schema.pipeline.automl_node import ImageClassificationMulticlassNodeSchema
+
+ schema_dict: dict = {}
+ if inside_pipeline:
+ schema_dict = ImageClassificationMulticlassNodeSchema(
+ context={BASE_PATH_CONTEXT_KEY: "./", "inside_pipeline": True}
+ ).dump(self)
+ else:
+ schema_dict = ImageClassificationSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+
+ return schema_dict
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, ImageClassificationJob):
+ return NotImplemented
+
+ if not super().__eq__(other):
+ return False
+
+ return self.primary_metric == other.primary_metric
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_classification_multilabel_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_classification_multilabel_job.py
new file mode 100644
index 00000000..541f41c7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_classification_multilabel_job.py
@@ -0,0 +1,252 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from typing import Any, Dict, Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import AutoMLJob as RestAutoMLJob
+from azure.ai.ml._restclient.v2023_04_01_preview.models import ClassificationMultilabelPrimaryMetrics
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ ImageClassificationMultilabel as RestImageClassificationMultilabel,
+)
+from azure.ai.ml._restclient.v2023_04_01_preview.models import JobBase, TaskType
+from azure.ai.ml._utils.utils import camel_to_snake, is_data_binding_expression
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+from azure.ai.ml.constants._job.automl import AutoMLConstants
+from azure.ai.ml.entities._credentials import _BaseJobIdentityConfiguration
+from azure.ai.ml.entities._job._input_output_helpers import from_rest_data_outputs, to_rest_data_outputs
+from azure.ai.ml.entities._job.automl.image.automl_image_classification_base import AutoMLImageClassificationBase
+from azure.ai.ml.entities._job.automl.image.image_limit_settings import ImageLimitSettings
+from azure.ai.ml.entities._job.automl.image.image_model_settings import ImageModelSettingsClassification
+from azure.ai.ml.entities._job.automl.image.image_sweep_settings import ImageSweepSettings
+from azure.ai.ml.entities._util import load_from_dict
+
+
+class ImageClassificationMultilabelJob(AutoMLImageClassificationBase):
+ """Configuration for AutoML multi-label Image Classification job.
+
+ :param primary_metric: The primary metric to use for optimization.
+ :type primary_metric: Optional[str, ~azure.ai.ml.automl.ClassificationMultilabelPrimaryMetrics]
+ :param kwargs: Job-specific arguments.
+ :type kwargs: Dict[str, Any]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_automl_image.py
+ :start-after: [START automl.automl_image_job.image_classification_multilabel_job]
+ :end-before: [END automl.automl_image_job.image_classification_multilabel_job]
+ :language: python
+ :dedent: 8
+ :caption: creating an automl image classification multilabel job
+ """
+
+ _DEFAULT_PRIMARY_METRIC = ClassificationMultilabelPrimaryMetrics.IOU
+
+ def __init__(
+ self,
+ *,
+ primary_metric: Optional[Union[str, ClassificationMultilabelPrimaryMetrics]] = None,
+ **kwargs: Any,
+ ) -> None:
+
+ # Extract any super class init settings
+ limits = kwargs.pop("limits", None)
+ sweep = kwargs.pop("sweep", None)
+ training_parameters = kwargs.pop("training_parameters", None)
+ search_space = kwargs.pop("search_space", None)
+
+ super().__init__(
+ task_type=TaskType.IMAGE_CLASSIFICATION_MULTILABEL,
+ limits=limits,
+ sweep=sweep,
+ training_parameters=training_parameters,
+ search_space=search_space,
+ **kwargs,
+ )
+
+ self.primary_metric = primary_metric or ImageClassificationMultilabelJob._DEFAULT_PRIMARY_METRIC
+
+ @property
+ def primary_metric(self) -> Union[str, ClassificationMultilabelPrimaryMetrics]:
+ return self._primary_metric
+
+ @primary_metric.setter
+ def primary_metric(self, value: Union[str, ClassificationMultilabelPrimaryMetrics]) -> None:
+ if is_data_binding_expression(str(value), ["parent"]):
+ self._primary_metric = value
+ return
+ self._primary_metric = (
+ ImageClassificationMultilabelJob._DEFAULT_PRIMARY_METRIC
+ if value is None
+ else ClassificationMultilabelPrimaryMetrics[camel_to_snake(value).upper()]
+ )
+
+ def _to_rest_object(self) -> JobBase:
+ image_classification_multilabel_task = RestImageClassificationMultilabel(
+ target_column_name=self.target_column_name,
+ training_data=self.training_data,
+ validation_data=self.validation_data,
+ validation_data_size=self.validation_data_size,
+ limit_settings=self._limits._to_rest_object() if self._limits else None,
+ sweep_settings=self._sweep._to_rest_object() if self._sweep else None,
+ model_settings=self._training_parameters._to_rest_object() if self._training_parameters else None,
+ search_space=(
+ [entry._to_rest_object() for entry in self._search_space if entry is not None]
+ if self._search_space is not None
+ else None
+ ),
+ primary_metric=self.primary_metric,
+ log_verbosity=self.log_verbosity,
+ )
+ # resolve data inputs in rest obj
+ self._resolve_data_inputs(image_classification_multilabel_task)
+
+ properties = RestAutoMLJob(
+ display_name=self.display_name,
+ description=self.description,
+ experiment_name=self.experiment_name,
+ tags=self.tags,
+ compute_id=self.compute,
+ properties=self.properties,
+ environment_id=self.environment_id,
+ environment_variables=self.environment_variables,
+ services=self.services,
+ outputs=to_rest_data_outputs(self.outputs),
+ resources=self.resources,
+ task_details=image_classification_multilabel_task,
+ identity=self.identity._to_job_rest_object() if self.identity else None,
+ queue_settings=self.queue_settings,
+ )
+
+ result = JobBase(properties=properties)
+ result.name = self.name
+ return result
+
+ @classmethod
+ def _from_rest_object(cls, obj: JobBase) -> "ImageClassificationMultilabelJob":
+ properties: RestAutoMLJob = obj.properties
+ task_details: RestImageClassificationMultilabel = properties.task_details
+
+ job_args_dict = {
+ "id": obj.id,
+ "name": obj.name,
+ "description": properties.description,
+ "tags": properties.tags,
+ "properties": properties.properties,
+ "experiment_name": properties.experiment_name,
+ "services": properties.services,
+ "status": properties.status,
+ "creation_context": obj.system_data,
+ "display_name": properties.display_name,
+ "compute": properties.compute_id,
+ "outputs": from_rest_data_outputs(properties.outputs),
+ "resources": properties.resources,
+ "identity": (
+ _BaseJobIdentityConfiguration._from_rest_object(properties.identity) if properties.identity else None
+ ),
+ "queue_settings": properties.queue_settings,
+ }
+
+ image_classification_multilabel_job = cls(
+ target_column_name=task_details.target_column_name,
+ training_data=task_details.training_data,
+ validation_data=task_details.validation_data,
+ validation_data_size=task_details.validation_data_size,
+ limits=(
+ ImageLimitSettings._from_rest_object(task_details.limit_settings)
+ if task_details.limit_settings
+ else None
+ ),
+ sweep=(
+ ImageSweepSettings._from_rest_object(task_details.sweep_settings)
+ if task_details.sweep_settings
+ else None
+ ),
+ training_parameters=(
+ ImageModelSettingsClassification._from_rest_object(task_details.model_settings)
+ if task_details.model_settings
+ else None
+ ),
+ search_space=cls._get_search_space_from_str(task_details.search_space),
+ primary_metric=task_details.primary_metric,
+ log_verbosity=task_details.log_verbosity,
+ **job_args_dict,
+ )
+
+ image_classification_multilabel_job._restore_data_inputs()
+
+ return image_classification_multilabel_job
+
+ @classmethod
+ def _load_from_dict(
+ cls,
+ data: Dict,
+ context: Dict,
+ additional_message: str,
+ **kwargs: Any,
+ ) -> "ImageClassificationMultilabelJob":
+ from azure.ai.ml._schema.automl.image_vertical.image_classification import ImageClassificationMultilabelSchema
+ from azure.ai.ml._schema.pipeline.automl_node import ImageClassificationMultilabelNodeSchema
+
+ inside_pipeline = kwargs.pop("inside_pipeline", False)
+ if inside_pipeline:
+ if context.get("inside_pipeline", None) is None:
+ context["inside_pipeline"] = True
+ loaded_data = load_from_dict(
+ ImageClassificationMultilabelNodeSchema,
+ data,
+ context,
+ additional_message,
+ **kwargs,
+ )
+ else:
+ loaded_data = load_from_dict(
+ ImageClassificationMultilabelSchema,
+ data,
+ context,
+ additional_message,
+ **kwargs,
+ )
+ job_instance = cls._create_instance_from_schema_dict(loaded_data)
+ return job_instance
+
+ @classmethod
+ def _create_instance_from_schema_dict(cls, loaded_data: Dict) -> "ImageClassificationMultilabelJob":
+ loaded_data.pop(AutoMLConstants.TASK_TYPE_YAML, None)
+ data_settings = {
+ "training_data": loaded_data.pop("training_data"),
+ "target_column_name": loaded_data.pop("target_column_name"),
+ "validation_data": loaded_data.pop("validation_data", None),
+ "validation_data_size": loaded_data.pop("validation_data_size", None),
+ }
+ job = ImageClassificationMultilabelJob(**loaded_data)
+ job.set_data(**data_settings)
+ return job
+
+ def _to_dict(self, inside_pipeline: bool = False) -> Dict:
+ from azure.ai.ml._schema.automl.image_vertical.image_classification import ImageClassificationMultilabelSchema
+ from azure.ai.ml._schema.pipeline.automl_node import ImageClassificationMultilabelNodeSchema
+
+ schema_dict: dict = {}
+ if inside_pipeline:
+ schema_dict = ImageClassificationMultilabelNodeSchema(
+ context={BASE_PATH_CONTEXT_KEY: "./", "inside_pipeline": True}
+ ).dump(self)
+ else:
+ schema_dict = ImageClassificationMultilabelSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+
+ return schema_dict
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, ImageClassificationMultilabelJob):
+ return NotImplemented
+
+ if not super().__eq__(other):
+ return False
+
+ return self.primary_metric == other.primary_metric
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_classification_search_space.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_classification_search_space.py
new file mode 100644
index 00000000..0691f243
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_classification_search_space.py
@@ -0,0 +1,437 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=R0902,too-many-locals
+
+from typing import Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import ImageModelDistributionSettingsClassification
+from azure.ai.ml.entities._job.automl.search_space import SearchSpace
+from azure.ai.ml.entities._job.automl.search_space_utils import _convert_from_rest_object, _convert_to_rest_object
+from azure.ai.ml.entities._job.sweep.search_space import SweepDistribution
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+
+class ImageClassificationSearchSpace(RestTranslatableMixin):
+ """Search space for AutoML Image Classification and Image Classification
+ Multilabel tasks.
+
+ :param ams_gradient: Enable AMSGrad when optimizer is 'adam' or 'adamw'.
+ :type ams_gradient: bool or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution
+ :param beta1: Value of 'beta1' when optimizer is 'adam' or 'adamw'. Must be a float in the
+ range [0, 1].
+ :type beta1: float or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution
+ :param beta2: Value of 'beta2' when optimizer is 'adam' or 'adamw'. Must be a float in the
+ range [0, 1].
+ :type beta2: float or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution
+ :param distributed: Whether to use distributer training.
+ :type distributed: bool or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution
+ :param early_stopping: Enable early stopping logic during training.
+ :type early_stopping: bool or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution
+ :param early_stopping_delay: Minimum number of epochs or validation evaluations to wait
+ before primary metric improvement
+ is tracked for early stopping. Must be a positive integer.
+ :type early_stopping_delay: int or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution
+ :param early_stopping_patience: Minimum number of epochs or validation evaluations with no
+ primary metric improvement before
+ the run is stopped. Must be a positive integer.
+ :type early_stopping_patience: int or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution
+ :param enable_onnx_normalization: Enable normalization when exporting ONNX model.
+ :type enable_onnx_normalization: bool or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution
+ :param evaluation_frequency: Frequency to evaluate validation dataset to get metric scores.
+ Must be a positive integer.
+ :type evaluation_frequency: int or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution
+ :param gradient_accumulation_step: Gradient accumulation means running a configured number of
+ "GradAccumulationStep" steps without
+ updating the model weights while accumulating the gradients of those steps, and then using
+ the accumulated gradients to compute the weight updates. Must be a positive integer.
+ :type gradient_accumulation_step: int or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution
+ :param layers_to_freeze: Number of layers to freeze for the model. Must be a positive
+ integer.
+ For instance, passing 2 as value for 'seresnext' means
+ freezing layer0 and layer1. For a full list of models supported and details on layer freeze,
+ please
+ see: https://learn.microsoft.com/azure/machine-learning/reference-automl-images-hyperparameters#model-agnostic-hyperparameters. # pylint: disable=line-too-long
+ :type layers_to_freeze: int or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution
+ :param learning_rate: Initial learning rate. Must be a float in the range [0, 1].
+ :type learning_rate: float or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution
+ :param learning_rate_scheduler: Type of learning rate scheduler. Must be 'warmup_cosine' or
+ 'step'.
+ :type learning_rate_scheduler: str or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution
+ :param model_name: Name of the model to use for training.
+ For more information on the available models please visit the official documentation:
+ https://learn.microsoft.com/azure/machine-learning/how-to-auto-train-image-models.
+ :type model_name: str or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution
+ :param momentum: Value of momentum when optimizer is 'sgd'. Must be a float in the range [0,
+ 1].
+ :type momentum: float or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution
+ :param nesterov: Enable nesterov when optimizer is 'sgd'.
+ :type nesterov: bool or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution
+ :param number_of_epochs: Number of training epochs. Must be a positive integer.
+ :type number_of_epochs: int or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution
+ :param number_of_workers: Number of data loader workers. Must be a non-negative integer.
+ :type number_of_workers: int or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution
+ :param optimizer: Type of optimizer. Must be either 'sgd', 'adam', or 'adamw'.
+ :type optimizer: str or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution
+ :param random_seed: Random seed to be used when using deterministic training.
+ :type random_seed: int or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution
+ :param step_lr_gamma: Value of gamma when learning rate scheduler is 'step'. Must be a float
+ in the range [0, 1].
+ :type step_lr_gamma: float or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution
+ :param step_lr_step_size: Value of step size when learning rate scheduler is 'step'. Must be
+ a positive integer.
+ :type step_lr_step_size: int or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution
+ :param training_batch_size: Training batch size. Must be a positive integer.
+ :type training_batch_size: int or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution
+ :param validation_batch_size: Validation batch size. Must be a positive integer.
+ :type validation_batch_size: int or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution
+ :param warmup_cosine_lr_cycles: Value of cosine cycle when learning rate scheduler is
+ 'warmup_cosine'. Must be a float in the range [0, 1].
+ :type warmup_cosine_lr_cycles: float or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution
+ :param warmup_cosine_lr_warmup_epochs: Value of warmup epochs when learning rate scheduler is
+ 'warmup_cosine'. Must be a positive integer.
+ :type warmup_cosine_lr_warmup_epochs: int or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution
+ :param weight_decay: Value of weight decay when optimizer is 'sgd', 'adam', or 'adamw'. Must
+ be a float in the range[0, 1].
+ :type weight_decay: float or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution
+ :param training_crop_size: Image crop size that is input to the neural network for the
+ training dataset. Must be a positive integer.
+ :type training_crop_size: int or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution
+ :param validation_crop_size: Image crop size that is input to the neural network for the
+ validation dataset. Must be a positive integer.
+ :type validation_crop_size: int or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution
+ :param validation_resize_size: Image size to which to resize before cropping for validation
+ dataset. Must be a positive integer.
+ :type validation_resize_size: int or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution
+ :param weighted_loss: Weighted loss. The accepted values are 0 for no weighted loss.
+ 1 for weighted loss with sqrt.(class_weights). 2 for weighted loss with class_weights. Must be
+ 0 or 1 or 2.
+ :type weighted_loss: int or ~azure.ai.ml.entities._job.sweep.search_space.SweepDistribution
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_automl_image.py
+ :start-after: [START automl.automl_image_job.image_classification_search_space]
+ :end-before: [END automl.automl_image_job.image_classification_search_space]
+ :language: python
+ :dedent: 8
+ :caption: Defining an automl image classification search space
+ """
+
+ def __init__(
+ self,
+ *,
+ ams_gradient: Optional[Union[bool, SweepDistribution]] = None,
+ beta1: Optional[Union[float, SweepDistribution]] = None,
+ beta2: Optional[Union[float, SweepDistribution]] = None,
+ distributed: Optional[Union[bool, SweepDistribution]] = None,
+ early_stopping: Optional[Union[bool, SweepDistribution]] = None,
+ early_stopping_delay: Optional[Union[int, SweepDistribution]] = None,
+ early_stopping_patience: Optional[Union[int, SweepDistribution]] = None,
+ enable_onnx_normalization: Optional[Union[bool, SweepDistribution]] = None,
+ evaluation_frequency: Optional[Union[int, SweepDistribution]] = None,
+ gradient_accumulation_step: Optional[Union[int, SweepDistribution]] = None,
+ layers_to_freeze: Optional[Union[int, SweepDistribution]] = None,
+ learning_rate: Optional[Union[float, SweepDistribution]] = None,
+ learning_rate_scheduler: Optional[Union[str, SweepDistribution]] = None,
+ model_name: Optional[Union[str, SweepDistribution]] = None,
+ momentum: Optional[Union[float, SweepDistribution]] = None,
+ nesterov: Optional[Union[bool, SweepDistribution]] = None,
+ number_of_epochs: Optional[Union[int, SweepDistribution]] = None,
+ number_of_workers: Optional[Union[int, SweepDistribution]] = None,
+ optimizer: Optional[Union[str, SweepDistribution]] = None,
+ random_seed: Optional[Union[int, SweepDistribution]] = None,
+ step_lr_gamma: Optional[Union[float, SweepDistribution]] = None,
+ step_lr_step_size: Optional[Union[int, SweepDistribution]] = None,
+ training_batch_size: Optional[Union[int, SweepDistribution]] = None,
+ validation_batch_size: Optional[Union[int, SweepDistribution]] = None,
+ warmup_cosine_lr_cycles: Optional[Union[float, SweepDistribution]] = None,
+ warmup_cosine_lr_warmup_epochs: Optional[Union[int, SweepDistribution]] = None,
+ weight_decay: Optional[Union[float, SweepDistribution]] = None,
+ training_crop_size: Optional[Union[int, SweepDistribution]] = None,
+ validation_crop_size: Optional[Union[int, SweepDistribution]] = None,
+ validation_resize_size: Optional[Union[int, SweepDistribution]] = None,
+ weighted_loss: Optional[Union[int, SweepDistribution]] = None,
+ ) -> None:
+ self.ams_gradient = ams_gradient
+ self.beta1 = beta1
+ self.beta2 = beta2
+ self.distributed = distributed
+ self.early_stopping = early_stopping
+ self.early_stopping_delay = early_stopping_delay
+ self.early_stopping_patience = early_stopping_patience
+ self.enable_onnx_normalization = enable_onnx_normalization
+ self.evaluation_frequency = evaluation_frequency
+ self.gradient_accumulation_step = gradient_accumulation_step
+ self.layers_to_freeze = layers_to_freeze
+ self.learning_rate = learning_rate
+ self.learning_rate_scheduler = learning_rate_scheduler
+ self.model_name = model_name
+ self.momentum = momentum
+ self.nesterov = nesterov
+ self.number_of_epochs = number_of_epochs
+ self.number_of_workers = number_of_workers
+ self.optimizer = optimizer
+ self.random_seed = random_seed
+ self.step_lr_gamma = step_lr_gamma
+ self.step_lr_step_size = step_lr_step_size
+ self.training_batch_size = training_batch_size
+ self.validation_batch_size = validation_batch_size
+ self.warmup_cosine_lr_cycles = warmup_cosine_lr_cycles
+ self.warmup_cosine_lr_warmup_epochs = warmup_cosine_lr_warmup_epochs
+ self.weight_decay = weight_decay
+ self.training_crop_size = training_crop_size
+ self.validation_crop_size = validation_crop_size
+ self.validation_resize_size = validation_resize_size
+ self.weighted_loss = weighted_loss
+
+ def _to_rest_object(self) -> ImageModelDistributionSettingsClassification:
+ return ImageModelDistributionSettingsClassification(
+ ams_gradient=_convert_to_rest_object(self.ams_gradient) if self.ams_gradient is not None else None,
+ beta1=_convert_to_rest_object(self.beta1) if self.beta1 is not None else None,
+ beta2=_convert_to_rest_object(self.beta2) if self.beta2 is not None else None,
+ distributed=_convert_to_rest_object(self.distributed) if self.distributed is not None else None,
+ early_stopping=_convert_to_rest_object(self.early_stopping) if self.early_stopping is not None else None,
+ early_stopping_delay=(
+ _convert_to_rest_object(self.early_stopping_delay) if self.early_stopping_delay is not None else None
+ ),
+ early_stopping_patience=(
+ _convert_to_rest_object(self.early_stopping_patience)
+ if self.early_stopping_patience is not None
+ else None
+ ),
+ enable_onnx_normalization=(
+ _convert_to_rest_object(self.enable_onnx_normalization)
+ if self.enable_onnx_normalization is not None
+ else None
+ ),
+ evaluation_frequency=(
+ _convert_to_rest_object(self.evaluation_frequency) if self.evaluation_frequency is not None else None
+ ),
+ gradient_accumulation_step=(
+ _convert_to_rest_object(self.gradient_accumulation_step)
+ if self.gradient_accumulation_step is not None
+ else None
+ ),
+ layers_to_freeze=(
+ _convert_to_rest_object(self.layers_to_freeze) if self.layers_to_freeze is not None else None
+ ),
+ learning_rate=_convert_to_rest_object(self.learning_rate) if self.learning_rate is not None else None,
+ learning_rate_scheduler=(
+ _convert_to_rest_object(self.learning_rate_scheduler)
+ if self.learning_rate_scheduler is not None
+ else None
+ ),
+ model_name=_convert_to_rest_object(self.model_name) if self.model_name is not None else None,
+ momentum=_convert_to_rest_object(self.momentum) if self.momentum is not None else None,
+ nesterov=_convert_to_rest_object(self.nesterov) if self.nesterov is not None else None,
+ number_of_epochs=(
+ _convert_to_rest_object(self.number_of_epochs) if self.number_of_epochs is not None else None
+ ),
+ number_of_workers=(
+ _convert_to_rest_object(self.number_of_workers) if self.number_of_workers is not None else None
+ ),
+ optimizer=_convert_to_rest_object(self.optimizer) if self.optimizer is not None else None,
+ random_seed=_convert_to_rest_object(self.random_seed) if self.random_seed is not None else None,
+ step_lr_gamma=_convert_to_rest_object(self.step_lr_gamma) if self.step_lr_gamma is not None else None,
+ step_lr_step_size=(
+ _convert_to_rest_object(self.step_lr_step_size) if self.step_lr_step_size is not None else None
+ ),
+ training_batch_size=(
+ _convert_to_rest_object(self.training_batch_size) if self.training_batch_size is not None else None
+ ),
+ validation_batch_size=(
+ _convert_to_rest_object(self.validation_batch_size) if self.validation_batch_size is not None else None
+ ),
+ warmup_cosine_lr_cycles=(
+ _convert_to_rest_object(self.warmup_cosine_lr_cycles)
+ if self.warmup_cosine_lr_cycles is not None
+ else None
+ ),
+ warmup_cosine_lr_warmup_epochs=(
+ _convert_to_rest_object(self.warmup_cosine_lr_warmup_epochs)
+ if self.warmup_cosine_lr_warmup_epochs is not None
+ else None
+ ),
+ weight_decay=_convert_to_rest_object(self.weight_decay) if self.weight_decay is not None else None,
+ training_crop_size=(
+ _convert_to_rest_object(self.training_crop_size) if self.training_crop_size is not None else None
+ ),
+ validation_crop_size=(
+ _convert_to_rest_object(self.validation_crop_size) if self.validation_crop_size is not None else None
+ ),
+ validation_resize_size=(
+ _convert_to_rest_object(self.validation_resize_size)
+ if self.validation_resize_size is not None
+ else None
+ ),
+ weighted_loss=_convert_to_rest_object(self.weighted_loss) if self.weighted_loss is not None else None,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: ImageModelDistributionSettingsClassification) -> "ImageClassificationSearchSpace":
+ return cls(
+ ams_gradient=_convert_from_rest_object(obj.ams_gradient) if obj.ams_gradient is not None else None,
+ beta1=_convert_from_rest_object(obj.beta1) if obj.beta1 is not None else None,
+ beta2=_convert_from_rest_object(obj.beta2) if obj.beta2 is not None else None,
+ distributed=_convert_from_rest_object(obj.distributed) if obj.distributed is not None else None,
+ early_stopping=_convert_from_rest_object(obj.early_stopping) if obj.early_stopping is not None else None,
+ early_stopping_delay=(
+ _convert_from_rest_object(obj.early_stopping_delay) if obj.early_stopping_delay is not None else None
+ ),
+ early_stopping_patience=(
+ _convert_from_rest_object(obj.early_stopping_patience)
+ if obj.early_stopping_patience is not None
+ else None
+ ),
+ enable_onnx_normalization=(
+ _convert_from_rest_object(obj.enable_onnx_normalization)
+ if obj.enable_onnx_normalization is not None
+ else None
+ ),
+ evaluation_frequency=(
+ _convert_from_rest_object(obj.evaluation_frequency) if obj.evaluation_frequency is not None else None
+ ),
+ gradient_accumulation_step=(
+ _convert_from_rest_object(obj.gradient_accumulation_step)
+ if obj.gradient_accumulation_step is not None
+ else None
+ ),
+ layers_to_freeze=(
+ _convert_from_rest_object(obj.layers_to_freeze) if obj.layers_to_freeze is not None else None
+ ),
+ learning_rate=_convert_from_rest_object(obj.learning_rate) if obj.learning_rate is not None else None,
+ learning_rate_scheduler=(
+ _convert_from_rest_object(obj.learning_rate_scheduler)
+ if obj.learning_rate_scheduler is not None
+ else None
+ ),
+ model_name=_convert_from_rest_object(obj.model_name) if obj.model_name is not None else None,
+ momentum=_convert_from_rest_object(obj.momentum) if obj.momentum is not None else None,
+ nesterov=_convert_from_rest_object(obj.nesterov) if obj.nesterov is not None else None,
+ number_of_epochs=(
+ _convert_from_rest_object(obj.number_of_epochs) if obj.number_of_epochs is not None else None
+ ),
+ number_of_workers=(
+ _convert_from_rest_object(obj.number_of_workers) if obj.number_of_workers is not None else None
+ ),
+ optimizer=_convert_from_rest_object(obj.optimizer) if obj.optimizer is not None else None,
+ random_seed=_convert_from_rest_object(obj.random_seed) if obj.random_seed is not None else None,
+ step_lr_gamma=_convert_from_rest_object(obj.step_lr_gamma) if obj.step_lr_gamma is not None else None,
+ step_lr_step_size=(
+ _convert_from_rest_object(obj.step_lr_step_size) if obj.step_lr_step_size is not None else None
+ ),
+ training_batch_size=(
+ _convert_from_rest_object(obj.training_batch_size) if obj.training_batch_size is not None else None
+ ),
+ validation_batch_size=(
+ _convert_from_rest_object(obj.validation_batch_size) if obj.validation_batch_size is not None else None
+ ),
+ warmup_cosine_lr_cycles=(
+ _convert_from_rest_object(obj.warmup_cosine_lr_cycles)
+ if obj.warmup_cosine_lr_cycles is not None
+ else None
+ ),
+ warmup_cosine_lr_warmup_epochs=(
+ _convert_from_rest_object(obj.warmup_cosine_lr_warmup_epochs)
+ if obj.warmup_cosine_lr_warmup_epochs is not None
+ else None
+ ),
+ weight_decay=_convert_from_rest_object(obj.weight_decay) if obj.weight_decay is not None else None,
+ training_crop_size=(
+ _convert_from_rest_object(obj.training_crop_size) if obj.training_crop_size is not None else None
+ ),
+ validation_crop_size=(
+ _convert_from_rest_object(obj.validation_crop_size) if obj.validation_crop_size is not None else None
+ ),
+ validation_resize_size=(
+ _convert_from_rest_object(obj.validation_resize_size)
+ if obj.validation_resize_size is not None
+ else None
+ ),
+ weighted_loss=_convert_from_rest_object(obj.weighted_loss) if obj.weighted_loss is not None else None,
+ )
+
+ @classmethod
+ def _from_search_space_object(cls, obj: SearchSpace) -> "ImageClassificationSearchSpace":
+ return cls(
+ ams_gradient=obj.ams_gradient if hasattr(obj, "ams_gradient") else None,
+ beta1=obj.beta1 if hasattr(obj, "beta1") else None,
+ beta2=obj.beta2 if hasattr(obj, "beta2") else None,
+ distributed=obj.distributed if hasattr(obj, "distributed") else None,
+ early_stopping=obj.early_stopping if hasattr(obj, "early_stopping") else None,
+ early_stopping_delay=obj.early_stopping_delay if hasattr(obj, "early_stopping_delay") else None,
+ early_stopping_patience=obj.early_stopping_patience if hasattr(obj, "early_stopping_patience") else None,
+ enable_onnx_normalization=(
+ obj.enable_onnx_normalization if hasattr(obj, "enable_onnx_normalization") else None
+ ),
+ evaluation_frequency=obj.evaluation_frequency if hasattr(obj, "evaluation_frequency") else None,
+ gradient_accumulation_step=(
+ obj.gradient_accumulation_step if hasattr(obj, "gradient_accumulation_step") else None
+ ),
+ layers_to_freeze=obj.layers_to_freeze if hasattr(obj, "layers_to_freeze") else None,
+ learning_rate=obj.learning_rate if hasattr(obj, "learning_rate") else None,
+ learning_rate_scheduler=obj.learning_rate_scheduler if hasattr(obj, "learning_rate_scheduler") else None,
+ model_name=obj.model_name if hasattr(obj, "model_name") else None,
+ momentum=obj.momentum if hasattr(obj, "momentum") else None,
+ nesterov=obj.nesterov if hasattr(obj, "nesterov") else None,
+ number_of_epochs=obj.number_of_epochs if hasattr(obj, "number_of_epochs") else None,
+ number_of_workers=obj.number_of_workers if hasattr(obj, "number_of_workers") else None,
+ optimizer=obj.optimizer if hasattr(obj, "optimizer") else None,
+ random_seed=obj.random_seed if hasattr(obj, "random_seed") else None,
+ step_lr_gamma=obj.step_lr_gamma if hasattr(obj, "step_lr_gamma") else None,
+ step_lr_step_size=obj.step_lr_step_size if hasattr(obj, "step_lr_step_size") else None,
+ training_batch_size=obj.training_batch_size if hasattr(obj, "training_batch_size") else None,
+ validation_batch_size=obj.validation_batch_size if hasattr(obj, "validation_batch_size") else None,
+ warmup_cosine_lr_cycles=obj.warmup_cosine_lr_cycles if hasattr(obj, "warmup_cosine_lr_cycles") else None,
+ warmup_cosine_lr_warmup_epochs=(
+ obj.warmup_cosine_lr_warmup_epochs if hasattr(obj, "warmup_cosine_lr_warmup_epochs") else None
+ ),
+ weight_decay=obj.weight_decay if hasattr(obj, "weight_decay") else None,
+ training_crop_size=obj.training_crop_size if hasattr(obj, "training_crop_size") else None,
+ validation_crop_size=obj.validation_crop_size if hasattr(obj, "validation_crop_size") else None,
+ validation_resize_size=obj.validation_resize_size if hasattr(obj, "validation_resize_size") else None,
+ weighted_loss=obj.weighted_loss if hasattr(obj, "weighted_loss") else None,
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, ImageClassificationSearchSpace):
+ return NotImplemented
+
+ return (
+ self.ams_gradient == other.ams_gradient
+ and self.beta1 == other.beta1
+ and self.beta2 == other.beta2
+ and self.distributed == other.distributed
+ and self.early_stopping == other.early_stopping
+ and self.early_stopping_delay == other.early_stopping_delay
+ and self.early_stopping_patience == other.early_stopping_patience
+ and self.enable_onnx_normalization == other.enable_onnx_normalization
+ and self.evaluation_frequency == other.evaluation_frequency
+ and self.gradient_accumulation_step == other.gradient_accumulation_step
+ and self.layers_to_freeze == other.layers_to_freeze
+ and self.learning_rate == other.learning_rate
+ and self.learning_rate_scheduler == other.learning_rate_scheduler
+ and self.model_name == other.model_name
+ and self.momentum == other.momentum
+ and self.nesterov == other.nesterov
+ and self.number_of_epochs == other.number_of_epochs
+ and self.number_of_workers == other.number_of_workers
+ and self.optimizer == other.optimizer
+ and self.random_seed == other.random_seed
+ and self.step_lr_gamma == other.step_lr_gamma
+ and self.step_lr_step_size == other.step_lr_step_size
+ and self.training_batch_size == other.training_batch_size
+ and self.validation_batch_size == other.validation_batch_size
+ and self.warmup_cosine_lr_cycles == other.warmup_cosine_lr_cycles
+ and self.warmup_cosine_lr_warmup_epochs == other.warmup_cosine_lr_warmup_epochs
+ and self.weight_decay == other.weight_decay
+ and self.training_crop_size == other.training_crop_size
+ and self.validation_crop_size == other.validation_crop_size
+ and self.validation_resize_size == other.validation_resize_size
+ and self.weighted_loss == other.weighted_loss
+ )
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_instance_segmentation_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_instance_segmentation_job.py
new file mode 100644
index 00000000..c97d3c11
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_instance_segmentation_job.py
@@ -0,0 +1,249 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from typing import Any, Dict, Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import AutoMLJob as RestAutoMLJob
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ ImageInstanceSegmentation as RestImageInstanceSegmentation,
+)
+from azure.ai.ml._restclient.v2023_04_01_preview.models import InstanceSegmentationPrimaryMetrics, JobBase, TaskType
+from azure.ai.ml._utils.utils import camel_to_snake, is_data_binding_expression
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+from azure.ai.ml.constants._job.automl import AutoMLConstants
+from azure.ai.ml.entities._credentials import _BaseJobIdentityConfiguration
+from azure.ai.ml.entities._job._input_output_helpers import from_rest_data_outputs, to_rest_data_outputs
+from azure.ai.ml.entities._job.automl.image.automl_image_object_detection_base import AutoMLImageObjectDetectionBase
+from azure.ai.ml.entities._job.automl.image.image_limit_settings import ImageLimitSettings
+from azure.ai.ml.entities._job.automl.image.image_model_settings import ImageModelSettingsObjectDetection
+from azure.ai.ml.entities._job.automl.image.image_sweep_settings import ImageSweepSettings
+from azure.ai.ml.entities._util import load_from_dict
+
+
+class ImageInstanceSegmentationJob(AutoMLImageObjectDetectionBase):
+ """Configuration for AutoML Image Instance Segmentation job.
+
+ :keyword primary_metric: The primary metric to use for optimization.
+ :paramtype primary_metric: Optional[str, ~azure.ai.ml.automl.InstanceSegmentationPrimaryMetrics]
+ :keyword kwargs: Job-specific arguments.
+ :paramtype kwargs: Dict[str, Any]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_automl_image.py
+ :start-after: [START automl.automl_image_job.image_instance_segmentation_job]
+ :end-before: [END automl.automl_image_job.image_instance_segmentation_job]
+ :language: python
+ :dedent: 8
+ :caption: creating an automl image instance segmentation job
+ """
+
+ _DEFAULT_PRIMARY_METRIC = InstanceSegmentationPrimaryMetrics.MEAN_AVERAGE_PRECISION
+
+ def __init__(
+ self,
+ *,
+ primary_metric: Optional[Union[str, InstanceSegmentationPrimaryMetrics]] = None,
+ **kwargs: Any,
+ ) -> None:
+ # Extract any super class init settings
+ limits = kwargs.pop("limits", None)
+ sweep = kwargs.pop("sweep", None)
+ training_parameters = kwargs.pop("training_parameters", None)
+ search_space = kwargs.pop("search_space", None)
+
+ super().__init__(
+ task_type=TaskType.IMAGE_INSTANCE_SEGMENTATION,
+ limits=limits,
+ sweep=sweep,
+ training_parameters=training_parameters,
+ search_space=search_space,
+ **kwargs,
+ )
+ self.primary_metric = primary_metric or ImageInstanceSegmentationJob._DEFAULT_PRIMARY_METRIC
+
+ @property
+ def primary_metric(self) -> Union[str, InstanceSegmentationPrimaryMetrics]:
+ return self._primary_metric
+
+ @primary_metric.setter
+ def primary_metric(self, value: Union[str, InstanceSegmentationPrimaryMetrics]) -> None:
+ if is_data_binding_expression(str(value), ["parent"]):
+ self._primary_metric = value
+ return
+ self._primary_metric = (
+ ImageInstanceSegmentationJob._DEFAULT_PRIMARY_METRIC
+ if value is None
+ else InstanceSegmentationPrimaryMetrics[camel_to_snake(value).upper()]
+ )
+
+ def _to_rest_object(self) -> JobBase:
+ image_instance_segmentation_task = RestImageInstanceSegmentation(
+ target_column_name=self.target_column_name,
+ training_data=self.training_data,
+ validation_data=self.validation_data,
+ validation_data_size=self.validation_data_size,
+ limit_settings=self._limits._to_rest_object() if self._limits else None,
+ sweep_settings=self._sweep._to_rest_object() if self._sweep else None,
+ model_settings=self._training_parameters._to_rest_object() if self._training_parameters else None,
+ search_space=(
+ [entry._to_rest_object() for entry in self._search_space if entry is not None]
+ if self._search_space is not None
+ else None
+ ),
+ primary_metric=self.primary_metric,
+ log_verbosity=self.log_verbosity,
+ )
+ # resolve data inputs in rest obj
+ self._resolve_data_inputs(image_instance_segmentation_task)
+
+ properties = RestAutoMLJob(
+ display_name=self.display_name,
+ description=self.description,
+ experiment_name=self.experiment_name,
+ tags=self.tags,
+ compute_id=self.compute,
+ properties=self.properties,
+ environment_id=self.environment_id,
+ environment_variables=self.environment_variables,
+ services=self.services,
+ outputs=to_rest_data_outputs(self.outputs),
+ resources=self.resources,
+ task_details=image_instance_segmentation_task,
+ identity=self.identity._to_job_rest_object() if self.identity else None,
+ queue_settings=self.queue_settings,
+ )
+
+ result = JobBase(properties=properties)
+ result.name = self.name
+ return result
+
+ @classmethod
+ def _from_rest_object(cls, obj: JobBase) -> "ImageInstanceSegmentationJob":
+ properties: RestAutoMLJob = obj.properties
+ task_details: RestImageInstanceSegmentation = properties.task_details
+
+ job_args_dict = {
+ "id": obj.id,
+ "name": obj.name,
+ "description": properties.description,
+ "tags": properties.tags,
+ "properties": properties.properties,
+ "experiment_name": properties.experiment_name,
+ "services": properties.services,
+ "status": properties.status,
+ "creation_context": obj.system_data,
+ "display_name": properties.display_name,
+ "compute": properties.compute_id,
+ "outputs": from_rest_data_outputs(properties.outputs),
+ "resources": properties.resources,
+ "identity": (
+ _BaseJobIdentityConfiguration._from_rest_object(properties.identity) if properties.identity else None
+ ),
+ "queue_settings": properties.queue_settings,
+ }
+
+ image_instance_segmentation_job = cls(
+ target_column_name=task_details.target_column_name,
+ training_data=task_details.training_data,
+ validation_data=task_details.validation_data,
+ validation_data_size=task_details.validation_data_size,
+ limits=(
+ ImageLimitSettings._from_rest_object(task_details.limit_settings)
+ if task_details.limit_settings
+ else None
+ ),
+ sweep=(
+ ImageSweepSettings._from_rest_object(task_details.sweep_settings)
+ if task_details.sweep_settings
+ else None
+ ),
+ training_parameters=(
+ ImageModelSettingsObjectDetection._from_rest_object(task_details.model_settings)
+ if task_details.model_settings
+ else None
+ ),
+ search_space=cls._get_search_space_from_str(task_details.search_space),
+ primary_metric=task_details.primary_metric,
+ log_verbosity=task_details.log_verbosity,
+ **job_args_dict,
+ )
+
+ image_instance_segmentation_job._restore_data_inputs()
+
+ return image_instance_segmentation_job
+
+ @classmethod
+ def _load_from_dict(
+ cls,
+ data: Dict,
+ context: Dict,
+ additional_message: str,
+ **kwargs: Any,
+ ) -> "ImageInstanceSegmentationJob":
+ from azure.ai.ml._schema.automl.image_vertical.image_object_detection import ImageInstanceSegmentationSchema
+ from azure.ai.ml._schema.pipeline.automl_node import ImageInstanceSegmentationNodeSchema
+
+ inside_pipeline = kwargs.pop("inside_pipeline", False)
+ if inside_pipeline:
+ if context.get("inside_pipeline", None) is None:
+ context["inside_pipeline"] = True
+ loaded_data = load_from_dict(
+ ImageInstanceSegmentationNodeSchema,
+ data,
+ context,
+ additional_message,
+ **kwargs,
+ )
+ else:
+ loaded_data = load_from_dict(
+ ImageInstanceSegmentationSchema,
+ data,
+ context,
+ additional_message,
+ **kwargs,
+ )
+ job_instance = cls._create_instance_from_schema_dict(loaded_data)
+ return job_instance
+
+ @classmethod
+ def _create_instance_from_schema_dict(cls, loaded_data: Dict) -> "ImageInstanceSegmentationJob":
+ loaded_data.pop(AutoMLConstants.TASK_TYPE_YAML, None)
+ data_settings = {
+ "training_data": loaded_data.pop("training_data"),
+ "target_column_name": loaded_data.pop("target_column_name"),
+ "validation_data": loaded_data.pop("validation_data", None),
+ "validation_data_size": loaded_data.pop("validation_data_size", None),
+ }
+ job = ImageInstanceSegmentationJob(**loaded_data)
+ job.set_data(**data_settings)
+ return job
+
+ def _to_dict(self, inside_pipeline: bool = False) -> Dict:
+ from azure.ai.ml._schema.automl.image_vertical.image_object_detection import ImageInstanceSegmentationSchema
+ from azure.ai.ml._schema.pipeline.automl_node import ImageInstanceSegmentationNodeSchema
+
+ schema_dict: dict = {}
+ if inside_pipeline:
+ schema_dict = ImageInstanceSegmentationNodeSchema(
+ context={BASE_PATH_CONTEXT_KEY: "./", "inside_pipeline": True}
+ ).dump(self)
+ else:
+ schema_dict = ImageInstanceSegmentationSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+
+ return schema_dict
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, ImageInstanceSegmentationJob):
+ return NotImplemented
+
+ if not super().__eq__(other):
+ return False
+
+ return self.primary_metric == other.primary_metric
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_limit_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_limit_settings.py
new file mode 100644
index 00000000..12ec8b57
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_limit_settings.py
@@ -0,0 +1,117 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Optional
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import ImageLimitSettings as RestImageLimitSettings
+from azure.ai.ml._utils.utils import from_iso_duration_format_mins, to_iso_duration_format_mins
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+
+class ImageLimitSettings(RestTranslatableMixin):
+ r"""Limit settings for AutoML Image Verticals.
+
+ ImageLimitSettings is a class that contains the following parameters: max_concurrent_trials, max_trials, and \
+ timeout_minutes.
+
+ This is an optional configuration method to configure limits parameters such as timeouts etc.
+
+ .. note::
+
+ The number of concurrent runs is gated on the resources available in the specified compute target.
+ Ensure that the compute target has the available resources for the desired concurrency.
+
+ :keyword max_concurrent_trials: Maximum number of concurrent AutoML iterations, defaults to None.
+ :paramtype max_concurrent_trials: typing.Optional[int]
+ :keyword max_trials: Represents the maximum number of trials (children jobs).
+ :paramtype max_trials: typing.Optional[int]
+ :keyword timeout_minutes: AutoML job timeout. Defaults to None
+ :paramtype timeout_minutes: typing.Optional[int]
+ :raises ValueError: If max_concurrent_trials is not None and is not a positive integer.
+ :raises ValueError: If max_trials is not None and is not a positive integer.
+ :raises ValueError: If timeout_minutes is not None and is not a positive integer.
+ :return: ImageLimitSettings object.
+ :rtype: ImageLimitSettings
+
+ .. tip::
+ It's a good practice to match max_concurrent_trials count with the number of nodes in the cluster.
+ For example, if you have a cluster with 4 nodes, set max_concurrent_trials to 4.
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_automl_image.py
+ :start-after: [START automl.automl_image_job.image_limit_settings]
+ :end-before: [END automl.automl_image_job.image_limit_settings]
+ :language: python
+ :dedent: 8
+ :caption: Defining the limit settings for an automl image job.
+ """
+
+ def __init__(
+ self,
+ *,
+ max_concurrent_trials: Optional[int] = None,
+ max_trials: Optional[int] = None,
+ timeout_minutes: Optional[int] = None,
+ ) -> None:
+ self.max_concurrent_trials = max_concurrent_trials
+ self.max_trials = max_trials
+ self.timeout_minutes = timeout_minutes
+
+ def _to_rest_object(self) -> RestImageLimitSettings:
+ """Convert ImageLimitSettings objects to a rest object.
+
+ :return: A rest object of ImageLimitSettings objects.
+ :rtype: RestImageLimitSettings
+ """
+ return RestImageLimitSettings(
+ max_concurrent_trials=self.max_concurrent_trials,
+ max_trials=self.max_trials,
+ timeout=to_iso_duration_format_mins(self.timeout_minutes),
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestImageLimitSettings) -> "ImageLimitSettings":
+ """Convert the rest object to a dict containing items to init the ImageLimitSettings objects.
+
+ :param obj: Limit settings for the AutoML job in Rest format.
+ :type obj: RestImageLimitSettings
+ :return: Limit settings for an AutoML Image Vertical.
+ :rtype: ImageLimitSettings
+ """
+ return cls(
+ max_concurrent_trials=obj.max_concurrent_trials,
+ max_trials=obj.max_trials,
+ timeout_minutes=from_iso_duration_format_mins(obj.timeout),
+ )
+
+ def __eq__(self, other: object) -> bool:
+ """Check equality between two ImageLimitSettings objects.
+
+ This method check instances equality and returns True if both of
+ the instances have the same attributes with the same values.
+
+ :param other: Any object
+ :type other: object
+ :return: True or False
+ :rtype: bool
+ """
+ if not isinstance(other, ImageLimitSettings):
+ return NotImplemented
+
+ return (
+ self.max_concurrent_trials == other.max_concurrent_trials
+ and self.max_trials == other.max_trials
+ and self.timeout_minutes == other.timeout_minutes
+ )
+
+ def __ne__(self, other: object) -> bool:
+ """Check inequality between two ImageLimitSettings objects.
+
+ :param other: Any object
+ :type other: object
+ :return: True or False
+ :rtype: bool
+ """
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_model_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_model_settings.py
new file mode 100644
index 00000000..890f987a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_model_settings.py
@@ -0,0 +1,876 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Any, Optional
+
+# pylint: disable=R0902,too-many-locals
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ ImageModelSettingsClassification as RestImageModelSettingsClassification,
+)
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ ImageModelSettingsObjectDetection as RestImageModelSettingsObjectDetection,
+)
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ LearningRateScheduler,
+ LogTrainingMetrics,
+ LogValidationLoss,
+ ModelSize,
+ StochasticOptimizer,
+ ValidationMetricType,
+)
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+
+class ImageModelDistributionSettings(RestTranslatableMixin):
+ """Model settings for all AutoML Image Verticals.
+ Please do not instantiate directly. Use the child classes instead.
+
+ :param advanced_settings: Settings for advanced scenarios.
+ :type advanced_settings: str
+ :param ams_gradient: Enable AMSGrad when optimizer is 'adam' or 'adamw'.
+ :type ams_gradient: bool
+ :param beta1: Value of 'beta1' when optimizer is 'adam' or 'adamw'. Must be a float in the range
+ [0, 1].
+ :type beta1: float
+ :param beta2: Value of 'beta2' when optimizer is 'adam' or 'adamw'. Must be a float in the range
+ [0, 1].
+ :type beta2: float
+ :param checkpoint_frequency: Frequency to store model checkpoints. Must be a positive integer.
+ :type checkpoint_frequency: int
+ :param checkpoint_run_id: The id of a previous run that has a pretrained checkpoint for
+ incremental training.
+ :type checkpoint_run_id: str
+ :param distributed: Whether to use distributed training.
+ :type distributed: bool
+ :param early_stopping: Enable early stopping logic during training.
+ :type early_stopping: bool
+ :param early_stopping_delay: Minimum number of epochs or validation evaluations to wait before
+ primary metric improvement
+ is tracked for early stopping. Must be a positive integer.
+ :type early_stopping_delay: int
+ :param early_stopping_patience: Minimum number of epochs or validation evaluations with no
+ primary metric improvement before
+ the run is stopped. Must be a positive integer.
+ :type early_stopping_patience: int
+ :param enable_onnx_normalization: Enable normalization when exporting ONNX model.
+ :type enable_onnx_normalization: bool
+ :param evaluation_frequency: Frequency to evaluate validation dataset to get metric scores. Must
+ be a positive integer.
+ :type evaluation_frequency: int
+ :param gradient_accumulation_step: Gradient accumulation means running a configured number of
+ "GradAccumulationStep" steps without
+ updating the model weights while accumulating the gradients of those steps, and then using
+ the accumulated gradients to compute the weight updates. Must be a positive integer.
+ :type gradient_accumulation_step: int
+ :param layers_to_freeze: Number of layers to freeze for the model. Must be a positive integer.
+ For instance, passing 2 as value for 'seresnext' means
+ freezing layer0 and layer1. For a full list of models supported and details on layer freeze,
+ please
+ see: https://learn.microsoft.com/azure/machine-learning/how-to-auto-train-image-models.
+ :type layers_to_freeze: int
+ :param learning_rate: Initial learning rate. Must be a float in the range [0, 1].
+ :type learning_rate: float
+ :param learning_rate_scheduler: Type of learning rate scheduler. Must be 'warmup_cosine' or
+ 'step'. Possible values include: "None", "WarmupCosine", "Step".
+ :type learning_rate_scheduler: str or
+ ~azure.mgmt.machinelearningservices.models.LearningRateScheduler
+ :param model_name: Name of the model to use for training.
+ For more information on the available models please visit the official documentation:
+ https://learn.microsoft.com/azure/machine-learning/how-to-auto-train-image-models.
+ :type model_name: str
+ :param momentum: Value of momentum when optimizer is 'sgd'. Must be a float in the range [0, 1].
+ :type momentum: float
+ :param nesterov: Enable nesterov when optimizer is 'sgd'.
+ :type nesterov: bool
+ :param number_of_epochs: Number of training epochs. Must be a positive integer.
+ :type number_of_epochs: int
+ :param number_of_workers: Number of data loader workers. Must be a non-negative integer.
+ :type number_of_workers: int
+ :param optimizer: Type of optimizer. Possible values include: "None", "Sgd", "Adam", "Adamw".
+ :type optimizer: str or ~azure.mgmt.machinelearningservices.models.StochasticOptimizer
+ :param random_seed: Random seed to be used when using deterministic training.
+ :type random_seed: int
+ :param step_lr_gamma: Value of gamma when learning rate scheduler is 'step'. Must be a float in
+ the range [0, 1].
+ :type step_lr_gamma: float
+ :param step_lr_step_size: Value of step size when learning rate scheduler is 'step'. Must be a
+ positive integer.
+ :type step_lr_step_size: int
+ :param training_batch_size: Training batch size. Must be a positive integer.
+ :type training_batch_size: int
+ :param validation_batch_size: Validation batch size. Must be a positive integer.
+ :type validation_batch_size: int
+ :param warmup_cosine_lr_cycles: Value of cosine cycle when learning rate scheduler is
+ 'warmup_cosine'. Must be a float in the range [0, 1].
+ :type warmup_cosine_lr_cycles: float
+ :param warmup_cosine_lr_warmup_epochs: Value of warmup epochs when learning rate scheduler is
+ 'warmup_cosine'. Must be a positive integer.
+ :type warmup_cosine_lr_warmup_epochs: int
+ :param weight_decay: Value of weight decay when optimizer is 'sgd', 'adam', or 'adamw'. Must be
+ a float in the range[0, 1].
+ :type weight_decay: float
+ """
+
+ def __init__(
+ self,
+ *,
+ advanced_settings: Optional[str] = None,
+ ams_gradient: Optional[bool] = None,
+ beta1: Optional[float] = None,
+ beta2: Optional[float] = None,
+ checkpoint_frequency: Optional[int] = None,
+ checkpoint_run_id: Optional[str] = None,
+ distributed: Optional[bool] = None,
+ early_stopping: Optional[bool] = None,
+ early_stopping_delay: Optional[int] = None,
+ early_stopping_patience: Optional[int] = None,
+ enable_onnx_normalization: Optional[bool] = None,
+ evaluation_frequency: Optional[int] = None,
+ gradient_accumulation_step: Optional[int] = None,
+ layers_to_freeze: Optional[int] = None,
+ learning_rate: Optional[float] = None,
+ learning_rate_scheduler: Optional[LearningRateScheduler] = None,
+ model_name: Optional[str] = None,
+ momentum: Optional[float] = None,
+ nesterov: Optional[bool] = None,
+ number_of_epochs: Optional[int] = None,
+ number_of_workers: Optional[int] = None,
+ optimizer: Optional[StochasticOptimizer] = None,
+ random_seed: Optional[int] = None,
+ step_lr_gamma: Optional[float] = None,
+ step_lr_step_size: Optional[int] = None,
+ training_batch_size: Optional[int] = None,
+ validation_batch_size: Optional[int] = None,
+ warmup_cosine_lr_cycles: Optional[float] = None,
+ warmup_cosine_lr_warmup_epochs: Optional[int] = None,
+ weight_decay: Optional[float] = None,
+ ):
+ self.advanced_settings = advanced_settings
+ self.ams_gradient = ams_gradient
+ self.beta1 = beta1
+ self.beta2 = beta2
+ self.checkpoint_frequency = checkpoint_frequency
+ self.checkpoint_run_id = checkpoint_run_id
+ self.distributed = distributed
+ self.early_stopping = early_stopping
+ self.early_stopping_delay = early_stopping_delay
+ self.early_stopping_patience = early_stopping_patience
+ self.enable_onnx_normalization = enable_onnx_normalization
+ self.evaluation_frequency = evaluation_frequency
+ self.gradient_accumulation_step = gradient_accumulation_step
+ self.layers_to_freeze = layers_to_freeze
+ self.learning_rate = learning_rate
+ self.learning_rate_scheduler = learning_rate_scheduler
+ self.model_name = model_name
+ self.momentum = momentum
+ self.nesterov = nesterov
+ self.number_of_epochs = number_of_epochs
+ self.number_of_workers = number_of_workers
+ self.optimizer = optimizer
+ self.random_seed = random_seed
+ self.step_lr_gamma = step_lr_gamma
+ self.step_lr_step_size = step_lr_step_size
+ self.training_batch_size = training_batch_size
+ self.validation_batch_size = validation_batch_size
+ self.warmup_cosine_lr_cycles = warmup_cosine_lr_cycles
+ self.warmup_cosine_lr_warmup_epochs = warmup_cosine_lr_warmup_epochs
+ self.weight_decay = weight_decay
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, ImageModelDistributionSettings):
+ return NotImplemented
+
+ return (
+ self.advanced_settings == other.advanced_settings
+ and self.ams_gradient == other.ams_gradient
+ and self.beta1 == other.beta1
+ and self.beta2 == other.beta2
+ and self.checkpoint_frequency == other.checkpoint_frequency
+ and self.checkpoint_run_id == other.checkpoint_run_id
+ and self.distributed == other.distributed
+ and self.early_stopping == other.early_stopping
+ and self.early_stopping_delay == other.early_stopping_delay
+ and self.early_stopping_patience == other.early_stopping_patience
+ and self.enable_onnx_normalization == other.enable_onnx_normalization
+ and self.evaluation_frequency == other.evaluation_frequency
+ and self.gradient_accumulation_step == other.gradient_accumulation_step
+ and self.layers_to_freeze == other.layers_to_freeze
+ and self.learning_rate == other.learning_rate
+ and self.learning_rate_scheduler == other.learning_rate_scheduler
+ and self.model_name == other.model_name
+ and self.momentum == other.momentum
+ and self.nesterov == other.nesterov
+ and self.number_of_epochs == other.number_of_epochs
+ and self.number_of_workers == other.number_of_workers
+ and self.optimizer == other.optimizer
+ and self.random_seed == other.random_seed
+ and self.step_lr_gamma == other.step_lr_gamma
+ and self.step_lr_step_size == other.step_lr_step_size
+ and self.training_batch_size == other.training_batch_size
+ and self.validation_batch_size == other.validation_batch_size
+ and self.warmup_cosine_lr_cycles == other.warmup_cosine_lr_cycles
+ and self.warmup_cosine_lr_warmup_epochs == other.warmup_cosine_lr_warmup_epochs
+ and self.weight_decay == other.weight_decay
+ )
+
+
+class ImageModelSettingsClassification(ImageModelDistributionSettings):
+ """Model settings for AutoML Image Classification tasks.
+
+ :param advanced_settings: Settings for advanced scenarios.
+ :type advanced_settings: str
+ :param ams_gradient: Enable AMSGrad when optimizer is 'adam' or 'adamw'.
+ :type ams_gradient: bool
+ :param beta1: Value of 'beta1' when optimizer is 'adam' or 'adamw'. Must be a float in the range
+ [0, 1].
+ :type beta1: float
+ :param beta2: Value of 'beta2' when optimizer is 'adam' or 'adamw'. Must be a float in the range
+ [0, 1].
+ :type beta2: float
+ :param checkpoint_frequency: Frequency to store model checkpoints. Must be a positive integer.
+ :type checkpoint_frequency: int
+ :param checkpoint_run_id: The id of a previous run that has a pretrained checkpoint for
+ incremental training.
+ :type checkpoint_run_id: str
+ :param distributed: Whether to use distributed training.
+ :type distributed: bool
+ :param early_stopping: Enable early stopping logic during training.
+ :type early_stopping: bool
+ :param early_stopping_delay: Minimum number of epochs or validation evaluations to wait before
+ primary metric improvement
+ is tracked for early stopping. Must be a positive integer.
+ :type early_stopping_delay: int
+ :param early_stopping_patience: Minimum number of epochs or validation evaluations with no
+ primary metric improvement before
+ the run is stopped. Must be a positive integer.
+ :type early_stopping_patience: int
+ :param enable_onnx_normalization: Enable normalization when exporting ONNX model.
+ :type enable_onnx_normalization: bool
+ :param evaluation_frequency: Frequency to evaluate validation dataset to get metric scores. Must
+ be a positive integer.
+ :type evaluation_frequency: int
+ :param gradient_accumulation_step: Gradient accumulation means running a configured number of
+ "GradAccumulationStep" steps without
+ updating the model weights while accumulating the gradients of those steps, and then using
+ the accumulated gradients to compute the weight updates. Must be a positive integer.
+ :type gradient_accumulation_step: int
+ :param layers_to_freeze: Number of layers to freeze for the model. Must be a positive integer.
+ For instance, passing 2 as value for 'seresnext' means
+ freezing layer0 and layer1. For a full list of models supported and details on layer freeze,
+ please
+ see: https://learn.microsoft.com/azure/machine-learning/how-to-auto-train-image-models.
+ :type layers_to_freeze: int
+ :param learning_rate: Initial learning rate. Must be a float in the range [0, 1].
+ :type learning_rate: float
+ :param learning_rate_scheduler: Type of learning rate scheduler. Must be 'warmup_cosine' or
+ 'step'. Possible values include: "None", "WarmupCosine", "Step".
+ :type learning_rate_scheduler: str or
+ ~azure.mgmt.machinelearningservices.models.LearningRateScheduler
+ :param model_name: Name of the model to use for training.
+ For more information on the available models please visit the official documentation:
+ https://learn.microsoft.com/azure/machine-learning/how-to-auto-train-image-models.
+ :type model_name: str
+ :param momentum: Value of momentum when optimizer is 'sgd'. Must be a float in the range [0, 1].
+ :type momentum: float
+ :param nesterov: Enable nesterov when optimizer is 'sgd'.
+ :type nesterov: bool
+ :param number_of_epochs: Number of training epochs. Must be a positive integer.
+ :type number_of_epochs: int
+ :param number_of_workers: Number of data loader workers. Must be a non-negative integer.
+ :type number_of_workers: int
+ :param optimizer: Type of optimizer. Possible values include: "None", "Sgd", "Adam", "Adamw".
+ :type optimizer: str or ~azure.mgmt.machinelearningservices.models.StochasticOptimizer
+ :param random_seed: Random seed to be used when using deterministic training.
+ :type random_seed: int
+ :param step_lr_gamma: Value of gamma when learning rate scheduler is 'step'. Must be a float in
+ the range [0, 1].
+ :type step_lr_gamma: float
+ :param step_lr_step_size: Value of step size when learning rate scheduler is 'step'. Must be a
+ positive integer.
+ :type step_lr_step_size: int
+ :param training_batch_size: Training batch size. Must be a positive integer.
+ :type training_batch_size: int
+ :param validation_batch_size: Validation batch size. Must be a positive integer.
+ :type validation_batch_size: int
+ :param warmup_cosine_lr_cycles: Value of cosine cycle when learning rate scheduler is
+ 'warmup_cosine'. Must be a float in the range [0, 1].
+ :type warmup_cosine_lr_cycles: float
+ :param warmup_cosine_lr_warmup_epochs: Value of warmup epochs when learning rate scheduler is
+ 'warmup_cosine'. Must be a positive integer.
+ :type warmup_cosine_lr_warmup_epochs: int
+ :param weight_decay: Value of weight decay when optimizer is 'sgd', 'adam', or 'adamw'. Must be
+ a float in the range[0, 1].
+ :type weight_decay: float
+ :param training_crop_size: Image crop size that is input to the neural network for the training
+ dataset. Must be a positive integer.
+ :type training_crop_size: int
+ :param validation_crop_size: Image crop size that is input to the neural network for the
+ validation dataset. Must be a positive integer.
+ :type validation_crop_size: int
+ :param validation_resize_size: Image size to which to resize before cropping for validation
+ dataset. Must be a positive integer.
+ :type validation_resize_size: int
+ :param weighted_loss: Weighted loss. The accepted values are 0 for no weighted loss.
+ 1 for weighted loss with sqrt.(class_weights). 2 for weighted loss with class_weights. Must be
+ 0 or 1 or 2.
+ :type weighted_loss: int
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_automl_image.py
+ :start-after: [START automl.automl_image_job.image_classification_model_settings]
+ :end-before: [END automl.automl_image_job.image_classification_model_settings]
+ :language: python
+ :dedent: 8
+ :caption: Defining the automl image classification model settings.
+ """
+
+ def __init__(
+ self,
+ *,
+ advanced_settings: Optional[str] = None,
+ ams_gradient: Optional[bool] = None,
+ beta1: Optional[float] = None,
+ beta2: Optional[float] = None,
+ checkpoint_frequency: Optional[int] = None,
+ checkpoint_run_id: Optional[str] = None,
+ distributed: Optional[bool] = None,
+ early_stopping: Optional[bool] = None,
+ early_stopping_delay: Optional[int] = None,
+ early_stopping_patience: Optional[int] = None,
+ enable_onnx_normalization: Optional[bool] = None,
+ evaluation_frequency: Optional[int] = None,
+ gradient_accumulation_step: Optional[int] = None,
+ layers_to_freeze: Optional[int] = None,
+ learning_rate: Optional[float] = None,
+ learning_rate_scheduler: Optional[LearningRateScheduler] = None,
+ model_name: Optional[str] = None,
+ momentum: Optional[float] = None,
+ nesterov: Optional[bool] = None,
+ number_of_epochs: Optional[int] = None,
+ number_of_workers: Optional[int] = None,
+ optimizer: Optional[StochasticOptimizer] = None,
+ random_seed: Optional[int] = None,
+ step_lr_gamma: Optional[float] = None,
+ step_lr_step_size: Optional[int] = None,
+ training_batch_size: Optional[int] = None,
+ validation_batch_size: Optional[int] = None,
+ warmup_cosine_lr_cycles: Optional[float] = None,
+ warmup_cosine_lr_warmup_epochs: Optional[int] = None,
+ weight_decay: Optional[float] = None,
+ training_crop_size: Optional[int] = None,
+ validation_crop_size: Optional[int] = None,
+ validation_resize_size: Optional[int] = None,
+ weighted_loss: Optional[int] = None,
+ **kwargs: Any,
+ ):
+ super(ImageModelSettingsClassification, self).__init__(
+ advanced_settings=advanced_settings,
+ ams_gradient=ams_gradient,
+ beta1=beta1,
+ beta2=beta2,
+ checkpoint_frequency=checkpoint_frequency,
+ checkpoint_run_id=checkpoint_run_id,
+ distributed=distributed,
+ early_stopping=early_stopping,
+ early_stopping_delay=early_stopping_delay,
+ early_stopping_patience=early_stopping_patience,
+ enable_onnx_normalization=enable_onnx_normalization,
+ evaluation_frequency=evaluation_frequency,
+ gradient_accumulation_step=gradient_accumulation_step,
+ layers_to_freeze=layers_to_freeze,
+ learning_rate=learning_rate,
+ learning_rate_scheduler=learning_rate_scheduler,
+ model_name=model_name,
+ momentum=momentum,
+ nesterov=nesterov,
+ number_of_epochs=number_of_epochs,
+ number_of_workers=number_of_workers,
+ optimizer=optimizer,
+ random_seed=random_seed,
+ step_lr_gamma=step_lr_gamma,
+ step_lr_step_size=step_lr_step_size,
+ training_batch_size=training_batch_size,
+ validation_batch_size=validation_batch_size,
+ warmup_cosine_lr_cycles=warmup_cosine_lr_cycles,
+ warmup_cosine_lr_warmup_epochs=warmup_cosine_lr_warmup_epochs,
+ weight_decay=weight_decay,
+ **kwargs,
+ )
+ self.training_crop_size = training_crop_size
+ self.validation_crop_size = validation_crop_size
+ self.validation_resize_size = validation_resize_size
+ self.weighted_loss = weighted_loss
+
+ def _to_rest_object(self) -> RestImageModelSettingsClassification:
+ return RestImageModelSettingsClassification(
+ advanced_settings=self.advanced_settings,
+ ams_gradient=self.ams_gradient,
+ beta1=self.beta1,
+ beta2=self.beta2,
+ checkpoint_frequency=self.checkpoint_frequency,
+ checkpoint_run_id=self.checkpoint_run_id,
+ distributed=self.distributed,
+ early_stopping=self.early_stopping,
+ early_stopping_delay=self.early_stopping_delay,
+ early_stopping_patience=self.early_stopping_patience,
+ enable_onnx_normalization=self.enable_onnx_normalization,
+ evaluation_frequency=self.evaluation_frequency,
+ gradient_accumulation_step=self.gradient_accumulation_step,
+ layers_to_freeze=self.layers_to_freeze,
+ learning_rate=self.learning_rate,
+ learning_rate_scheduler=self.learning_rate_scheduler,
+ model_name=self.model_name,
+ momentum=self.momentum,
+ nesterov=self.nesterov,
+ number_of_epochs=self.number_of_epochs,
+ number_of_workers=self.number_of_workers,
+ optimizer=self.optimizer,
+ random_seed=self.random_seed,
+ step_lr_gamma=self.step_lr_gamma,
+ step_lr_step_size=self.step_lr_step_size,
+ training_batch_size=self.training_batch_size,
+ validation_batch_size=self.validation_batch_size,
+ warmup_cosine_lr_cycles=self.warmup_cosine_lr_cycles,
+ warmup_cosine_lr_warmup_epochs=self.warmup_cosine_lr_warmup_epochs,
+ weight_decay=self.weight_decay,
+ training_crop_size=self.training_crop_size,
+ validation_crop_size=self.validation_crop_size,
+ validation_resize_size=self.validation_resize_size,
+ weighted_loss=self.weighted_loss,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestImageModelSettingsClassification) -> "ImageModelSettingsClassification":
+ return cls(
+ advanced_settings=obj.advanced_settings,
+ ams_gradient=obj.ams_gradient,
+ beta1=obj.beta1,
+ beta2=obj.beta2,
+ checkpoint_frequency=obj.checkpoint_frequency,
+ checkpoint_run_id=obj.checkpoint_run_id,
+ distributed=obj.distributed,
+ early_stopping=obj.early_stopping,
+ early_stopping_delay=obj.early_stopping_delay,
+ early_stopping_patience=obj.early_stopping_patience,
+ enable_onnx_normalization=obj.enable_onnx_normalization,
+ evaluation_frequency=obj.evaluation_frequency,
+ gradient_accumulation_step=obj.gradient_accumulation_step,
+ layers_to_freeze=obj.layers_to_freeze,
+ learning_rate=obj.learning_rate,
+ learning_rate_scheduler=obj.learning_rate_scheduler,
+ model_name=obj.model_name,
+ momentum=obj.momentum,
+ nesterov=obj.nesterov,
+ number_of_epochs=obj.number_of_epochs,
+ number_of_workers=obj.number_of_workers,
+ optimizer=obj.optimizer,
+ random_seed=obj.random_seed,
+ step_lr_gamma=obj.step_lr_gamma,
+ step_lr_step_size=obj.step_lr_step_size,
+ training_batch_size=obj.training_batch_size,
+ validation_batch_size=obj.validation_batch_size,
+ warmup_cosine_lr_cycles=obj.warmup_cosine_lr_cycles,
+ warmup_cosine_lr_warmup_epochs=obj.warmup_cosine_lr_warmup_epochs,
+ weight_decay=obj.weight_decay,
+ training_crop_size=obj.training_crop_size,
+ validation_crop_size=obj.validation_crop_size,
+ validation_resize_size=obj.validation_resize_size,
+ weighted_loss=obj.weighted_loss,
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, ImageModelSettingsClassification):
+ return NotImplemented
+
+ return (
+ super().__eq__(other)
+ and self.training_crop_size == other.training_crop_size
+ and self.validation_crop_size == other.validation_crop_size
+ and self.validation_resize_size == other.validation_resize_size
+ and self.weighted_loss == other.weighted_loss
+ )
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
+
+
+class ImageModelSettingsObjectDetection(ImageModelDistributionSettings):
+ """Model settings for AutoML Image Object Detection Task.
+
+ :param advanced_settings: Settings for advanced scenarios.
+ :type advanced_settings: str
+ :param ams_gradient: Enable AMSGrad when optimizer is 'adam' or 'adamw'.
+ :type ams_gradient: bool
+ :param beta1: Value of 'beta1' when optimizer is 'adam' or 'adamw'. Must be a float in the range
+ [0, 1].
+ :type beta1: float
+ :param beta2: Value of 'beta2' when optimizer is 'adam' or 'adamw'. Must be a float in the range
+ [0, 1].
+ :type beta2: float
+ :param checkpoint_frequency: Frequency to store model checkpoints. Must be a positive integer.
+ :type checkpoint_frequency: int
+ :param checkpoint_run_id: The id of a previous run that has a pretrained checkpoint for
+ incremental training.
+ :type checkpoint_run_id: str
+ :param distributed: Whether to use distributed training.
+ :type distributed: bool
+ :param early_stopping: Enable early stopping logic during training.
+ :type early_stopping: bool
+ :param early_stopping_delay: Minimum number of epochs or validation evaluations to wait before
+ primary metric improvement
+ is tracked for early stopping. Must be a positive integer.
+ :type early_stopping_delay: int
+ :param early_stopping_patience: Minimum number of epochs or validation evaluations with no
+ primary metric improvement before
+ the run is stopped. Must be a positive integer.
+ :type early_stopping_patience: int
+ :param enable_onnx_normalization: Enable normalization when exporting ONNX model.
+ :type enable_onnx_normalization: bool
+ :param evaluation_frequency: Frequency to evaluate validation dataset to get metric scores. Must
+ be a positive integer.
+ :type evaluation_frequency: int
+ :param gradient_accumulation_step: Gradient accumulation means running a configured number of
+ "GradAccumulationStep" steps without
+ updating the model weights while accumulating the gradients of those steps, and then using
+ the accumulated gradients to compute the weight updates. Must be a positive integer.
+ :type gradient_accumulation_step: int
+ :param layers_to_freeze: Number of layers to freeze for the model. Must be a positive integer.
+ For instance, passing 2 as value for 'seresnext' means
+ freezing layer0 and layer1. For a full list of models supported and details on layer freeze,
+ please
+ see: https://learn.microsoft.com/azure/machine-learning/how-to-auto-train-image-models.
+ :type layers_to_freeze: int
+ :param learning_rate: Initial learning rate. Must be a float in the range [0, 1].
+ :type learning_rate: float
+ :param learning_rate_scheduler: Type of learning rate scheduler. Must be 'warmup_cosine' or
+ 'step'. Possible values include: "None", "WarmupCosine", "Step".
+ :type learning_rate_scheduler: str or
+ ~azure.mgmt.machinelearningservices.models.LearningRateScheduler
+ :param model_name: Name of the model to use for training.
+ For more information on the available models please visit the official documentation:
+ https://learn.microsoft.com/azure/machine-learning/how-to-auto-train-image-models.
+ :type model_name: str
+ :param momentum: Value of momentum when optimizer is 'sgd'. Must be a float in the range [0, 1].
+ :type momentum: float
+ :param nesterov: Enable nesterov when optimizer is 'sgd'.
+ :type nesterov: bool
+ :param number_of_epochs: Number of training epochs. Must be a positive integer.
+ :type number_of_epochs: int
+ :param number_of_workers: Number of data loader workers. Must be a non-negative integer.
+ :type number_of_workers: int
+ :param optimizer: Type of optimizer. Possible values include: "None", "Sgd", "Adam", "Adamw".
+ :type optimizer: str or ~azure.mgmt.machinelearningservices.models.StochasticOptimizer
+ :param random_seed: Random seed to be used when using deterministic training.
+ :type random_seed: int
+ :param step_lr_gamma: Value of gamma when learning rate scheduler is 'step'. Must be a float in
+ the range [0, 1].
+ :type step_lr_gamma: float
+ :param step_lr_step_size: Value of step size when learning rate scheduler is 'step'. Must be a
+ positive integer.
+ :type step_lr_step_size: int
+ :param training_batch_size: Training batch size. Must be a positive integer.
+ :type training_batch_size: int
+ :param validation_batch_size: Validation batch size. Must be a positive integer.
+ :type validation_batch_size: int
+ :param warmup_cosine_lr_cycles: Value of cosine cycle when learning rate scheduler is
+ 'warmup_cosine'. Must be a float in the range [0, 1].
+ :type warmup_cosine_lr_cycles: float
+ :param warmup_cosine_lr_warmup_epochs: Value of warmup epochs when learning rate scheduler is
+ 'warmup_cosine'. Must be a positive integer.
+ :type warmup_cosine_lr_warmup_epochs: int
+ :param weight_decay: Value of weight decay when optimizer is 'sgd', 'adam', or 'adamw'. Must be
+ a float in the range[0, 1].
+ :type weight_decay: float
+ :param box_detections_per_image: Maximum number of detections per image, for all classes. Must
+ be a positive integer.
+ Note: This settings is not supported for the 'yolov5' algorithm.
+ :type box_detections_per_image: int
+ :param box_score_threshold: During inference, only return proposals with a classification score
+ greater than
+ BoxScoreThreshold. Must be a float in the range[0, 1].
+ :type box_score_threshold: float
+ :param image_size: Image size for train and validation. Must be a positive integer.
+ Note: The training run may get into CUDA OOM if the size is too big.
+ Note: This settings is only supported for the 'yolov5' algorithm.
+ :type image_size: int
+ :param max_size: Maximum size of the image to be rescaled before feeding it to the backbone.
+ Must be a positive integer. Note: training run may get into CUDA OOM if the size is too big.
+ Note: This settings is not supported for the 'yolov5' algorithm.
+ :type max_size: int
+ :param min_size: Minimum size of the image to be rescaled before feeding it to the backbone.
+ Must be a positive integer. Note: training run may get into CUDA OOM if the size is too big.
+ Note: This settings is not supported for the 'yolov5' algorithm.
+ :type min_size: int
+ :param model_size: Model size. Must be 'small', 'medium', 'large'.
+ Note: training run may get into CUDA OOM if the model size is too big.
+ Note: This settings is only supported for the 'yolov5' algorithm. Possible values include:
+ "None", "Small", "Medium", "Large", "ExtraLarge".
+ :type model_size: str or ~azure.mgmt.machinelearningservices.models.ModelSize
+ :param multi_scale: Enable multi-scale image by varying image size by +/- 50%.
+ Note: training run may get into CUDA OOM if no sufficient GPU memory.
+ Note: This settings is only supported for the 'yolov5' algorithm.
+ :type multi_scale: bool
+ :param nms_iou_threshold: IOU threshold used during inference in NMS post processing. Must be a
+ float in the range [0, 1].
+ :type nms_iou_threshold: float
+ :param tile_grid_size: The grid size to use for tiling each image. Note: TileGridSize must not
+ be
+ None to enable small object detection logic. A string containing two integers in mxn format.
+ Note: This settings is not supported for the 'yolov5' algorithm.
+ :type tile_grid_size: str
+ :param tile_overlap_ratio: Overlap ratio between adjacent tiles in each dimension. Must be float
+ in the range [0, 1).
+ Note: This settings is not supported for the 'yolov5' algorithm.
+ :type tile_overlap_ratio: float
+ :param tile_predictions_nms_threshold: The IOU threshold to use to perform NMS while merging
+ predictions from tiles and image.
+ Used in validation/ inference. Must be float in the range [0, 1].
+ Note: This settings is not supported for the 'yolov5' algorithm.
+ :type tile_predictions_nms_threshold: float
+ :param validation_iou_threshold: IOU threshold to use when computing validation metric. Must be
+ float in the range [0, 1].
+ :type validation_iou_threshold: float
+ :param validation_metric_type: Metric computation method to use for validation metrics. Possible
+ values include: "None", "Coco", "Voc", "CocoVoc".
+ :type validation_metric_type: str or
+ ~azure.mgmt.machinelearningservices.models.ValidationMetricType
+ :param log_training_metrics: indicates whether or not to log training metrics
+ :type log_training_metrics: str or
+ ~azure.mgmt.machinelearningservices.models.LogTrainingMetrics
+ :param log_validation_loss: indicates whether or not to log validation loss
+ :type log_validation_loss: str or
+ ~azure.mgmt.machinelearningservices.models.LogValidationLoss
+
+ .. literalinclude:: ../samples/ml_samples_automl_image.py
+ :start-after: [START automl.automl_image_job.image_object_detection_model_settings]
+ :end-before: [END automl.automl_image_job.image_object_detection_model_settings]
+ :language: python
+ :dedent: 8
+ :caption: Defining the automl image object detection or instance segmentation model settings.
+ """
+
+ def __init__(
+ self,
+ *,
+ advanced_settings: Optional[str] = None,
+ ams_gradient: Optional[bool] = None,
+ beta1: Optional[float] = None,
+ beta2: Optional[float] = None,
+ checkpoint_frequency: Optional[int] = None,
+ checkpoint_run_id: Optional[str] = None,
+ distributed: Optional[bool] = None,
+ early_stopping: Optional[bool] = None,
+ early_stopping_delay: Optional[int] = None,
+ early_stopping_patience: Optional[int] = None,
+ enable_onnx_normalization: Optional[bool] = None,
+ evaluation_frequency: Optional[int] = None,
+ gradient_accumulation_step: Optional[int] = None,
+ layers_to_freeze: Optional[int] = None,
+ learning_rate: Optional[float] = None,
+ learning_rate_scheduler: Optional[LearningRateScheduler] = None,
+ model_name: Optional[str] = None,
+ momentum: Optional[float] = None,
+ nesterov: Optional[bool] = None,
+ number_of_epochs: Optional[int] = None,
+ number_of_workers: Optional[int] = None,
+ optimizer: Optional[StochasticOptimizer] = None,
+ random_seed: Optional[int] = None,
+ step_lr_gamma: Optional[float] = None,
+ step_lr_step_size: Optional[int] = None,
+ training_batch_size: Optional[int] = None,
+ validation_batch_size: Optional[int] = None,
+ warmup_cosine_lr_cycles: Optional[float] = None,
+ warmup_cosine_lr_warmup_epochs: Optional[int] = None,
+ weight_decay: Optional[float] = None,
+ box_detections_per_image: Optional[int] = None,
+ box_score_threshold: Optional[float] = None,
+ image_size: Optional[int] = None,
+ max_size: Optional[int] = None,
+ min_size: Optional[int] = None,
+ model_size: Optional[ModelSize] = None,
+ multi_scale: Optional[bool] = None,
+ nms_iou_threshold: Optional[float] = None,
+ tile_grid_size: Optional[str] = None,
+ tile_overlap_ratio: Optional[float] = None,
+ tile_predictions_nms_threshold: Optional[float] = None,
+ validation_iou_threshold: Optional[float] = None,
+ validation_metric_type: Optional[ValidationMetricType] = None,
+ log_training_metrics: Optional[LogTrainingMetrics] = None,
+ log_validation_loss: Optional[LogValidationLoss] = None,
+ **kwargs: Any,
+ ):
+ super(ImageModelSettingsObjectDetection, self).__init__(
+ advanced_settings=advanced_settings,
+ ams_gradient=ams_gradient,
+ beta1=beta1,
+ beta2=beta2,
+ checkpoint_frequency=checkpoint_frequency,
+ checkpoint_run_id=checkpoint_run_id,
+ distributed=distributed,
+ early_stopping=early_stopping,
+ early_stopping_delay=early_stopping_delay,
+ early_stopping_patience=early_stopping_patience,
+ enable_onnx_normalization=enable_onnx_normalization,
+ evaluation_frequency=evaluation_frequency,
+ gradient_accumulation_step=gradient_accumulation_step,
+ layers_to_freeze=layers_to_freeze,
+ learning_rate=learning_rate,
+ learning_rate_scheduler=learning_rate_scheduler,
+ model_name=model_name,
+ momentum=momentum,
+ nesterov=nesterov,
+ number_of_epochs=number_of_epochs,
+ number_of_workers=number_of_workers,
+ optimizer=optimizer,
+ random_seed=random_seed,
+ step_lr_gamma=step_lr_gamma,
+ step_lr_step_size=step_lr_step_size,
+ training_batch_size=training_batch_size,
+ validation_batch_size=validation_batch_size,
+ warmup_cosine_lr_cycles=warmup_cosine_lr_cycles,
+ warmup_cosine_lr_warmup_epochs=warmup_cosine_lr_warmup_epochs,
+ weight_decay=weight_decay,
+ **kwargs,
+ )
+ self.box_detections_per_image = box_detections_per_image
+ self.box_score_threshold = box_score_threshold
+ self.image_size = image_size
+ self.max_size = max_size
+ self.min_size = min_size
+ self.model_size = model_size
+ self.multi_scale = multi_scale
+ self.nms_iou_threshold = nms_iou_threshold
+ self.tile_grid_size = tile_grid_size
+ self.tile_overlap_ratio = tile_overlap_ratio
+ self.tile_predictions_nms_threshold = tile_predictions_nms_threshold
+ self.validation_iou_threshold = validation_iou_threshold
+ self.validation_metric_type = validation_metric_type
+ self.log_training_metrics = log_training_metrics
+ self.log_validation_loss = log_validation_loss
+
+ def _to_rest_object(self) -> RestImageModelSettingsObjectDetection:
+ return RestImageModelSettingsObjectDetection(
+ advanced_settings=self.advanced_settings,
+ ams_gradient=self.ams_gradient,
+ beta1=self.beta1,
+ beta2=self.beta2,
+ checkpoint_frequency=self.checkpoint_frequency,
+ checkpoint_run_id=self.checkpoint_run_id,
+ distributed=self.distributed,
+ early_stopping=self.early_stopping,
+ early_stopping_delay=self.early_stopping_delay,
+ early_stopping_patience=self.early_stopping_patience,
+ enable_onnx_normalization=self.enable_onnx_normalization,
+ evaluation_frequency=self.evaluation_frequency,
+ gradient_accumulation_step=self.gradient_accumulation_step,
+ layers_to_freeze=self.layers_to_freeze,
+ learning_rate=self.learning_rate,
+ learning_rate_scheduler=self.learning_rate_scheduler,
+ model_name=self.model_name,
+ momentum=self.momentum,
+ nesterov=self.nesterov,
+ number_of_epochs=self.number_of_epochs,
+ number_of_workers=self.number_of_workers,
+ optimizer=self.optimizer,
+ random_seed=self.random_seed,
+ step_lr_gamma=self.step_lr_gamma,
+ step_lr_step_size=self.step_lr_step_size,
+ training_batch_size=self.training_batch_size,
+ validation_batch_size=self.validation_batch_size,
+ warmup_cosine_lr_cycles=self.warmup_cosine_lr_cycles,
+ warmup_cosine_lr_warmup_epochs=self.warmup_cosine_lr_warmup_epochs,
+ weight_decay=self.weight_decay,
+ box_detections_per_image=self.box_detections_per_image,
+ box_score_threshold=self.box_score_threshold,
+ image_size=self.image_size,
+ max_size=self.max_size,
+ min_size=self.min_size,
+ model_size=self.model_size,
+ multi_scale=self.multi_scale,
+ nms_iou_threshold=self.nms_iou_threshold,
+ tile_grid_size=self.tile_grid_size,
+ tile_overlap_ratio=self.tile_overlap_ratio,
+ tile_predictions_nms_threshold=self.tile_predictions_nms_threshold,
+ validation_iou_threshold=self.validation_iou_threshold,
+ validation_metric_type=self.validation_metric_type,
+ log_training_metrics=self.log_training_metrics,
+ log_validation_loss=self.log_validation_loss,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestImageModelSettingsObjectDetection) -> "ImageModelSettingsObjectDetection":
+ return cls(
+ advanced_settings=obj.advanced_settings,
+ ams_gradient=obj.ams_gradient,
+ beta1=obj.beta1,
+ beta2=obj.beta2,
+ checkpoint_frequency=obj.checkpoint_frequency,
+ checkpoint_run_id=obj.checkpoint_run_id,
+ distributed=obj.distributed,
+ early_stopping=obj.early_stopping,
+ early_stopping_delay=obj.early_stopping_delay,
+ early_stopping_patience=obj.early_stopping_patience,
+ enable_onnx_normalization=obj.enable_onnx_normalization,
+ evaluation_frequency=obj.evaluation_frequency,
+ gradient_accumulation_step=obj.gradient_accumulation_step,
+ layers_to_freeze=obj.layers_to_freeze,
+ learning_rate=obj.learning_rate,
+ learning_rate_scheduler=obj.learning_rate_scheduler,
+ model_name=obj.model_name,
+ momentum=obj.momentum,
+ nesterov=obj.nesterov,
+ number_of_epochs=obj.number_of_epochs,
+ number_of_workers=obj.number_of_workers,
+ optimizer=obj.optimizer,
+ random_seed=obj.random_seed,
+ step_lr_gamma=obj.step_lr_gamma,
+ step_lr_step_size=obj.step_lr_step_size,
+ training_batch_size=obj.training_batch_size,
+ validation_batch_size=obj.validation_batch_size,
+ warmup_cosine_lr_cycles=obj.warmup_cosine_lr_cycles,
+ warmup_cosine_lr_warmup_epochs=obj.warmup_cosine_lr_warmup_epochs,
+ weight_decay=obj.weight_decay,
+ box_detections_per_image=obj.box_detections_per_image,
+ box_score_threshold=obj.box_score_threshold,
+ image_size=obj.image_size,
+ max_size=obj.max_size,
+ min_size=obj.min_size,
+ model_size=obj.model_size,
+ multi_scale=obj.multi_scale,
+ nms_iou_threshold=obj.nms_iou_threshold,
+ tile_grid_size=obj.tile_grid_size,
+ tile_overlap_ratio=obj.tile_overlap_ratio,
+ tile_predictions_nms_threshold=obj.tile_predictions_nms_threshold,
+ validation_iou_threshold=obj.validation_iou_threshold,
+ validation_metric_type=obj.validation_metric_type,
+ log_training_metrics=obj.log_training_metrics,
+ log_validation_loss=obj.log_validation_loss,
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, ImageModelSettingsObjectDetection):
+ return NotImplemented
+
+ return (
+ super().__eq__(other)
+ and self.box_detections_per_image == other.box_detections_per_image
+ and self.box_score_threshold == other.box_score_threshold
+ and self.image_size == other.image_size
+ and self.max_size == other.max_size
+ and self.min_size == other.min_size
+ and self.model_size == other.model_size
+ and self.multi_scale == other.multi_scale
+ and self.nms_iou_threshold == other.nms_iou_threshold
+ and self.tile_grid_size == other.tile_grid_size
+ and self.tile_overlap_ratio == other.tile_overlap_ratio
+ and self.tile_predictions_nms_threshold == other.tile_predictions_nms_threshold
+ and self.validation_iou_threshold == other.validation_iou_threshold
+ and self.validation_metric_type == other.validation_metric_type
+ and self.log_training_metrics == other.log_training_metrics
+ and self.log_validation_loss == other.log_validation_loss
+ )
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_object_detection_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_object_detection_job.py
new file mode 100644
index 00000000..f8d070d2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_object_detection_job.py
@@ -0,0 +1,240 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from typing import Any, Dict, Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import AutoMLJob as RestAutoMLJob
+from azure.ai.ml._restclient.v2023_04_01_preview.models import ImageObjectDetection as RestImageObjectDetection
+from azure.ai.ml._restclient.v2023_04_01_preview.models import JobBase, ObjectDetectionPrimaryMetrics, TaskType
+from azure.ai.ml._utils.utils import camel_to_snake, is_data_binding_expression
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+from azure.ai.ml.constants._job.automl import AutoMLConstants
+from azure.ai.ml.entities._credentials import _BaseJobIdentityConfiguration
+from azure.ai.ml.entities._job._input_output_helpers import from_rest_data_outputs, to_rest_data_outputs
+from azure.ai.ml.entities._job.automl.image.automl_image_object_detection_base import AutoMLImageObjectDetectionBase
+from azure.ai.ml.entities._job.automl.image.image_limit_settings import ImageLimitSettings
+from azure.ai.ml.entities._job.automl.image.image_model_settings import ImageModelSettingsObjectDetection
+from azure.ai.ml.entities._job.automl.image.image_sweep_settings import ImageSweepSettings
+from azure.ai.ml.entities._util import load_from_dict
+
+
+class ImageObjectDetectionJob(AutoMLImageObjectDetectionBase):
+ """Configuration for AutoML Image Object Detection job.
+
+ :keyword primary_metric: The primary metric to use for optimization.
+ :paramtype primary_metric: Optional[str, ~azure.ai.ml.ObjectDetectionPrimaryMetrics]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_automl_image.py
+ :start-after: [START automl.automl_image_job.image_object_detection_job]
+ :end-before: [END automl.automl_image_job.image_object_detection_job]
+ :language: python
+ :dedent: 8
+ :caption: creating an automl image object detection job
+ """
+
+ _DEFAULT_PRIMARY_METRIC = ObjectDetectionPrimaryMetrics.MEAN_AVERAGE_PRECISION
+
+ def __init__(
+ self,
+ *,
+ primary_metric: Optional[Union[str, ObjectDetectionPrimaryMetrics]] = None,
+ **kwargs: Any,
+ ) -> None:
+
+ # Extract any super class init settings
+ limits = kwargs.pop("limits", None)
+ sweep = kwargs.pop("sweep", None)
+ training_parameters = kwargs.pop("training_parameters", None)
+ search_space = kwargs.pop("search_space", None)
+
+ super().__init__(
+ task_type=TaskType.IMAGE_OBJECT_DETECTION,
+ limits=limits,
+ sweep=sweep,
+ training_parameters=training_parameters,
+ search_space=search_space,
+ **kwargs,
+ )
+
+ self.primary_metric = primary_metric or ImageObjectDetectionJob._DEFAULT_PRIMARY_METRIC
+
+ @property
+ def primary_metric(self) -> Union[str, ObjectDetectionPrimaryMetrics]:
+ return self._primary_metric
+
+ @primary_metric.setter
+ def primary_metric(self, value: Union[str, ObjectDetectionPrimaryMetrics]) -> None:
+ if is_data_binding_expression(str(value), ["parent"]):
+ self._primary_metric = value
+ return
+ self._primary_metric = (
+ ImageObjectDetectionJob._DEFAULT_PRIMARY_METRIC
+ if value is None
+ else ObjectDetectionPrimaryMetrics[camel_to_snake(value).upper()]
+ )
+
+ def _to_rest_object(self) -> JobBase:
+ image_object_detection_task = RestImageObjectDetection(
+ target_column_name=self.target_column_name,
+ training_data=self.training_data,
+ validation_data=self.validation_data,
+ validation_data_size=self.validation_data_size,
+ limit_settings=self._limits._to_rest_object() if self._limits else None,
+ sweep_settings=self._sweep._to_rest_object() if self._sweep else None,
+ model_settings=self._training_parameters._to_rest_object() if self._training_parameters else None,
+ search_space=(
+ [entry._to_rest_object() for entry in self._search_space if entry is not None]
+ if self._search_space is not None
+ else None
+ ),
+ primary_metric=self.primary_metric,
+ log_verbosity=self.log_verbosity,
+ )
+ # resolve data inputs in rest object
+ self._resolve_data_inputs(image_object_detection_task)
+
+ properties = RestAutoMLJob(
+ display_name=self.display_name,
+ description=self.description,
+ experiment_name=self.experiment_name,
+ tags=self.tags,
+ compute_id=self.compute,
+ properties=self.properties,
+ environment_id=self.environment_id,
+ environment_variables=self.environment_variables,
+ services=self.services,
+ outputs=to_rest_data_outputs(self.outputs),
+ resources=self.resources,
+ task_details=image_object_detection_task,
+ identity=self.identity._to_job_rest_object() if self.identity else None,
+ queue_settings=self.queue_settings,
+ )
+
+ result = JobBase(properties=properties)
+ result.name = self.name
+ return result
+
+ @classmethod
+ def _from_rest_object(cls, obj: JobBase) -> "ImageObjectDetectionJob":
+ properties: RestAutoMLJob = obj.properties
+ task_details: RestImageObjectDetection = properties.task_details
+
+ job_args_dict = {
+ "id": obj.id,
+ "name": obj.name,
+ "description": properties.description,
+ "tags": properties.tags,
+ "properties": properties.properties,
+ "experiment_name": properties.experiment_name,
+ "services": properties.services,
+ "status": properties.status,
+ "creation_context": obj.system_data,
+ "display_name": properties.display_name,
+ "compute": properties.compute_id,
+ "outputs": from_rest_data_outputs(properties.outputs),
+ "resources": properties.resources,
+ "identity": (
+ _BaseJobIdentityConfiguration._from_rest_object(properties.identity) if properties.identity else None
+ ),
+ "queue_settings": properties.queue_settings,
+ }
+
+ image_object_detection_job = cls(
+ target_column_name=task_details.target_column_name,
+ training_data=task_details.training_data,
+ validation_data=task_details.validation_data,
+ validation_data_size=task_details.validation_data_size,
+ limits=(
+ ImageLimitSettings._from_rest_object(task_details.limit_settings)
+ if task_details.limit_settings
+ else None
+ ),
+ sweep=(
+ ImageSweepSettings._from_rest_object(task_details.sweep_settings)
+ if task_details.sweep_settings
+ else None
+ ),
+ training_parameters=(
+ ImageModelSettingsObjectDetection._from_rest_object(task_details.model_settings)
+ if task_details.model_settings
+ else None
+ ),
+ search_space=cls._get_search_space_from_str(task_details.search_space),
+ primary_metric=task_details.primary_metric,
+ log_verbosity=task_details.log_verbosity,
+ **job_args_dict,
+ )
+
+ image_object_detection_job._restore_data_inputs()
+
+ return image_object_detection_job
+
+ @classmethod
+ def _load_from_dict(
+ cls,
+ data: Dict,
+ context: Dict,
+ additional_message: str,
+ **kwargs: Any,
+ ) -> "ImageObjectDetectionJob":
+ from azure.ai.ml._schema.automl.image_vertical.image_object_detection import ImageObjectDetectionSchema
+ from azure.ai.ml._schema.pipeline.automl_node import ImageObjectDetectionNodeSchema
+
+ if kwargs.pop("inside_pipeline", False):
+ if context.get("inside_pipeline", None) is None:
+ context["inside_pipeline"] = True
+ loaded_data = load_from_dict(
+ ImageObjectDetectionNodeSchema,
+ data,
+ context,
+ additional_message,
+ **kwargs,
+ )
+ else:
+ loaded_data = load_from_dict(ImageObjectDetectionSchema, data, context, additional_message, **kwargs)
+ job_instance = cls._create_instance_from_schema_dict(loaded_data)
+ return job_instance
+
+ @classmethod
+ def _create_instance_from_schema_dict(cls, loaded_data: Dict) -> "ImageObjectDetectionJob":
+ loaded_data.pop(AutoMLConstants.TASK_TYPE_YAML, None)
+ data_settings = {
+ "training_data": loaded_data.pop("training_data"),
+ "target_column_name": loaded_data.pop("target_column_name"),
+ "validation_data": loaded_data.pop("validation_data", None),
+ "validation_data_size": loaded_data.pop("validation_data_size", None),
+ }
+ job = ImageObjectDetectionJob(**loaded_data)
+ job.set_data(**data_settings)
+ return job
+
+ def _to_dict(self, inside_pipeline: bool = False) -> Dict:
+ from azure.ai.ml._schema.automl.image_vertical.image_object_detection import ImageObjectDetectionSchema
+ from azure.ai.ml._schema.pipeline.automl_node import ImageObjectDetectionNodeSchema
+
+ schema_dict: dict = {}
+ if inside_pipeline:
+ schema_dict = ImageObjectDetectionNodeSchema(
+ context={BASE_PATH_CONTEXT_KEY: "./", "inside_pipeline": True}
+ ).dump(self)
+ else:
+ schema_dict = ImageObjectDetectionSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+
+ return schema_dict
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, ImageObjectDetectionJob):
+ return NotImplemented
+
+ if not super().__eq__(other):
+ return False
+
+ return self.primary_metric == other.primary_metric
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_object_detection_search_space.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_object_detection_search_space.py
new file mode 100644
index 00000000..a9004d1e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_object_detection_search_space.py
@@ -0,0 +1,899 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=R0902,too-many-locals
+
+from typing import Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import ImageModelDistributionSettingsObjectDetection
+from azure.ai.ml.entities._job.automl.search_space import SearchSpace
+from azure.ai.ml.entities._job.automl.search_space_utils import _convert_from_rest_object, _convert_to_rest_object
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+from azure.ai.ml.sweep import (
+ Choice,
+ LogNormal,
+ LogUniform,
+ Normal,
+ QLogNormal,
+ QLogUniform,
+ QNormal,
+ QUniform,
+ Randint,
+ Uniform,
+)
+
+
+class ImageObjectDetectionSearchSpace(RestTranslatableMixin):
+ """Search space for AutoML Image Object Detection and Image Instance Segmentation tasks.
+
+ :param ams_gradient: Enable AMSGrad when optimizer is 'adam' or 'adamw'.
+ :type ams_gradient: bool or ~azure.ai.ml.entities.SweepDistribution
+ :param beta1: Value of 'beta1' when optimizer is 'adam' or 'adamw'. Must be a float in the
+ range [0, 1].
+ :type beta1: float or ~azure.ai.ml.entities.SweepDistribution
+ :param beta2: Value of 'beta2' when optimizer is 'adam' or 'adamw'. Must be a float in the
+ range [0, 1].
+ :type beta2: float or ~azure.ai.ml.entities.SweepDistribution
+ :param distributed: Whether to use distributer training.
+ :type distributed: bool or ~azure.ai.ml.entities.SweepDistribution
+ :param early_stopping: Enable early stopping logic during training.
+ :type early_stopping: bool or ~azure.ai.ml.entities.SweepDistribution
+ :param early_stopping_delay: Minimum number of epochs or validation evaluations to wait
+ before primary metric improvement
+ is tracked for early stopping. Must be a positive integer.
+ :type early_stopping_delay: int or ~azure.ai.ml.entities.SweepDistribution
+ :param early_stopping_patience: Minimum number of epochs or validation evaluations with no
+ primary metric improvement before the run is stopped. Must be a positive integer.
+ :type early_stopping_patience: int or ~azure.ai.ml.entities.SweepDistribution
+ :param enable_onnx_normalization: Enable normalization when exporting ONNX model.
+ :type enable_onnx_normalization: bool or ~azure.ai.ml.entities.SweepDistribution
+ :param evaluation_frequency: Frequency to evaluate validation dataset to get metric scores.
+ Must be a positive integer.
+ :type evaluation_frequency: int or ~azure.ai.ml.entities.SweepDistribution
+ :param gradient_accumulation_step: Gradient accumulation means running a configured number of
+ "GradAccumulationStep" steps without updating the model weights while accumulating the gradients of those steps,
+ and then using the accumulated gradients to compute the weight updates. Must be a positive integer.
+ :type gradient_accumulation_step: int or ~azure.ai.ml.entities.SweepDistribution
+ :param layers_to_freeze: Number of layers to freeze for the model. Must be a positive
+ integer. For instance, passing 2 as value for 'seresnext' means freezing layer0 and layer1.
+ For a full list of models supported and details on layer freeze, please
+ see: https://learn.microsoft.com/azure/machine-learning/reference-automl-images-hyperparameters#model-agnostic-hyperparameters. # pylint: disable=line-too-long
+ :type layers_to_freeze: int or ~azure.ai.ml.entities.SweepDistribution
+ :param learning_rate: Initial learning rate. Must be a float in the range [0, 1].
+ :type learning_rate: float or ~azure.ai.ml.entities.SweepDistribution
+ :param learning_rate_scheduler: Type of learning rate scheduler. Must be 'warmup_cosine' or
+ 'step'.
+ :type learning_rate_scheduler: str or ~azure.ai.ml.entities.SweepDistribution
+ :param model_name: Name of the model to use for training.
+ For more information on the available models please visit the official documentation:
+ https://learn.microsoft.com/azure/machine-learning/how-to-auto-train-image-models.
+ :type model_name: str or ~azure.ai.ml.entities.SweepDistribution
+ :param momentum: Value of momentum when optimizer is 'sgd'. Must be a float in the range [0,
+ 1].
+ :type momentum: float or ~azure.ai.ml.entities.SweepDistribution
+ :param nesterov: Enable nesterov when optimizer is 'sgd'.
+ :type nesterov: bool or ~azure.ai.ml.entities.SweepDistribution
+ :param number_of_epochs: Number of training epochs. Must be a positive integer.
+ :type number_of_epochs: int or ~azure.ai.ml.entities.SweepDistribution
+ :param number_of_workers: Number of data loader workers. Must be a non-negative integer.
+ :type number_of_workers: int or ~azure.ai.ml.entities.SweepDistribution
+ :param optimizer: Type of optimizer. Must be either 'sgd', 'adam', or 'adamw'.
+ :type optimizer: str or ~azure.ai.ml.entities.SweepDistribution
+ :param random_seed: Random seed to be used when using deterministic training.
+ :type random_seed: int or ~azure.ai.ml.entities.SweepDistribution
+ :param step_lr_gamma: Value of gamma when learning rate scheduler is 'step'. Must be a float
+ in the range [0, 1].
+ :type step_lr_gamma: float or ~azure.ai.ml.entities.SweepDistribution
+ :param step_lr_step_size: Value of step size when learning rate scheduler is 'step'. Must be
+ a positive integer.
+ :type step_lr_step_size: int or ~azure.ai.ml.entities.SweepDistribution
+ :param training_batch_size: Training batch size. Must be a positive integer.
+ :type training_batch_size: int or ~azure.ai.ml.entities.SweepDistribution
+ :param validation_batch_size: Validation batch size. Must be a positive integer.
+ :type validation_batch_size: int or ~azure.ai.ml.entities.SweepDistribution
+ :param warmup_cosine_lr_cycles: Value of cosine cycle when learning rate scheduler is
+ 'warmup_cosine'. Must be a float in the range [0, 1].
+ :type warmup_cosine_lr_cycles: float or ~azure.ai.ml.entities.SweepDistribution
+ :param warmup_cosine_lr_warmup_epochs: Value of warmup epochs when learning rate scheduler is
+ 'warmup_cosine'. Must be a positive integer.
+ :type warmup_cosine_lr_warmup_epochs: int or ~azure.ai.ml.entities.SweepDistribution
+ :param weight_decay: Value of weight decay when optimizer is 'sgd', 'adam', or 'adamw'. Must
+ be a float in the range[0, 1].
+ :type weight_decay: int or ~azure.ai.ml.entities.SweepDistribution
+ :param box_detections_per_image: Maximum number of detections per image, for all classes.
+ Must be a positive integer. Note: This settings is not supported for the 'yolov5' algorithm.
+ :type box_detections_per_image: int or ~azure.ai.ml.entities.SweepDistribution
+ :param box_score_threshold: During inference, only return proposals with a classification
+ score greater than BoxScoreThreshold. Must be a float in the range[0, 1].
+ :type box_score_threshold: float or ~azure.ai.ml.entities.SweepDistribution
+ :param image_size: Image size for train and validation. Must be a positive integer.
+ Note: The training run may get into CUDA OOM if the size is too big.
+ Note: This settings is only supported for the 'yolov5' algorithm.
+ :type image_size: int or ~azure.ai.ml.entities.SweepDistribution
+ :param max_size: Maximum size of the image to be rescaled before feeding it to the backbone.
+ Must be a positive integer. Note: training run may get into CUDA OOM if the size is too big.
+ Note: This settings is not supported for the 'yolov5' algorithm.
+ :type max_size: int or ~azure.ai.ml.entities.SweepDistribution
+ :param min_size: Minimum size of the image to be rescaled before feeding it to the backbone.
+ Must be a positive integer. Note: training run may get into CUDA OOM if the size is too big.
+ Note: This settings is not supported for the 'yolov5' algorithm.
+ :type min_size: int or ~azure.ai.ml.entities.SweepDistribution
+ :param model_size: Model size. Must be 'small', 'medium', 'large', or 'extra_large'.
+ Note: training run may get into CUDA OOM if the model size is too big.
+ Note: This settings is only supported for the 'yolov5' algorithm.
+ :type model_size: str or ~azure.ai.ml.entities.SweepDistribution
+ :param multi_scale: Enable multi-scale image by varying image size by +/- 50%.
+ Note: training run may get into CUDA OOM if no sufficient GPU memory.
+ Note: This settings is only supported for the 'yolov5' algorithm.
+ :type multi_scale: bool or ~azure.ai.ml.entities.SweepDistribution
+ :param nms_iou_threshold: IOU threshold used during inference in NMS post processing. Must be
+ float in the range [0, 1].
+ :type nms_iou_threshold: float or ~azure.ai.ml.entities.SweepDistribution
+ :param tile_grid_size: The grid size to use for tiling each image. Note: TileGridSize must
+ not be None to enable small object detection logic. A string containing two integers in mxn format.
+ :type tile_grid_size: str or ~azure.ai.ml.entities.SweepDistribution
+ :param tile_overlap_ratio: Overlap ratio between adjacent tiles in each dimension. Must be
+ float in the range [0, 1).
+ :type tile_overlap_ratio: float or ~azure.ai.ml.entities.SweepDistribution
+ :param tile_predictions_nms_threshold: The IOU threshold to use to perform NMS while merging
+ predictions from tiles and image. Used in validation/ inference. Must be float in the range [0, 1].
+ NMS: Non-maximum suppression.
+ :type tile_predictions_nms_threshold: float or ~azure.ai.ml.entities.SweepDistribution
+ :param validation_iou_threshold: IOU threshold to use when computing validation metric. Must
+ be float in the range [0, 1].
+ :type validation_iou_threshold: float or ~azure.ai.ml.entities.SweepDistribution
+ :param validation_metric_type: Metric computation method to use for validation metrics. Must
+ be 'none', 'coco', 'voc', or 'coco_voc'.
+ :type validation_metric_type: str or ~azure.ai.ml.entities.SweepDistribution
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_automl_image.py
+ :start-after: [START automl.automl_image_job.image_object_detection_search_space]
+ :end-before: [END automl.automl_image_job.image_object_detection_search_space]
+ :language: python
+ :dedent: 8
+ :caption: Defining an automl image object detection or instance segmentation search space
+ """
+
+ def __init__(
+ self,
+ *,
+ ams_gradient: Optional[
+ Union[
+ bool,
+ Choice,
+ LogNormal,
+ LogUniform,
+ Normal,
+ QLogNormal,
+ QLogUniform,
+ QNormal,
+ QUniform,
+ Randint,
+ Uniform,
+ ]
+ ] = None,
+ beta1: Optional[
+ Union[
+ float,
+ Choice,
+ LogNormal,
+ LogUniform,
+ Normal,
+ QLogNormal,
+ QLogUniform,
+ QNormal,
+ QUniform,
+ Randint,
+ Uniform,
+ ]
+ ] = None,
+ beta2: Optional[
+ Union[
+ float,
+ Choice,
+ LogNormal,
+ LogUniform,
+ Normal,
+ QLogNormal,
+ QLogUniform,
+ QNormal,
+ QUniform,
+ Randint,
+ Uniform,
+ ]
+ ] = None,
+ distributed: Optional[
+ Union[
+ bool,
+ Choice,
+ LogNormal,
+ LogUniform,
+ Normal,
+ QLogNormal,
+ QLogUniform,
+ QNormal,
+ QUniform,
+ Randint,
+ Uniform,
+ ]
+ ] = None,
+ early_stopping: Optional[
+ Union[
+ bool,
+ Choice,
+ LogNormal,
+ LogUniform,
+ Normal,
+ QLogNormal,
+ QLogUniform,
+ QNormal,
+ QUniform,
+ Randint,
+ Uniform,
+ ]
+ ] = None,
+ early_stopping_delay: Optional[
+ Union[
+ int, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform
+ ]
+ ] = None,
+ early_stopping_patience: Optional[
+ Union[
+ int, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform
+ ]
+ ] = None,
+ enable_onnx_normalization: Optional[
+ Union[
+ bool,
+ Choice,
+ LogNormal,
+ LogUniform,
+ Normal,
+ QLogNormal,
+ QLogUniform,
+ QNormal,
+ QUniform,
+ Randint,
+ Uniform,
+ ]
+ ] = None,
+ evaluation_frequency: Optional[
+ Union[
+ int, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform
+ ]
+ ] = None,
+ gradient_accumulation_step: Optional[
+ Union[
+ int, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform
+ ]
+ ] = None,
+ layers_to_freeze: Optional[
+ Union[
+ int, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform
+ ]
+ ] = None,
+ learning_rate: Optional[
+ Union[
+ float,
+ Choice,
+ LogNormal,
+ LogUniform,
+ Normal,
+ QLogNormal,
+ QLogUniform,
+ QNormal,
+ QUniform,
+ Randint,
+ Uniform,
+ ]
+ ] = None,
+ learning_rate_scheduler: Optional[
+ Union[
+ str, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform
+ ]
+ ] = None,
+ model_name: Optional[
+ Union[
+ str, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform
+ ]
+ ] = None,
+ momentum: Optional[
+ Union[
+ float,
+ Choice,
+ LogNormal,
+ LogUniform,
+ Normal,
+ QLogNormal,
+ QLogUniform,
+ QNormal,
+ QUniform,
+ Randint,
+ Uniform,
+ ]
+ ] = None,
+ nesterov: Optional[
+ Union[
+ bool,
+ Choice,
+ LogNormal,
+ LogUniform,
+ Normal,
+ QLogNormal,
+ QLogUniform,
+ QNormal,
+ QUniform,
+ Randint,
+ Uniform,
+ ]
+ ] = None,
+ number_of_epochs: Optional[
+ Union[
+ int, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform
+ ]
+ ] = None,
+ number_of_workers: Optional[
+ Union[
+ int, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform
+ ]
+ ] = None,
+ optimizer: Optional[
+ Union[
+ str, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform
+ ]
+ ] = None,
+ random_seed: Optional[
+ Union[
+ int, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform
+ ]
+ ] = None,
+ step_lr_gamma: Optional[
+ Union[
+ float,
+ Choice,
+ LogNormal,
+ LogUniform,
+ Normal,
+ QLogNormal,
+ QLogUniform,
+ QNormal,
+ QUniform,
+ Randint,
+ Uniform,
+ ]
+ ] = None,
+ step_lr_step_size: Optional[
+ Union[
+ int, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform
+ ]
+ ] = None,
+ training_batch_size: Optional[
+ Union[
+ int, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform
+ ]
+ ] = None,
+ validation_batch_size: Optional[
+ Union[
+ int, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform
+ ]
+ ] = None,
+ warmup_cosine_lr_cycles: Optional[
+ Union[
+ float,
+ Choice,
+ LogNormal,
+ LogUniform,
+ Normal,
+ QLogNormal,
+ QLogUniform,
+ QNormal,
+ QUniform,
+ Randint,
+ Uniform,
+ ]
+ ] = None,
+ warmup_cosine_lr_warmup_epochs: Optional[
+ Union[
+ int, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform
+ ]
+ ] = None,
+ weight_decay: Optional[
+ Union[
+ float,
+ Choice,
+ LogNormal,
+ LogUniform,
+ Normal,
+ QLogNormal,
+ QLogUniform,
+ QNormal,
+ QUniform,
+ Randint,
+ Uniform,
+ ]
+ ] = None,
+ box_detections_per_image: Optional[
+ Union[
+ int, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform
+ ]
+ ] = None,
+ box_score_threshold: Optional[
+ Union[
+ float,
+ Choice,
+ LogNormal,
+ LogUniform,
+ Normal,
+ QLogNormal,
+ QLogUniform,
+ QNormal,
+ QUniform,
+ Randint,
+ Uniform,
+ ]
+ ] = None,
+ image_size: Optional[
+ Union[
+ int, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform
+ ]
+ ] = None,
+ max_size: Optional[
+ Union[
+ int, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform
+ ]
+ ] = None,
+ min_size: Optional[
+ Union[
+ int, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform
+ ]
+ ] = None,
+ model_size: Optional[
+ Union[
+ str, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform
+ ]
+ ] = None,
+ multi_scale: Optional[
+ Union[
+ bool,
+ Choice,
+ LogNormal,
+ LogUniform,
+ Normal,
+ QLogNormal,
+ QLogUniform,
+ QNormal,
+ QUniform,
+ Randint,
+ Uniform,
+ ]
+ ] = None,
+ nms_iou_threshold: Optional[
+ Union[
+ float,
+ Choice,
+ LogNormal,
+ LogUniform,
+ Normal,
+ QLogNormal,
+ QLogUniform,
+ QNormal,
+ QUniform,
+ Randint,
+ Uniform,
+ ]
+ ] = None,
+ tile_grid_size: Optional[
+ Union[
+ str, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform
+ ]
+ ] = None,
+ tile_overlap_ratio: Optional[
+ Union[
+ float,
+ Choice,
+ LogNormal,
+ LogUniform,
+ Normal,
+ QLogNormal,
+ QLogUniform,
+ QNormal,
+ QUniform,
+ Randint,
+ Uniform,
+ ]
+ ] = None,
+ tile_predictions_nms_threshold: Optional[
+ Union[
+ float,
+ Choice,
+ LogNormal,
+ LogUniform,
+ Normal,
+ QLogNormal,
+ QLogUniform,
+ QNormal,
+ QUniform,
+ Randint,
+ Uniform,
+ ]
+ ] = None,
+ validation_iou_threshold: Optional[
+ Union[
+ float,
+ Choice,
+ LogNormal,
+ LogUniform,
+ Normal,
+ QLogNormal,
+ QLogUniform,
+ QNormal,
+ QUniform,
+ Randint,
+ Uniform,
+ ]
+ ] = None,
+ validation_metric_type: Optional[
+ Union[
+ str, Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform
+ ]
+ ] = None,
+ ) -> None:
+ self.ams_gradient = ams_gradient
+ self.beta1 = beta1
+ self.beta2 = beta2
+ self.distributed = distributed
+ self.early_stopping = early_stopping
+ self.early_stopping_delay = early_stopping_delay
+ self.early_stopping_patience = early_stopping_patience
+ self.enable_onnx_normalization = enable_onnx_normalization
+ self.evaluation_frequency = evaluation_frequency
+ self.gradient_accumulation_step = gradient_accumulation_step
+ self.layers_to_freeze = layers_to_freeze
+ self.learning_rate = learning_rate
+ self.learning_rate_scheduler = learning_rate_scheduler
+ self.model_name = model_name
+ self.momentum = momentum
+ self.nesterov = nesterov
+ self.number_of_epochs = number_of_epochs
+ self.number_of_workers = number_of_workers
+ self.optimizer = optimizer
+ self.random_seed = random_seed
+ self.step_lr_gamma = step_lr_gamma
+ self.step_lr_step_size = step_lr_step_size
+ self.training_batch_size = training_batch_size
+ self.validation_batch_size = validation_batch_size
+ self.warmup_cosine_lr_cycles = warmup_cosine_lr_cycles
+ self.warmup_cosine_lr_warmup_epochs = warmup_cosine_lr_warmup_epochs
+ self.weight_decay = weight_decay
+ self.box_detections_per_image = box_detections_per_image
+ self.box_score_threshold = box_score_threshold
+ self.image_size = image_size
+ self.max_size = max_size
+ self.min_size = min_size
+ self.model_size = model_size
+ self.multi_scale = multi_scale
+ self.nms_iou_threshold = nms_iou_threshold
+ self.tile_grid_size = tile_grid_size
+ self.tile_overlap_ratio = tile_overlap_ratio
+ self.tile_predictions_nms_threshold = tile_predictions_nms_threshold
+ self.validation_iou_threshold = validation_iou_threshold
+ self.validation_metric_type = validation_metric_type
+
+ def _to_rest_object(self) -> ImageModelDistributionSettingsObjectDetection:
+ return ImageModelDistributionSettingsObjectDetection(
+ ams_gradient=_convert_to_rest_object(self.ams_gradient) if self.ams_gradient is not None else None,
+ beta1=_convert_to_rest_object(self.beta1) if self.beta1 is not None else None,
+ beta2=_convert_to_rest_object(self.beta2) if self.beta2 is not None else None,
+ distributed=_convert_to_rest_object(self.distributed) if self.distributed is not None else None,
+ early_stopping=_convert_to_rest_object(self.early_stopping) if self.early_stopping is not None else None,
+ early_stopping_delay=(
+ _convert_to_rest_object(self.early_stopping_delay) if self.early_stopping_delay is not None else None
+ ),
+ early_stopping_patience=(
+ _convert_to_rest_object(self.early_stopping_patience)
+ if self.early_stopping_patience is not None
+ else None
+ ),
+ enable_onnx_normalization=(
+ _convert_to_rest_object(self.enable_onnx_normalization)
+ if self.enable_onnx_normalization is not None
+ else None
+ ),
+ evaluation_frequency=(
+ _convert_to_rest_object(self.evaluation_frequency) if self.evaluation_frequency is not None else None
+ ),
+ gradient_accumulation_step=(
+ _convert_to_rest_object(self.gradient_accumulation_step)
+ if self.gradient_accumulation_step is not None
+ else None
+ ),
+ layers_to_freeze=(
+ _convert_to_rest_object(self.layers_to_freeze) if self.layers_to_freeze is not None else None
+ ),
+ learning_rate=_convert_to_rest_object(self.learning_rate) if self.learning_rate is not None else None,
+ learning_rate_scheduler=(
+ _convert_to_rest_object(self.learning_rate_scheduler)
+ if self.learning_rate_scheduler is not None
+ else None
+ ),
+ model_name=_convert_to_rest_object(self.model_name) if self.model_name is not None else None,
+ momentum=_convert_to_rest_object(self.momentum) if self.momentum is not None else None,
+ nesterov=_convert_to_rest_object(self.nesterov) if self.nesterov is not None else None,
+ number_of_epochs=(
+ _convert_to_rest_object(self.number_of_epochs) if self.number_of_epochs is not None else None
+ ),
+ number_of_workers=(
+ _convert_to_rest_object(self.number_of_workers) if self.number_of_workers is not None else None
+ ),
+ optimizer=_convert_to_rest_object(self.optimizer) if self.optimizer is not None else None,
+ random_seed=_convert_to_rest_object(self.random_seed) if self.random_seed is not None else None,
+ step_lr_gamma=_convert_to_rest_object(self.step_lr_gamma) if self.step_lr_gamma is not None else None,
+ step_lr_step_size=(
+ _convert_to_rest_object(self.step_lr_step_size) if self.step_lr_step_size is not None else None
+ ),
+ training_batch_size=(
+ _convert_to_rest_object(self.training_batch_size) if self.training_batch_size is not None else None
+ ),
+ validation_batch_size=(
+ _convert_to_rest_object(self.validation_batch_size) if self.validation_batch_size is not None else None
+ ),
+ warmup_cosine_lr_cycles=(
+ _convert_to_rest_object(self.warmup_cosine_lr_cycles)
+ if self.warmup_cosine_lr_cycles is not None
+ else None
+ ),
+ warmup_cosine_lr_warmup_epochs=(
+ _convert_to_rest_object(self.warmup_cosine_lr_warmup_epochs)
+ if self.warmup_cosine_lr_warmup_epochs is not None
+ else None
+ ),
+ weight_decay=_convert_to_rest_object(self.weight_decay) if self.weight_decay is not None else None,
+ box_detections_per_image=(
+ _convert_to_rest_object(self.box_detections_per_image)
+ if self.box_detections_per_image is not None
+ else None
+ ),
+ box_score_threshold=(
+ _convert_to_rest_object(self.box_score_threshold) if self.box_score_threshold is not None else None
+ ),
+ image_size=_convert_to_rest_object(self.image_size) if self.image_size is not None else None,
+ max_size=_convert_to_rest_object(self.max_size) if self.max_size is not None else None,
+ min_size=_convert_to_rest_object(self.min_size) if self.min_size is not None else None,
+ model_size=_convert_to_rest_object(self.model_size) if self.model_size is not None else None,
+ multi_scale=_convert_to_rest_object(self.multi_scale) if self.multi_scale is not None else None,
+ nms_iou_threshold=(
+ _convert_to_rest_object(self.nms_iou_threshold) if self.nms_iou_threshold is not None else None
+ ),
+ tile_grid_size=_convert_to_rest_object(self.tile_grid_size) if self.tile_grid_size is not None else None,
+ tile_overlap_ratio=(
+ _convert_to_rest_object(self.tile_overlap_ratio) if self.tile_overlap_ratio is not None else None
+ ),
+ tile_predictions_nms_threshold=(
+ _convert_to_rest_object(self.tile_predictions_nms_threshold)
+ if self.tile_predictions_nms_threshold is not None
+ else None
+ ),
+ validation_iou_threshold=(
+ _convert_to_rest_object(self.validation_iou_threshold)
+ if self.validation_iou_threshold is not None
+ else None
+ ),
+ validation_metric_type=(
+ _convert_to_rest_object(self.validation_metric_type)
+ if self.validation_metric_type is not None
+ else None
+ ),
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: ImageModelDistributionSettingsObjectDetection) -> "ImageObjectDetectionSearchSpace":
+ return cls(
+ ams_gradient=_convert_from_rest_object(obj.ams_gradient) if obj.ams_gradient is not None else None,
+ beta1=_convert_from_rest_object(obj.beta1) if obj.beta1 is not None else None,
+ beta2=_convert_from_rest_object(obj.beta2) if obj.beta2 is not None else None,
+ distributed=_convert_from_rest_object(obj.distributed) if obj.distributed is not None else None,
+ early_stopping=_convert_from_rest_object(obj.early_stopping) if obj.early_stopping is not None else None,
+ early_stopping_delay=(
+ _convert_from_rest_object(obj.early_stopping_delay) if obj.early_stopping_delay is not None else None
+ ),
+ early_stopping_patience=(
+ _convert_from_rest_object(obj.early_stopping_patience)
+ if obj.early_stopping_patience is not None
+ else None
+ ),
+ enable_onnx_normalization=(
+ _convert_from_rest_object(obj.enable_onnx_normalization)
+ if obj.enable_onnx_normalization is not None
+ else None
+ ),
+ evaluation_frequency=(
+ _convert_from_rest_object(obj.evaluation_frequency) if obj.evaluation_frequency is not None else None
+ ),
+ gradient_accumulation_step=(
+ _convert_from_rest_object(obj.gradient_accumulation_step)
+ if obj.gradient_accumulation_step is not None
+ else None
+ ),
+ layers_to_freeze=(
+ _convert_from_rest_object(obj.layers_to_freeze) if obj.layers_to_freeze is not None else None
+ ),
+ learning_rate=_convert_from_rest_object(obj.learning_rate) if obj.learning_rate is not None else None,
+ learning_rate_scheduler=(
+ _convert_from_rest_object(obj.learning_rate_scheduler)
+ if obj.learning_rate_scheduler is not None
+ else None
+ ),
+ model_name=_convert_from_rest_object(obj.model_name) if obj.model_name is not None else None,
+ momentum=_convert_from_rest_object(obj.momentum) if obj.momentum is not None else None,
+ nesterov=_convert_from_rest_object(obj.nesterov) if obj.nesterov is not None else None,
+ number_of_epochs=(
+ _convert_from_rest_object(obj.number_of_epochs) if obj.number_of_epochs is not None else None
+ ),
+ number_of_workers=(
+ _convert_from_rest_object(obj.number_of_workers) if obj.number_of_workers is not None else None
+ ),
+ optimizer=_convert_from_rest_object(obj.optimizer) if obj.optimizer is not None else None,
+ random_seed=_convert_from_rest_object(obj.random_seed) if obj.random_seed is not None else None,
+ step_lr_gamma=_convert_from_rest_object(obj.step_lr_gamma) if obj.step_lr_gamma is not None else None,
+ step_lr_step_size=(
+ _convert_from_rest_object(obj.step_lr_step_size) if obj.step_lr_step_size is not None else None
+ ),
+ training_batch_size=(
+ _convert_from_rest_object(obj.training_batch_size) if obj.training_batch_size is not None else None
+ ),
+ validation_batch_size=(
+ _convert_from_rest_object(obj.validation_batch_size) if obj.validation_batch_size is not None else None
+ ),
+ warmup_cosine_lr_cycles=(
+ _convert_from_rest_object(obj.warmup_cosine_lr_cycles)
+ if obj.warmup_cosine_lr_cycles is not None
+ else None
+ ),
+ warmup_cosine_lr_warmup_epochs=(
+ _convert_from_rest_object(obj.warmup_cosine_lr_warmup_epochs)
+ if obj.warmup_cosine_lr_warmup_epochs is not None
+ else None
+ ),
+ weight_decay=_convert_from_rest_object(obj.weight_decay) if obj.weight_decay is not None else None,
+ box_detections_per_image=(
+ _convert_from_rest_object(obj.box_detections_per_image)
+ if obj.box_detections_per_image is not None
+ else None
+ ),
+ box_score_threshold=(
+ _convert_from_rest_object(obj.box_score_threshold) if obj.box_score_threshold is not None else None
+ ),
+ image_size=_convert_from_rest_object(obj.image_size) if obj.image_size is not None else None,
+ max_size=_convert_from_rest_object(obj.max_size) if obj.max_size is not None else None,
+ min_size=_convert_from_rest_object(obj.min_size) if obj.min_size is not None else None,
+ model_size=_convert_from_rest_object(obj.model_size) if obj.model_size is not None else None,
+ multi_scale=_convert_from_rest_object(obj.multi_scale) if obj.multi_scale is not None else None,
+ nms_iou_threshold=(
+ _convert_from_rest_object(obj.nms_iou_threshold) if obj.nms_iou_threshold is not None else None
+ ),
+ tile_grid_size=_convert_from_rest_object(obj.tile_grid_size) if obj.tile_grid_size is not None else None,
+ tile_overlap_ratio=(
+ _convert_from_rest_object(obj.tile_overlap_ratio) if obj.tile_overlap_ratio is not None else None
+ ),
+ tile_predictions_nms_threshold=(
+ _convert_from_rest_object(obj.tile_predictions_nms_threshold)
+ if obj.tile_predictions_nms_threshold is not None
+ else None
+ ),
+ validation_iou_threshold=(
+ _convert_from_rest_object(obj.validation_iou_threshold)
+ if obj.validation_iou_threshold is not None
+ else None
+ ),
+ validation_metric_type=(
+ _convert_from_rest_object(obj.validation_metric_type)
+ if obj.validation_metric_type is not None
+ else None
+ ),
+ )
+
+ @classmethod
+ def _from_search_space_object(cls, obj: SearchSpace) -> "ImageObjectDetectionSearchSpace":
+ return cls(
+ ams_gradient=obj.ams_gradient if hasattr(obj, "ams_gradient") else None,
+ beta1=obj.beta1 if hasattr(obj, "beta1") else None,
+ beta2=obj.beta2 if hasattr(obj, "beta2") else None,
+ distributed=obj.distributed if hasattr(obj, "distributed") else None,
+ early_stopping=obj.early_stopping if hasattr(obj, "early_stopping") else None,
+ early_stopping_delay=obj.early_stopping_delay if hasattr(obj, "early_stopping_delay") else None,
+ early_stopping_patience=obj.early_stopping_patience if hasattr(obj, "early_stopping_patience") else None,
+ enable_onnx_normalization=(
+ obj.enable_onnx_normalization if hasattr(obj, "enable_onnx_normalization") else None
+ ),
+ evaluation_frequency=obj.evaluation_frequency if hasattr(obj, "evaluation_frequency") else None,
+ gradient_accumulation_step=(
+ obj.gradient_accumulation_step if hasattr(obj, "gradient_accumulation_step") else None
+ ),
+ layers_to_freeze=obj.layers_to_freeze if hasattr(obj, "layers_to_freeze") else None,
+ learning_rate=obj.learning_rate if hasattr(obj, "learning_rate") else None,
+ learning_rate_scheduler=obj.learning_rate_scheduler if hasattr(obj, "learning_rate_scheduler") else None,
+ model_name=obj.model_name if hasattr(obj, "model_name") else None,
+ momentum=obj.momentum if hasattr(obj, "momentum") else None,
+ nesterov=obj.nesterov if hasattr(obj, "nesterov") else None,
+ number_of_epochs=obj.number_of_epochs if hasattr(obj, "number_of_epochs") else None,
+ number_of_workers=obj.number_of_workers if hasattr(obj, "number_of_workers") else None,
+ optimizer=obj.optimizer if hasattr(obj, "optimizer") else None,
+ random_seed=obj.random_seed if hasattr(obj, "random_seed") else None,
+ step_lr_gamma=obj.step_lr_gamma if hasattr(obj, "step_lr_gamma") else None,
+ step_lr_step_size=obj.step_lr_step_size if hasattr(obj, "step_lr_step_size") else None,
+ training_batch_size=obj.training_batch_size if hasattr(obj, "training_batch_size") else None,
+ validation_batch_size=obj.validation_batch_size if hasattr(obj, "validation_batch_size") else None,
+ warmup_cosine_lr_cycles=obj.warmup_cosine_lr_cycles if hasattr(obj, "warmup_cosine_lr_cycles") else None,
+ warmup_cosine_lr_warmup_epochs=(
+ obj.warmup_cosine_lr_warmup_epochs if hasattr(obj, "warmup_cosine_lr_warmup_epochs") else None
+ ),
+ weight_decay=obj.weight_decay if hasattr(obj, "weight_decay") else None,
+ box_detections_per_image=obj.box_detections_per_image if hasattr(obj, "box_detections_per_image") else None,
+ box_score_threshold=obj.box_score_threshold if hasattr(obj, "box_score_threshold") else None,
+ image_size=obj.image_size if hasattr(obj, "image_size") else None,
+ max_size=obj.max_size if hasattr(obj, "max_size") else None,
+ min_size=obj.min_size if hasattr(obj, "min_size") else None,
+ model_size=obj.model_size if hasattr(obj, "model_size") else None,
+ multi_scale=obj.multi_scale if hasattr(obj, "multi_scale") else None,
+ nms_iou_threshold=obj.nms_iou_threshold if hasattr(obj, "nms_iou_threshold") else None,
+ tile_grid_size=obj.tile_grid_size if hasattr(obj, "tile_grid_size") else None,
+ tile_overlap_ratio=obj.tile_overlap_ratio if hasattr(obj, "tile_overlap_ratio") else None,
+ tile_predictions_nms_threshold=(
+ obj.tile_predictions_nms_threshold if hasattr(obj, "tile_predictions_nms_threshold") else None
+ ),
+ validation_iou_threshold=obj.validation_iou_threshold if hasattr(obj, "validation_iou_threshold") else None,
+ validation_metric_type=obj.validation_metric_type if hasattr(obj, "validation_metric_type") else None,
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, ImageObjectDetectionSearchSpace):
+ return NotImplemented
+
+ return (
+ self.ams_gradient == other.ams_gradient
+ and self.beta1 == other.beta1
+ and self.beta2 == other.beta2
+ and self.distributed == other.distributed
+ and self.early_stopping == other.early_stopping
+ and self.early_stopping_delay == other.early_stopping_delay
+ and self.early_stopping_patience == other.early_stopping_patience
+ and self.enable_onnx_normalization == other.enable_onnx_normalization
+ and self.evaluation_frequency == other.evaluation_frequency
+ and self.gradient_accumulation_step == other.gradient_accumulation_step
+ and self.layers_to_freeze == other.layers_to_freeze
+ and self.learning_rate == other.learning_rate
+ and self.learning_rate_scheduler == other.learning_rate_scheduler
+ and self.model_name == other.model_name
+ and self.momentum == other.momentum
+ and self.nesterov == other.nesterov
+ and self.number_of_epochs == other.number_of_epochs
+ and self.number_of_workers == other.number_of_workers
+ and self.optimizer == other.optimizer
+ and self.random_seed == other.random_seed
+ and self.step_lr_gamma == other.step_lr_gamma
+ and self.step_lr_step_size == other.step_lr_step_size
+ and self.training_batch_size == other.training_batch_size
+ and self.validation_batch_size == other.validation_batch_size
+ and self.warmup_cosine_lr_cycles == other.warmup_cosine_lr_cycles
+ and self.warmup_cosine_lr_warmup_epochs == other.warmup_cosine_lr_warmup_epochs
+ and self.weight_decay == other.weight_decay
+ and self.box_detections_per_image == other.box_detections_per_image
+ and self.box_score_threshold == other.box_score_threshold
+ and self.image_size == other.image_size
+ and self.max_size == other.max_size
+ and self.min_size == other.min_size
+ and self.model_size == other.model_size
+ and self.multi_scale == other.multi_scale
+ and self.nms_iou_threshold == other.nms_iou_threshold
+ and self.tile_grid_size == other.tile_grid_size
+ and self.tile_overlap_ratio == other.tile_overlap_ratio
+ and self.tile_predictions_nms_threshold == other.tile_predictions_nms_threshold
+ and self.validation_iou_threshold == other.validation_iou_threshold
+ and self.validation_metric_type == other.validation_metric_type
+ )
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_sweep_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_sweep_settings.py
new file mode 100644
index 00000000..b5e9ffaf
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/image/image_sweep_settings.py
@@ -0,0 +1,86 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from typing import Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import ImageSweepSettings as RestImageSweepSettings
+from azure.ai.ml._restclient.v2023_04_01_preview.models import SamplingAlgorithmType
+from azure.ai.ml.entities._job.sweep.early_termination_policy import (
+ BanditPolicy,
+ EarlyTerminationPolicy,
+ MedianStoppingPolicy,
+ TruncationSelectionPolicy,
+)
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+
+class ImageSweepSettings(RestTranslatableMixin):
+ """Sweep settings for all AutoML Image Verticals.
+
+ :keyword sampling_algorithm: Required. Type of the hyperparameter sampling.
+ algorithms. Possible values include: "Grid", "Random", "Bayesian".
+ :paramtype sampling_algorithm: Union[
+ str,
+ ~azure.mgmt.machinelearningservices.models.SamplingAlgorithmType.GRID,
+ ~azure.mgmt.machinelearningservices.models.SamplingAlgorithmType.BAYESIAN,
+ ~azure.mgmt.machinelearningservices.models.SamplingAlgorithmType.RANDOM
+
+ ]
+ :keyword early_termination: Type of early termination policy.
+ :paramtype early_termination: Union[
+
+ ~azure.mgmt.machinelearningservices.models.BanditPolicy,
+ ~azure.mgmt.machinelearningservices.models.MedianStoppingPolicy,
+ ~azure.mgmt.machinelearningservices.models.TruncationSelectionPolicy
+
+ ]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_automl_image.py
+ :start-after: [START automl.automl_image_job.image_sweep_settings]
+ :end-before: [END automl.automl_image_job.image_sweep_settings]
+ :language: python
+ :dedent: 8
+ :caption: Defining the sweep settings for an automl image job.
+ """
+
+ def __init__(
+ self,
+ *,
+ sampling_algorithm: Union[
+ str, SamplingAlgorithmType.GRID, SamplingAlgorithmType.BAYESIAN, SamplingAlgorithmType.RANDOM
+ ],
+ early_termination: Optional[
+ Union[EarlyTerminationPolicy, BanditPolicy, MedianStoppingPolicy, TruncationSelectionPolicy]
+ ] = None,
+ ):
+ self.sampling_algorithm = sampling_algorithm
+ self.early_termination = early_termination
+
+ def _to_rest_object(self) -> RestImageSweepSettings:
+ return RestImageSweepSettings(
+ sampling_algorithm=self.sampling_algorithm,
+ early_termination=self.early_termination._to_rest_object() if self.early_termination else None,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestImageSweepSettings) -> "ImageSweepSettings":
+ return cls(
+ sampling_algorithm=obj.sampling_algorithm,
+ early_termination=(
+ EarlyTerminationPolicy._from_rest_object(obj.early_termination) if obj.early_termination else None
+ ),
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, ImageSweepSettings):
+ return NotImplemented
+
+ return self.sampling_algorithm == other.sampling_algorithm and self.early_termination == other.early_termination
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/__init__.py
new file mode 100644
index 00000000..9be7b483
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/__init__.py
@@ -0,0 +1,25 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from .automl_nlp_job import AutoMLNLPJob
+from .nlp_featurization_settings import NlpFeaturizationSettings
+from .nlp_fixed_parameters import NlpFixedParameters
+from .nlp_limit_settings import NlpLimitSettings
+from .nlp_search_space import NlpSearchSpace
+from .nlp_sweep_settings import NlpSweepSettings
+from .text_classification_job import TextClassificationJob
+from .text_classification_multilabel_job import TextClassificationMultilabelJob
+from .text_ner_job import TextNerJob
+
+__all__ = [
+ "AutoMLNLPJob",
+ "NlpFeaturizationSettings",
+ "NlpFixedParameters",
+ "NlpLimitSettings",
+ "NlpSearchSpace",
+ "NlpSweepSettings",
+ "TextClassificationJob",
+ "TextClassificationMultilabelJob",
+ "TextNerJob",
+]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/automl_nlp_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/automl_nlp_job.py
new file mode 100644
index 00000000..f0b3baa8
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/automl_nlp_job.py
@@ -0,0 +1,467 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from abc import ABC
+from typing import Any, Dict, List, Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ LogVerbosity,
+ NlpLearningRateScheduler,
+ SamplingAlgorithmType,
+)
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.entities._inputs_outputs import Input
+from azure.ai.ml.entities._job.automl.automl_vertical import AutoMLVertical
+from azure.ai.ml.entities._job.automl.nlp.nlp_featurization_settings import NlpFeaturizationSettings
+from azure.ai.ml.entities._job.automl.nlp.nlp_fixed_parameters import NlpFixedParameters
+from azure.ai.ml.entities._job.automl.nlp.nlp_limit_settings import NlpLimitSettings
+from azure.ai.ml.entities._job.automl.nlp.nlp_search_space import NlpSearchSpace
+from azure.ai.ml.entities._job.automl.nlp.nlp_sweep_settings import NlpSweepSettings
+from azure.ai.ml.entities._job.automl.search_space import SearchSpace
+from azure.ai.ml.entities._job.automl.utils import cast_to_specific_search_space
+from azure.ai.ml.entities._job.sweep.early_termination_policy import EarlyTerminationPolicy
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
+
+
+# pylint: disable=too-many-instance-attributes,protected-access
+class AutoMLNLPJob(AutoMLVertical, ABC):
+ """Base class for AutoML NLP jobs.
+
+ You should not instantiate this class directly. Instead you should
+ create classes for specific NLP Jobs.
+
+ :param task_type: NLP task type, must be one of 'TextClassification',
+ 'TextClassificationMultilabel', or 'TextNER'
+ :type task_type: str
+ :param primary_metric: Primary metric to display from NLP job
+ :type primary_metric: str
+ :param training_data: Training data
+ :type training_data: Input
+ :param validation_data: Validation data
+ :type validation_data: Input
+ :param target_column_name: Column name of the target column, defaults to None
+ :type target_column_name: Optional[str]
+ :param log_verbosity: The degree of verbosity used in logging, defaults to None,
+ must be one of 'NotSet', 'Debug', 'Info', 'Warning', 'Error', 'Critical', or None
+ :type log_verbosity: Optional[str]
+ :param featurization: Featurization settings used for NLP job, defaults to None
+ :type featurization: Optional[~azure.ai.ml.automl.NlpFeaturizationSettings]
+ :param limits: Limit settings for NLP jobs, defaults to None
+ :type limits: Optional[~azure.ai.ml.automl.NlpLimitSettings]
+ :param sweep: Sweep settings used for NLP job, defaults to None
+ :type sweep: Optional[~azure.ai.ml.automl.NlpSweepSettings]
+ :param training_parameters: Fixed parameters for the training of all candidates.
+ , defaults to None
+ :type training_parameters: Optional[~azure.ai.ml.automl.NlpFixedParameters]
+ :param search_space: Search space(s) to sweep over for NLP sweep jobs, defaults to None
+ :type search_space: Optional[List[~azure.ai.ml.automl.NlpSearchSpace]]
+ """
+
+ def __init__(
+ self,
+ *,
+ task_type: str,
+ primary_metric: str,
+ training_data: Optional[Input],
+ validation_data: Optional[Input],
+ target_column_name: Optional[str] = None,
+ log_verbosity: Optional[str] = None,
+ featurization: Optional[NlpFeaturizationSettings] = None,
+ limits: Optional[NlpLimitSettings] = None,
+ sweep: Optional[NlpSweepSettings] = None,
+ training_parameters: Optional[NlpFixedParameters] = None,
+ search_space: Optional[List[NlpSearchSpace]] = None,
+ **kwargs: Any,
+ ):
+ self._training_parameters: Optional[NlpFixedParameters] = None
+
+ super().__init__(
+ task_type, training_data=training_data, validation_data=validation_data, **kwargs # type: ignore
+ )
+ self.log_verbosity = log_verbosity
+ self._primary_metric: str = ""
+ self.primary_metric = primary_metric
+
+ self.target_column_name = target_column_name
+
+ self._featurization = featurization
+ self._limits = limits or NlpLimitSettings()
+ self._sweep = sweep
+ self.training_parameters = training_parameters # via setter method.
+ self._search_space = search_space
+
+ @property
+ def training_parameters(self) -> Optional[NlpFixedParameters]:
+ """Parameters that are used for all submitted jobs.
+
+ :return: fixed training parameters for NLP jobs
+ :rtype: ~azure.ai.ml.automl.NlpFixedParameters
+ """
+ return self._training_parameters
+
+ @training_parameters.setter
+ def training_parameters(self, value: Union[Dict, NlpFixedParameters]) -> None:
+ if value is None:
+ self._training_parameters = None
+ elif isinstance(value, NlpFixedParameters):
+ self._training_parameters = value
+ # Convert parameters from snake case to enum.
+ self.set_training_parameters(learning_rate_scheduler=value.learning_rate_scheduler)
+ else:
+ if not isinstance(value, dict):
+ msg = "Expected a dictionary for nlp training parameters."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.AUTOML,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ self.set_training_parameters(**value)
+
+ @property
+ def search_space(self) -> Optional[List[NlpSearchSpace]]:
+ """Search space(s) to sweep over for NLP sweep jobs
+
+ :return: list of search spaces to sweep over for NLP jobs
+ :rtype: List[~azure.ai.ml.automl.NlpSearchSpace]
+ """
+ return self._search_space
+
+ @search_space.setter
+ def search_space(self, value: Union[List[dict], List[SearchSpace]]) -> None:
+ if not isinstance(value, list):
+ msg = "Expected a list for search space."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.AUTOML,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ all_dict_type = all(isinstance(item, dict) for item in value)
+ all_search_space_type = all(isinstance(item, SearchSpace) for item in value)
+
+ if not (all_search_space_type or all_dict_type):
+ msg = "Expected all items in the list to be either dictionaries or SearchSpace objects."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.AUTOML,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ self._search_space = [
+ cast_to_specific_search_space(item, NlpSearchSpace, self.task_type) for item in value # type: ignore
+ ]
+
+ @property
+ def primary_metric(self) -> str:
+ """Primary metric to display from NLP job
+
+ :return: primary metric to display
+ :rtype: str
+ """
+ return self._primary_metric
+
+ @primary_metric.setter
+ def primary_metric(self, value: str) -> None:
+ self._primary_metric = value
+
+ @property
+ def log_verbosity(self) -> LogVerbosity:
+ """Log verbosity configuration
+
+ :return: the degree of verbosity used in logging
+ :rtype: ~azure.mgmt.machinelearningservices.models.LogVerbosity
+ """
+ return self._log_verbosity
+
+ @log_verbosity.setter
+ def log_verbosity(self, value: Union[str, LogVerbosity]) -> None:
+ self._log_verbosity = None if value is None else LogVerbosity[camel_to_snake(value).upper()]
+
+ @property
+ def limits(self) -> NlpLimitSettings:
+ """Limit settings for NLP jobs
+
+ :return: limit configuration for NLP job
+ :rtype: ~azure.ai.ml.automl.NlpLimitSettings
+ """
+ return self._limits
+
+ @limits.setter
+ def limits(self, value: Union[Dict, NlpLimitSettings]) -> None:
+ if isinstance(value, NlpLimitSettings):
+ self._limits = value
+ else:
+ if not isinstance(value, dict):
+ msg = "Expected a dictionary for limit settings."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.AUTOML,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ self.set_limits(**value)
+
+ @property
+ def sweep(self) -> Optional[NlpSweepSettings]:
+ """Sweep settings used for NLP job
+
+ :return: sweep settings
+ :rtype: ~azure.ai.ml.automl.NlpSweepSettings
+ """
+ return self._sweep
+
+ @sweep.setter
+ def sweep(self, value: Union[Dict, NlpSweepSettings]) -> None:
+ if isinstance(value, NlpSweepSettings):
+ self._sweep = value
+ else:
+ if not isinstance(value, dict):
+ msg = "Expected a dictionary for sweep settings."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.AUTOML,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ self.set_sweep(**value)
+
+ @property
+ def featurization(self) -> Optional[NlpFeaturizationSettings]:
+ """Featurization settings used for NLP job
+
+ :return: featurization settings
+ :rtype: ~azure.ai.ml.automl.NlpFeaturizationSettings
+ """
+ return self._featurization
+
+ @featurization.setter
+ def featurization(self, value: Union[Dict, NlpFeaturizationSettings]) -> None:
+ if isinstance(value, NlpFeaturizationSettings):
+ self._featurization = value
+ else:
+ if not isinstance(value, dict):
+ msg = "Expected a dictionary for featurization settings."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.AUTOML,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ self.set_featurization(**value)
+
+ def set_data(self, *, training_data: Input, target_column_name: str, validation_data: Input) -> None:
+ """Define data configuration for NLP job
+
+ :keyword training_data: Training data
+ :type training_data: ~azure.ai.ml.Input
+ :keyword target_column_name: Column name of the target column.
+ :type target_column_name: str
+ :keyword validation_data: Validation data
+ :type validation_data: ~azure.ai.ml.Input
+ """
+ # Properties for NlpVerticalDataSettings
+ self.target_column_name = target_column_name
+ self.training_data = training_data
+ self.validation_data = validation_data
+
+ def set_limits(
+ self,
+ *,
+ max_trials: int = 1,
+ max_concurrent_trials: int = 1,
+ max_nodes: int = 1,
+ timeout_minutes: Optional[int] = None,
+ trial_timeout_minutes: Optional[int] = None,
+ ) -> None:
+ """Define limit configuration for AutoML NLP job
+
+ :keyword max_trials: Maximum number of AutoML iterations, defaults to 1
+ :type max_trials: int, optional
+ :keyword max_concurrent_trials: Maximum number of concurrent AutoML iterations, defaults to 1
+ :type max_concurrent_trials: int, optional
+ :keyword max_nodes: Maximum number of nodes used for sweep, defaults to 1
+ :type max_nodes: int, optional
+ :keyword timeout_minutes: Timeout for the AutoML job, defaults to None
+ :type timeout_minutes: Optional[int]
+ :keyword trial_timeout_minutes: Timeout for each AutoML trial, defaults to None
+ :type trial_timeout_minutes: Optional[int]
+ """
+ self._limits = NlpLimitSettings(
+ max_trials=max_trials,
+ max_concurrent_trials=max_concurrent_trials,
+ max_nodes=max_nodes,
+ timeout_minutes=timeout_minutes,
+ trial_timeout_minutes=trial_timeout_minutes,
+ )
+
+ def set_sweep(
+ self,
+ *,
+ sampling_algorithm: Union[str, SamplingAlgorithmType],
+ early_termination: Optional[EarlyTerminationPolicy] = None,
+ ) -> None:
+ """Define sweep configuration for AutoML NLP job
+
+ :keyword sampling_algorithm: Required. Specifies type of hyperparameter sampling algorithm.
+ Possible values include: "Grid", "Random", and "Bayesian".
+ :type sampling_algorithm: Union[str, ~azure.ai.ml.automl.SamplingAlgorithmType]
+ :keyword early_termination: Optional. early termination policy to end poorly performing training candidates,
+ defaults to None.
+ :type early_termination: Optional[~azure.mgmt.machinelearningservices.models.EarlyTerminationPolicy]
+ """
+ if self._sweep:
+ self._sweep.sampling_algorithm = sampling_algorithm
+ else:
+ self._sweep = NlpSweepSettings(sampling_algorithm=sampling_algorithm)
+
+ self._sweep.early_termination = early_termination or self._sweep.early_termination
+
+ def set_training_parameters(
+ self,
+ *,
+ gradient_accumulation_steps: Optional[int] = None,
+ learning_rate: Optional[float] = None,
+ learning_rate_scheduler: Optional[Union[str, NlpLearningRateScheduler]] = None,
+ model_name: Optional[str] = None,
+ number_of_epochs: Optional[int] = None,
+ training_batch_size: Optional[int] = None,
+ validation_batch_size: Optional[int] = None,
+ warmup_ratio: Optional[float] = None,
+ weight_decay: Optional[float] = None,
+ ) -> None:
+ """Fix certain training parameters throughout the training procedure for all candidates.
+
+ :keyword gradient_accumulation_steps: number of steps over which to accumulate gradients before a backward
+ pass. This must be a positive integer., defaults to None
+ :type gradient_accumulation_steps: Optional[int]
+ :keyword learning_rate: initial learning rate. Must be a float in (0, 1)., defaults to None
+ :type learning_rate: Optional[float]
+ :keyword learning_rate_scheduler: the type of learning rate scheduler. Must choose from 'linear', 'cosine',
+ 'cosine_with_restarts', 'polynomial', 'constant', and 'constant_with_warmup'., defaults to None
+ :type learning_rate_scheduler: Optional[Union[str, ~azure.ai.ml.automl.NlpLearningRateScheduler]]
+ :keyword model_name: the model name to use during training. Must choose from 'bert-base-cased',
+ 'bert-base-uncased', 'bert-base-multilingual-cased', 'bert-base-german-cased', 'bert-large-cased',
+ 'bert-large-uncased', 'distilbert-base-cased', 'distilbert-base-uncased', 'roberta-base', 'roberta-large',
+ 'distilroberta-base', 'xlm-roberta-base', 'xlm-roberta-large', xlnet-base-cased', and 'xlnet-large-cased'.,
+ defaults to None
+ :type model_name: Optional[str]
+ :keyword number_of_epochs: the number of epochs to train with. Must be a positive integer., defaults to None
+ :type number_of_epochs: Optional[int]
+ :keyword training_batch_size: the batch size during training. Must be a positive integer., defaults to None
+ :type training_batch_size: Optional[int]
+ :keyword validation_batch_size: the batch size during validation. Must be a positive integer., defaults to None
+ :type validation_batch_size: Optional[int]
+ :keyword warmup_ratio: ratio of total training steps used for a linear warmup from 0 to learning_rate.
+ Must be a float in [0, 1]., defaults to None
+ :type warmup_ratio: Optional[float]
+ :keyword weight_decay: value of weight decay when optimizer is sgd, adam, or adamw. This must be a float in
+ the range [0, 1]., defaults to None
+ :type weight_decay: Optional[float]
+ """
+ self._training_parameters = self._training_parameters or NlpFixedParameters()
+
+ self._training_parameters.gradient_accumulation_steps = (
+ gradient_accumulation_steps
+ if gradient_accumulation_steps is not None
+ else self._training_parameters.gradient_accumulation_steps
+ )
+
+ self._training_parameters.learning_rate = (
+ learning_rate if learning_rate is not None else self._training_parameters.learning_rate
+ )
+
+ self._training_parameters.learning_rate_scheduler = (
+ NlpLearningRateScheduler[camel_to_snake(learning_rate_scheduler).upper()]
+ if learning_rate_scheduler is not None
+ else self._training_parameters.learning_rate_scheduler
+ )
+
+ self._training_parameters.model_name = (
+ model_name if model_name is not None else self._training_parameters.model_name
+ )
+
+ self._training_parameters.number_of_epochs = (
+ number_of_epochs if number_of_epochs is not None else self._training_parameters.number_of_epochs
+ )
+
+ self._training_parameters.training_batch_size = (
+ training_batch_size if training_batch_size is not None else self._training_parameters.training_batch_size
+ )
+
+ self._training_parameters.validation_batch_size = (
+ validation_batch_size
+ if validation_batch_size is not None
+ else self._training_parameters.validation_batch_size
+ )
+
+ self._training_parameters.warmup_ratio = (
+ warmup_ratio if warmup_ratio is not None else self._training_parameters.warmup_ratio
+ )
+
+ self._training_parameters.weight_decay = (
+ weight_decay if weight_decay is not None else self._training_parameters.weight_decay
+ )
+
+ def set_featurization(self, *, dataset_language: Optional[str] = None) -> None:
+ """Define featurization configuration for AutoML NLP job.
+
+ :keyword dataset_language: Language of the dataset, defaults to None
+ :type dataset_language: Optional[str]
+ """
+ self._featurization = NlpFeaturizationSettings(
+ dataset_language=dataset_language,
+ )
+
+ def extend_search_space(self, value: Union[SearchSpace, List[SearchSpace]]) -> None:
+ """Add (a) search space(s) for an AutoML NLP job.
+
+ :param value: either a SearchSpace object or a list of SearchSpace objects with nlp-specific parameters.
+ :type value: Union[~azure.ai.ml.automl.SearchSpace, List[~azure.ai.ml.automl.SearchSpace]]
+ """
+ self._search_space = self._search_space or []
+ if isinstance(value, list):
+ self._search_space.extend(
+ [cast_to_specific_search_space(item, NlpSearchSpace, self.task_type) for item in value] # type: ignore
+ )
+ else:
+ self._search_space.append(
+ cast_to_specific_search_space(value, NlpSearchSpace, self.task_type) # type: ignore
+ )
+
+ @classmethod
+ def _get_search_space_from_str(cls, search_space_str: Optional[str]) -> Optional[List]:
+ if search_space_str is not None:
+ return [NlpSearchSpace._from_rest_object(entry) for entry in search_space_str if entry is not None]
+ return None
+
+ def _restore_data_inputs(self) -> None:
+ """Restore MLTableJobInputs to Inputs within data_settings.
+
+ self.training_data and self.validation_data should reflect what user passed in (Input) Once we get response back
+ from service (as MLTableJobInput), we should set responsible ones back to Input
+ """
+ super()._restore_data_inputs()
+ self.training_data = self.training_data if self.training_data else None # type: ignore
+ self.validation_data = self.validation_data if self.validation_data else None # type: ignore
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, AutoMLNLPJob):
+ return NotImplemented
+
+ return (
+ self.primary_metric == other.primary_metric
+ and self.log_verbosity == other.log_verbosity
+ and self.training_data == other.training_data
+ and self.validation_data == other.validation_data
+ and self._featurization == other._featurization
+ and self._limits == other._limits
+ and self._sweep == other._sweep
+ and self._training_parameters == other._training_parameters
+ and self._search_space == other._search_space
+ )
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/nlp_featurization_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/nlp_featurization_settings.py
new file mode 100644
index 00000000..5649dea2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/nlp_featurization_settings.py
@@ -0,0 +1,47 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ NlpVerticalFeaturizationSettings as RestNlpVerticalFeaturizationSettings,
+)
+from azure.ai.ml.entities._job.automl.featurization_settings import FeaturizationSettings, FeaturizationSettingsType
+
+
+class NlpFeaturizationSettings(FeaturizationSettings):
+ """Featurization settings for all AutoML NLP Verticals.
+
+ :ivar type: Specifies the type of FeaturizationSettings. Set automatically to "NLP" for this class.
+ :vartype type: str
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_automl_nlp.py
+ :start-after: [START automl.nlp_featurization_settings]
+ :end-before: [END automl.nlp_featurization_settings]
+ :language: python
+ :dedent: 8
+ :caption: creating an nlp featurization settings
+ """
+
+ type = FeaturizationSettingsType.NLP
+
+ def _to_rest_object(self) -> RestNlpVerticalFeaturizationSettings:
+ return RestNlpVerticalFeaturizationSettings(
+ dataset_language=self.dataset_language,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestNlpVerticalFeaturizationSettings) -> "NlpFeaturizationSettings":
+ return cls(
+ dataset_language=obj.dataset_language,
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, NlpFeaturizationSettings):
+ return NotImplemented
+
+ return super().__eq__(other)
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/nlp_fixed_parameters.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/nlp_fixed_parameters.py
new file mode 100644
index 00000000..13c594b6
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/nlp_fixed_parameters.py
@@ -0,0 +1,117 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from typing import Optional
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import NlpFixedParameters as RestNlpFixedParameters
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+
+class NlpFixedParameters(RestTranslatableMixin):
+ """Configuration of fixed parameters for all candidates of an AutoML NLP Job
+
+ :param gradient_accumulation_steps: number of steps over which to accumulate gradients before a backward
+ pass. This must be a positive integer, defaults to None
+ :type gradient_accumulation_steps: Optional[int]
+ :param learning_rate: initial learning rate. Must be a float in (0, 1), defaults to None
+ :type learning_rate: Optional[float]
+ :param learning_rate_scheduler: the type of learning rate scheduler. Must choose from 'linear', 'cosine',
+ 'cosine_with_restarts', 'polynomial', 'constant', and 'constant_with_warmup', defaults to None
+ :type learning_rate_scheduler: Optional[str]
+ :param model_name: the model name to use during training. Must choose from 'bert-base-cased',
+ 'bert-base-uncased', 'bert-base-multilingual-cased', 'bert-base-german-cased', 'bert-large-cased',
+ 'bert-large-uncased', 'distilbert-base-cased', 'distilbert-base-uncased', 'roberta-base', 'roberta-large',
+ 'distilroberta-base', 'xlm-roberta-base', 'xlm-roberta-large', xlnet-base-cased', and 'xlnet-large-cased',
+ defaults to None
+ :type model_name: Optional[str]
+ :param number_of_epochs: the number of epochs to train with. Must be a positive integer, defaults to None
+ :type number_of_epochs: Optional[int]
+ :param training_batch_size: the batch size during training. Must be a positive integer, defaults to None
+ :type training_batch_size: Optional[int]
+ :param validation_batch_size: the batch size during validation. Must be a positive integer, defaults to None
+ :type validation_batch_size: Optional[int]
+ :param warmup_ratio: ratio of total training steps used for a linear warmup from 0 to learning_rate.
+ Must be a float in [0, 1], defaults to None
+ :type warmup_ratio: Optional[float]
+ :param weight_decay: value of weight decay when optimizer is sgd, adam, or adamw. This must be a float in
+ the range [0, 1] defaults to None
+ :type weight_decay: Optional[float]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_automl_nlp.py
+ :start-after: [START automl.nlp_fixed_parameters]
+ :end-before: [END automl.nlp_fixed_parameters]
+ :language: python
+ :dedent: 8
+ :caption: creating an nlp fixed parameters
+ """
+
+ def __init__(
+ self,
+ *,
+ gradient_accumulation_steps: Optional[int] = None,
+ learning_rate: Optional[float] = None,
+ learning_rate_scheduler: Optional[str] = None,
+ model_name: Optional[str] = None,
+ number_of_epochs: Optional[int] = None,
+ training_batch_size: Optional[int] = None,
+ validation_batch_size: Optional[int] = None,
+ warmup_ratio: Optional[float] = None,
+ weight_decay: Optional[float] = None,
+ ):
+ self.gradient_accumulation_steps = gradient_accumulation_steps
+ self.learning_rate = learning_rate
+ self.learning_rate_scheduler = learning_rate_scheduler
+ self.model_name = model_name
+ self.number_of_epochs = number_of_epochs
+ self.training_batch_size = training_batch_size
+ self.validation_batch_size = validation_batch_size
+ self.warmup_ratio = warmup_ratio
+ self.weight_decay = weight_decay
+
+ def _to_rest_object(self) -> RestNlpFixedParameters:
+ return RestNlpFixedParameters(
+ gradient_accumulation_steps=self.gradient_accumulation_steps,
+ learning_rate=self.learning_rate,
+ learning_rate_scheduler=self.learning_rate_scheduler,
+ model_name=self.model_name,
+ number_of_epochs=self.number_of_epochs,
+ training_batch_size=self.training_batch_size,
+ validation_batch_size=self.validation_batch_size,
+ warmup_ratio=self.warmup_ratio,
+ weight_decay=self.weight_decay,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestNlpFixedParameters) -> "NlpFixedParameters":
+ return cls(
+ gradient_accumulation_steps=obj.gradient_accumulation_steps,
+ learning_rate=obj.learning_rate,
+ learning_rate_scheduler=obj.learning_rate_scheduler,
+ model_name=obj.model_name,
+ number_of_epochs=obj.number_of_epochs,
+ training_batch_size=obj.training_batch_size,
+ validation_batch_size=obj.validation_batch_size,
+ warmup_ratio=obj.warmup_ratio,
+ weight_decay=obj.weight_decay,
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, NlpFixedParameters):
+ return NotImplemented
+
+ return (
+ self.gradient_accumulation_steps == other.gradient_accumulation_steps
+ and self.learning_rate == other.learning_rate
+ and self.learning_rate_scheduler == other.learning_rate_scheduler
+ and self.model_name == other.model_name
+ and self.number_of_epochs == other.number_of_epochs
+ and self.training_batch_size == other.training_batch_size
+ and self.validation_batch_size == other.validation_batch_size
+ and self.warmup_ratio == other.warmup_ratio
+ and self.weight_decay == other.weight_decay
+ )
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/nlp_limit_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/nlp_limit_settings.py
new file mode 100644
index 00000000..1e99f4f0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/nlp_limit_settings.py
@@ -0,0 +1,79 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Optional
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import NlpVerticalLimitSettings as RestNlpLimitSettings
+from azure.ai.ml._utils.utils import from_iso_duration_format_mins, to_iso_duration_format_mins
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+
+class NlpLimitSettings(RestTranslatableMixin):
+ """Limit settings for all AutoML NLP Verticals.
+
+ :param max_concurrent_trials: Maximum number of concurrent AutoML iterations.
+ :type max_concurrent_trials: int
+ :param max_trials: Maximum number of AutoML iterations.
+ :type max_trials: int
+ :param timeout_minutes: AutoML job timeout.
+ :type timeout_minutes: int
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_automl_nlp.py
+ :start-after: [START automl.nlp_limit_settings]
+ :end-before: [END automl.nlp_limit_settings]
+ :language: python
+ :dedent: 8
+ :caption: creating an nlp limit settings
+ """
+
+ def __init__(
+ self,
+ *,
+ max_concurrent_trials: Optional[int] = None,
+ max_trials: int = 1,
+ max_nodes: int = 1,
+ timeout_minutes: Optional[int] = None,
+ trial_timeout_minutes: Optional[int] = None,
+ ):
+ self.max_concurrent_trials = max_concurrent_trials
+ self.max_trials = max_trials
+ self.max_nodes = max_nodes
+ self.timeout_minutes = timeout_minutes
+ self.trial_timeout_minutes = trial_timeout_minutes
+
+ def _to_rest_object(self) -> RestNlpLimitSettings:
+ return RestNlpLimitSettings(
+ max_concurrent_trials=self.max_concurrent_trials,
+ max_trials=self.max_trials,
+ max_nodes=self.max_nodes,
+ timeout=to_iso_duration_format_mins(self.timeout_minutes),
+ trial_timeout=to_iso_duration_format_mins(self.trial_timeout_minutes),
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestNlpLimitSettings) -> "NlpLimitSettings":
+ return cls(
+ max_concurrent_trials=obj.max_concurrent_trials,
+ max_trials=obj.max_trials,
+ max_nodes=obj.max_nodes,
+ timeout_minutes=from_iso_duration_format_mins(obj.timeout),
+ trial_timeout_minutes=from_iso_duration_format_mins(obj.trial_timeout),
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, NlpLimitSettings):
+ return NotImplemented
+
+ return (
+ self.max_concurrent_trials == other.max_concurrent_trials
+ and self.max_trials == other.max_trials
+ and self.max_nodes == other.max_nodes
+ and self.timeout_minutes == other.timeout_minutes
+ and self.trial_timeout_minutes == other.trial_timeout_minutes
+ )
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/nlp_search_space.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/nlp_search_space.py
new file mode 100644
index 00000000..e4ad435f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/nlp_search_space.py
@@ -0,0 +1,185 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import NlpLearningRateScheduler, NlpParameterSubspace
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.constants import NlpModels
+from azure.ai.ml.entities._job.automl.search_space import SearchSpace
+from azure.ai.ml.entities._job.automl.search_space_utils import _convert_from_rest_object, _convert_to_rest_object
+from azure.ai.ml.entities._job.sweep.search_space import Choice, SweepDistribution
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+
+class NlpSearchSpace(RestTranslatableMixin):
+ """Search space for AutoML NLP tasks.
+
+ :param gradient_accumulation_steps: number of steps over which to accumulate gradients before a backward
+ pass. This must be a positive integer., defaults to None
+ :type gradient_accumulation_steps: Optional[Union[int, SweepDistribution]]
+ :param learning_rate: initial learning rate. Must be a float in (0, 1), defaults to None
+ :type learning_rate: Optional[Union[float, SweepDistribution]]
+ :param learning_rate_scheduler: the type of learning rate scheduler. Must choose from 'linear', 'cosine',
+ 'cosine_with_restarts', 'polynomial', 'constant', and 'constant_with_warmup', defaults to None
+ :type learning_rate_scheduler: Optional[Union[str, SweepDistribution]]
+ :param model_name: the model name to use during training. Must choose from 'bert-base-cased',
+ 'bert-base-uncased', 'bert-base-multilingual-cased', 'bert-base-german-cased', 'bert-large-cased',
+ 'bert-large-uncased', 'distilbert-base-cased', 'distilbert-base-uncased', 'roberta-base', 'roberta-large',
+ 'distilroberta-base', 'xlm-roberta-base', 'xlm-roberta-large', xlnet-base-cased', and 'xlnet-large-cased',
+ defaults to None
+ :type model_name: Optional[Union[str, SweepDistribution]]
+ :param number_of_epochs: the number of epochs to train with. Must be a positive integer, defaults to None
+ :type number_of_epochs: Optional[Union[int, SweepDistribution]]
+ :param training_batch_size: the batch size during training. Must be a positive integer, defaults to None
+ :type training_batch_size: Optional[Union[int, SweepDistribution]]
+ :param validation_batch_size: the batch size during validation. Must be a positive integer, defaults to None
+ :type validation_batch_size: Optional[Union[int, SweepDistribution]]
+ :param warmup_ratio: ratio of total training steps used for a linear warmup from 0 to learning_rate.
+ Must be a float in [0, 1], defaults to None
+ :type warmup_ratio: Optional[Union[float, SweepDistribution]]
+ :param weight_decay: value of weight decay when optimizer is sgd, adam, or adamw. This must be a float in
+ the range [0, 1], defaults to None
+ :type weight_decay: Optional[Union[float, SweepDistribution]]
+
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_automl_nlp.py
+ :start-after: [START automl.nlp_search_space]
+ :end-before: [END automl.nlp_search_space]
+ :language: python
+ :dedent: 8
+ :caption: creating an nlp search space
+ """
+
+ def __init__(
+ self,
+ *,
+ gradient_accumulation_steps: Optional[Union[int, SweepDistribution]] = None,
+ learning_rate: Optional[Union[float, SweepDistribution]] = None,
+ learning_rate_scheduler: Optional[Union[str, SweepDistribution]] = None,
+ model_name: Optional[Union[str, SweepDistribution]] = None,
+ number_of_epochs: Optional[Union[int, SweepDistribution]] = None,
+ training_batch_size: Optional[Union[int, SweepDistribution]] = None,
+ validation_batch_size: Optional[Union[int, SweepDistribution]] = None,
+ warmup_ratio: Optional[Union[float, SweepDistribution]] = None,
+ weight_decay: Optional[Union[float, SweepDistribution]] = None
+ ):
+ # Since we want customers to be able to specify enums as well rather than just strings, we need to access
+ # the enum values here before we serialize them ('NlpModels.BERT_BASE_CASED' vs. 'bert-base-cased').
+ if isinstance(learning_rate_scheduler, NlpLearningRateScheduler):
+ learning_rate_scheduler = camel_to_snake(learning_rate_scheduler.value)
+ elif isinstance(learning_rate_scheduler, Choice):
+ if learning_rate_scheduler.values is not None:
+ learning_rate_scheduler.values = [
+ camel_to_snake(item.value) if isinstance(item, NlpLearningRateScheduler) else item
+ for item in learning_rate_scheduler.values
+ ]
+
+ if isinstance(model_name, NlpModels):
+ model_name = model_name.value
+ elif isinstance(model_name, Choice):
+ if model_name.values is not None:
+ model_name.values = [item.value if isinstance(item, NlpModels) else item for item in model_name.values]
+
+ self.gradient_accumulation_steps = gradient_accumulation_steps
+ self.learning_rate = learning_rate
+ self.learning_rate_scheduler = learning_rate_scheduler
+ self.model_name = model_name
+ self.number_of_epochs = number_of_epochs
+ self.training_batch_size = training_batch_size
+ self.validation_batch_size = validation_batch_size
+ self.warmup_ratio = warmup_ratio
+ self.weight_decay = weight_decay
+
+ def _to_rest_object(self) -> NlpParameterSubspace:
+ return NlpParameterSubspace(
+ gradient_accumulation_steps=(
+ _convert_to_rest_object(self.gradient_accumulation_steps)
+ if self.gradient_accumulation_steps is not None
+ else None
+ ),
+ learning_rate=_convert_to_rest_object(self.learning_rate) if self.learning_rate is not None else None,
+ learning_rate_scheduler=(
+ _convert_to_rest_object(self.learning_rate_scheduler)
+ if self.learning_rate_scheduler is not None
+ else None
+ ),
+ model_name=_convert_to_rest_object(self.model_name) if self.model_name is not None else None,
+ number_of_epochs=(
+ _convert_to_rest_object(self.number_of_epochs) if self.number_of_epochs is not None else None
+ ),
+ training_batch_size=(
+ _convert_to_rest_object(self.training_batch_size) if self.training_batch_size is not None else None
+ ),
+ validation_batch_size=(
+ _convert_to_rest_object(self.validation_batch_size) if self.validation_batch_size is not None else None
+ ),
+ warmup_ratio=_convert_to_rest_object(self.warmup_ratio) if self.warmup_ratio is not None else None,
+ weight_decay=_convert_to_rest_object(self.weight_decay) if self.weight_decay is not None else None,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: NlpParameterSubspace) -> "NlpSearchSpace":
+ return cls(
+ gradient_accumulation_steps=(
+ _convert_from_rest_object(obj.gradient_accumulation_steps)
+ if obj.gradient_accumulation_steps is not None
+ else None
+ ),
+ learning_rate=_convert_from_rest_object(obj.learning_rate) if obj.learning_rate is not None else None,
+ learning_rate_scheduler=(
+ _convert_from_rest_object(obj.learning_rate_scheduler)
+ if obj.learning_rate_scheduler is not None
+ else None
+ ),
+ model_name=_convert_from_rest_object(obj.model_name) if obj.model_name is not None else None,
+ number_of_epochs=(
+ _convert_from_rest_object(obj.number_of_epochs) if obj.number_of_epochs is not None else None
+ ),
+ training_batch_size=(
+ _convert_from_rest_object(obj.training_batch_size) if obj.training_batch_size is not None else None
+ ),
+ validation_batch_size=(
+ _convert_from_rest_object(obj.validation_batch_size) if obj.validation_batch_size is not None else None
+ ),
+ warmup_ratio=_convert_from_rest_object(obj.warmup_ratio) if obj.warmup_ratio is not None else None,
+ weight_decay=_convert_from_rest_object(obj.weight_decay) if obj.weight_decay is not None else None,
+ )
+
+ @classmethod
+ def _from_search_space_object(cls, obj: SearchSpace) -> "NlpSearchSpace":
+ return cls(
+ gradient_accumulation_steps=(
+ obj.gradient_accumulation_steps if hasattr(obj, "gradient_accumulation_steps") else None
+ ),
+ learning_rate=obj.learning_rate if hasattr(obj, "learning_rate") else None,
+ learning_rate_scheduler=obj.learning_rate_scheduler if hasattr(obj, "learning_rate_scheduler") else None,
+ model_name=obj.model_name if hasattr(obj, "model_name") else None,
+ number_of_epochs=obj.number_of_epochs if hasattr(obj, "number_of_epochs") else None,
+ training_batch_size=obj.training_batch_size if hasattr(obj, "training_batch_size") else None,
+ validation_batch_size=obj.validation_batch_size if hasattr(obj, "validation_batch_size") else None,
+ warmup_ratio=obj.warmup_ratio if hasattr(obj, "warmup_ratio") else None,
+ weight_decay=obj.weight_decay if hasattr(obj, "weight_decay") else None,
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, NlpSearchSpace):
+ return NotImplemented
+
+ return (
+ self.gradient_accumulation_steps == other.gradient_accumulation_steps
+ and self.learning_rate == other.learning_rate
+ and self.learning_rate_scheduler == other.learning_rate_scheduler
+ and self.model_name == other.model_name
+ and self.number_of_epochs == other.number_of_epochs
+ and self.training_batch_size == other.training_batch_size
+ and self.validation_batch_size == other.validation_batch_size
+ and self.warmup_ratio == other.warmup_ratio
+ and self.weight_decay == other.weight_decay
+ )
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/nlp_sweep_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/nlp_sweep_settings.py
new file mode 100644
index 00000000..e446a30c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/nlp_sweep_settings.py
@@ -0,0 +1,65 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import NlpSweepSettings as RestNlpSweepSettings
+from azure.ai.ml._restclient.v2023_04_01_preview.models import SamplingAlgorithmType
+from azure.ai.ml.entities._job.sweep.early_termination_policy import EarlyTerminationPolicy
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+
+# pylint: disable=protected-access
+class NlpSweepSettings(RestTranslatableMixin):
+ """Sweep settings for all AutoML NLP tasks.
+
+ :param sampling_algorithm: Required. Specifies type of hyperparameter sampling algorithm.
+ Possible values include: "Grid", "Random", and "Bayesian".
+ :type sampling_algorithm: Union[str, ~azure.ai.ml.automl.SamplingAlgorithmType]
+ :param early_termination: Early termination policy to end poorly performing training candidates,
+ defaults to None.
+ :type early_termination: Optional[~azure.mgmt.machinelearningservices.models.EarlyTerminationPolicy]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_automl_nlp.py
+ :start-after: [START automl.nlp_sweep_settings]
+ :end-before: [END automl.nlp_sweep_settings]
+ :language: python
+ :dedent: 8
+ :caption: creating an nlp sweep settings
+ """
+
+ def __init__(
+ self,
+ *,
+ sampling_algorithm: Union[str, SamplingAlgorithmType],
+ early_termination: Optional[EarlyTerminationPolicy] = None,
+ ):
+ self.sampling_algorithm = sampling_algorithm
+ self.early_termination = early_termination
+
+ def _to_rest_object(self) -> RestNlpSweepSettings:
+ return RestNlpSweepSettings(
+ sampling_algorithm=self.sampling_algorithm,
+ early_termination=self.early_termination._to_rest_object() if self.early_termination else None,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestNlpSweepSettings) -> "NlpSweepSettings":
+ return cls(
+ sampling_algorithm=obj.sampling_algorithm,
+ early_termination=(
+ EarlyTerminationPolicy._from_rest_object(obj.early_termination) if obj.early_termination else None
+ ),
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, NlpSweepSettings):
+ return NotImplemented
+
+ return self.sampling_algorithm == other.sampling_algorithm and self.early_termination == other.early_termination
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/text_classification_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/text_classification_job.py
new file mode 100644
index 00000000..290f4f70
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/text_classification_job.py
@@ -0,0 +1,248 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from typing import TYPE_CHECKING, Any, Dict, Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import AutoMLJob as RestAutoMLJob
+from azure.ai.ml._restclient.v2023_04_01_preview.models import JobBase, TaskType
+from azure.ai.ml._restclient.v2023_04_01_preview.models._azure_machine_learning_workspaces_enums import (
+ ClassificationPrimaryMetrics,
+)
+from azure.ai.ml._restclient.v2024_01_01_preview.models import TextClassification as RestTextClassification
+from azure.ai.ml._utils.utils import camel_to_snake, is_data_binding_expression
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+from azure.ai.ml.constants._job.automl import AutoMLConstants
+from azure.ai.ml.entities._credentials import _BaseJobIdentityConfiguration
+from azure.ai.ml.entities._inputs_outputs import Input
+from azure.ai.ml.entities._job._input_output_helpers import from_rest_data_outputs, to_rest_data_outputs
+from azure.ai.ml.entities._job.automl.nlp.automl_nlp_job import AutoMLNLPJob
+from azure.ai.ml.entities._job.automl.nlp.nlp_featurization_settings import NlpFeaturizationSettings
+from azure.ai.ml.entities._job.automl.nlp.nlp_fixed_parameters import NlpFixedParameters
+from azure.ai.ml.entities._job.automl.nlp.nlp_limit_settings import NlpLimitSettings
+from azure.ai.ml.entities._job.automl.nlp.nlp_sweep_settings import NlpSweepSettings
+from azure.ai.ml.entities._system_data import SystemData
+from azure.ai.ml.entities._util import load_from_dict
+
+# avoid circular import error
+if TYPE_CHECKING:
+ from azure.ai.ml.entities._component.component import Component
+
+
+class TextClassificationJob(AutoMLNLPJob):
+ """Configuration for AutoML Text Classification Job.
+
+ :param target_column_name: The name of the target column, defaults to None
+ :type target_column_name: Optional[str]
+ :param training_data: Training data to be used for training, defaults to None
+ :type training_data: Optional[~azure.ai.ml.Input]
+ :param validation_data: Validation data to be used for evaluating the trained model, defaults to None
+ :type validation_data: Optional[~azure.ai.ml.Input]
+ :param primary_metric: The primary metric to be displayed, defaults to None
+ :type primary_metric: Optional[~azure.ai.ml.automl.ClassificationPrimaryMetrics]
+ :param log_verbosity: Log verbosity level, defaults to None
+ :type log_verbosity: Optional[str]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_automl_nlp.py
+ :start-after: [START automl.automl_nlp_job.text_classification_job]
+ :end-before: [END automl.automl_nlp_job.text_classification_job]
+ :language: python
+ :dedent: 8
+ :caption: creating an automl text classification job
+ """
+
+ _DEFAULT_PRIMARY_METRIC = ClassificationPrimaryMetrics.ACCURACY
+
+ def __init__(
+ self,
+ *,
+ target_column_name: Optional[str] = None,
+ training_data: Optional[Input] = None,
+ validation_data: Optional[Input] = None,
+ primary_metric: Optional[ClassificationPrimaryMetrics] = None,
+ log_verbosity: Optional[str] = None,
+ **kwargs: Any
+ ):
+ super().__init__(
+ task_type=TaskType.TEXT_CLASSIFICATION,
+ primary_metric=primary_metric or TextClassificationJob._DEFAULT_PRIMARY_METRIC,
+ target_column_name=target_column_name,
+ training_data=training_data,
+ validation_data=validation_data,
+ log_verbosity=log_verbosity,
+ **kwargs,
+ )
+
+ @property
+ def primary_metric(self) -> Union[str, ClassificationPrimaryMetrics]:
+ return self._primary_metric
+
+ @primary_metric.setter
+ def primary_metric(self, value: Union[str, ClassificationPrimaryMetrics]) -> None:
+ """setter for primary metric
+
+ :param value: _description_
+ :type value: Union[str, ClassificationPrimaryMetrics]
+ """
+ if is_data_binding_expression(str(value), ["parent"]):
+ self._primary_metric = value
+ return
+
+ self._primary_metric = (
+ TextClassificationJob._DEFAULT_PRIMARY_METRIC
+ if value is None
+ else ClassificationPrimaryMetrics[camel_to_snake(value).upper()]
+ )
+
+ def _to_rest_object(self) -> JobBase:
+ text_classification = RestTextClassification(
+ target_column_name=self.target_column_name,
+ training_data=self.training_data,
+ validation_data=self.validation_data,
+ limit_settings=self._limits._to_rest_object() if self._limits else None,
+ sweep_settings=self._sweep._to_rest_object() if self._sweep else None,
+ fixed_parameters=self._training_parameters._to_rest_object() if self._training_parameters else None,
+ search_space=(
+ [entry._to_rest_object() for entry in self._search_space if entry is not None]
+ if self._search_space is not None
+ else None
+ ),
+ featurization_settings=self._featurization._to_rest_object() if self._featurization else None,
+ primary_metric=self.primary_metric,
+ log_verbosity=self.log_verbosity,
+ )
+ # resolve data inputs in rest object
+ self._resolve_data_inputs(text_classification)
+
+ properties = RestAutoMLJob(
+ display_name=self.display_name,
+ description=self.description,
+ experiment_name=self.experiment_name,
+ tags=self.tags,
+ compute_id=self.compute,
+ properties=self.properties,
+ environment_id=self.environment_id,
+ environment_variables=self.environment_variables,
+ services=self.services,
+ outputs=to_rest_data_outputs(self.outputs),
+ resources=self.resources,
+ task_details=text_classification,
+ identity=self.identity._to_job_rest_object() if self.identity else None,
+ queue_settings=self.queue_settings,
+ )
+
+ result = JobBase(properties=properties)
+ result.name = self.name
+ return result
+
+ @classmethod
+ def _from_rest_object(cls, obj: JobBase) -> "TextClassificationJob":
+ properties: RestAutoMLJob = obj.properties
+ task_details: RestTextClassification = properties.task_details
+ assert isinstance(task_details, RestTextClassification)
+ limits = (
+ NlpLimitSettings._from_rest_object(task_details.limit_settings) if task_details.limit_settings else None
+ )
+ featurization = (
+ NlpFeaturizationSettings._from_rest_object(task_details.featurization_settings)
+ if task_details.featurization_settings
+ else None
+ )
+ sweep = NlpSweepSettings._from_rest_object(task_details.sweep_settings) if task_details.sweep_settings else None
+ training_parameters = (
+ NlpFixedParameters._from_rest_object(task_details.fixed_parameters)
+ if task_details.fixed_parameters
+ else None
+ )
+
+ text_classification_job = cls(
+ # ----- job specific params
+ id=obj.id,
+ name=obj.name,
+ description=properties.description,
+ tags=properties.tags,
+ properties=properties.properties,
+ experiment_name=properties.experiment_name,
+ services=properties.services,
+ status=properties.status,
+ creation_context=SystemData._from_rest_object(obj.system_data) if obj.system_data else None,
+ display_name=properties.display_name,
+ compute=properties.compute_id,
+ outputs=from_rest_data_outputs(properties.outputs),
+ resources=properties.resources,
+ # ----- task specific params
+ primary_metric=task_details.primary_metric,
+ log_verbosity=task_details.log_verbosity,
+ target_column_name=task_details.target_column_name,
+ training_data=task_details.training_data,
+ validation_data=task_details.validation_data,
+ limits=limits,
+ sweep=sweep,
+ training_parameters=training_parameters,
+ search_space=cls._get_search_space_from_str(task_details.search_space),
+ featurization=featurization,
+ identity=(
+ _BaseJobIdentityConfiguration._from_rest_object(properties.identity) if properties.identity else None
+ ),
+ queue_settings=properties.queue_settings,
+ )
+
+ text_classification_job._restore_data_inputs()
+
+ return text_classification_job
+
+ def _to_component(self, context: Optional[Dict] = None, **kwargs: Any) -> "Component":
+ raise NotImplementedError()
+
+ @classmethod
+ def _load_from_dict(
+ cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any
+ ) -> "TextClassificationJob":
+ from azure.ai.ml._schema.automl.nlp_vertical.text_classification import TextClassificationSchema
+
+ if kwargs.pop("inside_pipeline", False):
+ from azure.ai.ml._schema.pipeline.automl_node import AutoMLTextClassificationNode
+
+ loaded_data = load_from_dict(
+ AutoMLTextClassificationNode,
+ data,
+ context,
+ additional_message,
+ **kwargs,
+ )
+ else:
+ loaded_data = load_from_dict(TextClassificationSchema, data, context, additional_message, **kwargs)
+ job_instance = cls._create_instance_from_schema_dict(loaded_data)
+ return job_instance
+
+ @classmethod
+ def _create_instance_from_schema_dict(cls, loaded_data: Dict) -> "TextClassificationJob":
+ loaded_data.pop(AutoMLConstants.TASK_TYPE_YAML, None)
+ return TextClassificationJob(**loaded_data)
+
+ def _to_dict(self, inside_pipeline: bool = False) -> Dict:
+ from azure.ai.ml._schema.automl.nlp_vertical.text_classification import TextClassificationSchema
+ from azure.ai.ml._schema.pipeline.automl_node import AutoMLTextClassificationNode
+
+ if inside_pipeline:
+ res_autoML: dict = AutoMLTextClassificationNode(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res_autoML
+
+ res: dict = TextClassificationSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, TextClassificationJob):
+ return NotImplemented
+
+ if not super(TextClassificationJob, self).__eq__(other):
+ return False
+
+ return self.primary_metric == other.primary_metric
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/text_classification_multilabel_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/text_classification_multilabel_job.py
new file mode 100644
index 00000000..ac19b451
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/text_classification_multilabel_job.py
@@ -0,0 +1,252 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from typing import TYPE_CHECKING, Any, Dict, Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import AutoMLJob as RestAutoMLJob
+from azure.ai.ml._restclient.v2023_04_01_preview.models import ClassificationMultilabelPrimaryMetrics, JobBase, TaskType
+from azure.ai.ml._restclient.v2024_01_01_preview.models import (
+ TextClassificationMultilabel as RestTextClassificationMultilabel,
+)
+from azure.ai.ml._utils.utils import camel_to_snake, is_data_binding_expression
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+from azure.ai.ml.constants._job.automl import AutoMLConstants
+from azure.ai.ml.entities._credentials import _BaseJobIdentityConfiguration
+from azure.ai.ml.entities._inputs_outputs import Input
+from azure.ai.ml.entities._job._input_output_helpers import from_rest_data_outputs, to_rest_data_outputs
+from azure.ai.ml.entities._job.automl.nlp.automl_nlp_job import AutoMLNLPJob
+from azure.ai.ml.entities._job.automl.nlp.nlp_featurization_settings import NlpFeaturizationSettings
+from azure.ai.ml.entities._job.automl.nlp.nlp_fixed_parameters import NlpFixedParameters
+from azure.ai.ml.entities._job.automl.nlp.nlp_limit_settings import NlpLimitSettings
+from azure.ai.ml.entities._job.automl.nlp.nlp_sweep_settings import NlpSweepSettings
+from azure.ai.ml.entities._system_data import SystemData
+from azure.ai.ml.entities._util import load_from_dict
+
+# avoid circular import error
+if TYPE_CHECKING:
+ from azure.ai.ml.entities._component.component import Component
+
+
+class TextClassificationMultilabelJob(AutoMLNLPJob):
+ """Configuration for AutoML Text Classification Multilabel Job.
+
+ :param target_column_name: The name of the target column, defaults to None
+ :type target_column_name: Optional[str]
+ :param training_data: Training data to be used for training, defaults to None
+ :type training_data: Optional[~azure.ai.ml.Input]
+ :param validation_data: Validation data to be used for evaluating the trained model, defaults to None
+ :type validation_data: Optional[~azure.ai.ml.Input]
+ :param primary_metric: The primary metric to be displayed., defaults to None
+ :type primary_metric: Optional[str]
+ :param log_verbosity: Log verbosity level, defaults to None
+ :type log_verbosity: Optional[str]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_automl_nlp.py
+ :start-after: [START automl.text_classification_multilabel_job]
+ :end-before: [END automl.text_classification_multilabel_job]
+ :language: python
+ :dedent: 8
+ :caption: creating an automl text classification multilabel job
+ """
+
+ _DEFAULT_PRIMARY_METRIC = ClassificationMultilabelPrimaryMetrics.ACCURACY
+
+ def __init__(
+ self,
+ *,
+ target_column_name: Optional[str] = None,
+ training_data: Optional[Input] = None,
+ validation_data: Optional[Input] = None,
+ primary_metric: Optional[str] = None,
+ log_verbosity: Optional[str] = None,
+ **kwargs: Any
+ ):
+ super().__init__(
+ task_type=TaskType.TEXT_CLASSIFICATION_MULTILABEL,
+ primary_metric=primary_metric or TextClassificationMultilabelJob._DEFAULT_PRIMARY_METRIC,
+ target_column_name=target_column_name,
+ training_data=training_data,
+ validation_data=validation_data,
+ log_verbosity=log_verbosity,
+ **kwargs,
+ )
+
+ @property
+ def primary_metric(self) -> Union[str, ClassificationMultilabelPrimaryMetrics]:
+ return self._primary_metric
+
+ @primary_metric.setter
+ def primary_metric(self, value: Union[str, ClassificationMultilabelPrimaryMetrics]) -> None:
+ if is_data_binding_expression(str(value), ["parent"]):
+ self._primary_metric = value
+ return
+
+ self._primary_metric = (
+ TextClassificationMultilabelJob._DEFAULT_PRIMARY_METRIC
+ if value is None
+ else ClassificationMultilabelPrimaryMetrics[camel_to_snake(value).upper()]
+ )
+
+ def _to_rest_object(self) -> JobBase:
+ text_classification_multilabel = RestTextClassificationMultilabel(
+ target_column_name=self.target_column_name,
+ training_data=self.training_data,
+ validation_data=self.validation_data,
+ limit_settings=self._limits._to_rest_object() if self._limits else None,
+ sweep_settings=self._sweep._to_rest_object() if self._sweep else None,
+ fixed_parameters=self._training_parameters._to_rest_object() if self._training_parameters else None,
+ search_space=(
+ [entry._to_rest_object() for entry in self._search_space if entry is not None]
+ if self._search_space is not None
+ else None
+ ),
+ featurization_settings=self._featurization._to_rest_object() if self._featurization else None,
+ primary_metric=self.primary_metric,
+ log_verbosity=self.log_verbosity,
+ )
+ # resolve data inputs in rest object
+ self._resolve_data_inputs(text_classification_multilabel)
+
+ properties = RestAutoMLJob(
+ display_name=self.display_name,
+ description=self.description,
+ experiment_name=self.experiment_name,
+ tags=self.tags,
+ compute_id=self.compute,
+ properties=self.properties,
+ environment_id=self.environment_id,
+ environment_variables=self.environment_variables,
+ services=self.services,
+ outputs=to_rest_data_outputs(self.outputs),
+ resources=self.resources,
+ task_details=text_classification_multilabel,
+ identity=self.identity._to_job_rest_object() if self.identity else None,
+ queue_settings=self.queue_settings,
+ )
+
+ result = JobBase(properties=properties)
+ result.name = self.name
+ return result
+
+ @classmethod
+ def _from_rest_object(cls, obj: JobBase) -> "TextClassificationMultilabelJob":
+ properties: RestAutoMLJob = obj.properties
+ task_details: RestTextClassificationMultilabel = properties.task_details
+ assert isinstance(task_details, RestTextClassificationMultilabel)
+ limits = (
+ NlpLimitSettings._from_rest_object(task_details.limit_settings) if task_details.limit_settings else None
+ )
+ featurization = (
+ NlpFeaturizationSettings._from_rest_object(task_details.featurization_settings)
+ if task_details.featurization_settings
+ else None
+ )
+ sweep = NlpSweepSettings._from_rest_object(task_details.sweep_settings) if task_details.sweep_settings else None
+ training_parameters = (
+ NlpFixedParameters._from_rest_object(task_details.fixed_parameters)
+ if task_details.fixed_parameters
+ else None
+ )
+
+ text_classification_multilabel_job = cls(
+ # ----- job specific params
+ id=obj.id,
+ name=obj.name,
+ description=properties.description,
+ tags=properties.tags,
+ properties=properties.properties,
+ experiment_name=properties.experiment_name,
+ services=properties.services,
+ status=properties.status,
+ creation_context=SystemData._from_rest_object(obj.system_data) if obj.system_data else None,
+ display_name=properties.display_name,
+ compute=properties.compute_id,
+ outputs=from_rest_data_outputs(properties.outputs),
+ resources=properties.resources,
+ # ----- task specific params
+ primary_metric=task_details.primary_metric,
+ log_verbosity=task_details.log_verbosity,
+ target_column_name=task_details.target_column_name,
+ training_data=task_details.training_data,
+ validation_data=task_details.validation_data,
+ limits=limits,
+ sweep=sweep,
+ training_parameters=training_parameters,
+ search_space=cls._get_search_space_from_str(task_details.search_space),
+ featurization=featurization,
+ identity=(
+ _BaseJobIdentityConfiguration._from_rest_object(properties.identity) if properties.identity else None
+ ),
+ queue_settings=properties.queue_settings,
+ )
+
+ text_classification_multilabel_job._restore_data_inputs()
+
+ return text_classification_multilabel_job
+
+ def _to_component(self, context: Optional[Dict] = None, **kwargs: Any) -> "Component":
+ raise NotImplementedError()
+
+ @classmethod
+ def _load_from_dict(
+ cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any
+ ) -> "TextClassificationMultilabelJob":
+ from azure.ai.ml._schema.automl.nlp_vertical.text_classification_multilabel import (
+ TextClassificationMultilabelSchema,
+ )
+
+ if kwargs.pop("inside_pipeline", False):
+ from azure.ai.ml._schema.pipeline.automl_node import AutoMLTextClassificationMultilabelNode
+
+ loaded_data = load_from_dict(
+ AutoMLTextClassificationMultilabelNode,
+ data,
+ context,
+ additional_message,
+ **kwargs,
+ )
+ else:
+ loaded_data = load_from_dict(
+ TextClassificationMultilabelSchema,
+ data,
+ context,
+ additional_message,
+ **kwargs,
+ )
+ job_instance = cls._create_instance_from_schema_dict(loaded_data)
+ return job_instance
+
+ @classmethod
+ def _create_instance_from_schema_dict(cls, loaded_data: Dict) -> "TextClassificationMultilabelJob":
+ loaded_data.pop(AutoMLConstants.TASK_TYPE_YAML, None)
+ return TextClassificationMultilabelJob(**loaded_data)
+
+ def _to_dict(self, inside_pipeline: bool = False) -> Dict:
+ from azure.ai.ml._schema.automl.nlp_vertical.text_classification_multilabel import (
+ TextClassificationMultilabelSchema,
+ )
+ from azure.ai.ml._schema.pipeline.automl_node import AutoMLTextClassificationMultilabelNode
+
+ if inside_pipeline:
+ res_autoML: dict = AutoMLTextClassificationMultilabelNode(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res_autoML
+
+ res: dict = TextClassificationMultilabelSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, TextClassificationMultilabelJob):
+ return NotImplemented
+
+ if not super(TextClassificationMultilabelJob, self).__eq__(other):
+ return False
+
+ return self.primary_metric == other.primary_metric
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/text_ner_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/text_ner_job.py
new file mode 100644
index 00000000..a87965f1
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/nlp/text_ner_job.py
@@ -0,0 +1,231 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from typing import TYPE_CHECKING, Any, Dict, Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import AutoMLJob as RestAutoMLJob
+from azure.ai.ml._restclient.v2023_04_01_preview.models import JobBase, TaskType
+from azure.ai.ml._restclient.v2023_04_01_preview.models._azure_machine_learning_workspaces_enums import (
+ ClassificationPrimaryMetrics,
+)
+from azure.ai.ml._restclient.v2024_01_01_preview.models import TextNer as RestTextNER
+from azure.ai.ml._utils.utils import camel_to_snake, is_data_binding_expression
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+from azure.ai.ml.constants._job.automl import AutoMLConstants
+from azure.ai.ml.entities._credentials import _BaseJobIdentityConfiguration
+from azure.ai.ml.entities._inputs_outputs import Input
+from azure.ai.ml.entities._job._input_output_helpers import from_rest_data_outputs, to_rest_data_outputs
+from azure.ai.ml.entities._job.automl.nlp.automl_nlp_job import AutoMLNLPJob
+from azure.ai.ml.entities._job.automl.nlp.nlp_featurization_settings import NlpFeaturizationSettings
+from azure.ai.ml.entities._job.automl.nlp.nlp_fixed_parameters import NlpFixedParameters
+from azure.ai.ml.entities._job.automl.nlp.nlp_limit_settings import NlpLimitSettings
+from azure.ai.ml.entities._job.automl.nlp.nlp_sweep_settings import NlpSweepSettings
+from azure.ai.ml.entities._system_data import SystemData
+from azure.ai.ml.entities._util import load_from_dict
+
+# avoid circular import error
+if TYPE_CHECKING:
+ from azure.ai.ml.entities._component.component import Component
+
+
+class TextNerJob(AutoMLNLPJob):
+ """Configuration for AutoML Text NER Job.
+
+ :param training_data: Training data to be used for training, defaults to None
+ :type training_data: Optional[~azure.ai.ml.Input]
+ :param validation_data: Validation data to be used for evaluating the trained model,
+ defaults to None
+ :type validation_data: Optional[~azure.ai.ml.Input]
+ :param primary_metric: The primary metric to be displayed, defaults to None
+ :type primary_metric: Optional[str]
+ :param log_verbosity: Log verbosity level, defaults to None
+ :type log_verbosity: Optional[str]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_automl_nlp.py
+ :start-after: [START automl.text_ner_job]
+ :end-before: [END automl.text_ner_job]
+ :language: python
+ :dedent: 8
+ :caption: creating an automl text ner job
+ """
+
+ _DEFAULT_PRIMARY_METRIC = ClassificationPrimaryMetrics.ACCURACY
+
+ def __init__(
+ self,
+ *,
+ training_data: Optional[Input] = None,
+ validation_data: Optional[Input] = None,
+ primary_metric: Optional[str] = None,
+ log_verbosity: Optional[str] = None,
+ **kwargs: Any
+ ):
+ super(TextNerJob, self).__init__(
+ task_type=TaskType.TEXT_NER,
+ primary_metric=primary_metric or TextNerJob._DEFAULT_PRIMARY_METRIC,
+ training_data=training_data,
+ validation_data=validation_data,
+ log_verbosity=log_verbosity,
+ **kwargs,
+ )
+
+ @property
+ def primary_metric(self) -> Union[str, ClassificationPrimaryMetrics]:
+ return self._primary_metric
+
+ @primary_metric.setter
+ def primary_metric(self, value: Union[str, ClassificationPrimaryMetrics]) -> None:
+ if is_data_binding_expression(str(value), ["parent"]):
+ self._primary_metric = value
+ return
+
+ self._primary_metric = (
+ TextNerJob._DEFAULT_PRIMARY_METRIC
+ if value is None
+ else ClassificationPrimaryMetrics[camel_to_snake(value).upper()]
+ )
+
+ def _to_rest_object(self) -> JobBase:
+ text_ner = RestTextNER(
+ training_data=self.training_data,
+ validation_data=self.validation_data,
+ limit_settings=self._limits._to_rest_object() if self._limits else None,
+ sweep_settings=self._sweep._to_rest_object() if self._sweep else None,
+ fixed_parameters=self._training_parameters._to_rest_object() if self._training_parameters else None,
+ search_space=(
+ [entry._to_rest_object() for entry in self._search_space if entry is not None]
+ if self._search_space is not None
+ else None
+ ),
+ featurization_settings=self._featurization._to_rest_object() if self._featurization else None,
+ primary_metric=self.primary_metric,
+ log_verbosity=self.log_verbosity,
+ )
+ # resolve data inputs in rest object
+ self._resolve_data_inputs(text_ner)
+
+ properties = RestAutoMLJob(
+ display_name=self.display_name,
+ description=self.description,
+ experiment_name=self.experiment_name,
+ tags=self.tags,
+ compute_id=self.compute,
+ properties=self.properties,
+ environment_id=self.environment_id,
+ environment_variables=self.environment_variables,
+ services=self.services,
+ outputs=to_rest_data_outputs(self.outputs),
+ resources=self.resources,
+ task_details=text_ner,
+ identity=self.identity._to_job_rest_object() if self.identity else None,
+ queue_settings=self.queue_settings,
+ )
+
+ result = JobBase(properties=properties)
+ result.name = self.name
+ return result
+
+ @classmethod
+ def _from_rest_object(cls, obj: JobBase) -> "TextNerJob":
+ properties: RestAutoMLJob = obj.properties
+ task_details: RestTextNER = properties.task_details
+ assert isinstance(task_details, RestTextNER)
+ limits = (
+ NlpLimitSettings._from_rest_object(task_details.limit_settings) if task_details.limit_settings else None
+ )
+ featurization = (
+ NlpFeaturizationSettings._from_rest_object(task_details.featurization_settings)
+ if task_details.featurization_settings
+ else None
+ )
+ sweep = NlpSweepSettings._from_rest_object(task_details.sweep_settings) if task_details.sweep_settings else None
+ training_parameters = (
+ NlpFixedParameters._from_rest_object(task_details.fixed_parameters)
+ if task_details.fixed_parameters
+ else None
+ )
+
+ text_ner_job = cls(
+ # ----- job specific params
+ id=obj.id,
+ name=obj.name,
+ description=properties.description,
+ tags=properties.tags,
+ properties=properties.properties,
+ experiment_name=properties.experiment_name,
+ services=properties.services,
+ status=properties.status,
+ creation_context=SystemData._from_rest_object(obj.system_data) if obj.system_data else None,
+ display_name=properties.display_name,
+ compute=properties.compute_id,
+ outputs=from_rest_data_outputs(properties.outputs),
+ resources=properties.resources,
+ # ----- task specific params
+ primary_metric=task_details.primary_metric,
+ log_verbosity=task_details.log_verbosity,
+ target_column_name=task_details.target_column_name,
+ training_data=task_details.training_data,
+ validation_data=task_details.validation_data,
+ limits=limits,
+ sweep=sweep,
+ training_parameters=training_parameters,
+ search_space=cls._get_search_space_from_str(task_details.search_space),
+ featurization=featurization,
+ identity=(
+ _BaseJobIdentityConfiguration._from_rest_object(properties.identity) if properties.identity else None
+ ),
+ queue_settings=properties.queue_settings,
+ )
+
+ text_ner_job._restore_data_inputs()
+
+ return text_ner_job
+
+ def _to_component(self, context: Optional[Dict] = None, **kwargs: Any) -> "Component":
+ raise NotImplementedError()
+
+ @classmethod
+ def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "TextNerJob":
+ from azure.ai.ml._schema.automl.nlp_vertical.text_ner import TextNerSchema
+
+ if kwargs.pop("inside_pipeline", False):
+ from azure.ai.ml._schema.pipeline.automl_node import AutoMLTextNerNode
+
+ loaded_data = load_from_dict(AutoMLTextNerNode, data, context, additional_message, **kwargs)
+ else:
+ loaded_data = load_from_dict(TextNerSchema, data, context, additional_message, **kwargs)
+ job_instance = cls._create_instance_from_schema_dict(loaded_data)
+ return job_instance
+
+ @classmethod
+ def _create_instance_from_schema_dict(cls, loaded_data: Dict) -> "TextNerJob":
+ loaded_data.pop(AutoMLConstants.TASK_TYPE_YAML, None)
+ return TextNerJob(**loaded_data)
+
+ def _to_dict(self, inside_pipeline: bool = False) -> Dict:
+ from azure.ai.ml._schema.automl.nlp_vertical.text_ner import TextNerSchema
+ from azure.ai.ml._schema.pipeline.automl_node import AutoMLTextNerNode
+
+ if inside_pipeline:
+ res_autoML: dict = AutoMLTextNerNode(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res_autoML
+
+ res: dict = TextNerSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, TextNerJob):
+ return NotImplemented
+
+ if not super(TextNerJob, self).__eq__(other):
+ return False
+
+ return self.primary_metric == other.primary_metric
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/search_space.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/search_space.py
new file mode 100644
index 00000000..a958de56
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/search_space.py
@@ -0,0 +1,14 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+
+from typing import Any
+
+
+class SearchSpace:
+ """SearchSpace class for AutoML verticals."""
+
+ def __init__(self, **kwargs: Any) -> None:
+ for k, v in kwargs.items():
+ self.__setattr__(k, v)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/search_space_utils.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/search_space_utils.py
new file mode 100644
index 00000000..732030d4
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/search_space_utils.py
@@ -0,0 +1,276 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import re
+from typing import Any, List, Union
+
+from marshmallow import fields
+
+from azure.ai.ml._schema._sweep.search_space import (
+ ChoiceSchema,
+ NormalSchema,
+ QNormalSchema,
+ QUniformSchema,
+ RandintSchema,
+ UniformSchema,
+)
+from azure.ai.ml._schema.core.fields import DumpableIntegerField, DumpableStringField, NestedField, UnionField
+from azure.ai.ml._utils.utils import float_to_str
+from azure.ai.ml.constants._job.sweep import SearchSpace
+from azure.ai.ml.entities._job.sweep.search_space import (
+ Choice,
+ LogNormal,
+ LogUniform,
+ Normal,
+ QLogNormal,
+ QLogUniform,
+ QNormal,
+ QUniform,
+ Randint,
+ SweepDistribution,
+ Uniform,
+)
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
+
+
+def _convert_to_rest_object(sweep_distribution: Union[bool, int, float, str, SweepDistribution]) -> str:
+ if isinstance(sweep_distribution, float):
+ # Float requires some special handling for small values that get auto-represented with scientific notation.
+ res: str = float_to_str(sweep_distribution)
+ return res
+ if not isinstance(sweep_distribution, SweepDistribution):
+ # Convert [bool, float, str] types to str
+ return str(sweep_distribution)
+
+ rest_object = sweep_distribution._to_rest_object()
+ if not isinstance(rest_object, list):
+ msg = "Rest Object for sweep distribution should be a list."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.AUTOML,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ if len(rest_object) <= 1:
+ msg = "Rest object for sweep distribution should contain at least two elements."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.AUTOML,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ sweep_distribution_type = rest_object[0]
+ sweep_distribution_args = []
+
+ if not isinstance(rest_object[1], list):
+ msg = "The second element of Rest object for sweep distribution should be a list."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.AUTOML,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ if sweep_distribution_type == SearchSpace.CHOICE:
+ # Rest objects for choice distribution are of format ["choice", [[0, 1, 2]]]
+ if not isinstance(rest_object[1][0], list):
+ msg = "The second element of Rest object for choice distribution should be a list of list."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.AUTOML,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ for value in rest_object[1][0]:
+ if isinstance(value, str):
+ sweep_distribution_args.append("'" + value + "'")
+ elif isinstance(value, float):
+ sweep_distribution_args.append(float_to_str(value))
+ else:
+ sweep_distribution_args.append(str(value))
+ else:
+ for value in rest_object[1]:
+ if isinstance(value, float):
+ sweep_distribution_args.append(float_to_str(value))
+ else:
+ sweep_distribution_args.append(str(value))
+
+ sweep_distribution_str: str = sweep_distribution_type + "("
+ sweep_distribution_str += ",".join(sweep_distribution_args)
+ sweep_distribution_str += ")"
+ return sweep_distribution_str
+
+
+def _is_int(value: str) -> bool:
+ try:
+ int(value)
+ return True
+ except ValueError:
+ return False
+
+
+def _is_float(value: str) -> bool:
+ try:
+ float(value)
+ return True
+ except ValueError:
+ return False
+
+
+def _get_type_inferred_value(value: str) -> Union[bool, int, float, str]:
+ value = value.strip()
+ if _is_int(value):
+ # Int
+ return int(value)
+ if _is_float(value):
+ # Float
+ return float(value)
+ if value in ["True", "False"]:
+ # Convert "True", "False" to python boolean literals
+ return value == "True"
+ # string value. Remove quotes before returning.
+ return value.strip("'\"")
+
+
+def _convert_from_rest_object(
+ sweep_distribution_str: str,
+) -> Any:
+ # sweep_distribution_str can be a distribution like "choice('vitb16r224', 'vits16r224')" or
+ # a single value like "True", "1", "1.0567", "vitb16r224"
+
+ sweep_distribution_str = sweep_distribution_str.strip()
+ # Filter by the delimiters and remove splits that are empty strings
+ sweep_distribution_separated = list(filter(None, re.split("[ ,()]+", sweep_distribution_str)))
+
+ if len(sweep_distribution_separated) == 1:
+ # Single value.
+ return _get_type_inferred_value(sweep_distribution_separated[0])
+
+ # Distribution string
+ sweep_distribution_type = sweep_distribution_separated[0].strip().lower()
+ sweep_distribution_args: List = []
+ for value in sweep_distribution_separated[1:]:
+ sweep_distribution_args.append(_get_type_inferred_value(value))
+
+ if sweep_distribution_type == SearchSpace.CHOICE:
+ sweep_distribution_args = [sweep_distribution_args] # Choice values are list of lists
+
+ sweep_distribution = SweepDistribution._from_rest_object([sweep_distribution_type, sweep_distribution_args])
+ return sweep_distribution
+
+
+def _convert_sweep_dist_dict_to_str_dict(sweep_distribution: dict) -> dict:
+ for k, sweep_dist_dict in sweep_distribution.items():
+ if sweep_dist_dict is not None:
+ sweep_distribution[k] = _convert_sweep_dist_dict_item_to_str(sweep_dist_dict)
+ return sweep_distribution
+
+
+class ChoicePlusSchema(ChoiceSchema):
+ """Choice schema that allows boolean values also"""
+
+ values = fields.List(
+ UnionField(
+ [
+ DumpableIntegerField(strict=True),
+ DumpableStringField(),
+ fields.Float(),
+ fields.Dict(
+ keys=fields.Str(),
+ values=UnionField(
+ [
+ NestedField("ChoicePlusSchema"),
+ NestedField(NormalSchema()),
+ NestedField(QNormalSchema()),
+ NestedField(RandintSchema()),
+ NestedField(UniformSchema()),
+ NestedField(QUniformSchema()),
+ DumpableIntegerField(strict=True),
+ fields.Float(),
+ fields.Str(),
+ fields.Boolean(),
+ ]
+ ),
+ ),
+ fields.Boolean(),
+ ]
+ )
+ )
+
+
+def _convert_sweep_dist_dict_item_to_str(sweep_distribution: Union[bool, int, float, str, dict]) -> str:
+ # Convert a Sweep Distribution dict to Sweep Distribution string
+ # Eg. {type: 'choice', values: ['vitb16r224','vits16r224']} => "Choice('vitb16r224','vits16r224')"
+ if isinstance(sweep_distribution, dict):
+ sweep_dist_type = sweep_distribution["type"]
+ if sweep_dist_type == SearchSpace.CHOICE:
+ sweep_dist_obj = ChoicePlusSchema().load(sweep_distribution) # pylint: disable=no-member
+ elif sweep_dist_type in SearchSpace.UNIFORM_LOGUNIFORM:
+ sweep_dist_obj = UniformSchema().load(sweep_distribution) # pylint: disable=no-member
+ elif sweep_dist_type in SearchSpace.NORMAL_LOGNORMAL:
+ sweep_dist_obj = NormalSchema().load(sweep_distribution) # pylint: disable=no-member
+ elif sweep_dist_type in SearchSpace.QUNIFORM_QLOGUNIFORM:
+ sweep_dist_obj = QUniformSchema().load(sweep_distribution) # pylint: disable=no-member
+ elif sweep_dist_type in SearchSpace.QNORMAL_QLOGNORMAL:
+ sweep_dist_obj = QNormalSchema().load(sweep_distribution) # pylint: disable=no-member
+ elif sweep_dist_type in SearchSpace.RANDINT:
+ sweep_dist_obj = RandintSchema().load(sweep_distribution) # pylint: disable=no-member
+ else:
+ msg = f"Unsupported sweep distribution type {sweep_dist_type}"
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.AUTOML,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ else: # Case for other primitive types
+ sweep_dist_obj = sweep_distribution
+
+ sweep_dist_str = _convert_to_rest_object(sweep_dist_obj)
+ return sweep_dist_str
+
+
+def _convert_sweep_dist_str_to_dict(sweep_dist_str_list: dict) -> dict:
+ for k, val in sweep_dist_str_list.items():
+ if isinstance(val, str):
+ sweep_dist_str_list[k] = _convert_sweep_dist_str_item_to_dict(val)
+ return sweep_dist_str_list
+
+
+def _convert_sweep_dist_str_item_to_dict(
+ sweep_distribution_str: str,
+) -> Union[bool, int, float, str, dict]:
+ # sweep_distribution_str can be a distribution like "choice('vitb16r224', 'vits16r224')"
+ # return type is {type: 'choice', values: ['vitb16r224', 'vits16r224']}
+ sweep_dist_obj = _convert_from_rest_object(sweep_distribution_str)
+ sweep_dist: Union[bool, int, float, str, dict] = ""
+ if isinstance(sweep_dist_obj, SweepDistribution):
+ if isinstance(sweep_dist_obj, Choice):
+ sweep_dist = ChoicePlusSchema().dump(sweep_dist_obj) # pylint: disable=no-member
+ elif isinstance(sweep_dist_obj, (QNormal, QLogNormal)):
+ sweep_dist = QNormalSchema().dump(sweep_dist_obj) # pylint: disable=no-member
+ elif isinstance(sweep_dist_obj, (QUniform, QLogUniform)):
+ sweep_dist = QUniformSchema().dump(sweep_dist_obj) # pylint: disable=no-member
+ elif isinstance(sweep_dist_obj, (Uniform, LogUniform)):
+ sweep_dist = UniformSchema().dump(sweep_dist_obj) # pylint: disable=no-member
+ elif isinstance(sweep_dist_obj, (Normal, LogNormal)):
+ sweep_dist = NormalSchema().dump(sweep_dist_obj) # pylint: disable=no-member
+ elif isinstance(sweep_dist_obj, Randint):
+ sweep_dist = RandintSchema().dump(sweep_dist_obj) # pylint: disable=no-member
+ else:
+ msg = "Invalid sweep distribution {}"
+ raise ValidationException(
+ message=msg.format(sweep_distribution_str),
+ no_personal_data_message=msg,
+ target=ErrorTarget.AUTOML,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ else: # Case for other primitive types
+ sweep_dist = sweep_dist_obj
+
+ return sweep_dist
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/stack_ensemble_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/stack_ensemble_settings.py
new file mode 100644
index 00000000..c17fa7e3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/stack_ensemble_settings.py
@@ -0,0 +1,70 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Any, Optional
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import StackEnsembleSettings as RestStackEnsembleSettings
+from azure.ai.ml._restclient.v2023_04_01_preview.models import StackMetaLearnerType
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+
+class StackEnsembleSettings(RestTranslatableMixin):
+ """Advance setting to customize StackEnsemble run."""
+
+ def __init__(
+ self,
+ *,
+ stack_meta_learner_k_wargs: Optional[Any] = None,
+ stack_meta_learner_train_percentage: float = 0.2,
+ stack_meta_learner_type: Optional[StackMetaLearnerType] = None,
+ **kwargs: Any
+ ):
+ """
+ :param stack_meta_learner_k_wargs: Optional parameters to pass to the initializer of the
+ meta-learner.
+ :type stack_meta_learner_k_wargs: any
+ :param stack_meta_learner_train_percentage: Specifies the proportion of the training set
+ (when choosing train and validation type of training) to be reserved for training the
+ meta-learner. Default value is 0.2.
+ :type stack_meta_learner_train_percentage: float
+ :param stack_meta_learner_type: The meta-learner is a model trained on the output of the
+ individual heterogeneous models. Possible values include: "None", "LogisticRegression",
+ "LogisticRegressionCV", "LightGBMClassifier", "ElasticNet", "ElasticNetCV",
+ "LightGBMRegressor", "LinearRegression".
+ :type stack_meta_learner_type: str or
+ ~azure.mgmt.machinelearningservices.models.StackMetaLearnerType
+ """
+ super(StackEnsembleSettings, self).__init__(**kwargs)
+ self.stack_meta_learner_k_wargs = stack_meta_learner_k_wargs
+ self.stack_meta_learner_train_percentage = stack_meta_learner_train_percentage
+ self.stack_meta_learner_type = stack_meta_learner_type
+
+ def _to_rest_object(self) -> RestStackEnsembleSettings:
+ return RestStackEnsembleSettings(
+ stack_meta_learner_k_wargs=self.stack_meta_learner_k_wargs,
+ stack_meta_learner_train_percentage=self.stack_meta_learner_train_percentage,
+ stack_meta_learner_type=self.stack_meta_learner_type,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestStackEnsembleSettings) -> "StackEnsembleSettings":
+ return cls(
+ stack_meta_learner_k_wargs=obj.stack_meta_learner_k_wargs,
+ stack_meta_learner_train_percentage=obj.stack_meta_learner_train_percentage,
+ stack_meta_learner_type=obj.stack_meta_learner_type,
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, StackEnsembleSettings):
+ return NotImplemented
+
+ return (
+ super().__eq__(other)
+ and self.stack_meta_learner_k_wargs == other.stack_meta_learner_k_wargs
+ and self.stack_meta_learner_train_percentage == other.stack_meta_learner_train_percentage
+ and self.stack_meta_learner_type == other.stack_meta_learner_type
+ )
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/__init__.py
new file mode 100644
index 00000000..c0373010
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/__init__.py
@@ -0,0 +1,22 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from .automl_tabular import AutoMLTabular
+from .classification_job import ClassificationJob
+from .featurization_settings import ColumnTransformer, TabularFeaturizationSettings
+from .forecasting_job import ForecastingJob
+from .forecasting_settings import ForecastingSettings
+from .limit_settings import TabularLimitSettings
+from .regression_job import RegressionJob
+
+__all__ = [
+ "AutoMLTabular",
+ "ClassificationJob",
+ "ColumnTransformer",
+ "ForecastingJob",
+ "ForecastingSettings",
+ "RegressionJob",
+ "TabularFeaturizationSettings",
+ "TabularLimitSettings",
+]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/automl_tabular.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/automl_tabular.py
new file mode 100644
index 00000000..5f4ed22b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/automl_tabular.py
@@ -0,0 +1,607 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=too-many-instance-attributes
+
+from abc import ABC
+from typing import Any, Dict, List, Optional, Union
+
+from azure.ai.ml._restclient.v2024_01_01_preview.models import (
+ AutoNCrossValidations,
+ BlockedTransformers,
+ CustomNCrossValidations,
+ LogVerbosity,
+)
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.constants import TabularTrainingMode
+from azure.ai.ml.constants._job.automl import AutoMLConstants
+from azure.ai.ml.entities._inputs_outputs import Input
+from azure.ai.ml.entities._job.automl.automl_vertical import AutoMLVertical
+from azure.ai.ml.entities._job.automl.stack_ensemble_settings import StackEnsembleSettings
+from azure.ai.ml.entities._job.automl.tabular.featurization_settings import (
+ ColumnTransformer,
+ TabularFeaturizationSettings,
+)
+from azure.ai.ml.entities._job.automl.tabular.limit_settings import TabularLimitSettings
+from azure.ai.ml.entities._job.automl.training_settings import TrainingSettings
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
+
+
+class AutoMLTabular(AutoMLVertical, ABC):
+ """Initialize an AutoML job entity for tabular data.
+
+ Constructor for AutoMLTabular.
+
+ :keyword task_type: The type of task to run. Possible values include: "classification", "regression"
+ , "forecasting".
+ :paramtype task_type: str
+ :keyword featurization: featurization settings. Defaults to None.
+ :paramtype featurization: typing.Optional[TabularFeaturizationSettings]
+ :keyword limits: limits settings. Defaults to None.
+ :paramtype limits: typing.Optional[TabularLimitSettings]
+ :keyword training: training settings. Defaults to None.
+ :paramtype training: typing.Optional[TrainingSettings]
+ :keyword log_verbosity: Verbosity of logging. Possible values include: "debug", "info", "warning", "error",
+ "critical". Defaults to "info".
+ :paramtype log_verbosity: str
+ :keyword target_column_name: The name of the target column. Defaults to None.
+ :paramtype target_column_name: typing.Optional[str]
+ :keyword weight_column_name: The name of the weight column. Defaults to None.
+ :paramtype weight_column_name: typing.Optional[str]
+ :keyword validation_data_size: The size of the validation data. Defaults to None.
+ :paramtype validation_data_size: typing.Optional[float]
+ :keyword cv_split_column_names: The names of the columns to use for cross validation. Defaults to None.
+ :paramtype cv_split_column_names: typing.Optional[List[str]]
+ :keyword n_cross_validations: The number of cross validations to run. Defaults to None.
+ :paramtype n_cross_validations: typing.Optional[int]
+ :keyword test_data_size: The size of the test data. Defaults to None.
+ :paramtype test_data_size: typing.Optional[float]
+ :keyword training_data: The training data. Defaults to None.
+ :paramtype training_data: typing.Optional[azure.ai.ml.entities._inputs_outputs.Input]
+ :keyword validation_data: The validation data. Defaults to None.
+ :paramtype validation_data: typing.Optional[azure.ai.ml.entities._inputs_outputs.Input]
+ :keyword test_data: The test data. Defaults to None.
+ :paramtype test_data: typing.Optional[azure.ai.ml.entities._inputs_outputs.Input]
+ """
+
+ def __init__(
+ self,
+ *,
+ task_type: str,
+ featurization: Optional[TabularFeaturizationSettings] = None,
+ limits: Optional[TabularLimitSettings] = None,
+ training: Optional[Any] = None,
+ **kwargs: Any,
+ ) -> None:
+ """Initialize an AutoML job entity for tabular data.
+
+ Constructor for AutoMLTabular.
+
+ :keyword task_type: The type of task to run. Possible values include: "classification", "regression"
+ , "forecasting".
+ :paramtype task_type: str
+ :keyword featurization: featurization settings. Defaults to None.
+ :paramtype featurization: typing.Optional[TabularFeaturizationSettings]
+ :keyword limits: limits settings. Defaults to None.
+ :paramtype limits: typing.Optional[TabularLimitSettings]
+ :keyword training: training settings. Defaults to None.
+ :paramtype training: typing.Optional[TrainingSettings]
+ :keyword log_verbosity: Verbosity of logging. Possible values include: "debug", "info", "warning", "error",
+ "critical". Defaults to "info".
+ :paramtype log_verbosity: str
+ :keyword target_column_name: The name of the target column. Defaults to None.
+ :paramtype target_column_name: typing.Optional[str]
+ :keyword weight_column_name: The name of the weight column. Defaults to None.
+ :paramtype weight_column_name: typing.Optional[str]
+ :keyword validation_data_size: The size of the validation data. Defaults to None.
+ :paramtype validation_data_size: typing.Optional[float]
+ :keyword cv_split_column_names: The names of the columns to use for cross validation. Defaults to None.
+ :paramtype cv_split_column_names: typing.Optional[List[str]]
+ :keyword n_cross_validations: The number of cross validations to run. Defaults to None.
+ :paramtype n_cross_validations: typing.Optional[int]
+ :keyword test_data_size: The size of the test data. Defaults to None.
+ :paramtype test_data_size: typing.Optional[float]
+ :keyword training_data: The training data. Defaults to None.
+ :paramtype training_data: typing.Optional[azure.ai.ml.entities._inputs_outputs.Input]
+ :keyword validation_data: The validation data. Defaults to None.
+ :paramtype validation_data: typing.Optional[azure.ai.ml.entities._inputs_outputs.Input]
+ :keyword test_data: The test data. Defaults to None.
+ :paramtype test_data: typing.Optional[azure.ai.ml.entities._inputs_outputs.Input]
+ :raises: :class:`azure.ai.ml.exceptions.ValidationException`
+ """
+ self.log_verbosity = kwargs.pop("log_verbosity", LogVerbosity.INFO)
+
+ self.target_column_name = kwargs.pop("target_column_name", None)
+ self.weight_column_name = kwargs.pop("weight_column_name", None)
+ self.validation_data_size = kwargs.pop("validation_data_size", None)
+ self.cv_split_column_names = kwargs.pop("cv_split_column_names", None)
+ self.n_cross_validations = kwargs.pop("n_cross_validations", None)
+ self.test_data_size = kwargs.pop("test_data_size", None)
+
+ super().__init__(
+ task_type=task_type,
+ training_data=kwargs.pop("training_data", None),
+ validation_data=kwargs.pop("validation_data", None),
+ test_data=kwargs.pop("test_data", None),
+ **kwargs,
+ )
+
+ self._featurization = featurization
+ self._limits = limits
+ self._training = training
+
+ @property
+ def log_verbosity(self) -> LogVerbosity:
+ """Get the log verbosity for the AutoML job.
+
+ :return: log verbosity for the AutoML job
+ :rtype: LogVerbosity
+ """
+ return self._log_verbosity
+
+ @log_verbosity.setter
+ def log_verbosity(self, value: Union[str, LogVerbosity]) -> None:
+ """Set the log verbosity for the AutoML job.
+
+ :param value: str or LogVerbosity
+ :type value: typing.Union[str, LogVerbosity]
+ """
+ self._log_verbosity = None if value is None else LogVerbosity[camel_to_snake(value).upper()]
+
+ @property
+ def limits(self) -> Optional[TabularLimitSettings]:
+ """Get the tabular limits for the AutoML job.
+
+ :return: Tabular limits for the AutoML job
+ :rtype: TabularLimitSettings
+ """
+ return self._limits
+
+ @limits.setter
+ def limits(self, value: Union[Dict, TabularLimitSettings]) -> None:
+ """Set the limits for the AutoML job.
+
+ :param value: typing.Dict or TabularLimitSettings
+ :type value: typing.Union[typing.Dict, TabularLimitSettings]
+ :raises ValidationException: Expected a dictionary for limit settings.
+ """
+ if isinstance(value, TabularLimitSettings):
+ self._limits = value
+ else:
+ if not isinstance(value, dict):
+ msg = "Expected a dictionary for limit settings."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.AUTOML,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ self.set_limits(**value)
+
+ @property
+ def training(self) -> Any:
+ """Get the training settings for the AutoML job.
+
+ :return: Training settings for the AutoML job.
+ :rtype: TrainingSettings
+ """
+ return self._training
+
+ @training.setter
+ def training(self, value: Union[Dict, TrainingSettings]) -> None:
+ """Set the training settings for the AutoML job.
+
+ :param value: typing.Dict or TrainingSettings
+ :type value: typing.Union[typing.Dict, TrainingSettings]
+ :raises ValidationException: Expected a dictionary for training settings.
+ """
+ if isinstance(value, TrainingSettings):
+ self._training = value
+ else:
+ if not isinstance(value, dict):
+ msg = "Expected a dictionary for training settings."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.AUTOML,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ self.set_training(**value)
+
+ @property
+ def featurization(self) -> Optional[TabularFeaturizationSettings]:
+ """Get the tabular featurization settings for the AutoML job.
+
+ :return: Tabular featurization settings for the AutoML job
+ :rtype: TabularFeaturizationSettings
+ """
+ return self._featurization
+
+ @featurization.setter
+ def featurization(self, value: Union[Dict, TabularFeaturizationSettings]) -> None:
+ """Set the featurization settings for the AutoML job.
+
+ :param value: typing.Dict or TabularFeaturizationSettings
+ :type value: typing.Union[typing.Dict, TabularFeaturizationSettings]
+ :raises ValidationException: Expected a dictionary for featurization settings
+ """
+ if isinstance(value, TabularFeaturizationSettings):
+ self._featurization = value
+ else:
+ if not isinstance(value, dict):
+ msg = "Expected a dictionary for featurization settings."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.AUTOML,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ self.set_featurization(**value)
+
+ def set_limits(
+ self,
+ *,
+ enable_early_termination: Optional[bool] = None,
+ exit_score: Optional[float] = None,
+ max_concurrent_trials: Optional[int] = None,
+ max_cores_per_trial: Optional[int] = None,
+ max_nodes: Optional[int] = None,
+ max_trials: Optional[int] = None,
+ timeout_minutes: Optional[int] = None,
+ trial_timeout_minutes: Optional[int] = None,
+ ) -> None:
+ """Set limits for the job.
+
+ :keyword enable_early_termination: Whether to enable early termination if the score is not improving in the
+ short term, defaults to None.
+
+ Early stopping logic:
+
+ * No early stopping for first 20 iterations (landmarks).
+ * Early stopping window starts on the 21st iteration and looks for early_stopping_n_iters iterations
+ (currently set to 10). This means that the first iteration where stopping can occur is the 31st.
+ * AutoML still schedules 2 ensemble iterations AFTER early stopping, which might result in higher scores.
+ * Early stopping is triggered if the absolute value of best score calculated is the same for past
+ early_stopping_n_iters iterations, that is, if there is no improvement in score for
+ early_stopping_n_iters iterations.
+
+ :paramtype enable_early_termination: typing.Optional[bool]
+ :keyword exit_score: Target score for experiment. The experiment terminates after this score is reached.
+ If not specified (no criteria), the experiment runs until no further progress is made
+ on the primary metric. For for more information on exit criteria, see this `article
+ <https://learn.microsoft.com/azure/machine-learning/how-to-configure-auto-train#exit-criteria>`_
+ , defaults to None
+ :paramtype exit_score: typing.Optional[float]
+ :keyword max_concurrent_trials: This is the maximum number of iterations that would be executed in parallel.
+ The default value is 1.
+
+ * AmlCompute clusters support one iteration running per node. For multiple AutoML experiment parent runs
+ executed in parallel on a single AmlCompute cluster, the sum of the ``max_concurrent_trials`` values
+ for all experiments should be less than or equal to the maximum number of nodes. Otherwise, runs
+ will be queued until nodes are available.
+
+ * DSVM supports multiple iterations per node. ``max_concurrent_trials`` should
+ be less than or equal to the number of cores on the DSVM. For multiple experiments
+ run in parallel on a single DSVM, the sum of the ``max_concurrent_trials`` values for all
+ experiments should be less than or equal to the maximum number of nodes.
+
+ * Databricks - ``max_concurrent_trials`` should be less than or equal to the number of
+ worker nodes on Databricks.
+
+ ``max_concurrent_trials`` does not apply to local runs. Formerly, this parameter
+ was named ``concurrent_iterations``.
+ :paramtype max_concurrent_trials: typing.Optional[int]
+ :keyword max_cores_per_trial: The maximum number of threads to use for a given training iteration.
+ Acceptable values:
+
+ * Greater than 1 and less than or equal to the maximum number of cores on the compute target.
+
+ * Equal to -1, which means to use all the possible cores per iteration per child-run.
+
+ * Equal to 1, the default.
+
+ :paramtype max_cores_per_trial: typing.Optional[int]
+ :keyword max_nodes: [Experimental] The maximum number of nodes to use for distributed training.
+
+ * For forecasting, each model is trained using max(2, int(max_nodes / max_concurrent_trials)) nodes.
+
+ * For classification/regression, each model is trained using max_nodes nodes.
+
+ Note- This parameter is in public preview and might change in future.
+ :paramtype max_nodes: typing.Optional[int]
+ :keyword max_trials: The total number of different algorithm and parameter combinations to test during an
+ automated ML experiment. If not specified, the default is 1000 iterations.
+ :paramtype max_trials: typing.Optional[int]
+ :keyword timeout_minutes: Maximum amount of time in minutes that all iterations combined can take before the
+ experiment terminates. If not specified, the default experiment timeout is 6 days. To specify a timeout
+ less than or equal to 1 hour, make sure your dataset's size is not greater than
+ 10,000,000 (rows times column) or an error results, defaults to None
+ :paramtype timeout_minutes: typing.Optional[int]
+ :keyword trial_timeout_minutes: Maximum time in minutes that each iteration can run for before it terminates.
+ If not specified, a value of 1 month or 43200 minutes is used, defaults to None
+ :paramtype trial_timeout_minutes: typing.Optional[int]
+ """
+ self._limits = self._limits or TabularLimitSettings()
+ self._limits.enable_early_termination = (
+ enable_early_termination if enable_early_termination is not None else self._limits.enable_early_termination
+ )
+ self._limits.exit_score = exit_score if exit_score is not None else self._limits.exit_score
+ self._limits.max_concurrent_trials = (
+ max_concurrent_trials if max_concurrent_trials is not None else self._limits.max_concurrent_trials
+ )
+ self._limits.max_cores_per_trial = (
+ max_cores_per_trial if max_cores_per_trial is not None else self._limits.max_cores_per_trial
+ )
+ self._limits.max_nodes = max_nodes if max_nodes is not None else self._limits.max_nodes
+ self._limits.max_trials = max_trials if max_trials is not None else self._limits.max_trials
+ self._limits.timeout_minutes = timeout_minutes if timeout_minutes is not None else self._limits.timeout_minutes
+ self._limits.trial_timeout_minutes = (
+ trial_timeout_minutes if trial_timeout_minutes is not None else self._limits.trial_timeout_minutes
+ )
+
+ def set_training(
+ self,
+ *,
+ enable_onnx_compatible_models: Optional[bool] = None,
+ enable_dnn_training: Optional[bool] = None,
+ enable_model_explainability: Optional[bool] = None,
+ enable_stack_ensemble: Optional[bool] = None,
+ enable_vote_ensemble: Optional[bool] = None,
+ stack_ensemble_settings: Optional[StackEnsembleSettings] = None,
+ ensemble_model_download_timeout: Optional[int] = None,
+ allowed_training_algorithms: Optional[List[str]] = None,
+ blocked_training_algorithms: Optional[List[str]] = None,
+ training_mode: Optional[Union[str, TabularTrainingMode]] = None,
+ ) -> None:
+ """The method to configure training related settings.
+
+ :keyword enable_onnx_compatible_models: Whether to enable or disable enforcing the ONNX-compatible models.
+ The default is False. For more information about Open Neural Network Exchange (ONNX) and Azure Machine
+ Learning,see this `article <https://learn.microsoft.com/azure/machine-learning/concept-onnx>`__.
+ :paramtype enable_onnx_compatible_models: typing.Optional[bool]
+ :keyword enable_dnn_training: Whether to include DNN based models during model selection.
+ However, the default is True for DNN NLP tasks, and it's False for all other AutoML tasks.
+ :paramtype enable_dnn_training: typing.Optional[bool]
+ :keyword enable_model_explainability: Whether to enable explaining the best AutoML model at the end of all
+ AutoML training iterations. For more information, see
+ `Interpretability: model explanations in automated machine learning
+ <https://learn.microsoft.com/azure/machine-learning/how-to-machine-learning-interpretability-automl>`__.
+ , defaults to None
+ :paramtype enable_model_explainability: typing.Optional[bool]
+ :keyword enable_stack_ensemble: Whether to enable/disable StackEnsemble iteration.
+ If `enable_onnx_compatible_models` flag is being set, then StackEnsemble iteration will be disabled.
+ Similarly, for Timeseries tasks, StackEnsemble iteration will be disabled by default, to avoid risks of
+ overfitting due to small training set used in fitting the meta learner.
+ For more information about ensembles, see `Ensemble configuration
+ <https://learn.microsoft.com/azure/machine-learning/how-to-configure-auto-train#ensemble>`__
+ , defaults to None
+ :paramtype enable_stack_ensemble: typing.Optional[bool]
+ :keyword enable_vote_ensemble: Whether to enable/disable VotingEnsemble iteration.
+ For more information about ensembles, see `Ensemble configuration
+ <https://learn.microsoft.com/azure/machine-learning/how-to-configure-auto-train#ensemble>`__
+ , defaults to None
+ :paramtype enable_vote_ensemble: typing.Optional[bool]
+ :keyword stack_ensemble_settings: Settings for StackEnsemble iteration, defaults to None
+ :paramtype stack_ensemble_settings: typing.Optional[StackEnsembleSettings]
+ :keyword ensemble_model_download_timeout: During VotingEnsemble and StackEnsemble model generation,
+ multiple fitted models from the previous child runs are downloaded. Configure this parameter with a
+ higher value than 300 secs, if more time is needed, defaults to None
+ :paramtype ensemble_model_download_timeout: typing.Optional[int]
+ :keyword allowed_training_algorithms: A list of model names to search for an experiment. If not specified,
+ then all models supported for the task are used minus any specified in ``blocked_training_algorithms``
+ or deprecated TensorFlow models, defaults to None
+ :paramtype allowed_training_algorithms: typing.Optional[List[str]]
+ :keyword blocked_training_algorithms: A list of algorithms to ignore for an experiment, defaults to None
+ :paramtype blocked_training_algorithms: typing.Optional[List[str]]
+ :keyword training_mode: [Experimental] The training mode to use.
+ The possible values are-
+
+ * distributed- enables distributed training for supported algorithms.
+
+ * non_distributed- disables distributed training.
+
+ * auto- Currently, it is same as non_distributed. In future, this might change.
+
+ Note: This parameter is in public preview and may change in future.
+ :paramtype training_mode: typing.Optional[typing.Union[str, azure.ai.ml.constants.TabularTrainingMode]]
+ """
+ # get training object by calling training getter of respective tabular task
+ self._training = self.training
+ if self._training is not None:
+ self._training.enable_onnx_compatible_models = (
+ enable_onnx_compatible_models
+ if enable_onnx_compatible_models is not None
+ else self._training.enable_onnx_compatible_models
+ )
+ self._training.enable_dnn_training = (
+ enable_dnn_training if enable_dnn_training is not None else self._training.enable_dnn_training
+ )
+ self._training.enable_model_explainability = (
+ enable_model_explainability
+ if enable_model_explainability is not None
+ else self._training.enable_model_explainability
+ )
+ self._training.enable_stack_ensemble = (
+ enable_stack_ensemble if enable_stack_ensemble is not None else self._training.enable_stack_ensemble
+ )
+ self._training.enable_vote_ensemble = (
+ enable_vote_ensemble if enable_vote_ensemble is not None else self._training.enable_vote_ensemble
+ )
+ self._training.stack_ensemble_settings = (
+ stack_ensemble_settings
+ if stack_ensemble_settings is not None
+ else self._training.stack_ensemble_settings
+ )
+ self._training.ensemble_model_download_timeout = (
+ ensemble_model_download_timeout
+ if ensemble_model_download_timeout is not None
+ else self._training.ensemble_model_download_timeout
+ )
+
+ self._training.allowed_training_algorithms = allowed_training_algorithms
+ self._training.blocked_training_algorithms = blocked_training_algorithms
+ self._training.training_mode = training_mode if training_mode is not None else self._training.training_mode
+
+ def set_featurization(
+ self,
+ *,
+ blocked_transformers: Optional[List[Union[BlockedTransformers, str]]] = None,
+ column_name_and_types: Optional[Dict[str, str]] = None,
+ dataset_language: Optional[str] = None,
+ transformer_params: Optional[Dict[str, List[ColumnTransformer]]] = None,
+ mode: Optional[str] = None,
+ enable_dnn_featurization: Optional[bool] = None,
+ ) -> None:
+ """Define feature engineering configuration.
+
+ :keyword blocked_transformers: A list of transformer names to be blocked during featurization, defaults to None
+ :paramtype blocked_transformers: Optional[List[Union[BlockedTransformers, str]]]
+ :keyword column_name_and_types: A dictionary of column names and feature types used to update column purpose
+ , defaults to None
+ :paramtype column_name_and_types: Optional[Dict[str, str]]
+ :keyword dataset_language: Three character ISO 639-3 code for the language(s) contained in the dataset.
+ Languages other than English are only supported if you use GPU-enabled compute. The language_code
+ 'mul' should be used if the dataset contains multiple languages. To find ISO 639-3 codes for different
+ languages, please refer to https://en.wikipedia.org/wiki/List_of_ISO_639-3_codes, defaults to None
+ :paramtype dataset_language: Optional[str]
+ :keyword transformer_params: A dictionary of transformer and corresponding customization parameters
+ , defaults to None
+ :paramtype transformer_params: Optional[Dict[str, List[ColumnTransformer]]]
+ :keyword mode: "off", "auto", defaults to "auto", defaults to None
+ :paramtype mode: Optional[str]
+ :keyword enable_dnn_featurization: Whether to include DNN based feature engineering methods, defaults to None
+ :paramtype enable_dnn_featurization: Optional[bool]
+ """
+ self._featurization = self._featurization or TabularFeaturizationSettings()
+ self._featurization.blocked_transformers = (
+ blocked_transformers if blocked_transformers is not None else self._featurization.blocked_transformers
+ )
+ self._featurization.column_name_and_types = (
+ column_name_and_types if column_name_and_types is not None else self._featurization.column_name_and_types
+ )
+ self._featurization.dataset_language = (
+ dataset_language if dataset_language is not None else self._featurization.dataset_language
+ )
+ self._featurization.transformer_params = (
+ transformer_params if transformer_params is not None else self._featurization.transformer_params
+ )
+ self._featurization.mode = mode or self._featurization.mode
+ self._featurization.enable_dnn_featurization = (
+ enable_dnn_featurization
+ if enable_dnn_featurization is not None
+ else self._featurization.enable_dnn_featurization
+ )
+
+ def set_data(
+ self,
+ *,
+ training_data: Input,
+ target_column_name: str,
+ weight_column_name: Optional[str] = None,
+ validation_data: Optional[Input] = None,
+ validation_data_size: Optional[float] = None,
+ n_cross_validations: Optional[Union[str, int]] = None,
+ cv_split_column_names: Optional[List[str]] = None,
+ test_data: Optional[Input] = None,
+ test_data_size: Optional[float] = None,
+ ) -> None:
+ """Define data configuration.
+
+ :keyword training_data: Training data.
+ :paramtype training_data: Input
+ :keyword target_column_name: Column name of the target column.
+ :paramtype target_column_name: str
+ :keyword weight_column_name: Weight column name, defaults to None
+ :paramtype weight_column_name: typing.Optional[str]
+ :keyword validation_data: Validation data, defaults to None
+ :paramtype validation_data: typing.Optional[Input]
+ :keyword validation_data_size: Validation data size, defaults to None
+ :paramtype validation_data_size: typing.Optional[float]
+ :keyword n_cross_validations: n_cross_validations, defaults to None
+ :paramtype n_cross_validations: typing.Optional[typing.Union[str, int]]
+ :keyword cv_split_column_names: cv_split_column_names, defaults to None
+ :paramtype cv_split_column_names: typing.Optional[typing.List[str]]
+ :keyword test_data: Test data, defaults to None
+ :paramtype test_data: typing.Optional[Input]
+ :keyword test_data_size: Test data size, defaults to None
+ :paramtype test_data_size: typing.Optional[float]
+ """
+ self.target_column_name = target_column_name if target_column_name is not None else self.target_column_name
+ self.weight_column_name = weight_column_name if weight_column_name is not None else self.weight_column_name
+ self.training_data = training_data if training_data is not None else self.training_data
+ self.validation_data = validation_data if validation_data is not None else self.validation_data
+ self.validation_data_size = (
+ validation_data_size if validation_data_size is not None else self.validation_data_size
+ )
+ self.cv_split_column_names = (
+ cv_split_column_names if cv_split_column_names is not None else self.cv_split_column_names
+ )
+ self.n_cross_validations = n_cross_validations if n_cross_validations is not None else self.n_cross_validations
+ self.test_data = test_data if test_data is not None else self.test_data
+ self.test_data_size = test_data_size if test_data_size is not None else self.test_data_size
+
+ def _validation_data_to_rest(self, rest_obj: "AutoMLTabular") -> None:
+ """Validation data serialization.
+
+ :param rest_obj: Serialized object
+ :type rest_obj: AutoMLTabular
+ """
+ if rest_obj.n_cross_validations:
+ n_cross_val = rest_obj.n_cross_validations
+ # Convert n_cross_validations int value to CustomNCrossValidations
+ if isinstance(n_cross_val, int) and n_cross_val > 1:
+ rest_obj.n_cross_validations = CustomNCrossValidations(value=n_cross_val)
+ # Convert n_cross_validations str value to AutoNCrossValidations
+ elif isinstance(n_cross_val, str):
+ rest_obj.n_cross_validations = AutoNCrossValidations()
+
+ def _validation_data_from_rest(self) -> None:
+ """Validation data deserialization."""
+ if self.n_cross_validations:
+ n_cross_val = self.n_cross_validations
+ # Convert n_cross_validations CustomNCrossValidations back into int value
+ if isinstance(n_cross_val, CustomNCrossValidations):
+ self.n_cross_validations = n_cross_val.value
+ # Convert n_cross_validations AutoNCrossValidations to str value
+ elif isinstance(n_cross_val, AutoNCrossValidations):
+ self.n_cross_validations = AutoMLConstants.AUTO
+
+ def __eq__(self, other: object) -> bool:
+ """Return True if both instances have the same values.
+
+ This method check instances equality and returns True if both of
+ the instances have the same attributes with the same values.
+
+ :param other: Any object
+ :type other: object
+ :return: True or False
+ :rtype: bool
+ """
+ if not isinstance(other, AutoMLTabular):
+ return NotImplemented
+
+ return (
+ self.target_column_name == other.target_column_name
+ and self.weight_column_name == other.weight_column_name
+ and self.training_data == other.training_data
+ and self.validation_data == other.validation_data
+ and self.validation_data_size == other.validation_data_size
+ and self.cv_split_column_names == other.cv_split_column_names
+ and self.n_cross_validations == other.n_cross_validations
+ and self.test_data == other.test_data
+ and self.test_data_size == other.test_data_size
+ and self._featurization == other._featurization
+ and self._limits == other._limits
+ and self._training == other._training
+ )
+
+ def __ne__(self, other: object) -> bool:
+ """Check inequality between two AutoMLTabular objects.
+
+ :param other: Any object
+ :type other: object
+ :return: True or False
+ :rtype: bool
+ """
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/classification_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/classification_job.py
new file mode 100644
index 00000000..6f5ab271
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/classification_job.py
@@ -0,0 +1,352 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+from typing import Any, Dict, Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import AutoMLJob as RestAutoMLJob
+from azure.ai.ml._restclient.v2023_04_01_preview.models import Classification as RestClassification
+from azure.ai.ml._restclient.v2023_04_01_preview.models import ClassificationPrimaryMetrics, JobBase, TaskType
+from azure.ai.ml._utils.utils import camel_to_snake, is_data_binding_expression
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+from azure.ai.ml.constants._job.automl import AutoMLConstants
+from azure.ai.ml.entities._credentials import _BaseJobIdentityConfiguration
+from azure.ai.ml.entities._job._input_output_helpers import from_rest_data_outputs, to_rest_data_outputs
+from azure.ai.ml.entities._job.automl.tabular.automl_tabular import AutoMLTabular
+from azure.ai.ml.entities._job.automl.tabular.featurization_settings import TabularFeaturizationSettings
+from azure.ai.ml.entities._job.automl.tabular.limit_settings import TabularLimitSettings
+from azure.ai.ml.entities._job.automl.training_settings import ( # noqa: F401 # pylint: disable=unused-import
+ ClassificationTrainingSettings,
+ TrainingSettings,
+)
+from azure.ai.ml.entities._util import load_from_dict
+
+
+class ClassificationJob(AutoMLTabular):
+ """Configuration for AutoML Classification Job.
+
+ :keyword primary_metric: The primary metric to use for optimization, defaults to None
+ :paramtype primary_metric: typing.Optional[str]
+ :keyword positive_label: Positive label for binary metrics calculation, defaults to None
+ :paramtype positive_label: typing.Optional[str]
+ :keyword featurization: Featurization settings. Defaults to None.
+ :paramtype featurization: typing.Optional[TabularFeaturizationSettings]
+ :keyword limits: Limits settings. Defaults to None.
+ :paramtype limits: typing.Optional[TabularLimitSettings]
+ :keyword training: Training settings. Defaults to None.
+ :paramtype training: typing.Optional[TrainingSettings]
+ :return: An instance of ClassificationJob object.
+ :rtype: ~azure.ai.ml.entities.automl.ClassificationJob
+ :raises ValueError: If primary_metric is not a valid primary metric
+ :raises ValueError: If positive_label is not a valid positive label
+ :raises ValueError: If featurization is not a valid featurization settings
+ :raises ValueError: If limits is not a valid limits settings
+ :raises ValueError: If training is not a valid training settings
+ """
+
+ _DEFAULT_PRIMARY_METRIC = ClassificationPrimaryMetrics.ACCURACY
+
+ def __init__(
+ self,
+ *,
+ primary_metric: Optional[str] = None,
+ positive_label: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ """Initialize a new AutoML Classification task.
+
+ :keyword primary_metric: The primary metric to use for optimization, defaults to None
+ :paramtype primary_metric: typing.Optional[str]
+ :keyword positive_label: Positive label for binary metrics calculation, defaults to None
+ :paramtype positive_label: typing.Optional[str]
+ :keyword featurization: featurization settings. Defaults to None.
+ :paramtype featurization: typing.Optional[TabularFeaturizationSettings]
+ :keyword limits: limits settings. Defaults to None.
+ :paramtype limits: typing.Optional[TabularLimitSettings]
+ :keyword training: training settings. Defaults to None.
+ :paramtype training: typing.Optional[TrainingSettings]
+ :raises ValueError: If primary_metric is not a valid primary metric
+ :raises ValueError: If positive_label is not a valid positive label
+ :raises ValueError: If featurization is not a valid featurization settings
+ :raises ValueError: If limits is not a valid limits settings
+ :raises ValueError: If training is not a valid training settings
+ """
+ # Extract any task specific settings
+ featurization = kwargs.pop("featurization", None)
+ limits = kwargs.pop("limits", None)
+ training = kwargs.pop("training", None)
+
+ super().__init__(
+ task_type=TaskType.CLASSIFICATION,
+ featurization=featurization,
+ limits=limits,
+ training=training,
+ **kwargs,
+ )
+
+ self.primary_metric = primary_metric or ClassificationJob._DEFAULT_PRIMARY_METRIC
+ self.positive_label = positive_label
+
+ @property
+ def primary_metric(self) -> Union[str, ClassificationPrimaryMetrics]:
+ """The primary metric to use for optimization.
+
+ :return: The primary metric to use for optimization.
+ :rtype: typing.Union[str, ClassificationPrimaryMetrics]
+ """
+ return self._primary_metric
+
+ @primary_metric.setter
+ def primary_metric(self, value: Union[str, ClassificationPrimaryMetrics]) -> None:
+ """The primary metric to use for optimization setter.
+
+ :param value: Primary metric to use for optimization.
+ :type value: typing.Union[str, ClassificationPrimaryMetrics]
+ """
+ # TODO: better way to do this
+ if is_data_binding_expression(str(value), ["parent"]):
+ self._primary_metric = value
+ return
+ self._primary_metric = (
+ ClassificationJob._DEFAULT_PRIMARY_METRIC
+ if value is None
+ else ClassificationPrimaryMetrics[camel_to_snake(value).upper()]
+ )
+
+ @property # type: ignore
+ def training(self) -> ClassificationTrainingSettings:
+ """Training Settings for AutoML Classification Job.
+
+ :return: Training settings used for AutoML Classification Job.
+ :rtype: ClassificationTrainingSettings
+ """
+ return self._training or ClassificationTrainingSettings()
+
+ @training.setter
+ def training(self, value: Union[Dict, ClassificationTrainingSettings]) -> None: # pylint: disable=unused-argument
+ ...
+
+ def _to_rest_object(self) -> JobBase:
+ """Convert ClassificationJob object to a REST object.
+
+ :return: REST object representation of this object.
+ :rtype: JobBase
+ """
+ classification_task = RestClassification(
+ target_column_name=self.target_column_name,
+ training_data=self.training_data,
+ validation_data=self.validation_data,
+ validation_data_size=self.validation_data_size,
+ weight_column_name=self.weight_column_name,
+ cv_split_column_names=self.cv_split_column_names,
+ n_cross_validations=self.n_cross_validations,
+ test_data=self.test_data,
+ test_data_size=self.test_data_size,
+ featurization_settings=self._featurization._to_rest_object() if self._featurization else None,
+ limit_settings=self._limits._to_rest_object() if self._limits else None,
+ training_settings=self._training._to_rest_object() if self._training else None,
+ primary_metric=self.primary_metric,
+ positive_label=self.positive_label,
+ log_verbosity=self.log_verbosity,
+ )
+ self._resolve_data_inputs(classification_task)
+ self._validation_data_to_rest(classification_task)
+
+ properties = RestAutoMLJob(
+ display_name=self.display_name,
+ description=self.description,
+ experiment_name=self.experiment_name,
+ tags=self.tags,
+ compute_id=self.compute,
+ properties=self.properties,
+ environment_id=self.environment_id,
+ environment_variables=self.environment_variables,
+ services=self.services,
+ outputs=to_rest_data_outputs(self.outputs),
+ resources=self.resources,
+ task_details=classification_task,
+ identity=self.identity._to_job_rest_object() if self.identity else None,
+ queue_settings=self.queue_settings,
+ )
+
+ result = JobBase(properties=properties)
+ result.name = self.name
+ return result
+
+ @classmethod
+ def _from_rest_object(cls, obj: JobBase) -> "ClassificationJob":
+ """Convert a REST object to ClassificationJob object.
+
+ :param obj: ClassificationJob in Rest format.
+ :type obj: JobBase
+ :return: ClassificationJob objects.
+ :rtype: ClassificationJob
+ """
+
+ properties: RestAutoMLJob = obj.properties
+ task_details: RestClassification = properties.task_details
+
+ job_args_dict = {
+ "id": obj.id,
+ "name": obj.name,
+ "description": properties.description,
+ "tags": properties.tags,
+ "properties": properties.properties,
+ "experiment_name": properties.experiment_name,
+ "services": properties.services,
+ "status": properties.status,
+ "creation_context": obj.system_data,
+ "display_name": properties.display_name,
+ "compute": properties.compute_id,
+ "outputs": from_rest_data_outputs(properties.outputs),
+ "resources": properties.resources,
+ "identity": (
+ _BaseJobIdentityConfiguration._from_rest_object(properties.identity) if properties.identity else None
+ ),
+ "queue_settings": properties.queue_settings,
+ }
+
+ classification_job = cls(
+ target_column_name=task_details.target_column_name,
+ training_data=task_details.training_data,
+ validation_data=task_details.validation_data,
+ validation_data_size=task_details.validation_data_size,
+ weight_column_name=task_details.weight_column_name,
+ cv_split_column_names=task_details.cv_split_column_names,
+ n_cross_validations=task_details.n_cross_validations,
+ test_data=task_details.test_data,
+ test_data_size=task_details.test_data_size,
+ featurization=(
+ TabularFeaturizationSettings._from_rest_object(task_details.featurization_settings)
+ if task_details.featurization_settings
+ else None
+ ),
+ limits=(
+ TabularLimitSettings._from_rest_object(task_details.limit_settings)
+ if task_details.limit_settings
+ else None
+ ),
+ training=(
+ ClassificationTrainingSettings._from_rest_object(task_details.training_settings)
+ if task_details.training_settings
+ else None
+ ),
+ primary_metric=task_details.primary_metric,
+ positive_label=task_details.positive_label,
+ log_verbosity=task_details.log_verbosity,
+ **job_args_dict,
+ )
+
+ classification_job._restore_data_inputs()
+ classification_job._validation_data_from_rest()
+
+ return classification_job
+
+ @classmethod
+ def _load_from_dict(
+ cls,
+ data: Dict,
+ context: Dict,
+ additional_message: str,
+ **kwargs: Any,
+ ) -> "ClassificationJob":
+ """Load from a dictionary.
+
+ :param data: dictionary representation of the object.
+ :type data: typing.Dict
+ :param context: dictionary containing the context.
+ :type context: typing.Dict
+ :param additional_message: additional message to be added to the error message.
+ :type additional_message: str
+ :return: ClassificationJob object.
+ :rtype: ClassificationJob
+ """
+ from azure.ai.ml._schema.automl.table_vertical.classification import AutoMLClassificationSchema
+ from azure.ai.ml._schema.pipeline.automl_node import AutoMLClassificationNodeSchema
+
+ if kwargs.pop("inside_pipeline", False):
+ loaded_data = load_from_dict(
+ AutoMLClassificationNodeSchema,
+ data,
+ context,
+ additional_message,
+ **kwargs,
+ )
+ else:
+ loaded_data = load_from_dict(AutoMLClassificationSchema, data, context, additional_message, **kwargs)
+ job_instance = cls._create_instance_from_schema_dict(loaded_data)
+ return job_instance
+
+ @classmethod
+ def _create_instance_from_schema_dict(cls, loaded_data: Dict) -> "ClassificationJob":
+ """Create an instance from a schema dictionary.
+
+ :param loaded_data: dictionary containing the data.
+ :type loaded_data: typing.Dict
+ :return: ClassificationJob object.
+ :rtype: ClassificationJob
+ """
+ loaded_data.pop(AutoMLConstants.TASK_TYPE_YAML, None)
+ data_settings = {
+ "training_data": loaded_data.pop("training_data"),
+ "target_column_name": loaded_data.pop("target_column_name"),
+ "weight_column_name": loaded_data.pop("weight_column_name", None),
+ "validation_data": loaded_data.pop("validation_data", None),
+ "validation_data_size": loaded_data.pop("validation_data_size", None),
+ "cv_split_column_names": loaded_data.pop("cv_split_column_names", None),
+ "n_cross_validations": loaded_data.pop("n_cross_validations", None),
+ "test_data": loaded_data.pop("test_data", None),
+ "test_data_size": loaded_data.pop("test_data_size", None),
+ }
+ job = ClassificationJob(**loaded_data)
+ job.set_data(**data_settings)
+ return job
+
+ def _to_dict(self, inside_pipeline: bool = False) -> Dict:
+ """Convert the object to a dictionary.
+
+ :param inside_pipeline: whether the job is inside a pipeline or not, defaults to False
+ :type inside_pipeline: bool
+ :return: dictionary representation of the object.
+ :rtype: typing.Dict
+ """
+ from azure.ai.ml._schema.automl.table_vertical.classification import AutoMLClassificationSchema
+ from azure.ai.ml._schema.pipeline.automl_node import AutoMLClassificationNodeSchema
+
+ schema_dict: dict = {}
+ if inside_pipeline:
+ schema_dict = AutoMLClassificationNodeSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ else:
+ schema_dict = AutoMLClassificationSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+
+ return schema_dict
+
+ def __eq__(self, other: object) -> bool:
+ """Returns True if both instances have the same values.
+
+ This method check instances equality and returns True if both of
+ the instances have the same attributes with the same values.
+
+ :param other: Any object
+ :type other: object
+ :return: True or False
+ :rtype: bool
+ """
+ if not isinstance(other, ClassificationJob):
+ return NotImplemented
+
+ if not super().__eq__(other):
+ return False
+
+ return self.primary_metric == other.primary_metric
+
+ def __ne__(self, other: object) -> bool:
+ """Check inequality between two ImageLimitSettings objects.
+
+ :param other: Any object
+ :type other: object
+ :return: True or False
+ :rtype: bool
+ """
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/featurization_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/featurization_settings.py
new file mode 100644
index 00000000..6ef2332e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/featurization_settings.py
@@ -0,0 +1,170 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import logging
+from typing import Dict, List, Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import BlockedTransformers
+from azure.ai.ml._restclient.v2023_04_01_preview.models import ColumnTransformer as RestColumnTransformer
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ TableVerticalFeaturizationSettings as RestTabularFeaturizationSettings,
+)
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.constants._job.automl import AutoMLTransformerParameterKeys
+from azure.ai.ml.entities._job.automl.featurization_settings import FeaturizationSettings, FeaturizationSettingsType
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+module_logger = logging.getLogger(__name__)
+
+
+class ColumnTransformer(RestTranslatableMixin):
+ """Column transformer settings.
+
+ :param fields: The fields on which to perform custom featurization
+ :type field: List[str]
+ :param parameters: parameters used for custom featurization
+ :type parameters: Dict[str, Optional[str, float]]
+ """
+
+ def __init__(
+ self,
+ *,
+ fields: Optional[List[str]] = None,
+ parameters: Optional[Dict[str, Union[str, float]]] = None,
+ ):
+ self.fields = fields
+ self.parameters = parameters
+
+ def _to_rest_object(self) -> RestColumnTransformer:
+ return RestColumnTransformer(fields=self.fields, parameters=self.parameters)
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestColumnTransformer) -> Optional["ColumnTransformer"]:
+ if obj:
+ fields = obj.fields
+ parameters = obj.parameters
+ return ColumnTransformer(fields=fields, parameters=parameters)
+ return None
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, ColumnTransformer):
+ return NotImplemented
+ return self.fields == other.fields and self.parameters == other.parameters
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
+
+
+class TabularFeaturizationSettings(FeaturizationSettings):
+ """Featurization settings for an AutoML Job."""
+
+ def __init__(
+ self,
+ *,
+ blocked_transformers: Optional[List[Union[BlockedTransformers, str]]] = None,
+ column_name_and_types: Optional[Dict[str, str]] = None,
+ dataset_language: Optional[str] = None,
+ transformer_params: Optional[Dict[str, List[ColumnTransformer]]] = None,
+ mode: Optional[str] = None,
+ enable_dnn_featurization: Optional[bool] = None,
+ ):
+ """
+ :param blocked_transformers: A list of transformers to ignore when featurizing.
+ :type blocked_transformers: List[Union[BlockedTransformers, str]]
+ :param column_name_and_types: A dictionary of column names and feature types used to update column purpose.
+ :type column_name_and_types: Dict[str, str]
+ :param dataset_language: The language of the dataset.
+ :type dataset_language: str
+ :param transformer_params: A dictionary of transformers and their parameters.
+ :type transformer_params: Dict[str, List[ColumnTransformer]]
+ :param mode: The mode of the featurization.
+ :type mode: str
+ :param enable_dnn_featurization: Whether to enable DNN featurization.
+ :type enable_dnn_featurization: bool
+ :ivar type: Specifies the type of FeaturizationSettings. Set automatically to "Tabular" for this class.
+ :vartype type: str
+ """
+ super().__init__(dataset_language=dataset_language)
+ self.blocked_transformers = blocked_transformers
+ self.column_name_and_types = column_name_and_types
+ self.transformer_params = transformer_params
+ self.mode = mode
+ self.enable_dnn_featurization = enable_dnn_featurization
+ self.type = FeaturizationSettingsType.TABULAR
+
+ @property
+ def transformer_params(self) -> Optional[Dict[str, List[ColumnTransformer]]]:
+ """A dictionary of transformers and their parameters."""
+ return self._transformer_params
+
+ @transformer_params.setter
+ def transformer_params(self, value: Dict[str, List[ColumnTransformer]]) -> None:
+ self._transformer_params = (
+ None
+ if not value
+ else {(AutoMLTransformerParameterKeys[camel_to_snake(k).upper()].value): v for k, v in value.items()}
+ )
+
+ @property
+ def blocked_transformers(self) -> Optional[List[Union[BlockedTransformers, str]]]:
+ """A list of transformers to ignore when featurizing."""
+ return self._blocked_transformers
+
+ @blocked_transformers.setter
+ def blocked_transformers(self, blocked_transformers_list: List[Union[BlockedTransformers, str]]) -> None:
+ self._blocked_transformers = (
+ None
+ if blocked_transformers_list is None
+ else [BlockedTransformers[camel_to_snake(o)] for o in blocked_transformers_list]
+ )
+
+ def _to_rest_object(self) -> RestTabularFeaturizationSettings:
+ transformer_dict = {}
+ if self.transformer_params:
+ for key, settings in self.transformer_params.items():
+ transformer_dict[key] = [o._to_rest_object() for o in settings]
+ return RestTabularFeaturizationSettings(
+ blocked_transformers=self.blocked_transformers,
+ column_name_and_types=self.column_name_and_types,
+ dataset_language=self.dataset_language,
+ mode=self.mode,
+ transformer_params=transformer_dict,
+ enable_dnn_featurization=self.enable_dnn_featurization,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestTabularFeaturizationSettings) -> "TabularFeaturizationSettings":
+ rest_transformers_params = obj.transformer_params
+ transformer_dict: Optional[Dict] = None
+ if rest_transformers_params:
+ transformer_dict = {}
+ for key, settings in rest_transformers_params.items():
+ transformer_dict[key] = [ColumnTransformer._from_rest_object(o) for o in settings]
+ transformer_params = transformer_dict
+
+ return TabularFeaturizationSettings(
+ blocked_transformers=obj.blocked_transformers,
+ column_name_and_types=obj.column_name_and_types,
+ dataset_language=obj.dataset_language,
+ transformer_params=transformer_params,
+ mode=obj.mode,
+ enable_dnn_featurization=obj.enable_dnn_featurization,
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, TabularFeaturizationSettings):
+ return NotImplemented
+ return (
+ super().__eq__(other)
+ and self.blocked_transformers == other.blocked_transformers
+ and self.column_name_and_types == other.column_name_and_types
+ and self.transformer_params == other.transformer_params
+ and self.mode == other.mode
+ and self.enable_dnn_featurization == other.enable_dnn_featurization
+ )
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/forecasting_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/forecasting_job.py
new file mode 100644
index 00000000..9bd10b19
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/forecasting_job.py
@@ -0,0 +1,686 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from typing import Any, Dict, List, Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import AutoMLJob as RestAutoMLJob
+from azure.ai.ml._restclient.v2023_04_01_preview.models import Forecasting as RestForecasting
+from azure.ai.ml._restclient.v2023_04_01_preview.models import ForecastingPrimaryMetrics, JobBase, TaskType
+from azure.ai.ml._utils.utils import camel_to_snake, is_data_binding_expression
+from azure.ai.ml.constants import TabularTrainingMode
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+from azure.ai.ml.constants._job.automl import AutoMLConstants
+from azure.ai.ml.entities._credentials import _BaseJobIdentityConfiguration
+from azure.ai.ml.entities._job._input_output_helpers import from_rest_data_outputs, to_rest_data_outputs
+from azure.ai.ml.entities._job.automl.stack_ensemble_settings import StackEnsembleSettings
+from azure.ai.ml.entities._job.automl.tabular.automl_tabular import AutoMLTabular
+from azure.ai.ml.entities._job.automl.tabular.featurization_settings import TabularFeaturizationSettings
+from azure.ai.ml.entities._job.automl.tabular.forecasting_settings import ForecastingSettings
+from azure.ai.ml.entities._job.automl.tabular.limit_settings import TabularLimitSettings
+from azure.ai.ml.entities._job.automl.training_settings import ForecastingTrainingSettings
+from azure.ai.ml.entities._util import load_from_dict
+
+
+class ForecastingJob(AutoMLTabular):
+ """
+ Configuration for AutoML Forecasting Task.
+
+ :param primary_metric: The primary metric to use for model selection.
+ :type primary_metric: Optional[str]
+ :param forecasting_settings: The settings for the forecasting task.
+ :type forecasting_settings:
+ Optional[~azure.ai.ml.automl.ForecastingSettings]
+ :param kwargs: Job-specific arguments
+ :type kwargs: Dict[str, Any]
+ """
+
+ _DEFAULT_PRIMARY_METRIC = ForecastingPrimaryMetrics.NORMALIZED_ROOT_MEAN_SQUARED_ERROR
+
+ def __init__(
+ self,
+ *,
+ primary_metric: Optional[str] = None,
+ forecasting_settings: Optional[ForecastingSettings] = None,
+ **kwargs: Any,
+ ) -> None:
+ """Initialize a new AutoML Forecasting task."""
+ # Extract any task specific settings
+ featurization = kwargs.pop("featurization", None)
+ limits = kwargs.pop("limits", None)
+ training = kwargs.pop("training", None)
+
+ super().__init__(
+ task_type=TaskType.FORECASTING,
+ featurization=featurization,
+ limits=limits,
+ training=training,
+ **kwargs,
+ )
+
+ self.primary_metric = primary_metric or ForecastingJob._DEFAULT_PRIMARY_METRIC
+ self._forecasting_settings = forecasting_settings
+
+ @property
+ def primary_metric(self) -> Optional[str]:
+ """
+ Return the primary metric to use for model selection.
+
+ :return: The primary metric for model selection.
+ :rtype: Optional[str]
+ """
+ return self._primary_metric
+
+ @primary_metric.setter
+ def primary_metric(self, value: Union[str, ForecastingPrimaryMetrics]) -> None:
+ """
+ Set the primary metric to use for model selection.
+
+ :param value: The primary metric for model selection.
+ :type: Union[str, ~azure.ai.ml.automl.ForecastingPrimaryMetrics]
+ """
+ if is_data_binding_expression(str(value), ["parent"]):
+ self._primary_metric = value
+ return
+ self._primary_metric = (
+ ForecastingJob._DEFAULT_PRIMARY_METRIC
+ if value is None
+ else ForecastingPrimaryMetrics[camel_to_snake(value).upper()]
+ )
+
+ @property # type: ignore
+ def training(self) -> ForecastingTrainingSettings:
+ """
+ Return the forecast training settings.
+
+ :return: training settings.
+ :rtype: ~azure.ai.ml.automl.ForecastingTrainingSettings
+ """
+ return self._training or ForecastingTrainingSettings()
+
+ @training.setter
+ def training(self, value: Union[Dict, ForecastingTrainingSettings]) -> None: # pylint: disable=unused-argument
+ ...
+
+ @property
+ def forecasting_settings(self) -> Optional[ForecastingSettings]:
+ """
+ Return the forecast settings.
+
+ :return: forecast settings.
+ :rtype: ~azure.ai.ml.automl.ForecastingSettings
+ """
+ return self._forecasting_settings
+
+ def set_forecast_settings(
+ self,
+ *,
+ time_column_name: Optional[str] = None,
+ forecast_horizon: Optional[Union[str, int]] = None,
+ time_series_id_column_names: Optional[Union[str, List[str]]] = None,
+ target_lags: Optional[Union[str, int, List[int]]] = None,
+ feature_lags: Optional[str] = None,
+ target_rolling_window_size: Optional[Union[str, int]] = None,
+ country_or_region_for_holidays: Optional[str] = None,
+ use_stl: Optional[str] = None,
+ seasonality: Optional[Union[str, int]] = None,
+ short_series_handling_config: Optional[str] = None,
+ frequency: Optional[str] = None,
+ target_aggregate_function: Optional[str] = None,
+ cv_step_size: Optional[int] = None,
+ features_unknown_at_forecast_time: Optional[Union[str, List[str]]] = None,
+ ) -> None:
+ """Manage parameters used by forecasting tasks.
+
+ :keyword time_column_name:
+ The name of the time column. This parameter is required when forecasting to specify the datetime
+ column in the input data used for building the time series and inferring its frequency.
+ :paramtype time_column_name: Optional[str]
+ :keyword forecast_horizon:
+ The desired maximum forecast horizon in units of time-series frequency. The default value is 1.
+
+ Units are based on the time interval of your training data, e.g., monthly, weekly that the forecaster
+ should predict out. When task type is forecasting, this parameter is required. For more information on
+ setting forecasting parameters, see `Auto-train a time-series forecast model <https://learn.microsoft.com/
+ azure/machine-learning/how-to-auto-train-forecast>`_.
+ :type forecast_horizon: Optional[Union[int, str]]
+ :keyword time_series_id_column_names:
+ The names of columns used to group a time series.
+ It can be used to create multiple series. If time series id column names is not defined or
+ the identifier columns specified do not identify all the series in the dataset, the time series identifiers
+ will be automatically created for your data set.
+ :paramtype time_series_id_column_names: Optional[Union[str, List[str]]]
+ :keyword target_lags: The number of past periods to lag from the target column. By default the lags are turned
+ off.
+
+ When forecasting, this parameter represents the number of rows to lag the target values based
+ on the frequency of the data. This is represented as a list or single integer. Lag should be used
+ when the relationship between the independent variables and dependent variable do not match up or
+ correlate by default. For example, when trying to forecast demand for a product, the demand in any
+ month may depend on the price of specific commodities 3 months prior. In this example, you may want
+ to lag the target (demand) negatively by 3 months so that the model is training on the correct
+ relationship. For more information, see `Auto-train a time-series forecast model
+ <https://learn.microsoft.com/azure/machine-learning/how-to-auto-train-forecast>`_.
+
+ **Note on auto detection of target lags and rolling window size.
+ Please see the corresponding comments in the rolling window section.**
+ We use the next algorithm to detect the optimal target lag and rolling window size.
+
+ #. Estimate the maximum lag order for the look back feature selection. In our case it is the number of
+ periods till the next date frequency granularity i.e. if frequency is daily, it will be a week (7),
+ if it is a week, it will be month (4). That values multiplied by two is the largest
+ possible values of lags/rolling windows. In our examples, we will consider the maximum lag
+ order of 14 and 8 respectively).
+ #. Create a de-seasonalized series by adding trend and residual components. This will be used
+ in the next step.
+ #. Estimate the PACF - Partial Auto Correlation Function on the on the data from (2)
+ and search for points, where the auto correlation is significant i.e. its absolute
+ value is more then 1.96/square_root(maximal lag value), which correspond to significance of 95%.
+ #. If all points are significant, we consider it being strong seasonality
+ and do not create look back features.
+ #. We scan the PACF values from the beginning and the value before the first insignificant
+ auto correlation will designate the lag. If first significant element (value correlate with
+ itself) is followed by insignificant, the lag will be 0 and we will not use look back features.
+
+ :type target_lags: Optional[Union[str, int, List[int]]]
+ :keyword feature_lags: Flag for generating lags for the numeric features with 'auto' or None.
+ :paramtype feature_lags: Optional[str]
+ :keyword target_rolling_window_size: The number of past periods used to create a rolling window average of the
+ target column.
+
+ When forecasting, this parameter represents `n` historical periods to use to generate forecasted values,
+ <= training set size. If omitted, `n` is the full training set size. Specify this parameter
+ when you only want to consider a certain amount of history when training the model.
+ If set to 'auto', rolling window will be estimated as the last
+ value where the PACF is more then the significance threshold. Please see target_lags section for details.
+ :paramtype target_rolling_window_size: Optional[Union[str, int]]
+ :keyword country_or_region_for_holidays: The country/region used to generate holiday features.
+ These should be ISO 3166 two-letter country/region codes, for example 'US' or 'GB'.
+ :paramtype country_or_region_for_holidays: Optional[str]
+ :keyword use_stl: Configure STL Decomposition of the time-series target column.
+ use_stl can take three values: None (default) - no stl decomposition, 'season' - only generate
+ season component and season_trend - generate both season and trend components.
+ :type use_stl: Optional[str]
+ :keyword seasonality: Set time series seasonality as an integer multiple of the series frequency.
+ If seasonality is set to 'auto', it will be inferred.
+ If set to None, the time series is assumed non-seasonal which is equivalent to seasonality=1.
+ :paramtype seasonality: Optional[Union[int, str]
+ :keyword short_series_handling_config:
+ The parameter defining how if AutoML should handle short time series.
+
+ Possible values: 'auto' (default), 'pad', 'drop' and None.
+
+ * **auto** short series will be padded if there are no long series,
+ otherwise short series will be dropped.
+ * **pad** all the short series will be padded.
+ * **drop** all the short series will be dropped".
+ * **None** the short series will not be modified.
+
+ If set to 'pad', the table will be padded with the zeroes and
+ empty values for the regressors and random values for target with the mean
+ equal to target value median for given time series id. If median is more or equal
+ to zero, the minimal padded value will be clipped by zero:
+ Input:
+
+ +------------+---------------+----------+--------+
+ | Date | numeric_value | string | target |
+ +============+===============+==========+========+
+ | 2020-01-01 | 23 | green | 55 |
+ +------------+---------------+----------+--------+
+
+ Output assuming minimal number of values is four:
+
+ +------------+---------------+----------+--------+
+ | Date | numeric_value | string | target |
+ +============+===============+==========+========+
+ | 2019-12-29 | 0 | NA | 55.1 |
+ +------------+---------------+----------+--------+
+ | 2019-12-30 | 0 | NA | 55.6 |
+ +------------+---------------+----------+--------+
+ | 2019-12-31 | 0 | NA | 54.5 |
+ +------------+---------------+----------+--------+
+ | 2020-01-01 | 23 | green | 55 |
+ +------------+---------------+----------+--------+
+
+ **Note:** We have two parameters short_series_handling_configuration and
+ legacy short_series_handling. When both parameters are set we are
+ synchronize them as shown in the table below (short_series_handling_configuration and
+ short_series_handling for brevity are marked as handling_configuration and handling
+ respectively).
+
+ +------------+--------------------------+----------------------+-----------------------------+
+ | | handling | | handling | | resulting | | resulting |
+ | | | configuration | | handling | | handling |
+ | | | | | configuration |
+ +============+==========================+======================+=============================+
+ | True | auto | True | auto |
+ +------------+--------------------------+----------------------+-----------------------------+
+ | True | pad | True | auto |
+ +------------+--------------------------+----------------------+-----------------------------+
+ | True | drop | True | auto |
+ +------------+--------------------------+----------------------+-----------------------------+
+ | True | None | False | None |
+ +------------+--------------------------+----------------------+-----------------------------+
+ | False | auto | False | None |
+ +------------+--------------------------+----------------------+-----------------------------+
+ | False | pad | False | None |
+ +------------+--------------------------+----------------------+-----------------------------+
+ | False | drop | False | None |
+ +------------+--------------------------+----------------------+-----------------------------+
+ | False | None | False | None |
+ +------------+--------------------------+----------------------+-----------------------------+
+
+ :type short_series_handling_config: Optional[str]
+ :keyword frequency: Forecast frequency.
+
+ When forecasting, this parameter represents the period with which the forecast is desired,
+ for example daily, weekly, yearly, etc. The forecast frequency is dataset frequency by default.
+ You can optionally set it to greater (but not lesser) than dataset frequency.
+ We'll aggregate the data and generate the results at forecast frequency. For example,
+ for daily data, you can set the frequency to be daily, weekly or monthly, but not hourly.
+ The frequency needs to be a pandas offset alias.
+ Please refer to pandas documentation for more information:
+ https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#dateoffset-objects
+ :type frequency: Optional[str]
+ :keyword target_aggregate_function: The function to be used to aggregate the time series target
+ column to conform to a user specified frequency. If the target_aggregation_function is set,
+ but the freq parameter is not set, the error is raised. The possible target aggregation
+ functions are: "sum", "max", "min" and "mean".
+
+ * The target column values are aggregated based on the specified operation.
+ Typically, sum is appropriate for most scenarios.
+ * Numerical predictor columns in your data are aggregated by sum, mean, minimum value,
+ and maximum value. As a result, automated ML generates new columns suffixed with the
+ aggregation function name and applies the selected aggregate operation.
+ * For categorical predictor columns, the data is aggregated by mode,
+ the most prominent category in the window.
+ * Date predictor columns are aggregated by minimum value, maximum value and mode.
+
+ +----------------+-------------------------------+--------------------------------------+
+ | | freq | | target_aggregation_function | | Data regularity |
+ | | | | fixing mechanism |
+ +================+===============================+======================================+
+ | None (Default) | None (Default) | | The aggregation |
+ | | | | is not applied. |
+ | | | | If the valid |
+ | | | | frequency can |
+ | | | | not be |
+ | | | | determined |
+ | | | | the error |
+ | | | | will be raised. |
+ +----------------+-------------------------------+--------------------------------------+
+ | Some Value | None (Default) | | The aggregation |
+ | | | | is not applied. |
+ | | | | If the number |
+ | | | | of data points |
+ | | | | compliant to |
+ | | | | given frequency |
+ | | | | grid is |
+ | | | | less then 90% |
+ | | | | these points |
+ | | | | will be |
+ | | | | removed, |
+ | | | | otherwise |
+ | | | | the error will |
+ | | | | be raised. |
+ +----------------+-------------------------------+--------------------------------------+
+ | None (Default) | Aggregation function | | The error about |
+ | | | | missing |
+ | | | | frequency |
+ | | | | parameter is |
+ | | | | raised. |
+ +----------------+-------------------------------+--------------------------------------+
+ | Some Value | Aggregation function | | Aggregate to |
+ | | | | frequency using |
+ | | | | provided |
+ | | | | aggregation |
+ | | | | function. |
+ +----------------+-------------------------------+--------------------------------------+
+
+ :type target_aggregate_function: Optional[str]
+ :keyword cv_step_size: Number of periods between the origin_time of one CV fold and the next fold.
+ For example, if `n_step` = 3 for daily data, the origin time for each fold will be three days apart.
+ :paramtype cv_step_size: Optional[int]
+ :keyword features_unknown_at_forecast_time: The feature columns that are available for training but
+ unknown at the time of forecast/inference. If features_unknown_at_forecast_time is set to an empty
+ list, it is assumed that all the feature columns in the dataset are known at inference time. If this
+ parameter is not set the support for future features is not enabled.
+ :paramtype features_unknown_at_forecast_time: Optional[Union[str, List[str]]]
+ """
+ self._forecasting_settings = self._forecasting_settings or ForecastingSettings()
+
+ self._forecasting_settings.country_or_region_for_holidays = (
+ country_or_region_for_holidays
+ if country_or_region_for_holidays is not None
+ else self._forecasting_settings.country_or_region_for_holidays
+ )
+ self._forecasting_settings.cv_step_size = (
+ cv_step_size if cv_step_size is not None else self._forecasting_settings.cv_step_size
+ )
+ self._forecasting_settings.forecast_horizon = (
+ forecast_horizon if forecast_horizon is not None else self._forecasting_settings.forecast_horizon
+ )
+ self._forecasting_settings.target_lags = (
+ target_lags if target_lags is not None else self._forecasting_settings.target_lags
+ )
+ self._forecasting_settings.target_rolling_window_size = (
+ target_rolling_window_size
+ if target_rolling_window_size is not None
+ else self._forecasting_settings.target_rolling_window_size
+ )
+ self._forecasting_settings.frequency = (
+ frequency if frequency is not None else self._forecasting_settings.frequency
+ )
+ self._forecasting_settings.feature_lags = (
+ feature_lags if feature_lags is not None else self._forecasting_settings.feature_lags
+ )
+ self._forecasting_settings.seasonality = (
+ seasonality if seasonality is not None else self._forecasting_settings.seasonality
+ )
+ self._forecasting_settings.use_stl = use_stl if use_stl is not None else self._forecasting_settings.use_stl
+ self._forecasting_settings.short_series_handling_config = (
+ short_series_handling_config
+ if short_series_handling_config is not None
+ else self._forecasting_settings.short_series_handling_config
+ )
+ self._forecasting_settings.target_aggregate_function = (
+ target_aggregate_function
+ if target_aggregate_function is not None
+ else self._forecasting_settings.target_aggregate_function
+ )
+ self._forecasting_settings.time_column_name = (
+ time_column_name if time_column_name is not None else self._forecasting_settings.time_column_name
+ )
+ self._forecasting_settings.time_series_id_column_names = (
+ time_series_id_column_names
+ if time_series_id_column_names is not None
+ else self._forecasting_settings.time_series_id_column_names
+ )
+ self._forecasting_settings.features_unknown_at_forecast_time = (
+ features_unknown_at_forecast_time
+ if features_unknown_at_forecast_time is not None
+ else self._forecasting_settings.features_unknown_at_forecast_time
+ )
+
+ # override
+ def set_training(
+ self,
+ *,
+ enable_onnx_compatible_models: Optional[bool] = None,
+ enable_dnn_training: Optional[bool] = None,
+ enable_model_explainability: Optional[bool] = None,
+ enable_stack_ensemble: Optional[bool] = None,
+ enable_vote_ensemble: Optional[bool] = None,
+ stack_ensemble_settings: Optional[StackEnsembleSettings] = None,
+ ensemble_model_download_timeout: Optional[int] = None,
+ allowed_training_algorithms: Optional[List[str]] = None,
+ blocked_training_algorithms: Optional[List[str]] = None,
+ training_mode: Optional[Union[str, TabularTrainingMode]] = None,
+ ) -> None:
+ """
+ The method to configure forecast training related settings.
+
+ :keyword enable_onnx_compatible_models:
+ Whether to enable or disable enforcing the ONNX-compatible models.
+ The default is False. For more information about Open Neural Network Exchange (ONNX) and Azure Machine
+ Learning, see this `article <https://learn.microsoft.com/azure/machine-learning/concept-onnx>`__.
+ :type enable_onnx_compatible: Optional[bool]
+ :keyword enable_dnn_training:
+ Whether to include DNN based models during model selection.
+ However, the default is True for DNN NLP tasks, and it's False for all other AutoML tasks.
+ :paramtype enable_dnn_training: Optional[bool]
+ :keyword enable_model_explainability:
+ Whether to enable explaining the best AutoML model at the end of all AutoML training iterations.
+ For more information, see `Interpretability: model explanations in automated machine learning
+ <https://learn.microsoft.com/azure/machine-learning/how-to-machine-learning-interpretability-automl>`__.
+ , defaults to None
+ :type enable_model_explainability: Optional[bool]
+ :keyword enable_stack_ensemble:
+ Whether to enable/disable StackEnsemble iteration.
+ If `enable_onnx_compatible_models` flag is being set, then StackEnsemble iteration will be disabled.
+ Similarly, for Timeseries tasks, StackEnsemble iteration will be disabled by default, to avoid risks of
+ overfitting due to small training set used in fitting the meta learner.
+ For more information about ensembles, see `Ensemble configuration
+ <https://learn.microsoft.com/azure/machine-learning/how-to-configure-auto-train#ensemble>`__
+ , defaults to None
+ :type enable_stack_ensemble: Optional[bool]
+ :keyword enable_vote_ensemble:
+ Whether to enable/disable VotingEnsemble iteration.
+ For more information about ensembles, see `Ensemble configuration
+ <https://learn.microsoft.com/azure/machine-learning/how-to-configure-auto-train#ensemble>`__
+ , defaults to None
+ :type enable_vote_ensemble: Optional[bool]
+ :keyword stack_ensemble_settings:
+ Settings for StackEnsemble iteration, defaults to None
+ :paramtype stack_ensemble_settings: Optional[StackEnsembleSettings]
+ :keyword ensemble_model_download_timeout:
+ During VotingEnsemble and StackEnsemble model generation,
+ multiple fitted models from the previous child runs are downloaded. Configure this parameter with a
+ higher value than 300 secs, if more time is needed, defaults to None
+ :paramtype ensemble_model_download_timeout: Optional[int]
+ :keyword allowed_training_algorithms:
+ A list of model names to search for an experiment. If not specified,
+ then all models supported for the task are used minus any specified in ``blocked_training_algorithms``
+ or deprecated TensorFlow models, defaults to None
+ :paramtype allowed_training_algorithms: Optional[List[str]]
+ :keyword blocked_training_algorithms:
+ A list of algorithms to ignore for an experiment, defaults to None
+ :paramtype blocked_training_algorithms: Optional[List[str]]
+ :keyword training_mode:
+ [Experimental] The training mode to use.
+ The possible values are-
+
+ * distributed- enables distributed training for supported algorithms.
+
+ * non_distributed- disables distributed training.
+
+ * auto- Currently, it is same as non_distributed. In future, this might change.
+
+ Note: This parameter is in public preview and may change in future.
+ :type training_mode: Optional[Union[~azure.ai.ml.constants.TabularTrainingMode, str]]
+ """
+ super().set_training(
+ enable_onnx_compatible_models=enable_onnx_compatible_models,
+ enable_dnn_training=enable_dnn_training,
+ enable_model_explainability=enable_model_explainability,
+ enable_stack_ensemble=enable_stack_ensemble,
+ enable_vote_ensemble=enable_vote_ensemble,
+ stack_ensemble_settings=stack_ensemble_settings,
+ ensemble_model_download_timeout=ensemble_model_download_timeout,
+ allowed_training_algorithms=allowed_training_algorithms,
+ blocked_training_algorithms=blocked_training_algorithms,
+ training_mode=training_mode,
+ )
+
+ # Disable stack ensemble by default, since it is currently not supported for forecasting tasks
+ if enable_stack_ensemble is None:
+ if self._training is not None:
+ self._training.enable_stack_ensemble = False
+
+ def _to_rest_object(self) -> JobBase:
+ if self._forecasting_settings is not None:
+ forecasting_task = RestForecasting(
+ target_column_name=self.target_column_name,
+ training_data=self.training_data,
+ validation_data=self.validation_data,
+ validation_data_size=self.validation_data_size,
+ weight_column_name=self.weight_column_name,
+ cv_split_column_names=self.cv_split_column_names,
+ n_cross_validations=self.n_cross_validations,
+ test_data=self.test_data,
+ test_data_size=self.test_data_size,
+ featurization_settings=self._featurization._to_rest_object() if self._featurization else None,
+ limit_settings=self._limits._to_rest_object() if self._limits else None,
+ training_settings=self._training._to_rest_object() if self._training else None,
+ primary_metric=self.primary_metric,
+ log_verbosity=self.log_verbosity,
+ forecasting_settings=self._forecasting_settings._to_rest_object(),
+ )
+ else:
+ forecasting_task = RestForecasting(
+ target_column_name=self.target_column_name,
+ training_data=self.training_data,
+ validation_data=self.validation_data,
+ validation_data_size=self.validation_data_size,
+ weight_column_name=self.weight_column_name,
+ cv_split_column_names=self.cv_split_column_names,
+ n_cross_validations=self.n_cross_validations,
+ test_data=self.test_data,
+ test_data_size=self.test_data_size,
+ featurization_settings=self._featurization._to_rest_object() if self._featurization else None,
+ limit_settings=self._limits._to_rest_object() if self._limits else None,
+ training_settings=self._training._to_rest_object() if self._training else None,
+ primary_metric=self.primary_metric,
+ log_verbosity=self.log_verbosity,
+ forecasting_settings=None,
+ )
+
+ self._resolve_data_inputs(forecasting_task)
+ self._validation_data_to_rest(forecasting_task)
+
+ properties = RestAutoMLJob(
+ display_name=self.display_name,
+ description=self.description,
+ experiment_name=self.experiment_name,
+ tags=self.tags,
+ compute_id=self.compute,
+ properties=self.properties,
+ environment_id=self.environment_id,
+ environment_variables=self.environment_variables,
+ services=self.services,
+ outputs=to_rest_data_outputs(self.outputs),
+ resources=self.resources,
+ task_details=forecasting_task,
+ identity=self.identity._to_job_rest_object() if self.identity else None,
+ queue_settings=self.queue_settings,
+ )
+
+ result = JobBase(properties=properties)
+ result.name = self.name
+ return result
+
+ @classmethod
+ def _from_rest_object(cls, obj: JobBase) -> "ForecastingJob":
+ properties: RestAutoMLJob = obj.properties
+ task_details: RestForecasting = properties.task_details
+
+ job_args_dict = {
+ "id": obj.id,
+ "name": obj.name,
+ "description": properties.description,
+ "tags": properties.tags,
+ "properties": properties.properties,
+ "experiment_name": properties.experiment_name,
+ "services": properties.services,
+ "status": properties.status,
+ "creation_context": obj.system_data,
+ "display_name": properties.display_name,
+ "compute": properties.compute_id,
+ "outputs": from_rest_data_outputs(properties.outputs),
+ "resources": properties.resources,
+ "identity": (
+ _BaseJobIdentityConfiguration._from_rest_object(properties.identity) if properties.identity else None
+ ),
+ "queue_settings": properties.queue_settings,
+ }
+
+ forecasting_job = cls(
+ target_column_name=task_details.target_column_name,
+ training_data=task_details.training_data,
+ validation_data=task_details.validation_data,
+ validation_data_size=task_details.validation_data_size,
+ weight_column_name=task_details.weight_column_name,
+ cv_split_column_names=task_details.cv_split_column_names,
+ n_cross_validations=task_details.n_cross_validations,
+ test_data=task_details.test_data,
+ test_data_size=task_details.test_data_size,
+ featurization=(
+ TabularFeaturizationSettings._from_rest_object(task_details.featurization_settings)
+ if task_details.featurization_settings
+ else None
+ ),
+ limits=(
+ TabularLimitSettings._from_rest_object(task_details.limit_settings)
+ if task_details.limit_settings
+ else None
+ ),
+ training=(
+ ForecastingTrainingSettings._from_rest_object(task_details.training_settings)
+ if task_details.training_settings
+ else None
+ ),
+ primary_metric=task_details.primary_metric,
+ forecasting_settings=(
+ ForecastingSettings._from_rest_object(task_details.forecasting_settings)
+ if task_details.forecasting_settings
+ else None
+ ),
+ log_verbosity=task_details.log_verbosity,
+ **job_args_dict,
+ )
+
+ forecasting_job._restore_data_inputs()
+ forecasting_job._validation_data_from_rest()
+
+ return forecasting_job
+
+ @classmethod
+ def _load_from_dict(
+ cls,
+ data: Dict,
+ context: Dict,
+ additional_message: str,
+ **kwargs: Any,
+ ) -> "ForecastingJob":
+ from azure.ai.ml._schema.automl.table_vertical.forecasting import AutoMLForecastingSchema
+ from azure.ai.ml._schema.pipeline.automl_node import AutoMLForecastingNodeSchema
+
+ if kwargs.pop("inside_pipeline", False):
+ loaded_data = load_from_dict(AutoMLForecastingNodeSchema, data, context, additional_message, **kwargs)
+ else:
+ loaded_data = load_from_dict(AutoMLForecastingSchema, data, context, additional_message, **kwargs)
+ job_instance = cls._create_instance_from_schema_dict(loaded_data)
+ return job_instance
+
+ @classmethod
+ def _create_instance_from_schema_dict(cls, loaded_data: Dict) -> "ForecastingJob":
+ loaded_data.pop(AutoMLConstants.TASK_TYPE_YAML, None)
+ data_settings = {
+ "training_data": loaded_data.pop("training_data"),
+ "target_column_name": loaded_data.pop("target_column_name"),
+ "weight_column_name": loaded_data.pop("weight_column_name", None),
+ "validation_data": loaded_data.pop("validation_data", None),
+ "validation_data_size": loaded_data.pop("validation_data_size", None),
+ "cv_split_column_names": loaded_data.pop("cv_split_column_names", None),
+ "n_cross_validations": loaded_data.pop("n_cross_validations", None),
+ "test_data": loaded_data.pop("test_data", None),
+ "test_data_size": loaded_data.pop("test_data_size", None),
+ }
+ job = ForecastingJob(**loaded_data)
+ job.set_data(**data_settings)
+ return job
+
+ def _to_dict(self, inside_pipeline: bool = False) -> Dict:
+ from azure.ai.ml._schema.automl.table_vertical.forecasting import AutoMLForecastingSchema
+ from azure.ai.ml._schema.pipeline.automl_node import AutoMLForecastingNodeSchema
+
+ schema_dict: dict = {}
+ if inside_pipeline:
+ schema_dict = AutoMLForecastingNodeSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ else:
+ schema_dict = AutoMLForecastingSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return schema_dict
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, ForecastingJob):
+ return NotImplemented
+
+ if not super(ForecastingJob, self).__eq__(other):
+ return False
+
+ return self.primary_metric == other.primary_metric and self._forecasting_settings == other._forecasting_settings
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/forecasting_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/forecasting_settings.py
new file mode 100644
index 00000000..09439483
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/forecasting_settings.py
@@ -0,0 +1,383 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=too-many-instance-attributes
+
+from typing import List, Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ AutoForecastHorizon,
+ AutoSeasonality,
+ AutoTargetLags,
+ AutoTargetRollingWindowSize,
+ CustomForecastHorizon,
+ CustomSeasonality,
+ CustomTargetLags,
+ CustomTargetRollingWindowSize,
+ ForecastHorizonMode,
+)
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ ForecastingSettings as RestForecastingSettings,
+)
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ SeasonalityMode,
+ TargetLagsMode,
+ TargetRollingWindowSizeMode,
+)
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+
+class ForecastingSettings(RestTranslatableMixin):
+ """Forecasting settings for an AutoML Job.
+
+ :param country_or_region_for_holidays: The country/region used to generate holiday features. These should be ISO
+ 3166 two-letter country/region code, for example 'US' or 'GB'.
+ :type country_or_region_for_holidays: Optional[str]
+ :param cv_step_size:
+ Number of periods between the origin_time of one CV fold and the next fold. For
+ example, if `n_step` = 3 for daily data, the origin time for each fold will be
+ three days apart.
+ :type cv_step_size: Optional[int]
+ :param forecast_horizon:
+ The desired maximum forecast horizon in units of time-series frequency. The default value is 1.
+
+ Units are based on the time interval of your training data, e.g., monthly, weekly that the forecaster
+ should predict out. When task type is forecasting, this parameter is required. For more information on
+ setting forecasting parameters, see `Auto-train a time-series forecast model <https://learn.microsoft.com/
+ azure/machine-learning/how-to-auto-train-forecast>`_.
+ :type forecast_horizon: Optional[Union[int, str]]
+ :param target_lags:
+ The number of past periods to lag from the target column. By default the lags are turned off.
+
+ When forecasting, this parameter represents the number of rows to lag the target values based
+ on the frequency of the data. This is represented as a list or single integer. Lag should be used
+ when the relationship between the independent variables and dependent variable do not match up or
+ correlate by default. For example, when trying to forecast demand for a product, the demand in any
+ month may depend on the price of specific commodities 3 months prior. In this example, you may want
+ to lag the target (demand) negatively by 3 months so that the model is training on the correct
+ relationship. For more information, see `Auto-train a time-series forecast model
+ <https://learn.microsoft.com/azure/machine-learning/how-to-auto-train-forecast>`_.
+
+ **Note on auto detection of target lags and rolling window size.
+ Please see the corresponding comments in the rolling window section.**
+ We use the next algorithm to detect the optimal target lag and rolling window size.
+
+ #. Estimate the maximum lag order for the look back feature selection. In our case it is the number of
+ periods till the next date frequency granularity i.e. if frequency is daily, it will be a week (7),
+ if it is a week, it will be month (4). That values multiplied by two is the largest
+ possible values of lags/rolling windows. In our examples, we will consider the maximum lag
+ order of 14 and 8 respectively).
+ #. Create a de-seasonalized series by adding trend and residual components. This will be used
+ in the next step.
+ #. Estimate the PACF - Partial Auto Correlation Function on the on the data from (2)
+ and search for points, where the auto correlation is significant i.e. its absolute
+ value is more then 1.96/square_root(maximal lag value), which correspond to significance of 95%.
+ #. If all points are significant, we consider it being strong seasonality
+ and do not create look back features.
+ #. We scan the PACF values from the beginning and the value before the first insignificant
+ auto correlation will designate the lag. If first significant element (value correlate with
+ itself) is followed by insignificant, the lag will be 0 and we will not use look back features.
+ :type target_lags: Union[str, int, List[int]]
+ :param target_rolling_window_size:
+ The number of past periods used to create a rolling window average of the target column.
+
+ When forecasting, this parameter represents `n` historical periods to use to generate forecasted values,
+ <= training set size. If omitted, `n` is the full training set size. Specify this parameter
+ when you only want to consider a certain amount of history when training the model.
+ If set to 'auto', rolling window will be estimated as the last
+ value where the PACF is more then the significance threshold. Please see target_lags section for details.
+ :type target_rolling_window_size: Optional[Union[str, int]]
+ :param frequency: Forecast frequency.
+
+ When forecasting, this parameter represents the period with which the forecast is desired,
+ for example daily, weekly, yearly, etc. The forecast frequency is dataset frequency by default.
+ You can optionally set it to greater (but not lesser) than dataset frequency.
+ We'll aggregate the data and generate the results at forecast frequency. For example,
+ for daily data, you can set the frequency to be daily, weekly or monthly, but not hourly.
+ The frequency needs to be a pandas offset alias.
+ Please refer to pandas documentation for more information:
+ https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#dateoffset-objects
+ :type frequency: Optional[str]
+ :param feature_lags: Flag for generating lags for the numeric features with 'auto' or None.
+ :type feature_lags: Optional[str]
+ :param seasonality: Set time series seasonality as an integer multiple of the series frequency.
+ If seasonality is set to 'auto', it will be inferred.
+ If set to None, the time series is assumed non-seasonal which is equivalent to seasonality=1.
+ :type seasonality: Optional[Union[int, str]]
+ :param use_stl: Configure STL Decomposition of the time-series target column.
+ use_stl can take three values: None (default) - no stl decomposition, 'season' - only generate
+ season component and season_trend - generate both season and trend components.
+ :type use_stl: Optional[str]
+ :param short_series_handling_config:
+ The parameter defining how if AutoML should handle short time series.
+
+ Possible values: 'auto' (default), 'pad', 'drop' and None.
+ * **auto** short series will be padded if there are no long series,
+ otherwise short series will be dropped.
+ * **pad** all the short series will be padded.
+ * **drop** all the short series will be dropped".
+ * **None** the short series will not be modified.
+ If set to 'pad', the table will be padded with the zeroes and
+ empty values for the regressors and random values for target with the mean
+ equal to target value median for given time series id. If median is more or equal
+ to zero, the minimal padded value will be clipped by zero.
+ Input:
+
+ +------------+---------------+----------+--------+
+ | Date | numeric_value | string | target |
+ +============+===============+==========+========+
+ | 2020-01-01 | 23 | green | 55 |
+ +------------+---------------+----------+--------+
+
+ Output assuming minimal number of values is four:
+
+ +------------+---------------+----------+--------+
+ | Date | numeric_value | string | target |
+ +============+===============+==========+========+
+ | 2019-12-29 | 0 | NA | 55.1 |
+ +------------+---------------+----------+--------+
+ | 2019-12-30 | 0 | NA | 55.6 |
+ +------------+---------------+----------+--------+
+ | 2019-12-31 | 0 | NA | 54.5 |
+ +------------+---------------+----------+--------+
+ | 2020-01-01 | 23 | green | 55 |
+ +------------+---------------+----------+--------+
+
+ **Note:** We have two parameters short_series_handling_configuration and
+ legacy short_series_handling. When both parameters are set we are
+ synchronize them as shown in the table below (short_series_handling_configuration and
+ short_series_handling for brevity are marked as handling_configuration and handling
+ respectively).
+
+ +------------+--------------------------+----------------------+-----------------------------+
+ | | handling | | handling configuration | | resulting handling | | resulting handling |
+ | | | | | configuration |
+ +============+==========================+======================+=============================+
+ | True | auto | True | auto |
+ +------------+--------------------------+----------------------+-----------------------------+
+ | True | pad | True | auto |
+ +------------+--------------------------+----------------------+-----------------------------+
+ | True | drop | True | auto |
+ +------------+--------------------------+----------------------+-----------------------------+
+ | True | None | False | None |
+ +------------+--------------------------+----------------------+-----------------------------+
+ | False | auto | False | None |
+ +------------+--------------------------+----------------------+-----------------------------+
+ | False | pad | False | None |
+ +------------+--------------------------+----------------------+-----------------------------+
+ | False | drop | False | None |
+ +------------+--------------------------+----------------------+-----------------------------+
+ | False | None | False | None |
+ +------------+--------------------------+----------------------+-----------------------------+
+
+ :type short_series_handling_config: Optional[str]
+ :param target_aggregate_function: The function to be used to aggregate the time series target
+ column to conform to a user specified frequency. If the
+ target_aggregation_function is set, but the freq parameter
+ is not set, the error is raised. The possible target
+ aggregation functions are: "sum", "max", "min" and "mean".
+
+ * The target column values are aggregated based on the specified operation.
+ Typically, sum is appropriate for most scenarios.
+ * Numerical predictor columns in your data are aggregated by sum, mean, minimum value,
+ and maximum value. As a result, automated ML generates new columns suffixed with the
+ aggregation function name and applies the selected aggregate operation.
+ * For categorical predictor columns, the data is aggregated by mode,
+ the most prominent category in the window.
+ * Date predictor columns are aggregated by minimum value, maximum value and mode.
+
+ +----------------+-------------------------------+--------------------------------------+
+ | | freq | | target_aggregation_function | | Data regularity |
+ | | | | fixing mechanism |
+ +================+===============================+======================================+
+ | None (Default) | None (Default) | | The aggregation is not |
+ | | | | applied. If the valid |
+ | | | | frequency can not be |
+ | | | | determined the error will |
+ | | | | be raised. |
+ +----------------+-------------------------------+--------------------------------------+
+ | Some Value | None (Default) | | The aggregation is not |
+ | | | | applied. If the number |
+ | | | | of data points compliant |
+ | | | | to given frequency grid |
+ | | | | is less then 90% these points |
+ | | | | will be removed, otherwise |
+ | | | | the error will be raised. |
+ +----------------+-------------------------------+--------------------------------------+
+ | None (Default) | Aggregation function | | The error about missing |
+ | | | | frequency parameter |
+ | | | | is raised. |
+ +----------------+-------------------------------+--------------------------------------+
+ | Some Value | Aggregation function | | Aggregate to frequency using |
+ | | | | provided aggregation function. |
+ +----------------+-------------------------------+--------------------------------------+
+ :type target_aggregate_function: str
+ :param time_column_name:
+ The name of the time column. This parameter is required when forecasting to specify the datetime
+ column in the input data used for building the time series and inferring its frequency.
+ :type time_column_name: Optional[str]
+ :param time_series_id_column_names:
+ The names of columns used to group a timeseries.
+ It can be used to create multiple series. If time series id column names is not defined or
+ the identifier columns specified do not identify all the series in the dataset, the time series identifiers
+ will be automatically created for your dataset.
+ :type time_series_id_column_names: Union[str, List[str]]
+ :param features_unknown_at_forecast_time:
+ The feature columns that are available for training but unknown at the time of forecast/inference.
+ If features_unknown_at_forecast_time is set to an empty list, it is assumed that
+ all the feature columns in the dataset are known at inference time. If this parameter is not set
+ the support for future features is not enabled.
+ :type features_unknown_at_forecast_time: Optional[Union[str, List[str]]]
+ """
+
+ def __init__(
+ self,
+ *,
+ country_or_region_for_holidays: Optional[str] = None,
+ cv_step_size: Optional[int] = None,
+ forecast_horizon: Optional[Union[str, int]] = None,
+ target_lags: Optional[Union[str, int, List[int]]] = None,
+ target_rolling_window_size: Optional[Union[str, int]] = None,
+ frequency: Optional[str] = None,
+ feature_lags: Optional[str] = None,
+ seasonality: Optional[Union[str, int]] = None,
+ use_stl: Optional[str] = None,
+ short_series_handling_config: Optional[str] = None,
+ target_aggregate_function: Optional[str] = None,
+ time_column_name: Optional[str] = None,
+ time_series_id_column_names: Optional[Union[str, List[str]]] = None,
+ features_unknown_at_forecast_time: Optional[Union[str, List[str]]] = None,
+ ):
+ self.country_or_region_for_holidays = country_or_region_for_holidays
+ self.cv_step_size = cv_step_size
+ self.forecast_horizon = forecast_horizon
+ self.target_lags = target_lags
+ self.target_rolling_window_size = target_rolling_window_size
+ self.frequency = frequency
+ self.feature_lags = feature_lags
+ self.seasonality = seasonality
+ self.use_stl = use_stl
+ self.short_series_handling_config = short_series_handling_config
+ self.target_aggregate_function = target_aggregate_function
+ self.time_column_name = time_column_name
+ self.time_series_id_column_names = time_series_id_column_names
+ self.features_unknown_at_forecast_time = features_unknown_at_forecast_time
+
+ def _to_rest_object(self) -> RestForecastingSettings:
+ forecast_horizon = None
+ if isinstance(self.forecast_horizon, str):
+ forecast_horizon = AutoForecastHorizon()
+ elif self.forecast_horizon:
+ forecast_horizon = CustomForecastHorizon(value=self.forecast_horizon)
+
+ target_lags = None
+ if isinstance(self.target_lags, str):
+ target_lags = AutoTargetLags()
+ elif self.target_lags:
+ lags = [self.target_lags] if not isinstance(self.target_lags, list) else self.target_lags
+ target_lags = CustomTargetLags(values=lags)
+
+ target_rolling_window_size = None
+ if isinstance(self.target_rolling_window_size, str):
+ target_rolling_window_size = AutoTargetRollingWindowSize()
+ elif self.target_rolling_window_size:
+ target_rolling_window_size = CustomTargetRollingWindowSize(value=self.target_rolling_window_size)
+
+ seasonality = None
+ if isinstance(self.seasonality, str):
+ seasonality = AutoSeasonality()
+ elif self.seasonality:
+ seasonality = CustomSeasonality(value=self.seasonality)
+
+ time_series_id_column_names = self.time_series_id_column_names
+ if isinstance(self.time_series_id_column_names, str) and self.time_series_id_column_names:
+ time_series_id_column_names = [self.time_series_id_column_names]
+
+ features_unknown_at_forecast_time = self.features_unknown_at_forecast_time
+ if isinstance(self.features_unknown_at_forecast_time, str) and self.features_unknown_at_forecast_time:
+ features_unknown_at_forecast_time = [self.features_unknown_at_forecast_time]
+
+ return RestForecastingSettings(
+ country_or_region_for_holidays=self.country_or_region_for_holidays,
+ cv_step_size=self.cv_step_size,
+ forecast_horizon=forecast_horizon,
+ time_column_name=self.time_column_name,
+ target_lags=target_lags,
+ target_rolling_window_size=target_rolling_window_size,
+ seasonality=seasonality,
+ frequency=self.frequency,
+ feature_lags=self.feature_lags,
+ use_stl=self.use_stl,
+ short_series_handling_config=self.short_series_handling_config,
+ target_aggregate_function=self.target_aggregate_function,
+ time_series_id_column_names=time_series_id_column_names,
+ features_unknown_at_forecast_time=features_unknown_at_forecast_time,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestForecastingSettings) -> "ForecastingSettings":
+ forecast_horizon = None
+ if obj.forecast_horizon and obj.forecast_horizon.mode == ForecastHorizonMode.AUTO:
+ forecast_horizon = obj.forecast_horizon.mode.lower()
+ elif obj.forecast_horizon:
+ forecast_horizon = obj.forecast_horizon.value
+
+ rest_target_lags = obj.target_lags
+ target_lags = None
+ if rest_target_lags and rest_target_lags.mode == TargetLagsMode.AUTO:
+ target_lags = rest_target_lags.mode.lower()
+ elif rest_target_lags:
+ target_lags = rest_target_lags.values
+
+ target_rolling_window_size = None
+ if obj.target_rolling_window_size and obj.target_rolling_window_size.mode == TargetRollingWindowSizeMode.AUTO:
+ target_rolling_window_size = obj.target_rolling_window_size.mode.lower()
+ elif obj.target_rolling_window_size:
+ target_rolling_window_size = obj.target_rolling_window_size.value
+
+ seasonality = None
+ if obj.seasonality and obj.seasonality.mode == SeasonalityMode.AUTO:
+ seasonality = obj.seasonality.mode.lower()
+ elif obj.seasonality:
+ seasonality = obj.seasonality.value
+
+ return cls(
+ country_or_region_for_holidays=obj.country_or_region_for_holidays,
+ cv_step_size=obj.cv_step_size,
+ forecast_horizon=forecast_horizon,
+ target_lags=target_lags,
+ target_rolling_window_size=target_rolling_window_size,
+ frequency=obj.frequency,
+ feature_lags=obj.feature_lags,
+ seasonality=seasonality,
+ use_stl=obj.use_stl,
+ short_series_handling_config=obj.short_series_handling_config,
+ target_aggregate_function=obj.target_aggregate_function,
+ time_column_name=obj.time_column_name,
+ time_series_id_column_names=obj.time_series_id_column_names,
+ features_unknown_at_forecast_time=obj.features_unknown_at_forecast_time,
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, ForecastingSettings):
+ return NotImplemented
+ return (
+ self.country_or_region_for_holidays == other.country_or_region_for_holidays
+ and self.cv_step_size == other.cv_step_size
+ and self.forecast_horizon == other.forecast_horizon
+ and self.target_lags == other.target_lags
+ and self.target_rolling_window_size == other.target_rolling_window_size
+ and self.frequency == other.frequency
+ and self.feature_lags == other.feature_lags
+ and self.seasonality == other.seasonality
+ and self.use_stl == other.use_stl
+ and self.short_series_handling_config == other.short_series_handling_config
+ and self.target_aggregate_function == other.target_aggregate_function
+ and self.time_column_name == other.time_column_name
+ and self.time_series_id_column_names == other.time_series_id_column_names
+ and self.features_unknown_at_forecast_time == other.features_unknown_at_forecast_time
+ )
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/limit_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/limit_settings.py
new file mode 100644
index 00000000..1024f504
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/limit_settings.py
@@ -0,0 +1,101 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Optional
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import TableVerticalLimitSettings as RestTabularLimitSettings
+from azure.ai.ml._utils.utils import from_iso_duration_format_mins, to_iso_duration_format_mins
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+
+class TabularLimitSettings(RestTranslatableMixin):
+ """Limit settings for a AutoML Table Verticals.
+
+ :param enable_early_termination: Whether to enable early termination if the score is not improving in
+ the short term. The default is True.
+ :type enable_early_termination: bool
+ :param exit_score: Target score for experiment. The experiment terminates after this score is reached.
+ :type exit_score: float
+ :param max_concurrent_trials: Maximum number of concurrent AutoML iterations.
+ :type max_concurrent_trials: int
+ :param max_cores_per_trial: The maximum number of threads to use for a given training iteration.
+ :type max_cores_per_trial: int
+ :param max_nodes: [Experimental] The maximum number of nodes to use for distributed training.
+
+ * For forecasting, each model is trained using max(2, int(max_nodes / max_concurrent_trials)) nodes.
+
+ * For classification/regression, each model is trained using max_nodes nodes.
+
+ Note- This parameter is in public preview and might change in future.
+ :type max_nodes: int
+ :param max_trials: Maximum number of AutoML iterations.
+ :type max_trials: int
+ :param timeout_minutes: AutoML job timeout.
+ :type timeout_minutes: int
+ :param trial_timeout_minutes: AutoML job timeout.
+ :type trial_timeout_minutes: int
+ """
+
+ def __init__(
+ self,
+ *,
+ enable_early_termination: Optional[bool] = None,
+ exit_score: Optional[float] = None,
+ max_concurrent_trials: Optional[int] = None,
+ max_cores_per_trial: Optional[int] = None,
+ max_nodes: Optional[int] = None,
+ max_trials: Optional[int] = None,
+ timeout_minutes: Optional[int] = None,
+ trial_timeout_minutes: Optional[int] = None,
+ ):
+ self.enable_early_termination = enable_early_termination
+ self.exit_score = exit_score
+ self.max_concurrent_trials = max_concurrent_trials
+ self.max_cores_per_trial = max_cores_per_trial
+ self.max_nodes = max_nodes
+ self.max_trials = max_trials
+ self.timeout_minutes = timeout_minutes
+ self.trial_timeout_minutes = trial_timeout_minutes
+
+ def _to_rest_object(self) -> RestTabularLimitSettings:
+ return RestTabularLimitSettings(
+ enable_early_termination=self.enable_early_termination,
+ exit_score=self.exit_score,
+ max_concurrent_trials=self.max_concurrent_trials,
+ max_cores_per_trial=self.max_cores_per_trial,
+ max_nodes=self.max_nodes,
+ max_trials=self.max_trials,
+ timeout=to_iso_duration_format_mins(self.timeout_minutes),
+ trial_timeout=to_iso_duration_format_mins(self.trial_timeout_minutes),
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestTabularLimitSettings) -> "TabularLimitSettings":
+ return cls(
+ enable_early_termination=obj.enable_early_termination,
+ exit_score=obj.exit_score,
+ max_concurrent_trials=obj.max_concurrent_trials,
+ max_cores_per_trial=obj.max_cores_per_trial,
+ max_nodes=obj.max_nodes,
+ max_trials=obj.max_trials,
+ timeout_minutes=from_iso_duration_format_mins(obj.timeout),
+ trial_timeout_minutes=from_iso_duration_format_mins(obj.trial_timeout),
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, TabularLimitSettings):
+ return NotImplemented
+ return (
+ self.enable_early_termination == other.enable_early_termination
+ and self.exit_score == other.exit_score
+ and self.max_concurrent_trials == other.max_concurrent_trials
+ and self.max_cores_per_trial == other.max_cores_per_trial
+ and self.max_nodes == other.max_nodes
+ and self.max_trials == other.max_trials
+ and self.timeout_minutes == other.timeout_minutes
+ and self.trial_timeout_minutes == other.trial_timeout_minutes
+ )
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/regression_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/regression_job.py
new file mode 100644
index 00000000..3531e52c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/tabular/regression_job.py
@@ -0,0 +1,239 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from typing import Any, Dict, Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import AutoMLJob as RestAutoMLJob
+from azure.ai.ml._restclient.v2023_04_01_preview.models import JobBase
+from azure.ai.ml._restclient.v2023_04_01_preview.models import Regression as RestRegression
+from azure.ai.ml._restclient.v2023_04_01_preview.models import RegressionPrimaryMetrics, TaskType
+from azure.ai.ml._utils.utils import camel_to_snake, is_data_binding_expression
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+from azure.ai.ml.constants._job.automl import AutoMLConstants
+from azure.ai.ml.entities._credentials import _BaseJobIdentityConfiguration
+from azure.ai.ml.entities._job._input_output_helpers import from_rest_data_outputs, to_rest_data_outputs
+from azure.ai.ml.entities._job.automl.tabular import AutoMLTabular, TabularFeaturizationSettings, TabularLimitSettings
+from azure.ai.ml.entities._job.automl.training_settings import RegressionTrainingSettings
+from azure.ai.ml.entities._util import load_from_dict
+
+
+class RegressionJob(AutoMLTabular):
+ """Configuration for AutoML Regression Job."""
+
+ _DEFAULT_PRIMARY_METRIC = RegressionPrimaryMetrics.NORMALIZED_ROOT_MEAN_SQUARED_ERROR
+
+ def __init__(
+ self,
+ *,
+ primary_metric: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ """Initialize a new AutoML Regression task.
+
+ :param primary_metric: The primary metric to use for optimization
+ :type primary_metric: str
+ :param kwargs: Job-specific arguments
+ :type kwargs: dict
+ """
+ # Extract any task specific settings
+ featurization = kwargs.pop("featurization", None)
+ limits = kwargs.pop("limits", None)
+ training = kwargs.pop("training", None)
+
+ super().__init__(
+ task_type=TaskType.REGRESSION,
+ featurization=featurization,
+ limits=limits,
+ training=training,
+ **kwargs,
+ )
+
+ self.primary_metric = primary_metric or RegressionJob._DEFAULT_PRIMARY_METRIC
+
+ @property
+ def primary_metric(self) -> Union[str, RegressionPrimaryMetrics]:
+ return self._primary_metric
+
+ @primary_metric.setter
+ def primary_metric(self, value: Union[str, RegressionPrimaryMetrics]) -> None:
+ # TODO: better way to do this
+ if is_data_binding_expression(str(value), ["parent"]):
+ self._primary_metric = value
+ return
+ self._primary_metric = (
+ RegressionJob._DEFAULT_PRIMARY_METRIC
+ if value is None
+ else RegressionPrimaryMetrics[camel_to_snake(value).upper()]
+ )
+
+ @property
+ def training(self) -> RegressionTrainingSettings:
+ return self._training or RegressionTrainingSettings()
+
+ @training.setter
+ def training(self, value: Union[Dict, RegressionTrainingSettings]) -> None: # pylint: disable=unused-argument
+ ...
+
+ def _to_rest_object(self) -> JobBase:
+ regression_task = RestRegression(
+ target_column_name=self.target_column_name,
+ training_data=self.training_data,
+ validation_data=self.validation_data,
+ validation_data_size=self.validation_data_size,
+ weight_column_name=self.weight_column_name,
+ cv_split_column_names=self.cv_split_column_names,
+ n_cross_validations=self.n_cross_validations,
+ test_data=self.test_data,
+ test_data_size=self.test_data_size,
+ featurization_settings=self._featurization._to_rest_object() if self._featurization else None,
+ limit_settings=self._limits._to_rest_object() if self._limits else None,
+ training_settings=self._training._to_rest_object() if self._training else None,
+ primary_metric=self.primary_metric,
+ log_verbosity=self.log_verbosity,
+ )
+ self._resolve_data_inputs(regression_task)
+ self._validation_data_to_rest(regression_task)
+
+ properties = RestAutoMLJob(
+ display_name=self.display_name,
+ description=self.description,
+ experiment_name=self.experiment_name,
+ tags=self.tags,
+ compute_id=self.compute,
+ properties=self.properties,
+ environment_id=self.environment_id,
+ environment_variables=self.environment_variables,
+ services=self.services,
+ outputs=to_rest_data_outputs(self.outputs),
+ resources=self.resources,
+ task_details=regression_task,
+ identity=self.identity._to_job_rest_object() if self.identity else None,
+ queue_settings=self.queue_settings,
+ )
+
+ result = JobBase(properties=properties)
+ result.name = self.name
+ return result
+
+ @classmethod
+ def _from_rest_object(cls, obj: JobBase) -> "RegressionJob":
+ properties: RestAutoMLJob = obj.properties
+ task_details: RestRegression = properties.task_details
+
+ job_args_dict = {
+ "id": obj.id,
+ "name": obj.name,
+ "description": properties.description,
+ "tags": properties.tags,
+ "properties": properties.properties,
+ "experiment_name": properties.experiment_name,
+ "services": properties.services,
+ "status": properties.status,
+ "creation_context": obj.system_data,
+ "display_name": properties.display_name,
+ "compute": properties.compute_id,
+ "outputs": from_rest_data_outputs(properties.outputs),
+ "resources": properties.resources,
+ "identity": (
+ _BaseJobIdentityConfiguration._from_rest_object(properties.identity) if properties.identity else None
+ ),
+ "queue_settings": properties.queue_settings,
+ }
+
+ regression_job = cls(
+ target_column_name=task_details.target_column_name,
+ training_data=task_details.training_data,
+ validation_data=task_details.validation_data,
+ validation_data_size=task_details.validation_data_size,
+ weight_column_name=task_details.weight_column_name,
+ cv_split_column_names=task_details.cv_split_column_names,
+ n_cross_validations=task_details.n_cross_validations,
+ test_data=task_details.test_data,
+ test_data_size=task_details.test_data_size,
+ featurization=(
+ TabularFeaturizationSettings._from_rest_object(task_details.featurization_settings)
+ if task_details.featurization_settings
+ else None
+ ),
+ limits=(
+ TabularLimitSettings._from_rest_object(task_details.limit_settings)
+ if task_details.limit_settings
+ else None
+ ),
+ training=(
+ RegressionTrainingSettings._from_rest_object(task_details.training_settings)
+ if task_details.training_settings
+ else None
+ ),
+ primary_metric=task_details.primary_metric,
+ log_verbosity=task_details.log_verbosity,
+ **job_args_dict,
+ )
+
+ regression_job._restore_data_inputs()
+ regression_job._validation_data_from_rest()
+
+ return regression_job
+
+ @classmethod
+ def _load_from_dict(
+ cls,
+ data: Dict,
+ context: Dict,
+ additional_message: str,
+ **kwargs: Any,
+ ) -> "RegressionJob":
+ from azure.ai.ml._schema.automl.table_vertical.regression import AutoMLRegressionSchema
+ from azure.ai.ml._schema.pipeline.automl_node import AutoMLRegressionNodeSchema
+
+ if kwargs.pop("inside_pipeline", False):
+ loaded_data = load_from_dict(AutoMLRegressionNodeSchema, data, context, additional_message, **kwargs)
+ else:
+ loaded_data = load_from_dict(AutoMLRegressionSchema, data, context, additional_message, **kwargs)
+ job_instance = cls._create_instance_from_schema_dict(loaded_data)
+ return job_instance
+
+ @classmethod
+ def _create_instance_from_schema_dict(cls, loaded_data: Dict) -> "RegressionJob":
+ loaded_data.pop(AutoMLConstants.TASK_TYPE_YAML, None)
+ data_settings = {
+ "training_data": loaded_data.pop("training_data"),
+ "target_column_name": loaded_data.pop("target_column_name"),
+ "weight_column_name": loaded_data.pop("weight_column_name", None),
+ "validation_data": loaded_data.pop("validation_data", None),
+ "validation_data_size": loaded_data.pop("validation_data_size", None),
+ "cv_split_column_names": loaded_data.pop("cv_split_column_names", None),
+ "n_cross_validations": loaded_data.pop("n_cross_validations", None),
+ "test_data": loaded_data.pop("test_data", None),
+ "test_data_size": loaded_data.pop("test_data_size", None),
+ }
+ job = RegressionJob(**loaded_data)
+ job.set_data(**data_settings)
+ return job
+
+ def _to_dict(self, inside_pipeline: bool = False) -> Dict:
+ from azure.ai.ml._schema.automl.table_vertical.regression import AutoMLRegressionSchema
+ from azure.ai.ml._schema.pipeline.automl_node import AutoMLRegressionNodeSchema
+
+ schema_dict: dict = {}
+ if inside_pipeline:
+ schema_dict = AutoMLRegressionNodeSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ else:
+ schema_dict = AutoMLRegressionSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+
+ return schema_dict
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, RegressionJob):
+ return NotImplemented
+
+ if not super(RegressionJob, self).__eq__(other):
+ return False
+
+ return self.primary_metric == other.primary_metric
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/training_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/training_settings.py
new file mode 100644
index 00000000..97bc7e17
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/training_settings.py
@@ -0,0 +1,357 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=R0902,protected-access
+
+from typing import Any, List, Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import ClassificationModels
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ ClassificationTrainingSettings as RestClassificationTrainingSettings,
+)
+from azure.ai.ml._restclient.v2023_04_01_preview.models import ForecastingModels
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ ForecastingTrainingSettings as RestForecastingTrainingSettings,
+)
+from azure.ai.ml._restclient.v2023_04_01_preview.models import RegressionModels
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ RegressionTrainingSettings as RestRegressionTrainingSettings,
+)
+from azure.ai.ml._restclient.v2023_04_01_preview.models import TrainingSettings as RestTrainingSettings
+from azure.ai.ml._utils.utils import camel_to_snake, from_iso_duration_format_mins, to_iso_duration_format_mins
+from azure.ai.ml.constants import TabularTrainingMode
+from azure.ai.ml.entities._job.automl.stack_ensemble_settings import StackEnsembleSettings
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
+
+
+class TrainingSettings(RestTranslatableMixin):
+ """TrainingSettings class for Azure Machine Learning."""
+
+ def __init__(
+ self,
+ *,
+ enable_onnx_compatible_models: Optional[bool] = None,
+ enable_dnn_training: Optional[bool] = None,
+ enable_model_explainability: Optional[bool] = None,
+ enable_stack_ensemble: Optional[bool] = None,
+ enable_vote_ensemble: Optional[bool] = None,
+ stack_ensemble_settings: Optional[StackEnsembleSettings] = None,
+ ensemble_model_download_timeout: Optional[int] = None,
+ allowed_training_algorithms: Optional[List[str]] = None,
+ blocked_training_algorithms: Optional[List[str]] = None,
+ training_mode: Optional[Union[str, TabularTrainingMode]] = None,
+ ):
+ """TrainingSettings class for Azure Machine Learning.
+
+ :param enable_onnx_compatible_models: If set to True, the model will be trained to be compatible with ONNX
+ :type enable_onnx_compatible_models: typing.Optional[bool]
+ :param enable_dnn_training: If set to True,the model will use DNN training
+ :type enable_dnn_training: typing.Optional[bool]
+ :param enable_model_explainability: If set to True, the model will be trained to be explainable
+ :type enable_model_explainability: typing.Optional[bool]
+ :param enable_stack_ensemble: If set to True, a final ensemble model will be created using a stack of models
+ :type enable_stack_ensemble: typing.Optional[bool]
+ :param enable_vote_ensemble: If set to True, a final ensemble model will be created using a voting ensemble
+ :type enable_vote_ensemble: typing.Optional[bool]
+ :param stack_ensemble_settings: Settings for stack ensemble
+ :type stack_ensemble_settings: typing.Optional[azure.ai.ml.automl.StackEnsembleSettings]
+ :param ensemble_model_download_timeout: Timeout for downloading ensemble models
+ :type ensemble_model_download_timeout: typing.Optional[typing.List[int]]
+ :param allowed_training_algorithms: Models to train
+ :type allowed_training_algorithms: typing.Optional[typing.List[str]]
+ :param blocked_training_algorithms: Models that will not be considered for training
+ :type blocked_training_algorithms: typing.Optional[typing.List[str]]
+ :param training_mode: [Experimental] The training mode to use.
+ The possible values are-
+
+ * distributed- enables distributed training for supported algorithms.
+
+ * non_distributed- disables distributed training.
+
+ * auto- Currently, it is same as non_distributed. In future, this might change.
+
+ Note: This parameter is in public preview and may change in future.
+ :type training_mode: typing.Optional[typing.Union[str, azure.ai.ml.constants.TabularTrainingMode]]
+ """
+ self.enable_onnx_compatible_models = enable_onnx_compatible_models
+ self.enable_dnn_training = enable_dnn_training
+ self.enable_model_explainability = enable_model_explainability
+ self.enable_stack_ensemble = enable_stack_ensemble
+ self.enable_vote_ensemble = enable_vote_ensemble
+ self.stack_ensemble_settings = stack_ensemble_settings
+ self.ensemble_model_download_timeout = ensemble_model_download_timeout
+ self.allowed_training_algorithms = allowed_training_algorithms
+ self.blocked_training_algorithms = blocked_training_algorithms
+ self.training_mode = training_mode
+
+ @property
+ def training_mode(self) -> Optional[TabularTrainingMode]:
+ return self._training_mode
+
+ @training_mode.setter
+ def training_mode(self, value: Optional[Union[str, TabularTrainingMode]]) -> None:
+ if value is None or value is TabularTrainingMode:
+ self._training_mode = value
+ elif hasattr(TabularTrainingMode, camel_to_snake(value).upper()):
+ self._training_mode = TabularTrainingMode[camel_to_snake(value).upper()]
+ else:
+ supported_values = ", ".join([f'"{camel_to_snake(mode.value)}"' for mode in TabularTrainingMode])
+ msg = (
+ f"Unsupported training mode: {value}. Supported values are- {supported_values}. "
+ "Or you can use azure.ai.ml.constants.TabularTrainingMode enum."
+ )
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.AUTOML,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ @property
+ def allowed_training_algorithms(self) -> Optional[List[str]]:
+ return self._allowed_training_algorithms
+
+ @allowed_training_algorithms.setter
+ def allowed_training_algorithms(self, value: Optional[List[str]]) -> None:
+ self._allowed_training_algorithms = value
+
+ @property
+ def blocked_training_algorithms(self) -> Optional[List[str]]:
+ return self._blocked_training_algorithms
+
+ @blocked_training_algorithms.setter
+ def blocked_training_algorithms(self, value: Optional[List[str]]) -> None:
+ self._blocked_training_algorithms = value
+
+ def _to_rest_object(self) -> RestTrainingSettings:
+ return RestTrainingSettings(
+ enable_dnn_training=self.enable_dnn_training,
+ enable_onnx_compatible_models=self.enable_onnx_compatible_models,
+ enable_model_explainability=self.enable_model_explainability,
+ enable_stack_ensemble=self.enable_stack_ensemble,
+ enable_vote_ensemble=self.enable_vote_ensemble,
+ stack_ensemble_settings=(
+ self.stack_ensemble_settings._to_rest_object() if self.stack_ensemble_settings else None
+ ),
+ ensemble_model_download_timeout=to_iso_duration_format_mins(self.ensemble_model_download_timeout),
+ training_mode=self.training_mode,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestTrainingSettings) -> "TrainingSettings":
+ return cls(
+ enable_dnn_training=obj.enable_dnn_training,
+ enable_onnx_compatible_models=obj.enable_onnx_compatible_models,
+ enable_model_explainability=obj.enable_model_explainability,
+ enable_stack_ensemble=obj.enable_stack_ensemble,
+ enable_vote_ensemble=obj.enable_vote_ensemble,
+ ensemble_model_download_timeout=from_iso_duration_format_mins(obj.ensemble_model_download_timeout),
+ stack_ensemble_settings=(
+ StackEnsembleSettings._from_rest_object(obj.stack_ensemble_settings)
+ if obj.stack_ensemble_settings
+ else None
+ ),
+ training_mode=obj.training_mode,
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, TrainingSettings):
+ return NotImplemented
+ return (
+ self.enable_dnn_training == other.enable_dnn_training
+ and self.enable_onnx_compatible_models == other.enable_onnx_compatible_models
+ and self.enable_model_explainability == other.enable_model_explainability
+ and self.enable_stack_ensemble == other.enable_stack_ensemble
+ and self.enable_vote_ensemble == other.enable_vote_ensemble
+ and self.ensemble_model_download_timeout == other.ensemble_model_download_timeout
+ and self.stack_ensemble_settings == other.stack_ensemble_settings
+ and self.allowed_training_algorithms == other.allowed_training_algorithms
+ and self.blocked_training_algorithms == other.blocked_training_algorithms
+ and self.training_mode == other.training_mode
+ )
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
+
+
+class ClassificationTrainingSettings(TrainingSettings):
+ """Classification TrainingSettings class for Azure Machine Learning."""
+
+ def __init__(
+ self,
+ **kwargs: Any,
+ ):
+ super().__init__(**kwargs)
+
+ @property
+ def allowed_training_algorithms(self) -> Optional[List]:
+ return self._allowed_training_algorithms
+
+ @allowed_training_algorithms.setter
+ def allowed_training_algorithms(self, allowed_model_list: Union[List[str], List[ClassificationModels]]) -> None:
+ self._allowed_training_algorithms = (
+ None
+ if allowed_model_list is None
+ else [ClassificationModels[camel_to_snake(o)] for o in allowed_model_list]
+ )
+
+ @property
+ def blocked_training_algorithms(self) -> Optional[List]:
+ return self._blocked_training_algorithms
+
+ @blocked_training_algorithms.setter
+ def blocked_training_algorithms(self, blocked_model_list: Union[List[str], List[ClassificationModels]]) -> None:
+ self._blocked_training_algorithms = (
+ None
+ if blocked_model_list is None
+ else [ClassificationModels[camel_to_snake(o)] for o in blocked_model_list]
+ )
+
+ def _to_rest_object(self) -> RestClassificationTrainingSettings:
+ return RestClassificationTrainingSettings(
+ enable_dnn_training=self.enable_dnn_training,
+ enable_onnx_compatible_models=self.enable_onnx_compatible_models,
+ enable_model_explainability=self.enable_model_explainability,
+ enable_stack_ensemble=self.enable_stack_ensemble,
+ enable_vote_ensemble=self.enable_vote_ensemble,
+ stack_ensemble_settings=self.stack_ensemble_settings,
+ ensemble_model_download_timeout=to_iso_duration_format_mins(self.ensemble_model_download_timeout),
+ allowed_training_algorithms=self.allowed_training_algorithms,
+ blocked_training_algorithms=self.blocked_training_algorithms,
+ training_mode=self.training_mode,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestClassificationTrainingSettings) -> "ClassificationTrainingSettings":
+ return cls(
+ enable_dnn_training=obj.enable_dnn_training,
+ enable_onnx_compatible_models=obj.enable_onnx_compatible_models,
+ enable_model_explainability=obj.enable_model_explainability,
+ enable_stack_ensemble=obj.enable_stack_ensemble,
+ enable_vote_ensemble=obj.enable_vote_ensemble,
+ ensemble_model_download_timeout=from_iso_duration_format_mins(obj.ensemble_model_download_timeout),
+ stack_ensemble_settings=obj.stack_ensemble_settings,
+ allowed_training_algorithms=obj.allowed_training_algorithms,
+ blocked_training_algorithms=obj.blocked_training_algorithms,
+ training_mode=obj.training_mode,
+ )
+
+
+class ForecastingTrainingSettings(TrainingSettings):
+ """Forecasting TrainingSettings class for Azure Machine Learning."""
+
+ def __init__(
+ self,
+ **kwargs: Any,
+ ):
+ super().__init__(**kwargs)
+
+ @property
+ def allowed_training_algorithms(self) -> Optional[List]:
+ return self._allowed_training_algorithms
+
+ @allowed_training_algorithms.setter
+ def allowed_training_algorithms(self, allowed_model_list: Union[List[str], List[ForecastingModels]]) -> None:
+ self._allowed_training_algorithms = (
+ None if allowed_model_list is None else [ForecastingModels[camel_to_snake(o)] for o in allowed_model_list]
+ )
+
+ @property
+ def blocked_training_algorithms(self) -> Optional[List]:
+ return self._blocked_training_algorithms
+
+ @blocked_training_algorithms.setter
+ def blocked_training_algorithms(self, blocked_model_list: Union[List[str], List[ForecastingModels]]) -> None:
+ self._blocked_training_algorithms = (
+ None if blocked_model_list is None else [ForecastingModels[camel_to_snake(o)] for o in blocked_model_list]
+ )
+
+ def _to_rest_object(self) -> RestForecastingTrainingSettings:
+ return RestForecastingTrainingSettings(
+ enable_dnn_training=self.enable_dnn_training,
+ enable_onnx_compatible_models=self.enable_onnx_compatible_models,
+ enable_model_explainability=self.enable_model_explainability,
+ enable_stack_ensemble=self.enable_stack_ensemble,
+ enable_vote_ensemble=self.enable_vote_ensemble,
+ stack_ensemble_settings=self.stack_ensemble_settings,
+ ensemble_model_download_timeout=to_iso_duration_format_mins(self.ensemble_model_download_timeout),
+ allowed_training_algorithms=self.allowed_training_algorithms,
+ blocked_training_algorithms=self.blocked_training_algorithms,
+ training_mode=self.training_mode,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestForecastingTrainingSettings) -> "ForecastingTrainingSettings":
+ return cls(
+ enable_dnn_training=obj.enable_dnn_training,
+ enable_onnx_compatible_models=obj.enable_onnx_compatible_models,
+ enable_model_explainability=obj.enable_model_explainability,
+ enable_stack_ensemble=obj.enable_stack_ensemble,
+ enable_vote_ensemble=obj.enable_vote_ensemble,
+ ensemble_model_download_timeout=from_iso_duration_format_mins(obj.ensemble_model_download_timeout),
+ stack_ensemble_settings=obj.stack_ensemble_settings,
+ allowed_training_algorithms=obj.allowed_training_algorithms,
+ blocked_training_algorithms=obj.blocked_training_algorithms,
+ training_mode=obj.training_mode,
+ )
+
+
+class RegressionTrainingSettings(TrainingSettings):
+ """Regression TrainingSettings class for Azure Machine Learning."""
+
+ def __init__(
+ self,
+ **kwargs: Any,
+ ):
+ super().__init__(**kwargs)
+
+ @property
+ def allowed_training_algorithms(self) -> Optional[List]:
+ return self._allowed_training_algorithms
+
+ @allowed_training_algorithms.setter
+ def allowed_training_algorithms(self, allowed_model_list: Union[List[str], List[ForecastingModels]]) -> None:
+ self._allowed_training_algorithms = (
+ None if allowed_model_list is None else [RegressionModels[camel_to_snake(o)] for o in allowed_model_list]
+ )
+
+ @property
+ def blocked_training_algorithms(self) -> Optional[List]:
+ return self._blocked_training_algorithms
+
+ @blocked_training_algorithms.setter
+ def blocked_training_algorithms(self, blocked_model_list: Union[List[str], List[ForecastingModels]]) -> None:
+ self._blocked_training_algorithms = (
+ None if blocked_model_list is None else [RegressionModels[camel_to_snake(o)] for o in blocked_model_list]
+ )
+
+ def _to_rest_object(self) -> RestRegressionTrainingSettings:
+ return RestRegressionTrainingSettings(
+ enable_dnn_training=self.enable_dnn_training,
+ enable_onnx_compatible_models=self.enable_onnx_compatible_models,
+ enable_model_explainability=self.enable_model_explainability,
+ enable_stack_ensemble=self.enable_stack_ensemble,
+ enable_vote_ensemble=self.enable_vote_ensemble,
+ stack_ensemble_settings=self.stack_ensemble_settings,
+ ensemble_model_download_timeout=to_iso_duration_format_mins(self.ensemble_model_download_timeout),
+ allowed_training_algorithms=self.allowed_training_algorithms,
+ blocked_training_algorithms=self.blocked_training_algorithms,
+ training_mode=self.training_mode,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestRegressionTrainingSettings) -> "RegressionTrainingSettings":
+ return cls(
+ enable_dnn_training=obj.enable_dnn_training,
+ enable_onnx_compatible_models=obj.enable_onnx_compatible_models,
+ enable_model_explainability=obj.enable_model_explainability,
+ enable_stack_ensemble=obj.enable_stack_ensemble,
+ enable_vote_ensemble=obj.enable_vote_ensemble,
+ ensemble_model_download_timeout=from_iso_duration_format_mins(obj.ensemble_model_download_timeout),
+ stack_ensemble_settings=obj.stack_ensemble_settings,
+ allowed_training_algorithms=obj.allowed_training_algorithms,
+ blocked_training_algorithms=obj.blocked_training_algorithms,
+ training_mode=obj.training_mode,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/utils.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/utils.py
new file mode 100644
index 00000000..08521d7e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/utils.py
@@ -0,0 +1,47 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import TYPE_CHECKING, Dict, Type, Union
+
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
+
+if TYPE_CHECKING:
+ from azure.ai.ml.entities._job.automl.image.image_classification_search_space import ImageClassificationSearchSpace
+ from azure.ai.ml.entities._job.automl.image.image_object_detection_search_space import (
+ ImageObjectDetectionSearchSpace,
+ )
+ from azure.ai.ml.entities._job.automl.nlp.nlp_search_space import NlpSearchSpace
+ from azure.ai.ml.entities._job.automl.search_space import SearchSpace
+
+
+def cast_to_specific_search_space(
+ input: Union[Dict, "SearchSpace"], # pylint: disable=redefined-builtin
+ class_name: Union[
+ Type["ImageClassificationSearchSpace"], Type["ImageObjectDetectionSearchSpace"], Type["NlpSearchSpace"]
+ ],
+ task_type: str,
+) -> Union["ImageClassificationSearchSpace", "ImageObjectDetectionSearchSpace", "NlpSearchSpace"]:
+ def validate_searchspace_args(input_dict: dict) -> None:
+ searchspace = class_name()
+ for key in input_dict:
+ if not hasattr(searchspace, key):
+ msg = f"Received unsupported search space parameter for {task_type} Job."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.AUTOML,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ if isinstance(input, dict):
+ validate_searchspace_args(input)
+ specific_search_space = class_name(**input)
+ else:
+ validate_searchspace_args(input.__dict__)
+ specific_search_space = class_name._from_search_space_object(input) # pylint: disable=protected-access
+
+ res: Union["ImageClassificationSearchSpace", "ImageObjectDetectionSearchSpace", "NlpSearchSpace"] = (
+ specific_search_space
+ )
+ return res
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/base_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/base_job.py
new file mode 100644
index 00000000..72b464e5
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/base_job.py
@@ -0,0 +1,85 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import logging
+from typing import Any, Dict
+
+from azure.ai.ml._restclient.runhistory.models import Run
+from azure.ai.ml._schema.job import BaseJobSchema
+from azure.ai.ml.constants import JobType
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE
+from azure.ai.ml.entities._system_data import SystemData
+from azure.ai.ml.entities._util import load_from_dict
+
+from .job import Job
+
+module_logger = logging.getLogger(__name__)
+
+"""
+TODO[Joe]: This class is temporarily created to handle "Base" job type from the service.
+ We will be working on a more granular job type for pipeline child jobs in the future.
+ Spec Ref: https://github.com/Azure/azureml_run_specification/pull/340
+ MFE PR: https://msdata.visualstudio.com/DefaultCollection/Vienna/_workitems/edit/1167303/
+"""
+
+
+class _BaseJob(Job):
+ """Base Job, only used in pipeline child jobs.
+
+ :param name: Name of the resource.
+ :type name: str
+ :param description: Description of the resource.
+ :type description: str
+ :param tags: Tag dictionary. Tags can be added, removed, and updated.
+ :type tags: dict[str, str]
+ :param properties: The asset property dictionary.
+ :type properties: dict[str, str]
+ :param experiment_name: Name of the experiment the job will be created under,
+ if None is provided, default will be set to current directory name.
+ :type experiment_name: str
+ :param services: Information on services associated with the job, readonly.
+ :type services: dict[str, JobService]
+ :param compute: The compute target the job runs on.
+ :type compute: str
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: dict
+ """
+
+ def __init__(self, **kwargs: Any):
+ kwargs[TYPE] = JobType.BASE
+
+ super().__init__(**kwargs)
+
+ def _to_dict(self) -> Dict:
+ res: dict = BaseJobSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
+
+ @classmethod
+ def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "_BaseJob":
+ loaded_data = load_from_dict(BaseJobSchema, data, context, additional_message, **kwargs)
+ return _BaseJob(**loaded_data)
+
+ @classmethod
+ def _load_from_rest(cls, obj: Run) -> "_BaseJob":
+ creation_context = SystemData(
+ created_by=obj.created_by,
+ created_by_type=obj.created_from,
+ created_at=obj.created_utc,
+ last_modified_by=obj.last_modified_by,
+ last_modified_at=obj.last_modified_utc,
+ )
+ base_job = _BaseJob(
+ name=obj.run_id,
+ display_name=obj.display_name,
+ description=obj.description,
+ tags=obj.tags,
+ properties=obj.properties,
+ experiment_name=obj.experiment_id,
+ services=obj.services,
+ status=obj.status,
+ creation_context=creation_context,
+ compute=f"{obj.compute.target}" if obj.compute else None,
+ )
+
+ return base_job
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/command_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/command_job.py
new file mode 100644
index 00000000..0a0c7e82
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/command_job.py
@@ -0,0 +1,314 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import copy
+import logging
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Dict, Optional, Union
+
+from azure.ai.ml._restclient.v2025_01_01_preview.models import CommandJob as RestCommandJob
+from azure.ai.ml._restclient.v2025_01_01_preview.models import JobBase
+from azure.ai.ml._schema.job.command_job import CommandJobSchema
+from azure.ai.ml._utils.utils import map_single_brackets_and_warn
+from azure.ai.ml.constants import JobType
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, LOCAL_COMPUTE_PROPERTY, LOCAL_COMPUTE_TARGET, TYPE
+from azure.ai.ml.entities import Environment
+from azure.ai.ml.entities._credentials import (
+ AmlTokenConfiguration,
+ ManagedIdentityConfiguration,
+ UserIdentityConfiguration,
+ _BaseJobIdentityConfiguration,
+)
+from azure.ai.ml.entities._inputs_outputs import Input, Output
+from azure.ai.ml.entities._job._input_output_helpers import (
+ from_rest_data_outputs,
+ from_rest_inputs_to_dataset_literal,
+ to_rest_data_outputs,
+ to_rest_dataset_literal_inputs,
+ validate_inputs_for_command,
+)
+from azure.ai.ml.entities._job.distribution import DistributionConfiguration
+from azure.ai.ml.entities._job.job_service import (
+ JobService,
+ JobServiceBase,
+ JupyterLabJobService,
+ SshJobService,
+ TensorBoardJobService,
+ VsCodeJobService,
+)
+from azure.ai.ml.entities._system_data import SystemData
+from azure.ai.ml.entities._util import load_from_dict
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+
+from .job import Job
+from .job_io_mixin import JobIOMixin
+from .job_limits import CommandJobLimits
+from .job_resource_configuration import JobResourceConfiguration
+from .parameterized_command import ParameterizedCommand
+from .queue_settings import QueueSettings
+
+# avoid circular import error
+if TYPE_CHECKING:
+ from azure.ai.ml.entities import CommandComponent
+ from azure.ai.ml.entities._builders import Command
+
+module_logger = logging.getLogger(__name__)
+
+
+class CommandJob(Job, ParameterizedCommand, JobIOMixin):
+ """Command job.
+
+ .. note::
+ For sweep jobs, inputs, outputs, and parameters are accessible as environment variables using the prefix
+ ``AZUREML_PARAMETER_``. For example, if you have a parameter named "input_data", you can access it as
+ ``AZUREML_PARAMETER_input_data``.
+
+ :keyword services: Read-only information on services associated with the job.
+ :paramtype services: Optional[dict[str, ~azure.ai.ml.entities.JobService]]
+ :keyword inputs: Mapping of output data bindings used in the command.
+ :paramtype inputs: Optional[dict[str, Union[~azure.ai.ml.Input, str, bool, int, float]]]
+ :keyword outputs: Mapping of output data bindings used in the job.
+ :paramtype outputs: Optional[dict[str, ~azure.ai.ml.Output]]
+ :keyword identity: The identity that the job will use while running on compute.
+ :paramtype identity: Optional[Union[~azure.ai.ml.ManagedIdentityConfiguration, ~azure.ai.ml.AmlTokenConfiguration,
+ ~azure.ai.ml.UserIdentityConfiguration]]
+ :keyword limits: The limits for the job.
+ :paramtype limits: Optional[~azure.ai.ml.entities.CommandJobLimits]
+ :keyword parent_job_name: parent job id for command job
+ :paramtype parent_job_name: Optional[str]
+ :keyword kwargs: A dictionary of additional configuration parameters.
+ :paramtype kwargs: dict
+
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_command_configurations.py
+ :start-after: [START command_job_definition]
+ :end-before: [END command_job_definition]
+ :language: python
+ :dedent: 8
+ :caption: Configuring a CommandJob.
+ """
+
+ def __init__(
+ self,
+ *,
+ inputs: Optional[Dict[str, Union[Input, str, bool, int, float]]] = None,
+ outputs: Optional[Dict[str, Output]] = None,
+ limits: Optional[CommandJobLimits] = None,
+ identity: Optional[
+ Union[Dict, ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration]
+ ] = None,
+ services: Optional[
+ Dict[str, Union[JobService, JupyterLabJobService, SshJobService, TensorBoardJobService, VsCodeJobService]]
+ ] = None,
+ parent_job_name: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ kwargs[TYPE] = JobType.COMMAND
+ self._parameters: dict = kwargs.pop("parameters", {})
+ self.parent_job_name = parent_job_name
+
+ super().__init__(**kwargs)
+
+ self.outputs = outputs # type: ignore[assignment]
+ self.inputs = inputs # type: ignore[assignment]
+ self.limits = limits
+ self.identity = identity
+ self.services = services
+
+ @property
+ def parameters(self) -> Dict[str, str]:
+ """MLFlow parameters.
+
+ :return: MLFlow parameters logged in job.
+ :rtype: dict[str, str]
+ """
+ return self._parameters
+
+ def _to_dict(self) -> Dict:
+ res: dict = CommandJobSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
+
+ def _to_rest_object(self) -> JobBase:
+ self._validate()
+ self.command = map_single_brackets_and_warn(self.command)
+ modified_properties = copy.deepcopy(self.properties)
+ # Remove any properties set on the service as read-only
+ modified_properties.pop("_azureml.ComputeTargetType", None)
+ # Handle local compute case
+ compute = self.compute
+ resources = self.resources
+ if self.compute == LOCAL_COMPUTE_TARGET:
+ compute = None
+ if resources is None:
+ resources = JobResourceConfiguration()
+ if not isinstance(resources, Dict):
+ if resources.properties is None:
+ resources.properties = {}
+ # This is the format of the October Api response. We need to match it exactly
+ resources.properties[LOCAL_COMPUTE_PROPERTY] = {LOCAL_COMPUTE_PROPERTY: True}
+
+ properties = RestCommandJob(
+ display_name=self.display_name,
+ description=self.description,
+ command=self.command,
+ code_id=self.code,
+ compute_id=compute,
+ properties=modified_properties,
+ experiment_name=self.experiment_name,
+ inputs=to_rest_dataset_literal_inputs(self.inputs, job_type=self.type),
+ outputs=to_rest_data_outputs(self.outputs),
+ environment_id=self.environment,
+ distribution=(
+ self.distribution._to_rest_object()
+ if self.distribution and not isinstance(self.distribution, Dict)
+ else None
+ ),
+ tags=self.tags,
+ identity=(
+ self.identity._to_job_rest_object() if self.identity and not isinstance(self.identity, Dict) else None
+ ),
+ environment_variables=self.environment_variables,
+ resources=resources._to_rest_object() if resources and not isinstance(resources, Dict) else None,
+ limits=self.limits._to_rest_object() if self.limits else None,
+ services=JobServiceBase._to_rest_job_services(self.services),
+ queue_settings=self.queue_settings._to_rest_object() if self.queue_settings else None,
+ parent_job_name=self.parent_job_name,
+ )
+ result = JobBase(properties=properties)
+ result.name = self.name
+ return result
+
+ @classmethod
+ def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "CommandJob":
+ loaded_data = load_from_dict(CommandJobSchema, data, context, additional_message, **kwargs)
+ return CommandJob(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_data)
+
+ @classmethod
+ def _load_from_rest(cls, obj: JobBase) -> "CommandJob":
+ rest_command_job: RestCommandJob = obj.properties
+ command_job = CommandJob(
+ name=obj.name,
+ id=obj.id,
+ display_name=rest_command_job.display_name,
+ description=rest_command_job.description,
+ tags=rest_command_job.tags,
+ properties=rest_command_job.properties,
+ command=rest_command_job.command,
+ experiment_name=rest_command_job.experiment_name,
+ services=JobServiceBase._from_rest_job_services(rest_command_job.services),
+ status=rest_command_job.status,
+ creation_context=SystemData._from_rest_object(obj.system_data) if obj.system_data else None,
+ code=rest_command_job.code_id,
+ compute=rest_command_job.compute_id,
+ environment=rest_command_job.environment_id,
+ distribution=DistributionConfiguration._from_rest_object(rest_command_job.distribution),
+ parameters=rest_command_job.parameters,
+ # pylint: disable=protected-access
+ identity=(
+ _BaseJobIdentityConfiguration._from_rest_object(rest_command_job.identity)
+ if rest_command_job.identity
+ else None
+ ),
+ environment_variables=rest_command_job.environment_variables,
+ resources=JobResourceConfiguration._from_rest_object(rest_command_job.resources),
+ limits=CommandJobLimits._from_rest_object(rest_command_job.limits),
+ inputs=from_rest_inputs_to_dataset_literal(rest_command_job.inputs),
+ outputs=from_rest_data_outputs(rest_command_job.outputs),
+ queue_settings=QueueSettings._from_rest_object(rest_command_job.queue_settings),
+ parent_job_name=rest_command_job.parent_job_name,
+ )
+ # Handle special case of local job
+ if (
+ command_job.resources is not None
+ and not isinstance(command_job.resources, Dict)
+ and command_job.resources.properties is not None
+ and command_job.resources.properties.get(LOCAL_COMPUTE_PROPERTY, None)
+ ):
+ command_job.compute = LOCAL_COMPUTE_TARGET
+ command_job.resources.properties.pop(LOCAL_COMPUTE_PROPERTY)
+ return command_job
+
+ def _to_component(self, context: Optional[Dict] = None, **kwargs: Any) -> "CommandComponent":
+ """Translate a command job to component.
+
+ :param context: Context of command job YAML file.
+ :type context: dict
+ :return: Translated command component.
+ :rtype: CommandComponent
+ """
+ from azure.ai.ml.entities import CommandComponent
+
+ pipeline_job_dict = kwargs.get("pipeline_job_dict", {})
+ context = context or {BASE_PATH_CONTEXT_KEY: Path("./")}
+
+ # Create anonymous command component with default version as 1
+ return CommandComponent(
+ tags=self.tags,
+ is_anonymous=True,
+ base_path=context[BASE_PATH_CONTEXT_KEY],
+ code=self.code,
+ command=self.command,
+ environment=self.environment,
+ description=self.description,
+ inputs=self._to_inputs(inputs=self.inputs, pipeline_job_dict=pipeline_job_dict),
+ outputs=self._to_outputs(outputs=self.outputs, pipeline_job_dict=pipeline_job_dict),
+ resources=self.resources if self.resources else None,
+ distribution=self.distribution if self.distribution else None,
+ )
+
+ def _to_node(self, context: Optional[Dict] = None, **kwargs: Any) -> "Command":
+ """Translate a command job to a pipeline node.
+
+ :param context: Context of command job YAML file.
+ :type context: dict
+ :return: Translated command component.
+ :rtype: Command
+ """
+ from azure.ai.ml.entities._builders import Command
+
+ component = self._to_component(context, **kwargs)
+
+ return Command(
+ component=component,
+ compute=self.compute,
+ # Need to supply the inputs with double curly.
+ inputs=self.inputs, # type: ignore[arg-type]
+ outputs=self.outputs, # type: ignore[arg-type]
+ environment_variables=self.environment_variables,
+ description=self.description,
+ tags=self.tags,
+ display_name=self.display_name,
+ limits=self.limits,
+ services=self.services,
+ properties=self.properties,
+ identity=self.identity,
+ queue_settings=self.queue_settings,
+ )
+
+ def _validate(self) -> None:
+ if self.command is None:
+ msg = "command is required"
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.MISSING_FIELD,
+ )
+ if self.environment is None:
+ msg = "environment is required for non-local runs"
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.MISSING_FIELD,
+ )
+ if isinstance(self.environment, Environment):
+ self.environment.validate()
+ validate_inputs_for_command(self.command, self.inputs)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/compute_configuration.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/compute_configuration.py
new file mode 100644
index 00000000..dcc00825
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/compute_configuration.py
@@ -0,0 +1,110 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import json
+import logging
+from typing import Any, Dict, Optional
+
+from azure.ai.ml._restclient.v2020_09_01_dataplanepreview.models import ComputeConfiguration as RestComputeConfiguration
+from azure.ai.ml.constants._common import LOCAL_COMPUTE_TARGET
+from azure.ai.ml.constants._job.job import JobComputePropertyFields
+from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin
+
+module_logger = logging.getLogger(__name__)
+
+
+class ComputeConfiguration(RestTranslatableMixin, DictMixin):
+ """Compute resource configuration
+
+ :param target: The compute target.
+ :type target: Optional[str]
+ :param instance_count: The number of instances.
+ :type instance_count: Optional[int]
+ :param is_local: Specifies if the compute will be on the local machine.
+ :type is_local: Optional[bool]
+ :param location: The location of the compute resource.
+ :type location: Optional[str]
+ :param properties: The resource properties
+ :type properties: Optional[Dict[str, Any]]
+ :param deserialize_properties: Specifies if property bag should be deserialized. Defaults to False.
+ :type deserialize_properties: bool
+ """
+
+ def __init__(
+ self,
+ *,
+ target: Optional[str] = None,
+ instance_count: Optional[int] = None,
+ is_local: Optional[bool] = None,
+ instance_type: Optional[str] = None,
+ location: Optional[str] = None,
+ properties: Optional[Dict[str, Any]] = None,
+ deserialize_properties: bool = False,
+ ) -> None:
+ self.instance_count = instance_count
+ self.target = target or LOCAL_COMPUTE_TARGET
+ self.is_local = is_local or self.target == LOCAL_COMPUTE_TARGET
+ self.instance_type = instance_type
+ self.location = location
+ self.properties = properties
+ if deserialize_properties and properties and self.properties is not None:
+ for key, value in self.properties.items():
+ try:
+ self.properties[key] = json.loads(value)
+ except Exception: # pylint: disable=W0718
+ # keep serialized string if load fails
+ pass
+
+ def _to_rest_object(self) -> RestComputeConfiguration:
+ if self.properties:
+ serialized_properties = {}
+ for key, value in self.properties.items():
+ try:
+ if key.lower() == JobComputePropertyFields.SINGULARITY.lower():
+ # Map Singularity -> AISupercomputer in SDK until MFE does mapping
+ key = JobComputePropertyFields.AISUPERCOMPUTER
+ # Ensure keymatch is case invariant
+ elif key.lower() == JobComputePropertyFields.AISUPERCOMPUTER.lower():
+ key = JobComputePropertyFields.AISUPERCOMPUTER
+ serialized_properties[key] = json.dumps(value)
+ except Exception: # pylint: disable=W0718
+ pass
+ else:
+ serialized_properties = None
+ return RestComputeConfiguration(
+ target=self.target if not self.is_local else None,
+ is_local=self.is_local,
+ instance_count=self.instance_count,
+ instance_type=self.instance_type,
+ location=self.location,
+ properties=serialized_properties,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestComputeConfiguration) -> "ComputeConfiguration":
+ return ComputeConfiguration(
+ target=obj.target,
+ is_local=obj.is_local,
+ instance_count=obj.instance_count,
+ location=obj.location,
+ instance_type=obj.instance_type,
+ properties=obj.properties,
+ deserialize_properties=True,
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, ComputeConfiguration):
+ return NotImplemented
+ return (
+ self.instance_count == other.instance_count
+ and self.target == other.target
+ and self.is_local == other.is_local
+ and self.location == other.location
+ and self.instance_type == other.instance_type
+ )
+
+ def __ne__(self, other: object) -> bool:
+ if not isinstance(other, ComputeConfiguration):
+ return NotImplemented
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/data_transfer/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/data_transfer/__init__.py
new file mode 100644
index 00000000..fdf8caba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/data_transfer/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/data_transfer/data_transfer_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/data_transfer/data_transfer_job.py
new file mode 100644
index 00000000..b510da80
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/data_transfer/data_transfer_job.py
@@ -0,0 +1,358 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import logging
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Dict, Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import JobBase
+from azure.ai.ml._schema.job.data_transfer_job import (
+ DataTransferCopyJobSchema,
+ DataTransferExportJobSchema,
+ DataTransferImportJobSchema,
+)
+from azure.ai.ml.constants import JobType
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE
+from azure.ai.ml.constants._component import DataTransferBuiltinComponentUri, DataTransferTaskType, ExternalDataType
+from azure.ai.ml.entities._inputs_outputs import Input, Output
+from azure.ai.ml.entities._inputs_outputs.external_data import Database, FileSystem
+from azure.ai.ml.entities._util import load_from_dict
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+
+from ..job import Job
+from ..job_io_mixin import JobIOMixin
+
+# avoid circular import error
+if TYPE_CHECKING:
+ from azure.ai.ml.entities._builders import DataTransferCopy, DataTransferExport, DataTransferImport
+ from azure.ai.ml.entities._component.datatransfer_component import DataTransferCopyComponent
+
+module_logger = logging.getLogger(__name__)
+
+
+class DataTransferJob(Job, JobIOMixin):
+ """DataTransfer job.
+
+ :param name: Name of the job.
+ :type name: str
+ :param description: Description of the job.
+ :type description: str
+ :param tags: Tag dictionary. Tags can be added, removed, and updated.
+ :type tags: dict[str, str]
+ :param display_name: Display name of the job.
+ :type display_name: str
+ :param properties: The asset property dictionary.
+ :type properties: dict[str, str]
+ :param experiment_name: Name of the experiment the job will be created under.
+ If None is provided, default will be set to current directory name.
+ :type experiment_name: str
+ :param services: Information on services associated with the job, readonly.
+ :type services: dict[str, JobService]
+ :param inputs: Inputs to the command.
+ :type inputs: dict[str, Union[azure.ai.ml.Input, str, bool, int, float]]
+ :param outputs: Mapping of output data bindings used in the job.
+ :type outputs: dict[str, azure.ai.ml.Output]
+ :param compute: The compute target the job runs on.
+ :type compute: str
+ :param task: task type in data transfer component, possible value is "copy_data".
+ :type task: str
+ :param data_copy_mode: data copy mode in copy task, possible value is "merge_with_overwrite", "fail_if_conflict".
+ :type data_copy_mode: str
+ :keyword kwargs: A dictionary of additional configuration parameters.
+ :paramtype kwargs: dict
+ """
+
+ def __init__(
+ self,
+ task: str,
+ **kwargs: Any,
+ ):
+ kwargs[TYPE] = JobType.DATA_TRANSFER
+ self._parameters: Dict = kwargs.pop("parameters", {})
+ super().__init__(**kwargs)
+ self.task = task
+
+ @property
+ def parameters(self) -> Dict:
+ """MLFlow parameters.
+
+ :return: MLFlow parameters logged in job.
+ :rtype: Dict[str, str]
+ """
+ return self._parameters
+
+ def _validate(self) -> None:
+ if self.compute is None:
+ msg = "compute is required"
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.MISSING_FIELD,
+ )
+
+ @classmethod
+ def _load_from_rest(cls, obj: JobBase) -> "DataTransferJob":
+ # Todo: need update rest api
+ raise NotImplementedError("Not support submit standalone job for now")
+
+ def _to_rest_object(self) -> JobBase:
+ # Todo: need update rest api
+ raise NotImplementedError("Not support submit standalone job for now")
+
+ @classmethod
+ def _build_source_sink(
+ cls, io_dict: Optional[Union[Dict, Database, FileSystem]]
+ ) -> Optional[Union[(Database, FileSystem)]]:
+ if io_dict is None:
+ return io_dict
+ if isinstance(io_dict, (Database, FileSystem)):
+ component_io = io_dict
+ else:
+ if isinstance(io_dict, dict):
+ data_type = io_dict.pop("type", None)
+ if data_type == ExternalDataType.DATABASE:
+ component_io = Database(**io_dict)
+ elif data_type == ExternalDataType.FILE_SYSTEM:
+ component_io = FileSystem(**io_dict)
+ else:
+ msg = "Type in source or sink only support {} and {}, currently got {}."
+ raise ValidationException(
+ message=msg.format(
+ ExternalDataType.DATABASE,
+ ExternalDataType.FILE_SYSTEM,
+ data_type,
+ ),
+ no_personal_data_message=msg.format(
+ ExternalDataType.DATABASE,
+ ExternalDataType.FILE_SYSTEM,
+ "data_type",
+ ),
+ target=ErrorTarget.DATA_TRANSFER_JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ else:
+ msg = "Source or sink only support dict, Database and FileSystem"
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.DATA_TRANSFER_JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+ return component_io
+
+
+class DataTransferCopyJob(DataTransferJob):
+ def __init__(
+ self,
+ *,
+ inputs: Optional[Dict[str, Union[Input, str]]] = None,
+ outputs: Optional[Dict[str, Union[Output]]] = None,
+ data_copy_mode: Optional[str] = None,
+ **kwargs: Any,
+ ):
+ kwargs["task"] = DataTransferTaskType.COPY_DATA
+ super().__init__(**kwargs)
+
+ self.outputs = outputs # type: ignore[assignment]
+ self.inputs = inputs # type: ignore[assignment]
+ self.data_copy_mode = data_copy_mode
+
+ def _to_dict(self) -> Dict:
+ res: dict = DataTransferCopyJobSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
+
+ @classmethod
+ def _load_from_dict(
+ cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any
+ ) -> "DataTransferCopyJob":
+ loaded_data = load_from_dict(DataTransferCopyJobSchema, data, context, additional_message, **kwargs)
+ return DataTransferCopyJob(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_data)
+
+ def _to_component(self, context: Optional[Dict] = None, **kwargs: Any) -> "DataTransferCopyComponent":
+ """Translate a data transfer copy job to component.
+
+ :param context: Context of data transfer job YAML file.
+ :type context: dict
+ :return: Translated data transfer copy component.
+ :rtype: DataTransferCopyComponent
+ """
+ from azure.ai.ml.entities._component.datatransfer_component import DataTransferCopyComponent
+
+ pipeline_job_dict = kwargs.get("pipeline_job_dict", {})
+ context = context or {BASE_PATH_CONTEXT_KEY: Path("./")}
+
+ # Create anonymous command component with default version as 1
+ return DataTransferCopyComponent(
+ tags=self.tags,
+ is_anonymous=True,
+ base_path=context[BASE_PATH_CONTEXT_KEY],
+ description=self.description,
+ inputs=self._to_inputs(inputs=self.inputs, pipeline_job_dict=pipeline_job_dict),
+ outputs=self._to_outputs(outputs=self.outputs, pipeline_job_dict=pipeline_job_dict),
+ data_copy_mode=self.data_copy_mode,
+ )
+
+ def _to_node(self, context: Optional[Dict] = None, **kwargs: Any) -> "DataTransferCopy":
+ """Translate a data transfer copy job to a pipeline node.
+
+ :param context: Context of data transfer job YAML file.
+ :type context: dict
+ :return: Translated data transfer component.
+ :rtype: DataTransferCopy
+ """
+ from azure.ai.ml.entities._builders import DataTransferCopy
+
+ component = self._to_component(context, **kwargs)
+
+ return DataTransferCopy(
+ component=component,
+ compute=self.compute,
+ # Need to supply the inputs with double curly.
+ inputs=self.inputs, # type: ignore[arg-type]
+ outputs=self.outputs, # type: ignore[arg-type]
+ description=self.description,
+ tags=self.tags,
+ display_name=self.display_name,
+ )
+
+
+class DataTransferImportJob(DataTransferJob):
+ def __init__(
+ self,
+ *,
+ outputs: Optional[Dict[str, Union[Output]]] = None,
+ source: Optional[Union[Dict, Database, FileSystem]] = None,
+ **kwargs: Any,
+ ):
+ kwargs["task"] = DataTransferTaskType.IMPORT_DATA
+ super().__init__(**kwargs)
+
+ self.outputs = outputs # type: ignore[assignment]
+ self.source = self._build_source_sink(source)
+
+ def _to_dict(self) -> Dict:
+ res: dict = DataTransferImportJobSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
+
+ @classmethod
+ def _load_from_dict(
+ cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any
+ ) -> "DataTransferImportJob":
+ loaded_data = load_from_dict(DataTransferImportJobSchema, data, context, additional_message, **kwargs)
+ return DataTransferImportJob(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_data)
+
+ def _to_component(self, context: Optional[Dict] = None, **kwargs: Any) -> str:
+ """Translate a data transfer import job to component.
+
+ :param context: Context of data transfer job YAML file.
+ :type context: dict
+ :return: Translated data transfer import component.
+ :rtype: str
+ """
+
+ component: str = ""
+ if self.source is not None and self.source.type == ExternalDataType.DATABASE:
+ component = DataTransferBuiltinComponentUri.IMPORT_DATABASE
+ else:
+ component = DataTransferBuiltinComponentUri.IMPORT_FILE_SYSTEM
+
+ return component
+
+ def _to_node(self, context: Optional[Dict] = None, **kwargs: Any) -> "DataTransferImport":
+ """Translate a data transfer import job to a pipeline node.
+
+ :param context: Context of data transfer job YAML file.
+ :type context: dict
+ :return: Translated data transfer import node.
+ :rtype: DataTransferImport
+ """
+ from azure.ai.ml.entities._builders import DataTransferImport
+
+ component = self._to_component(context, **kwargs)
+
+ return DataTransferImport(
+ component=component,
+ compute=self.compute,
+ source=self.source,
+ outputs=self.outputs, # type: ignore[arg-type]
+ description=self.description,
+ tags=self.tags,
+ display_name=self.display_name,
+ properties=self.properties,
+ )
+
+
+class DataTransferExportJob(DataTransferJob):
+ def __init__(
+ self,
+ *,
+ inputs: Optional[Dict[str, Union[Input]]] = None,
+ sink: Optional[Union[Dict, Database, FileSystem]] = None,
+ **kwargs: Any,
+ ):
+ kwargs["task"] = DataTransferTaskType.EXPORT_DATA
+ super().__init__(**kwargs)
+
+ self.inputs = inputs # type: ignore[assignment]
+ self.sink = self._build_source_sink(sink)
+
+ def _to_dict(self) -> Dict:
+ res: dict = DataTransferExportJobSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
+
+ @classmethod
+ def _load_from_dict(
+ cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any
+ ) -> "DataTransferExportJob":
+ loaded_data = load_from_dict(DataTransferExportJobSchema, data, context, additional_message, **kwargs)
+ return DataTransferExportJob(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_data)
+
+ def _to_component(self, context: Optional[Dict] = None, **kwargs: Any) -> str:
+ """Translate a data transfer export job to component.
+
+ :param context: Context of data transfer job YAML file.
+ :type context: dict
+ :return: Translated data transfer export component.
+ :rtype: str
+ """
+ component: str = ""
+ if self.sink is not None and self.sink.type == ExternalDataType.DATABASE:
+ component = DataTransferBuiltinComponentUri.EXPORT_DATABASE
+ else:
+ msg = "Sink is a required field for export data task and we don't support exporting file system for now."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.DATA_TRANSFER_JOB,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ return component
+
+ def _to_node(self, context: Optional[Dict] = None, **kwargs: Any) -> "DataTransferExport":
+ """Translate a data transfer export job to a pipeline node.
+
+ :param context: Context of data transfer job YAML file.
+ :type context: dict
+ :return: Translated data transfer export node.
+ :rtype: DataTransferExport
+ """
+ from azure.ai.ml.entities._builders import DataTransferExport
+
+ component = self._to_component(context, **kwargs)
+
+ return DataTransferExport(
+ component=component,
+ compute=self.compute,
+ sink=self.sink,
+ inputs=self.inputs, # type: ignore[arg-type]
+ description=self.description,
+ tags=self.tags,
+ display_name=self.display_name,
+ properties=self.properties,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/__init__.py
new file mode 100644
index 00000000..fdf8caba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/constants.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/constants.py
new file mode 100644
index 00000000..5084ffbd
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/constants.py
@@ -0,0 +1,20 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+
+class AzureMLDistillationProperties:
+ ENABLE_DISTILLATION = "azureml.enable_distillation"
+ DATA_GENERATION_TYPE = "azureml.data_generation_type"
+ DATA_GENERATION_TASK_TYPE = "azureml.data_generation_task_type"
+ TEACHER_MODEL = "azureml.teacher_model"
+ INSTANCE_TYPE = "azureml.instance_type"
+ CONNECTION_INFORMATION = "azureml.connection_information"
+
+
+class EndpointSettings:
+ VALID_SETTINGS = {"request_batch_size", "min_endpoint_success_ratio"}
+
+
+class PromptSettingKeys:
+ VALID_SETTINGS = {"enable_chain_of_thought", "enable_chain_of_density", "max_len_summary"}
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/distillation_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/distillation_job.py
new file mode 100644
index 00000000..469fde98
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/distillation_job.py
@@ -0,0 +1,542 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import json
+from typing import Any, Dict, Optional
+
+from azure.ai.ml._restclient.v2024_01_01_preview.models import (
+ CustomModelFineTuning as RestCustomModelFineTuningVertical,
+)
+from azure.ai.ml._restclient.v2024_01_01_preview.models import FineTuningJob as RestFineTuningJob
+from azure.ai.ml._restclient.v2024_01_01_preview.models import JobBase as RestJobBase
+from azure.ai.ml._restclient.v2024_01_01_preview.models import MLFlowModelJobInput, UriFileJobInput
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants import DataGenerationType, JobType
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE, AssetTypes
+from azure.ai.ml.entities._inputs_outputs import Input
+from azure.ai.ml.entities._job._input_output_helpers import from_rest_data_outputs, to_rest_data_outputs
+from azure.ai.ml.entities._job.distillation.constants import (
+ AzureMLDistillationProperties,
+ EndpointSettings,
+ PromptSettingKeys,
+)
+from azure.ai.ml.entities._job.distillation.endpoint_request_settings import EndpointRequestSettings
+from azure.ai.ml.entities._job.distillation.prompt_settings import PromptSettings
+from azure.ai.ml.entities._job.distillation.teacher_model_settings import TeacherModelSettings
+from azure.ai.ml.entities._job.job import Job
+from azure.ai.ml.entities._job.job_io_mixin import JobIOMixin
+from azure.ai.ml.entities._job.resource_configuration import ResourceConfiguration
+from azure.ai.ml.entities._util import load_from_dict
+from azure.ai.ml.entities._workspace.connections.workspace_connection import WorkspaceConnection
+
+
+# pylint: disable=too-many-instance-attributes
+@experimental
+class DistillationJob(Job, JobIOMixin):
+ def __init__(
+ self,
+ *,
+ data_generation_type: str,
+ data_generation_task_type: str,
+ teacher_model_endpoint_connection: WorkspaceConnection,
+ student_model: Input,
+ training_data: Optional[Input] = None,
+ validation_data: Optional[Input] = None,
+ teacher_model_settings: Optional[TeacherModelSettings] = None,
+ prompt_settings: Optional[PromptSettings] = None,
+ hyperparameters: Optional[Dict] = None,
+ resources: Optional[ResourceConfiguration] = None,
+ **kwargs: Any,
+ ) -> None:
+ self._data_generation_type = data_generation_type
+ self._data_generation_task_type = data_generation_task_type
+ self._teacher_model_endpoint_connection = teacher_model_endpoint_connection
+ self._student_model = student_model
+ self._training_data = training_data
+ self._validation_data = validation_data
+ self._teacher_model_settings = teacher_model_settings
+ self._prompt_settings = prompt_settings
+ self._hyperparameters = hyperparameters
+ self._resources = resources
+
+ if self._training_data is None and self._data_generation_type == DataGenerationType.LABEL_GENERATION:
+ raise ValueError(
+ f"Training data can not be None when data generation type is set to "
+ f"{DataGenerationType.LABEL_GENERATION}."
+ )
+
+ if self._validation_data is None and self._data_generation_type == DataGenerationType.LABEL_GENERATION:
+ raise ValueError(
+ f"Validation data can not be None when data generation type is set to "
+ f"{DataGenerationType.LABEL_GENERATION}."
+ )
+
+ kwargs[TYPE] = JobType.DISTILLATION
+ self._outputs = kwargs.pop("outputs", None)
+ super().__init__(**kwargs)
+
+ @property
+ def data_generation_type(self) -> str:
+ """Get the type of synthetic data generation to perform.
+
+ :return: str representing the type of synthetic data generation to perform.
+ :rtype: str
+ """
+ return self._data_generation_type
+
+ @data_generation_type.setter
+ def data_generation_type(self, task: str) -> None:
+ """Set the data generation task.
+
+ :param task: The data generation task. Possible values include 'Label_Generation' and 'Data_Generation'.
+ :type task: str
+ """
+ self._data_generation_type = task
+
+ @property
+ def data_generation_task_type(self) -> str:
+ """Get the type of synthetic data to generate.
+
+ :return: str representing the type of synthetic data to generate.
+ :rtype: str
+ """
+ return self._data_generation_task_type
+
+ @data_generation_task_type.setter
+ def data_generation_task_type(self, task: str) -> None:
+ """Set the data generation type.
+
+ :param task: The data generation type. Possible values include 'nli', 'nlu_qa', 'conversational',
+ 'math', and 'summarization'.
+ :type task: str
+ """
+ self._data_generation_task_type = task
+
+ @property
+ def teacher_model_endpoint_connection(self) -> WorkspaceConnection:
+ """Get the endpoint connection of the teacher model to use for data generation.
+
+ :return: Endpoint connection
+ :rtype: WorkspaceConnection
+ """
+ return self._teacher_model_endpoint_connection
+
+ @teacher_model_endpoint_connection.setter
+ def teacher_model_endpoint_connection(self, connection: WorkspaceConnection) -> None:
+ """Set the endpoint information of the teacher model.
+
+ :param connection: Workspace connection
+ :type connection: WorkspaceConnection
+ """
+ self._teacher_model_endpoint_connection = connection
+
+ @property
+ def student_model(self) -> Input:
+ """Get the student model to be trained with synthetic data
+
+ :return: The student model to be finetuned
+ :rtype: Input
+ """
+ return self._student_model
+
+ @student_model.setter
+ def student_model(self, model: Input) -> None:
+ """Set the student model to be trained.
+
+ :param model: The model to use for finetuning
+ :type model: Input
+ """
+ self._student_model = model
+
+ @property
+ def training_data(self) -> Optional[Input]:
+ """Get the training data.
+
+ :return: Training data input
+ :rtype: typing.Optional[Input]
+ """
+ return self._training_data
+
+ @training_data.setter
+ def training_data(self, training_data: Optional[Input]) -> None:
+ """Set the training data.
+
+ :param training_data: Training data input
+ :type training_data: typing.Optional[Input]
+ """
+ self._training_data = training_data
+
+ @property
+ def validation_data(self) -> Optional[Input]:
+ """Get the validation data.
+
+ :return: Validation data input
+ :rtype: typing.Optional[Input]
+ """
+ return self._validation_data
+
+ @validation_data.setter
+ def validation_data(self, validation_data: Optional[Input]) -> None:
+ """Set the validation data.
+
+ :param validation_data: Validation data input
+ :type validation_data: typing.Optional[Input]
+ """
+ self._validation_data = validation_data
+
+ @property
+ def teacher_model_settings(self) -> Optional[TeacherModelSettings]:
+ """Get the teacher model settings.
+
+ :return: The settings for the teacher model to use.
+ :rtype: typing.Optional[TeacherModelSettings]
+ """
+ return self._teacher_model_settings
+
+ @property
+ def prompt_settings(self) -> Optional[PromptSettings]:
+ """Get the settings for the prompt.
+
+ :return: The settings for the prompt.
+ :rtype: typing.Optional[PromptSettings]
+ """
+ return self._prompt_settings
+
+ @property
+ def hyperparameters(self) -> Optional[Dict]:
+ """Get the finetuning hyperparameters.
+
+ :return: The finetuning hyperparameters.
+ :rtype: typing.Optional[typing.Dict]
+ """
+ return self._hyperparameters
+
+ @property
+ def resources(self) -> Optional[ResourceConfiguration]:
+ """Get the resources for data generation.
+
+ :return: The resources for data generation.
+ :rtype: typing.Optional[ResourceConfiguration]
+ """
+ return self._resources
+
+ @resources.setter
+ def resources(self, resource: Optional[ResourceConfiguration]) -> None:
+ """Set the resources for data generation.
+
+ :param resource: The resources for data generation.
+ :type resource: typing.Optional[ResourceConfiguration]
+ """
+ self._resources = resource
+
+ def set_teacher_model_settings(
+ self,
+ inference_parameters: Optional[Dict] = None,
+ endpoint_request_settings: Optional[EndpointRequestSettings] = None,
+ ):
+ """Set settings related to the teacher model.
+
+ :param inference_parameters: Settings the teacher model uses during inferencing.
+ :type inference_parameters: typing.Optional[typing.Dict]
+ :param endpoint_request_settings: Settings for inference requests to the endpoint
+ :type endpoint_request_settings: typing.Optional[EndpointRequestSettings]
+ """
+ self._teacher_model_settings = TeacherModelSettings(
+ inference_parameters=inference_parameters, endpoint_request_settings=endpoint_request_settings
+ )
+
+ def set_prompt_settings(self, prompt_settings: Optional[PromptSettings]):
+ """Set settings related to the system prompt used for generating data.
+
+ :param prompt_settings: Settings related to the system prompt used for generating data.
+ :type prompt_settings: typing.Optional[PromptSettings]
+ """
+ self._prompt_settings = prompt_settings if prompt_settings is not None else self._prompt_settings
+
+ def set_finetuning_settings(self, hyperparameters: Optional[Dict]):
+ """Set the hyperparamters for finetuning.
+
+ :param hyperparameters: The hyperparameters for finetuning.
+ :type hyperparameters: typing.Optional[typing.Dict]
+ """
+ self._hyperparameters = hyperparameters if hyperparameters is not None else self._hyperparameters
+
+ def _to_dict(self) -> Dict:
+ """Convert the object to a dictionary.
+
+ :return: dictionary representation of the object.
+ :rtype: typing.Dict
+ """
+ from azure.ai.ml._schema._distillation.distillation_job import DistillationJobSchema
+
+ schema_dict: dict = {}
+ schema_dict = DistillationJobSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+
+ return schema_dict
+
+ @classmethod
+ def _load_from_dict(
+ cls,
+ data: Dict,
+ context: Dict,
+ additional_message: str,
+ **kwargs: Any,
+ ) -> "DistillationJob":
+ """Load from a dictionary.
+
+ :param data: dictionary representation of the object.
+ :type data: typing.Dict
+ :param context: dictionary containing the context.
+ :type context: typing.Dict
+ :param additional_message: additional message to be added to the error message.
+ :type additional_message: str
+ :return: DistillationJob object.
+ :rtype: DistillationJob
+ """
+ from azure.ai.ml._schema._distillation.distillation_job import DistillationJobSchema
+
+ loaded_data = load_from_dict(DistillationJobSchema, data, context, additional_message, **kwargs)
+
+ training_data = loaded_data.get("training_data", None)
+ if isinstance(training_data, str):
+ loaded_data["training_data"] = Input(type="uri_file", path=training_data)
+
+ validation_data = loaded_data.get("validation_data", None)
+ if isinstance(validation_data, str):
+ loaded_data["validation_data"] = Input(type="uri_file", path=validation_data)
+
+ student_model = loaded_data.get("student_model", None)
+ if isinstance(student_model, str):
+ loaded_data["student_model"] = Input(type=AssetTypes.URI_FILE, path=student_model)
+
+ job_instance = DistillationJob(**loaded_data)
+ return job_instance
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestJobBase) -> "DistillationJob":
+ """Convert a REST object to DistillationJob object.
+
+ :param obj: CustomModelFineTuningJob in Rest format.
+ :type obj: JobBase
+ :return: DistillationJob objects.
+ :rtype: DistillationJob
+ """
+ properties: RestFineTuningJob = obj.properties
+ finetuning_details: RestCustomModelFineTuningVertical = properties.fine_tuning_details
+
+ job_kwargs_dict = DistillationJob._filter_properties(properties=properties.properties)
+
+ job_args_dict = {
+ "id": obj.id,
+ "name": obj.name,
+ "description": properties.description,
+ "tags": properties.tags,
+ "properties": properties.properties,
+ "experiment_name": properties.experiment_name,
+ "services": properties.services,
+ "status": properties.status,
+ "creation_context": obj.system_data,
+ "display_name": properties.display_name,
+ "outputs": from_rest_data_outputs(properties.outputs),
+ }
+
+ distillation_job = cls(
+ student_model=finetuning_details.model,
+ training_data=finetuning_details.training_data,
+ validation_data=finetuning_details.validation_data,
+ hyperparameters=finetuning_details.hyper_parameters,
+ **job_kwargs_dict,
+ **job_args_dict,
+ )
+
+ distillation_job._restore_inputs()
+
+ return distillation_job
+
+ def _to_rest_object(self) -> "RestFineTuningJob":
+ """Convert DistillationJob object to a RestFineTuningJob object.
+
+ :return: REST object representation of this object.
+ :rtype: JobBase
+ """
+ distillation = RestCustomModelFineTuningVertical(
+ task_type="ChatCompletion",
+ model=self.student_model,
+ model_provider="Custom",
+ training_data=self.training_data,
+ validation_data=self.validation_data,
+ hyper_parameters=self._hyperparameters,
+ )
+
+ if isinstance(distillation.training_data, Input):
+ distillation.training_data = UriFileJobInput(uri=distillation.training_data.path)
+ if isinstance(distillation.validation_data, Input):
+ distillation.validation_data = UriFileJobInput(uri=distillation.validation_data.path)
+ if isinstance(distillation.model, Input):
+ distillation.model = MLFlowModelJobInput(uri=distillation.model.path)
+
+ self._add_distillation_properties(self.properties)
+
+ finetuning_job = RestFineTuningJob(
+ display_name=self.display_name,
+ description=self.description,
+ experiment_name=self.experiment_name,
+ services=self.services,
+ tags=self.tags,
+ properties=self.properties,
+ fine_tuning_details=distillation,
+ outputs=to_rest_data_outputs(self.outputs),
+ )
+
+ result = RestJobBase(properties=finetuning_job)
+ result.name = self.name
+
+ return result
+
+ @classmethod
+ def _load_from_rest(cls, obj: RestJobBase) -> "DistillationJob":
+ """Loads the rest object to a dict containing items to init the AutoMLJob objects.
+
+ :param obj: Azure Resource Manager resource envelope.
+ :type obj: JobBase
+ :raises ValidationException: task type validation error
+ :return: A DistillationJob
+ :rtype: DistillationJob
+ """
+ return DistillationJob._from_rest_object(obj)
+
+ # TODO: Remove once Distillation is added to MFE
+ def _add_distillation_properties(self, properties: Dict) -> None:
+ """Adds DistillationJob attributes to properties to pass into the FT Overloaded API property bag
+
+ :param properties: Current distillation properties
+ :type properties: typing.Dict
+ """
+ properties[AzureMLDistillationProperties.ENABLE_DISTILLATION] = True
+ properties[AzureMLDistillationProperties.DATA_GENERATION_TASK_TYPE] = self._data_generation_task_type.upper()
+ properties[f"{AzureMLDistillationProperties.TEACHER_MODEL}.endpoint_name"] = (
+ self._teacher_model_endpoint_connection.name
+ )
+
+ # Not needed for FT Overload API but additional info needed to convert from REST object to Distillation object
+ properties[AzureMLDistillationProperties.DATA_GENERATION_TYPE] = self._data_generation_type
+ properties[AzureMLDistillationProperties.CONNECTION_INFORMATION] = json.dumps(
+ self._teacher_model_endpoint_connection._to_dict() # pylint: disable=protected-access
+ )
+
+ if self._prompt_settings:
+ for setting, value in self._prompt_settings.items():
+ if value is not None:
+ properties[f"azureml.{setting.strip('_')}"] = value
+
+ if self._teacher_model_settings:
+ inference_settings = self._teacher_model_settings.inference_parameters
+ endpoint_settings = self._teacher_model_settings.endpoint_request_settings
+
+ if inference_settings:
+ for inference_key, value in inference_settings.items():
+ if value is not None:
+ properties[f"{AzureMLDistillationProperties.TEACHER_MODEL}.{inference_key}"] = value
+
+ if endpoint_settings:
+ for setting, value in endpoint_settings.items():
+ if value is not None:
+ properties[f"azureml.{setting.strip('_')}"] = value
+
+ if self._resources and self._resources.instance_type:
+ properties[f"{AzureMLDistillationProperties.INSTANCE_TYPE}.data_generation"] = self._resources.instance_type
+
+ # TODO: Remove once Distillation is added to MFE
+ @classmethod
+ def _filter_properties(cls, properties: Dict) -> Dict:
+ """Convert properties from REST object back to their original states.
+
+ :param properties: Properties from a REST object
+ :type properties: typing.Dict
+ :return: A dict that can be used to create a DistillationJob
+ :rtype: typing.Dict
+ """
+ inference_parameters = {}
+ endpoint_settings = {}
+ prompt_settings = {}
+ resources = {}
+ teacher_settings = {}
+ teacher_model_info = ""
+ for key, val in properties.items():
+ param = key.split(".")[-1]
+ if AzureMLDistillationProperties.TEACHER_MODEL in key and param != "endpoint_name":
+ inference_parameters[param] = val
+ elif AzureMLDistillationProperties.INSTANCE_TYPE in key:
+ resources[key.split(".")[1]] = val
+ elif AzureMLDistillationProperties.CONNECTION_INFORMATION in key:
+ teacher_model_info = val
+ else:
+ if param in EndpointSettings.VALID_SETTINGS:
+ endpoint_settings[param] = val
+ elif param in PromptSettingKeys.VALID_SETTINGS:
+ prompt_settings[param] = val
+
+ if inference_parameters:
+ teacher_settings["inference_parameters"] = inference_parameters
+ if endpoint_settings:
+ teacher_settings["endpoint_request_settings"] = EndpointRequestSettings(**endpoint_settings) # type: ignore
+
+ return {
+ "data_generation_task_type": properties.get(AzureMLDistillationProperties.DATA_GENERATION_TASK_TYPE),
+ "data_generation_type": properties.get(AzureMLDistillationProperties.DATA_GENERATION_TYPE),
+ "teacher_model_endpoint_connection": WorkspaceConnection._load( # pylint: disable=protected-access
+ data=json.loads(teacher_model_info)
+ ),
+ "teacher_model_settings": (
+ TeacherModelSettings(**teacher_settings) if teacher_settings else None # type: ignore
+ ),
+ "prompt_settings": PromptSettings(**prompt_settings) if prompt_settings else None,
+ "resources": ResourceConfiguration(**resources) if resources else None,
+ }
+
+ def _restore_inputs(self) -> None:
+ """Restore UriFileJobInputs to JobInputs within data_settings."""
+ if isinstance(self.training_data, UriFileJobInput):
+ self.training_data = Input(type=AssetTypes.URI_FILE, path=self.training_data.uri)
+ if isinstance(self.validation_data, UriFileJobInput):
+ self.validation_data = Input(type=AssetTypes.URI_FILE, path=self.validation_data.uri)
+ if isinstance(self.student_model, MLFlowModelJobInput):
+ self.student_model = Input(type=AssetTypes.MLFLOW_MODEL, path=self.student_model.uri)
+
+ def __eq__(self, other: object) -> bool:
+ """Returns True if both instances have the same values.
+
+ This method check instances equality and returns True if both of
+ the instances have the same attributes with the same values.
+
+ :param other: Any object
+ :type other: object
+ :return: True or False
+ :rtype: bool
+ """
+ if not isinstance(other, DistillationJob):
+ return False
+ return (
+ super().__eq__(other)
+ and self.data_generation_type == other.data_generation_type
+ and self.data_generation_task_type == other.data_generation_task_type
+ and self.teacher_model_endpoint_connection.name == other.teacher_model_endpoint_connection.name
+ and self.student_model == other.student_model
+ and self.training_data == other.training_data
+ and self.validation_data == other.validation_data
+ and self.teacher_model_settings == other.teacher_model_settings
+ and self.prompt_settings == other.prompt_settings
+ and self.hyperparameters == other.hyperparameters
+ and self.resources == other.resources
+ )
+
+ def __ne__(self, other: object) -> bool:
+ """Check inequality between two DistillationJob objects.
+
+ :param other: Any object
+ :type other: object
+ :return: True or False
+ :rtype: bool
+ """
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/endpoint_request_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/endpoint_request_settings.py
new file mode 100644
index 00000000..89fb8015
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/endpoint_request_settings.py
@@ -0,0 +1,90 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from typing import Optional
+
+from azure.ai.ml._utils._experimental import experimental
+
+
+@experimental
+class EndpointRequestSettings:
+ def __init__(self, *, request_batch_size: Optional[int] = None, min_endpoint_success_ratio: Optional[float] = None):
+ """Initialize EndpointRequestSettings.
+
+ :param request_batch_size: The number of requests to send to the teacher model endpoint as a batch,
+ defaults to None
+ :type request_batch_size: typing.Optional[int], optional
+ :param min_endpoint_success_ratio: The ratio of (successful requests / total requests) needed for the
+ data generation step to be considered successful. Must be a value between 0 and 1 inclusive,
+ defaults to None
+ :type min_endpoint_success_ratio: typing.Optional[float], optional
+ """
+ self._request_batch_size = request_batch_size
+ self._min_endpoint_success_ratio = min_endpoint_success_ratio
+
+ @property
+ def request_batch_size(self) -> Optional[int]:
+ """Get the number of inference requests to send to the teacher model as a batch.
+
+ :return: The number of inference requests to send to the teacher model as a batch.
+ :rtype: typing.Optional[int]
+ """
+ return self._request_batch_size
+
+ @request_batch_size.setter
+ def request_batch_size(self, value: Optional[int]) -> None:
+ """Set the number of inference requests to send to the teacher model as a batch.
+
+ :param value: The number of inference requests to send to the teacher model as a batch.
+ :type value: typing.Optional[int]
+ """
+ self._request_batch_size = value
+
+ @property
+ def min_endpoint_success_ratio(self) -> Optional[float]:
+ """Get the minimum ratio of successful inferencing requests.
+
+ :return: The minimum ratio of successful inferencing requests.
+ :rtype: typing.Optional[float]
+ """
+ return self._min_endpoint_success_ratio
+
+ @min_endpoint_success_ratio.setter
+ def min_endpoint_success_ratio(self, ratio: Optional[float]) -> None:
+ """Set the minimum ratio of successful inferencing requests.
+
+ :param ratio: The minimum ratio of successful inferencing requests.
+ :type ratio: typing.Optional[float]
+ """
+ self._min_endpoint_success_ratio = ratio
+
+ def items(self):
+ return self.__dict__.items()
+
+ def __eq__(self, other: object) -> bool:
+ """Returns True if both instances have the same values.
+
+ This method check instances equality and returns True if both of
+ the instances have the same attributes with the same values.
+
+ :param other: Any object
+ :type other: object
+ :return: True or False
+ :rtype: bool
+ """
+ if not isinstance(other, EndpointRequestSettings):
+ return False
+ return (
+ self.request_batch_size == other.request_batch_size
+ and self.min_endpoint_success_ratio == other.min_endpoint_success_ratio
+ )
+
+ def __ne__(self, other: object) -> bool:
+ """Check inequality between two EndpointRequestSettings objects.
+
+ :param other: Any object
+ :type other: object
+ :return: True or False
+ :rtype: bool
+ """
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/prompt_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/prompt_settings.py
new file mode 100644
index 00000000..d74af748
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/prompt_settings.py
@@ -0,0 +1,138 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from typing import Optional
+
+from azure.ai.ml._utils._experimental import experimental
+
+
+@experimental
+class PromptSettings:
+ def __init__(
+ self,
+ *,
+ enable_chain_of_thought: bool = False,
+ enable_chain_of_density: bool = False,
+ max_len_summary: Optional[int] = None,
+ # custom_prompt: Optional[str] = None
+ ):
+ """Initialize PromptSettings.
+
+ :param enable_chain_of_thought: Whether or not to enable chain of thought which modifies the system prompt
+ used. Can be used for all `data_generation_task_type` values except `SUMMARIZATION`, defaults to False
+ :type enable_chain_of_thought: bool, optional
+ :param enable_chain_of_density: Whether or not to enable chain of density which modifies the system prompt
+ used. Can only be used for `data_generation_task_type` of `SUMMARIZATION`, defaults to False
+ :type enable_chain_of_density: bool, optional
+ :param max_len_summary: The maximum length of the summary generated for data_generation_task_type` of
+ `SUMMARIZATION`, defaults to None
+ :type max_len_summary: typing.Optional[int]
+ """
+ self._enable_chain_of_thought = enable_chain_of_thought
+ self._enable_chain_of_density = enable_chain_of_density
+ self._max_len_summary = max_len_summary
+ # self._custom_prompt = custom_prompt
+
+ @property
+ def enable_chain_of_thought(self) -> bool:
+ """Get whether or not chain of thought is enabled.
+
+ :return: Whether or not chain of thought is enabled.
+ :rtype: bool
+ """
+ return self._enable_chain_of_thought
+
+ @enable_chain_of_thought.setter
+ def enable_chain_of_thought(self, value: bool) -> None:
+ """Set chain of thought.
+
+ :param value: Whether or not chain of thought is enabled.
+ :type value: bool
+ """
+ self._enable_chain_of_thought = value
+
+ @property
+ def enable_chain_of_density(self) -> bool:
+ """Get whether or not chain of density is enabled.
+
+ :return: Whether or not chain of thought is enabled
+ :rtype: bool
+ """
+ return self._enable_chain_of_density
+
+ @enable_chain_of_density.setter
+ def enable_chain_of_density(self, value: bool) -> None:
+ """Set whether or not chain of thought is enabled.
+
+ :param value: Whether or not chain of thought is enabled
+ :type value: bool
+ """
+ self._enable_chain_of_density = value
+
+ @property
+ def max_len_summary(self) -> Optional[int]:
+ """The number of tokens to use for summarization.
+
+ :return: The number of tokens to use for summarization
+ :rtype: typing.Optional[int]
+ """
+ return self._max_len_summary
+
+ @max_len_summary.setter
+ def max_len_summary(self, length: Optional[int]) -> None:
+ """Set the number of tokens to use for summarization.
+
+ :param length: The number of tokens to use for summarization.
+ :type length: typing.Optional[int]
+ """
+ self._max_len_summary = length
+
+ # @property
+ # def custom_prompt(self) -> Optional[str]:
+ # """Get the custom system prompt to use for inferencing.
+ # :return: The custom prompt to use for inferencing.
+ # :rtype: Optional[str]
+ # """
+ # return self._custom_prompt
+
+ # @custom_prompt.setter
+ # def custom_prompt(self, prompt: Optional[str]) -> None:
+ # """Set the custom prompt to use for inferencing.
+
+ # :param prompt: The custom prompt to use for inferencing.
+ # :type prompt: Optional[str]
+ # """
+ # self._custom_prompt = prompt
+
+ def items(self):
+ return self.__dict__.items()
+
+ def __eq__(self, other: object) -> bool:
+ """Returns True if both instances have the same values.
+
+ This method check instances equality and returns True if both of
+ the instances have the same attributes with the same values.
+
+ :param other: Any object
+ :type other: object
+ :return: True or False
+ :rtype: bool
+ """
+ if not isinstance(other, PromptSettings):
+ return False
+ return (
+ self.enable_chain_of_thought == other.enable_chain_of_thought
+ and self.enable_chain_of_density == other.enable_chain_of_density
+ and self.max_len_summary == other.max_len_summary
+ # self.custom_prompt == other.custom_prompt
+ )
+
+ def __ne__(self, other: object) -> bool:
+ """Check inequality between two PromptSettings objects.
+
+ :param other: Any object
+ :type other: object
+ :return: True or False
+ :rtype: bool
+ """
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/teacher_model_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/teacher_model_settings.py
new file mode 100644
index 00000000..481800de
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distillation/teacher_model_settings.py
@@ -0,0 +1,93 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from typing import Dict, Optional
+
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.entities._job.distillation.endpoint_request_settings import EndpointRequestSettings
+
+
+@experimental
+class TeacherModelSettings:
+ def __init__(
+ self,
+ *,
+ inference_parameters: Optional[Dict] = None,
+ endpoint_request_settings: Optional[EndpointRequestSettings] = None,
+ ):
+ """Initialize TeacherModelSettings
+
+ :param inference_parameters: The inference parameters inferencing requests will use, defaults to None
+ :type inference_parameters: typing.Optional[typing.Dict], optional
+ :param endpoint_request_settings: The settings to use for the endpoint, defaults to None
+ :type endpoint_request_settings: typing.Optional[EndpointRequestSettings], optional
+ """
+ self._inference_parameters = inference_parameters
+ self._endpoint_request_settings = endpoint_request_settings
+
+ @property
+ def inference_parameters(self) -> Optional[Dict]:
+ """Get the inference parameters.
+
+ :return: The inference parameters.
+ :rtype: typing.Optional[typing.Dict]
+ """
+ return self._inference_parameters
+
+ @inference_parameters.setter
+ def inference_parameters(self, params: Optional[Dict]) -> None:
+ """Set the inference parameters.
+
+ :param params: Inference parameters.
+ :type params: typing.Optional[typing.Dict]
+ """
+ self._inference_parameters = params
+
+ @property
+ def endpoint_request_settings(self) -> Optional[EndpointRequestSettings]:
+ """Get the endpoint request settings.
+
+ :return: The endpoint request settings.
+ :rtype: typing.Optional[EndpointRequestSettings]
+ """
+ return self._endpoint_request_settings
+
+ @endpoint_request_settings.setter
+ def endpoint_request_settings(self, endpoint_settings: Optional[EndpointRequestSettings]) -> None:
+ """Set the endpoint request settings.
+
+ :param endpoint_settings: Endpoint request settings
+ :type endpoint_settings: typing.Optional[EndpointRequestSettings]
+ """
+ self._endpoint_request_settings = endpoint_settings
+
+ def items(self):
+ return self.__dict__.items()
+
+ def __eq__(self, other: object) -> bool:
+ """Returns True if both instances have the same values.
+
+ This method check instances equality and returns True if both of
+ the instances have the same attributes with the same values.
+
+ :param other: Any object
+ :type other: object
+ :return: True or False
+ :rtype: bool
+ """
+ if not isinstance(other, TeacherModelSettings):
+ return False
+ return (
+ self.inference_parameters == other.inference_parameters
+ and self.endpoint_request_settings == other.endpoint_request_settings
+ )
+
+ def __ne__(self, other: object) -> bool:
+ """Check inequality between two TeacherModelSettings objects.
+
+ :param other: Any object
+ :type other: object
+ :return: True or False
+ :rtype: bool
+ """
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distribution.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distribution.py
new file mode 100644
index 00000000..ec7277c6
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/distribution.py
@@ -0,0 +1,229 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from typing import Any, Dict, Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ DistributionConfiguration as RestDistributionConfiguration,
+)
+from azure.ai.ml._restclient.v2023_04_01_preview.models import DistributionType as RestDistributionType
+from azure.ai.ml._restclient.v2023_04_01_preview.models import Mpi as RestMpi
+from azure.ai.ml._restclient.v2023_04_01_preview.models import PyTorch as RestPyTorch
+from azure.ai.ml._restclient.v2023_04_01_preview.models import Ray as RestRay
+from azure.ai.ml._restclient.v2023_04_01_preview.models import TensorFlow as RestTensorFlow
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants import DistributionType
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+SDK_TO_REST = {
+ DistributionType.MPI: RestDistributionType.MPI,
+ DistributionType.TENSORFLOW: RestDistributionType.TENSOR_FLOW,
+ DistributionType.PYTORCH: RestDistributionType.PY_TORCH,
+ DistributionType.RAY: RestDistributionType.RAY,
+}
+
+
+class DistributionConfiguration(RestTranslatableMixin):
+ """Distribution configuration for a component or job.
+
+ This class is not meant to be instantiated directly. Instead, use one of its subclasses.
+ """
+
+ def __init__(self, **kwargs: Any) -> None:
+ self.type: Any = None
+
+ @classmethod
+ def _from_rest_object(
+ cls, obj: Optional[Union[RestDistributionConfiguration, Dict]]
+ ) -> Optional["DistributionConfiguration"]:
+ """Constructs a DistributionConfiguration object from a REST object
+
+ This function works for distribution property of a Job object and of a Component object()
+
+ Distribution of Job when returned by MFE, is a RestDistributionConfiguration
+
+ Distribution of Component when returned by MFE, is a Dict.
+ e.g. {'type': 'Mpi', 'process_count_per_instance': '1'}
+
+ So in the job distribution case, we need to call as_dict() first and get type from "distribution_type" property.
+ In the componenet case, we need to extract type from key "type"
+
+
+ :param obj: The object to translate
+ :type obj: Optional[Union[RestDistributionConfiguration, Dict]]
+ :return: The distribution configuration
+ :rtype: DistributionConfiguration
+ """
+ if obj is None:
+ return None
+
+ if isinstance(obj, dict):
+ data = obj
+ else:
+ data = obj.as_dict()
+
+ type_str = data.pop("distribution_type", None) or data.pop("type", None)
+ klass = DISTRIBUTION_TYPE_MAP[type_str.lower()]
+ res: DistributionConfiguration = klass(**data)
+ return res
+
+ def __eq__(self, other: Any) -> bool:
+ if not isinstance(other, DistributionConfiguration):
+ return NotImplemented
+ res: bool = self._to_rest_object() == other._to_rest_object()
+ return res
+
+
+class MpiDistribution(DistributionConfiguration):
+ """MPI distribution configuration.
+
+ :keyword process_count_per_instance: The number of processes per node.
+ :paramtype process_count_per_instance: Optional[int]
+ :ivar type: Specifies the type of distribution. Set automatically to "mpi" for this class.
+ :vartype type: str
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START mpi_distribution_configuration]
+ :end-before: [END mpi_distribution_configuration]
+ :language: python
+ :dedent: 8
+ :caption: Configuring a CommandComponent with an MpiDistribution.
+ """
+
+ def __init__(self, *, process_count_per_instance: Optional[int] = None, **kwargs: Any) -> None:
+ super().__init__(**kwargs)
+ self.type = DistributionType.MPI
+ self.process_count_per_instance = process_count_per_instance
+
+ def _to_rest_object(self) -> RestMpi:
+ return RestMpi(process_count_per_instance=self.process_count_per_instance)
+
+
+class PyTorchDistribution(DistributionConfiguration):
+ """PyTorch distribution configuration.
+
+ :keyword process_count_per_instance: The number of processes per node.
+ :paramtype process_count_per_instance: Optional[int]
+ :ivar type: Specifies the type of distribution. Set automatically to "pytorch" for this class.
+ :vartype type: str
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START pytorch_distribution_configuration]
+ :end-before: [END pytorch_distribution_configuration]
+ :language: python
+ :dedent: 8
+ :caption: Configuring a CommandComponent with a PyTorchDistribution.
+ """
+
+ def __init__(self, *, process_count_per_instance: Optional[int] = None, **kwargs: Any) -> None:
+ super().__init__(**kwargs)
+ self.type = DistributionType.PYTORCH
+ self.process_count_per_instance = process_count_per_instance
+
+ def _to_rest_object(self) -> RestPyTorch:
+ return RestPyTorch(process_count_per_instance=self.process_count_per_instance)
+
+
+class TensorFlowDistribution(DistributionConfiguration):
+ """TensorFlow distribution configuration.
+
+ :vartype distribution_type: str or ~azure.mgmt.machinelearningservices.models.DistributionType
+ :keyword parameter_server_count: The number of parameter server tasks. Defaults to 0.
+ :paramtype parameter_server_count: Optional[int]
+ :keyword worker_count: The number of workers. Defaults to the instance count.
+ :paramtype worker_count: Optional[int]
+ :ivar parameter_server_count: Number of parameter server tasks.
+ :vartype parameter_server_count: int
+ :ivar worker_count: Number of workers. If not specified, will default to the instance count.
+ :vartype worker_count: int
+ :ivar type: Specifies the type of distribution. Set automatically to "tensorflow" for this class.
+ :vartype type: str
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START tensorflow_distribution_configuration]
+ :end-before: [END tensorflow_distribution_configuration]
+ :language: python
+ :dedent: 8
+ :caption: Configuring a CommandComponent with a TensorFlowDistribution.
+ """
+
+ def __init__(
+ self, *, parameter_server_count: Optional[int] = 0, worker_count: Optional[int] = None, **kwargs: Any
+ ) -> None:
+ super().__init__(**kwargs)
+ self.type = DistributionType.TENSORFLOW
+ self.parameter_server_count = parameter_server_count
+ self.worker_count = worker_count
+
+ def _to_rest_object(self) -> RestTensorFlow:
+ return RestTensorFlow(parameter_server_count=self.parameter_server_count, worker_count=self.worker_count)
+
+
+@experimental
+class RayDistribution(DistributionConfiguration):
+ """Ray distribution configuration.
+
+ :vartype distribution_type: str or ~azure.mgmt.machinelearningservices.models.DistributionType
+ :ivar port: The port of the head ray process.
+ :vartype port: int
+ :ivar address: The address of Ray head node.
+ :vartype address: str
+ :ivar include_dashboard: Provide this argument to start the Ray dashboard GUI.
+ :vartype include_dashboard: bool
+ :ivar dashboard_port: The port to bind the dashboard server to.
+ :vartype dashboard_port: int
+ :ivar head_node_additional_args: Additional arguments passed to ray start in head node.
+ :vartype head_node_additional_args: str
+ :ivar worker_node_additional_args: Additional arguments passed to ray start in worker node.
+ :vartype worker_node_additional_args: str
+ :ivar type: Specifies the type of distribution. Set automatically to "Ray" for this class.
+ :vartype type: str
+ """
+
+ def __init__(
+ self,
+ *,
+ port: Optional[int] = None,
+ address: Optional[str] = None,
+ include_dashboard: Optional[bool] = None,
+ dashboard_port: Optional[int] = None,
+ head_node_additional_args: Optional[str] = None,
+ worker_node_additional_args: Optional[str] = None,
+ **kwargs: Any
+ ):
+ super().__init__(**kwargs)
+ self.type = DistributionType.RAY
+
+ self.port = port
+ self.address = address
+ self.include_dashboard = include_dashboard
+ self.dashboard_port = dashboard_port
+ self.head_node_additional_args = head_node_additional_args
+ self.worker_node_additional_args = worker_node_additional_args
+
+ def _to_rest_object(self) -> RestRay:
+ return RestRay(
+ port=self.port,
+ address=self.address,
+ include_dashboard=self.include_dashboard,
+ dashboard_port=self.dashboard_port,
+ head_node_additional_args=self.head_node_additional_args,
+ worker_node_additional_args=self.worker_node_additional_args,
+ )
+
+
+DISTRIBUTION_TYPE_MAP = {
+ DistributionType.MPI: MpiDistribution,
+ DistributionType.TENSORFLOW: TensorFlowDistribution,
+ DistributionType.PYTORCH: PyTorchDistribution,
+ DistributionType.RAY: RayDistribution,
+}
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/__init__.py
new file mode 100644
index 00000000..fdf8caba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/azure_openai_finetuning_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/azure_openai_finetuning_job.py
new file mode 100644
index 00000000..e659c634
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/azure_openai_finetuning_job.py
@@ -0,0 +1,242 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from typing import Any, Dict
+
+from azure.ai.ml._restclient.v2024_01_01_preview.models import (
+ ModelProvider as RestModelProvider,
+ AzureOpenAiFineTuning as RestAzureOpenAIFineTuning,
+ FineTuningJob as RestFineTuningJob,
+ JobBase as RestJobBase,
+)
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+from azure.ai.ml.entities._job._input_output_helpers import from_rest_data_outputs, to_rest_data_outputs
+
+from azure.ai.ml.entities._job.finetuning.finetuning_vertical import FineTuningVertical
+from azure.ai.ml.entities._job.finetuning.azure_openai_hyperparameters import AzureOpenAIHyperparameters
+from azure.ai.ml.entities._util import load_from_dict
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
+from azure.ai.ml._utils._experimental import experimental
+
+
+@experimental
+class AzureOpenAIFineTuningJob(FineTuningVertical):
+ def __init__(
+ self,
+ **kwargs: Any,
+ ) -> None:
+ # Extract any task specific settings
+ model = kwargs.pop("model", None)
+ task = kwargs.pop("task", None)
+ # Convert task to lowercase first letter, this is when we create
+ # object from the schema, using dict object from the REST api response.
+ # TextCompletion => textCompletion
+ if task:
+ task = task[0].lower() + task[1:]
+ training_data = kwargs.pop("training_data", None)
+ validation_data = kwargs.pop("validation_data", None)
+ hyperparameters = kwargs.pop("hyperparameters", None)
+ if hyperparameters and not isinstance(hyperparameters, AzureOpenAIHyperparameters):
+ raise ValidationException(
+ category=ErrorCategory.USER_ERROR,
+ target=ErrorTarget.JOB,
+ message="Hyperparameters if provided should of type AzureOpenAIHyperparameters",
+ no_personal_data_message="Hyperparameters if provided should of type AzureOpenAIHyperparameters",
+ )
+
+ self._hyperparameters = hyperparameters
+
+ super().__init__(
+ task=task,
+ model=model,
+ model_provider=RestModelProvider.AZURE_OPEN_AI,
+ training_data=training_data,
+ validation_data=validation_data,
+ **kwargs,
+ )
+
+ @property
+ def hyperparameters(self) -> AzureOpenAIHyperparameters:
+ """Get hyperparameters.
+
+ :return: Hyperparameters for finetuning the model.
+ :rtype: AzureOpenAIHyperparameters
+ """
+ return self._hyperparameters
+
+ @hyperparameters.setter
+ def hyperparameters(self, hyperparameters: AzureOpenAIHyperparameters) -> None:
+ """Set hyperparameters.
+
+ :param hyperparameters: Hyperparameters for finetuning the model.
+ :type hyperparameters: AzureOpenAiHyperParameters
+ """
+ self._hyperparameters = hyperparameters
+
+ def _to_rest_object(self) -> "RestFineTuningJob":
+ """Convert CustomFineTuningVertical object to a RestFineTuningJob object.
+
+ :return: REST object representation of this object.
+ :rtype: JobBase
+ """
+ aoai_finetuning_vertical = RestAzureOpenAIFineTuning(
+ task_type=self._task,
+ model=self._model,
+ model_provider=self._model_provider,
+ training_data=self._training_data,
+ validation_data=self._validation_data,
+ hyper_parameters=self.hyperparameters._to_rest_object() if self.hyperparameters else None,
+ )
+
+ self._resolve_inputs(aoai_finetuning_vertical)
+
+ finetuning_job = RestFineTuningJob(
+ display_name=self.display_name,
+ description=self.description,
+ experiment_name=self.experiment_name,
+ tags=self.tags,
+ properties=self.properties,
+ fine_tuning_details=aoai_finetuning_vertical,
+ outputs=to_rest_data_outputs(self.outputs),
+ )
+
+ result = RestJobBase(properties=finetuning_job)
+ result.name = self.name
+
+ return result
+
+ def _to_dict(self) -> Dict:
+ """Convert the object to a dictionary.
+
+ :return: dictionary representation of the object.
+ :rtype: typing.Dict
+ """
+ from azure.ai.ml._schema._finetuning.azure_openai_finetuning import AzureOpenAIFineTuningSchema
+
+ schema_dict: dict = {}
+ # TODO: Combeback to this later for FineTuningJob in Pipelines
+ # if inside_pipeline:
+ # schema_dict = AutoMLClassificationNodeSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ # else:
+ schema_dict = AzureOpenAIFineTuningSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+
+ return schema_dict
+
+ def __eq__(self, other: object) -> bool:
+ """Returns True if both instances have the same values.
+
+ This method check instances equality and returns True if both of
+ the instances have the same attributes with the same values.
+
+ :param other: Any object
+ :type other: object
+ :return: True or False
+ :rtype: bool
+ """
+ if not isinstance(other, AzureOpenAIFineTuningJob):
+ return NotImplemented
+
+ return super().__eq__(other) and self.hyperparameters == other.hyperparameters
+
+ def __ne__(self, other: object) -> bool:
+ """Check inequality between two AzureOpenAIFineTuningJob objects.
+
+ :param other: Any object
+ :type other: object
+ :return: True or False
+ :rtype: bool
+ """
+ return not self.__eq__(other)
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestJobBase) -> "AzureOpenAIFineTuningJob":
+ """Convert a REST object to AzureOpenAIFineTuningJob object.
+
+ :param obj: AzureOpenAIFineTuningJob in Rest format.
+ :type obj: JobBase
+ :return: AzureOpenAIFineTuningJob objects.
+ :rtype: AzureOpenAIFineTuningJob
+ """
+
+ properties: RestFineTuningJob = obj.properties
+ finetuning_details: RestAzureOpenAIFineTuning = properties.fine_tuning_details
+
+ job_args_dict = {
+ "id": obj.id,
+ "name": obj.name,
+ "description": properties.description,
+ "tags": properties.tags,
+ "properties": properties.properties,
+ "experiment_name": properties.experiment_name,
+ "status": properties.status,
+ "creation_context": obj.system_data,
+ "display_name": properties.display_name,
+ "outputs": from_rest_data_outputs(properties.outputs),
+ }
+
+ aoai_finetuning_job = cls(
+ task=finetuning_details.task_type,
+ model=finetuning_details.model,
+ training_data=finetuning_details.training_data,
+ validation_data=finetuning_details.validation_data,
+ hyperparameters=AzureOpenAIHyperparameters._from_rest_object(finetuning_details.hyper_parameters),
+ **job_args_dict,
+ )
+
+ aoai_finetuning_job._restore_inputs()
+
+ return aoai_finetuning_job
+
+ @classmethod
+ def _load_from_dict(
+ cls,
+ data: Dict,
+ context: Dict,
+ additional_message: str,
+ **kwargs: Any,
+ ) -> "AzureOpenAIFineTuningJob":
+ """Load from a dictionary.
+
+ :param data: dictionary representation of the object.
+ :type data: typing.Dict
+ :param context: dictionary containing the context.
+ :type context: typing.Dict
+ :param additional_message: additional message to be added to the error message.
+ :type additional_message: str
+ :return: AzureOpenAIFineTuningJob object.
+ :rtype: AzureOpenAIFineTuningJob
+ """
+ from azure.ai.ml._schema._finetuning.azure_openai_finetuning import AzureOpenAIFineTuningSchema
+
+ # TODO: Combeback to this later - Pipeline part.
+ # from azure.ai.ml._schema.pipeline.automl_node import AutoMLClassificationNodeSchema
+
+ # if kwargs.pop("inside_pipeline", False):
+ # loaded_data = load_from_dict(
+ # AutoMLClassificationNodeSchema,
+ # data,
+ # context,
+ # additional_message,
+ # **kwargs,
+ # )
+ # else:
+ loaded_data = load_from_dict(AzureOpenAIFineTuningSchema, data, context, additional_message, **kwargs)
+
+ job_instance = cls._create_instance_from_schema_dict(loaded_data)
+ return job_instance
+
+ @classmethod
+ def _create_instance_from_schema_dict(cls, loaded_data: Dict) -> "AzureOpenAIFineTuningJob":
+ """Create an instance from a schema dictionary.
+
+ :param loaded_data: dictionary containing the data.
+ :type loaded_data: typing.Dict
+ :return: AzureOpenAIFineTuningJob object.
+ :rtype: AzureOpenAIFineTuningJob
+ """
+
+ job = AzureOpenAIFineTuningJob(**loaded_data)
+ return job
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/azure_openai_hyperparameters.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/azure_openai_hyperparameters.py
new file mode 100644
index 00000000..2b420a46
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/azure_openai_hyperparameters.py
@@ -0,0 +1,125 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Optional
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+from azure.ai.ml._restclient.v2024_01_01_preview.models import (
+ AzureOpenAiHyperParameters as RestAzureOpenAiHyperParameters,
+)
+from azure.ai.ml._utils._experimental import experimental
+
+
+@experimental
+class AzureOpenAIHyperparameters(RestTranslatableMixin):
+ """Hyperparameters for Azure OpenAI model finetuning."""
+
+ def __init__(
+ self,
+ *,
+ batch_size: Optional[int] = None,
+ learning_rate_multiplier: Optional[float] = None,
+ n_epochs: Optional[int] = None,
+ ):
+ """Initialize AzureOpenAIHyperparameters.
+
+ param batch_size: Number of examples in each batch.
+ A larger batch size means that model parameters are updated less
+ frequently, but with lower variance. Defaults to None.
+ type batch_size: int
+ param learning_rate_multiplier: Scaling factor for the learning rate.
+ A smaller learning rate may be useful to avoid overfitting.
+ type learning_rate_multiplier: float
+ param n_epochs: The number of epochs to train the model for.
+ An epoch refers to one full cycle through the training dataset.
+ type n_epochs: int
+ """
+ self._batch_size = batch_size
+ self._learning_rate_multiplier = learning_rate_multiplier
+ self._n_epochs = n_epochs
+ # Not exposed in the public API, so need to check how to handle this
+ # self._additional_properties = kwargs
+
+ @property
+ def batch_size(self) -> Optional[int]:
+ """Get the batch size for training."""
+ return self._batch_size
+
+ @batch_size.setter
+ def batch_size(self, value: Optional[int]) -> None:
+ """Set the batch size for training.
+ :param value: The batch size for training.
+ :type value: int
+ """
+ self._batch_size = value
+
+ @property
+ def learning_rate_multiplier(self) -> Optional[float]:
+ """Get the learning rate multiplier.
+ :return: The learning rate multiplier.
+ :rtype: float
+ """
+ return self._learning_rate_multiplier
+
+ @learning_rate_multiplier.setter
+ def learning_rate_multiplier(self, value: Optional[float]) -> None:
+ """Set the learning rate multiplier.
+ :param value: The learning rate multiplier.
+ :type value: float
+ """
+ self._learning_rate_multiplier = value
+
+ @property
+ def n_epochs(self) -> Optional[int]:
+ """Get the number of epochs.
+ :return: The number of epochs.
+ :rtype: int
+ """
+ return self._n_epochs
+
+ @n_epochs.setter
+ def n_epochs(self, value: Optional[int]) -> None:
+ """Set the number of epochs.
+ :param value: The number of epochs.
+ :type value: int
+ """
+ self._n_epochs = value
+
+ # Not exposed in the public API, so need to check how to handle this
+ # @property
+ # def additional_properties(self) -> dict:
+ # """Get additional properties."""
+ # return self._additional_properties
+
+ # @additional_properties.setter
+ # def additional_properties(self, value: dict) -> None:
+ # """Set additional properties."""
+ # self._additional_properties = value
+
+ def _to_rest_object(self) -> RestAzureOpenAiHyperParameters:
+ return RestAzureOpenAiHyperParameters(
+ batch_size=self._batch_size,
+ learning_rate_multiplier=self._learning_rate_multiplier,
+ n_epochs=self._n_epochs,
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, AzureOpenAIHyperparameters):
+ return NotImplemented
+ return (
+ self._batch_size == other._batch_size
+ and self._learning_rate_multiplier == other._learning_rate_multiplier
+ and self._n_epochs == other._n_epochs
+ )
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestAzureOpenAiHyperParameters) -> "AzureOpenAIHyperparameters":
+ aoai_hyperparameters = cls(
+ batch_size=obj.batch_size,
+ learning_rate_multiplier=obj.learning_rate_multiplier,
+ n_epochs=obj.n_epochs,
+ )
+ return aoai_hyperparameters
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/custom_model_finetuning_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/custom_model_finetuning_job.py
new file mode 100644
index 00000000..e6ddd86d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/custom_model_finetuning_job.py
@@ -0,0 +1,258 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from typing import Any, Dict
+
+from azure.ai.ml._restclient.v2024_10_01_preview.models import (
+ ModelProvider as RestModelProvider,
+ CustomModelFineTuning as RestCustomModelFineTuningVertical,
+ FineTuningJob as RestFineTuningJob,
+ JobBase as RestJobBase,
+)
+from azure.ai.ml.entities._job._input_output_helpers import (
+ from_rest_data_outputs,
+ to_rest_data_outputs,
+)
+from azure.ai.ml.entities._job.job_resources import JobResources
+from azure.ai.ml.entities._job.queue_settings import QueueSettings
+from azure.ai.ml.entities._inputs_outputs import Input
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+from azure.ai.ml.entities._job.finetuning.finetuning_vertical import FineTuningVertical
+from azure.ai.ml.entities._util import load_from_dict
+from azure.ai.ml._utils._experimental import experimental
+
+
+@experimental
+class CustomModelFineTuningJob(FineTuningVertical):
+ def __init__(
+ self,
+ **kwargs: Any,
+ ) -> None:
+ # Extract any task specific settings
+ model = kwargs.pop("model", None)
+ task = kwargs.pop("task", None)
+ # Convert task to lowercase first letter, this is when we create
+ # object from the schema, using dict object from the REST api response.
+ # TextCompletion => textCompletion
+ if task:
+ task = task[0].lower() + task[1:]
+ training_data = kwargs.pop("training_data", None)
+ validation_data = kwargs.pop("validation_data", None)
+ self._hyperparameters = kwargs.pop("hyperparameters", None)
+ super().__init__(
+ task=task,
+ model=model,
+ model_provider=RestModelProvider.CUSTOM,
+ training_data=training_data,
+ validation_data=validation_data,
+ **kwargs,
+ )
+
+ @property
+ def hyperparameters(self) -> Dict[str, str]:
+ """Get hyperparameters.
+
+ :return:
+ :rtype: hyperparameters: Dict[str,str]
+ """
+ return self._hyperparameters
+
+ @hyperparameters.setter
+ def hyperparameters(self, hyperparameters: Dict[str, str]) -> None:
+ """Set hyperparameters.
+
+ :param hyperparameters: Hyperparameters for finetuning the model
+ :type hyperparameters: Dict[str,str]
+ """
+ self._hyperparameters = hyperparameters
+
+ def _to_rest_object(self) -> "RestFineTuningJob":
+ """Convert CustomFineTuningVertical object to a RestFineTuningJob object.
+
+ :return: REST object representation of this object.
+ :rtype: JobBase
+ """
+ custom_finetuning_vertical = RestCustomModelFineTuningVertical(
+ task_type=self._task,
+ model=self._model,
+ model_provider=self._model_provider,
+ training_data=self._training_data,
+ validation_data=self._validation_data,
+ hyper_parameters=self._hyperparameters,
+ )
+ self._resolve_inputs(custom_finetuning_vertical)
+
+ finetuning_job = RestFineTuningJob(
+ display_name=self.display_name,
+ description=self.description,
+ experiment_name=self.experiment_name,
+ services=self.services,
+ tags=self.tags,
+ properties=self.properties,
+ compute_id=self.compute,
+ fine_tuning_details=custom_finetuning_vertical,
+ outputs=to_rest_data_outputs(self.outputs),
+ )
+ if self.resources:
+ finetuning_job.resources = self.resources._to_rest_object()
+ if self.queue_settings:
+ finetuning_job.queue_settings = self.queue_settings._to_rest_object()
+
+ result = RestJobBase(properties=finetuning_job)
+ result.name = self.name
+
+ return result
+
+ def _to_dict(self) -> Dict:
+ """Convert the object to a dictionary.
+
+ :return: dictionary representation of the object.
+ :rtype: typing.Dict
+ """
+ from azure.ai.ml._schema._finetuning.custom_model_finetuning import (
+ CustomModelFineTuningSchema,
+ )
+
+ schema_dict: dict = {}
+ # TODO: Combeback to this later for FineTuningJob in pipeline
+ # if inside_pipeline:
+ # schema_dict = AutoMLClassificationNodeSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ # else:
+ schema_dict = CustomModelFineTuningSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+
+ return schema_dict
+
+ def __eq__(self, other: object) -> bool:
+ """Returns True if both instances have the same values.
+
+ This method check instances equality and returns True if both of
+ the instances have the same attributes with the same values.
+
+ :param other: Any object
+ :type other: object
+ :return: True or False
+ :rtype: bool
+ """
+ if not isinstance(other, CustomModelFineTuningJob):
+ return NotImplemented
+
+ return super().__eq__(other) and self.hyperparameters == other.hyperparameters
+
+ def __ne__(self, other: object) -> bool:
+ """Check inequality between two CustomModelFineTuningJob objects.
+
+ :param other: Any object
+ :type other: object
+ :return: True or False
+ :rtype: bool
+ """
+ return not self.__eq__(other)
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestJobBase) -> "CustomModelFineTuningJob":
+ """Convert a REST object to CustomModelFineTuningJob object.
+
+ :param obj: CustomModelFineTuningJob in Rest format.
+ :type obj: JobBase
+ :return: CustomModelFineTuningJob objects.
+ :rtype: CustomModelFineTuningJob
+ """
+
+ properties: RestFineTuningJob = obj.properties
+ finetuning_details: RestCustomModelFineTuningVertical = properties.fine_tuning_details
+
+ job_args_dict = {
+ "id": obj.id,
+ "name": obj.name,
+ "description": properties.description,
+ "tags": properties.tags,
+ "properties": properties.properties,
+ "services": properties.services,
+ "experiment_name": properties.experiment_name,
+ "status": properties.status,
+ "creation_context": obj.system_data,
+ "display_name": properties.display_name,
+ "compute": properties.compute_id,
+ "outputs": from_rest_data_outputs(properties.outputs),
+ }
+
+ if properties.resources:
+ job_args_dict["resources"] = JobResources._from_rest_object(properties.resources)
+ if properties.queue_settings:
+ job_args_dict["queue_settings"] = QueueSettings._from_rest_object(properties.queue_settings)
+
+ custom_model_finetuning_job = cls(
+ task=finetuning_details.task_type,
+ model=finetuning_details.model,
+ training_data=finetuning_details.training_data,
+ validation_data=finetuning_details.validation_data,
+ hyperparameters=finetuning_details.hyper_parameters,
+ **job_args_dict,
+ )
+
+ custom_model_finetuning_job._restore_inputs()
+
+ return custom_model_finetuning_job
+
+ @classmethod
+ def _load_from_dict(
+ cls,
+ data: Dict,
+ context: Dict,
+ additional_message: str,
+ **kwargs: Any,
+ ) -> "CustomModelFineTuningJob":
+ """Load from a dictionary.
+
+ :param data: dictionary representation of the object.
+ :type data: typing.Dict
+ :param context: dictionary containing the context.
+ :type context: typing.Dict
+ :param additional_message: additional message to be added to the error message.
+ :type additional_message: str
+ :return: CustomModelFineTuningJob object.
+ :rtype: CustomModelFineTuningJob
+ """
+ from azure.ai.ml._schema._finetuning.custom_model_finetuning import (
+ CustomModelFineTuningSchema,
+ )
+
+ # TODO: Combeback to this later - Pipeline part.
+ # from azure.ai.ml._schema.pipeline.automl_node import AutoMLClassificationNodeSchema
+
+ # if kwargs.pop("inside_pipeline", False):
+ # loaded_data = load_from_dict(
+ # AutoMLClassificationNodeSchema,
+ # data,
+ # context,
+ # additional_message,
+ # **kwargs,
+ # )
+ # else:
+ loaded_data = load_from_dict(CustomModelFineTuningSchema, data, context, additional_message, **kwargs)
+
+ training_data = loaded_data.get("training_data", None)
+ if isinstance(training_data, str):
+ loaded_data["training_data"] = Input(type="uri_file", path=training_data)
+
+ validation_data = loaded_data.get("validation_data", None)
+ if isinstance(validation_data, str):
+ loaded_data["validation_data"] = Input(type="uri_file", path=validation_data)
+
+ job_instance = cls._create_instance_from_schema_dict(loaded_data)
+ return job_instance
+
+ @classmethod
+ def _create_instance_from_schema_dict(cls, loaded_data: Dict) -> "CustomModelFineTuningJob":
+ """Create an instance from a schema dictionary.
+
+ :param loaded_data: dictionary containing the data.
+ :type loaded_data: typing.Dict
+ :return: CustomModelFineTuningJob object.
+ :rtype: CustomModelFineTuningJob
+ """
+ job = CustomModelFineTuningJob(**loaded_data)
+ return job
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/finetuning_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/finetuning_job.py
new file mode 100644
index 00000000..ec8d9d5d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/finetuning_job.py
@@ -0,0 +1,224 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from typing import Any, Dict, Optional
+
+from azure.ai.ml.entities._job.job import Job
+from azure.ai.ml.entities._job.job_io_mixin import JobIOMixin
+from azure.ai.ml._restclient.v2024_10_01_preview.models import (
+ ModelProvider as RestModelProvider,
+ JobBase as RestJobBase,
+)
+from azure.ai.ml.constants import JobType
+from azure.ai.ml.constants._common import TYPE
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.entities._job.job_resources import JobResources
+from azure.ai.ml.entities._job.queue_settings import QueueSettings
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
+from azure.ai.ml.constants._job.finetuning import FineTuningConstants
+from azure.ai.ml._utils._experimental import experimental
+
+
+@experimental
+class FineTuningJob(Job, JobIOMixin):
+ def __init__(
+ self,
+ **kwargs: Any,
+ ) -> None:
+ kwargs[TYPE] = JobType.FINE_TUNING
+ self.resources = kwargs.pop("resources", None)
+ self.queue_settings = kwargs.pop("queue_settings", None)
+ self.outputs = kwargs.pop("outputs", None)
+ super().__init__(**kwargs)
+
+ @property
+ def resources(self) -> Optional[JobResources]:
+ """Job resources to use during job execution.
+ :return: Job Resources object.
+ :rtype: JobResources
+ """
+ return self._resources if hasattr(self, "_resources") else None
+
+ @resources.setter
+ def resources(self, value: JobResources) -> None:
+ """Set JobResources.
+
+ :param value: JobResources object.
+ :type value: JobResources
+ :raises ValidationException: Expected a JobResources object.
+ """
+ if isinstance(value, JobResources):
+ self._resources = value
+ elif value:
+ msg = "Expected an instance of JobResources."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.FINETUNING,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ @property
+ def queue_settings(self) -> Optional[QueueSettings]:
+ """Queue settings for job execution.
+ :return: QueueSettings object.
+ :rtype: QueueSettings
+ """
+ return self._queue_settings if hasattr(self, "_queue_settings") else None
+
+ @queue_settings.setter
+ def queue_settings(self, value: QueueSettings) -> None:
+ """Set queue settings for job execution.
+
+ :param value: QueueSettings object.
+ :type value: QueueSettings
+ :raises ValidationException: Expected a QueueSettings object.
+ """
+ if isinstance(value, QueueSettings):
+ self._queue_settings = value
+ elif value:
+ msg = "Expected an instance of QueueSettings."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.FINETUNING,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ def __eq__(self, other: object) -> bool:
+ """Returns True if both instances have the same values.
+
+ This method check instances equality and returns True if both of
+ the instances have the same attributes with the same values.
+
+ :param other: Any object
+ :type other: object
+ :return: True or False
+ :rtype: bool
+ """
+ if not isinstance(other, FineTuningJob):
+ return NotImplemented
+
+ queue_settings_match = (not self.queue_settings and not other.queue_settings) or (
+ self.queue_settings is not None
+ and other.queue_settings is not None
+ and self.queue_settings.job_tier is not None
+ and other.queue_settings.job_tier is not None
+ and self.queue_settings.job_tier.lower() == other.queue_settings.job_tier.lower()
+ )
+
+ outputs_match = not self.outputs and not other.outputs
+ if self.outputs and other.outputs:
+ outputs_match = (
+ self.outputs["registered_model"].name == other.outputs["registered_model"].name
+ and self.outputs["registered_model"].type == other.outputs["registered_model"].type
+ )
+
+ return (
+ outputs_match
+ and self.resources == other.resources
+ and queue_settings_match
+ # add properties from base class
+ and self.name == other.name
+ and self.description == other.description
+ and self.tags == other.tags
+ and self.properties == other.properties
+ and self.compute == other.compute
+ and self.id == other.id
+ and self.experiment_name == other.experiment_name
+ and self.status == other.status
+ )
+
+ def __ne__(self, other: object) -> bool:
+ """Check inequality between two FineTuningJob objects.
+
+ :param other: Any object
+ :type other: object
+ :return: True or False
+ :rtype: bool
+ """
+ return not self.__eq__(other)
+
+ @classmethod
+ def _get_model_provider_mapping(cls) -> Dict:
+ """Create a mapping of task type to job class.
+
+ :return: An FineTuningVertical object containing the model provider type to job class mapping.
+ :rtype: FineTuningJob
+ """
+ from .custom_model_finetuning_job import CustomModelFineTuningJob
+ from .azure_openai_finetuning_job import AzureOpenAIFineTuningJob
+
+ return {
+ camel_to_snake(RestModelProvider.CUSTOM): CustomModelFineTuningJob,
+ camel_to_snake(RestModelProvider.AZURE_OPEN_AI): AzureOpenAIFineTuningJob,
+ }
+
+ @classmethod
+ def _load_from_rest(cls, obj: RestJobBase) -> "FineTuningJob":
+ """Loads the rest object to a dict containing items to init the AutoMLJob objects.
+
+ :param obj: Azure Resource Manager resource envelope.
+ :type obj: JobBase
+ :raises ValidationException: task type validation error
+ :return: A FineTuningJob
+ :rtype: FineTuningJob
+ """
+ model_provider = (
+ camel_to_snake(obj.properties.fine_tuning_details.model_provider)
+ if obj.properties.fine_tuning_details.model_provider
+ else None
+ )
+ class_type = cls._get_model_provider_mapping().get(model_provider, None)
+ if class_type:
+ res: FineTuningJob = class_type._from_rest_object(obj)
+ return res
+ msg = f"Unsupported model provider type: {obj.properties.fine_tuning_details.model_provider}"
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.FINETUNING,
+ error_category=ErrorCategory.SYSTEM_ERROR,
+ )
+
+ @classmethod
+ def _load_from_dict(
+ cls,
+ data: Dict,
+ context: Dict,
+ additional_message: str,
+ **kwargs: Any,
+ ) -> "FineTuningJob":
+ """Loads the dictionary objects to an FineTuningJob object.
+
+ :param data: A data dictionary.
+ :type data: typing.Dict
+ :param context: A context dictionary.
+ :type context: typing.Dict
+ :param additional_message: An additional message to be logged in the ValidationException.
+ :type additional_message: str
+
+ :raises ValidationException: task type validation error
+ :return: An FineTuningJob
+ :rtype: FineTuningJob
+ """
+ model_provider = data.get(FineTuningConstants.ModelProvider)
+ class_type = cls._get_model_provider_mapping().get(model_provider, None)
+ if class_type:
+ res: FineTuningJob = class_type._load_from_dict(
+ data,
+ context,
+ additional_message,
+ **kwargs,
+ )
+ return res
+ msg = f"Unsupported model provider type: {model_provider}"
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.AUTOML,
+ error_category=ErrorCategory.USER_ERROR,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/finetuning_vertical.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/finetuning_vertical.py
new file mode 100644
index 00000000..c9a5fe41
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/finetuning/finetuning_vertical.py
@@ -0,0 +1,202 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Any, Optional, cast
+
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
+from azure.ai.ml._restclient.v2024_10_01_preview.models import (
+ ModelProvider as RestModelProvider,
+ FineTuningVertical as RestFineTuningVertical,
+ UriFileJobInput,
+ MLFlowModelJobInput,
+)
+from azure.ai.ml.constants._common import AssetTypes
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.entities._inputs_outputs import Input
+from azure.ai.ml.entities._job.finetuning.finetuning_job import FineTuningJob
+
+from azure.ai.ml._utils._experimental import experimental
+
+
+@experimental
+class FineTuningVertical(FineTuningJob):
+ def __init__(
+ self,
+ *,
+ task: str,
+ model: Input,
+ model_provider: Optional[str],
+ training_data: Input,
+ validation_data: Optional[Input] = None,
+ **kwargs: Any,
+ ) -> None:
+ self._task = task
+ self._model = model
+ self._model_provider = model_provider
+ self._training_data = training_data
+ self._validation_data = validation_data
+ super().__init__(**kwargs)
+
+ @property
+ def task(self) -> str:
+ """Get finetuning task.
+
+ :return: The type of task to run. Possible values include: "ChatCompletion"
+ "TextCompletion", "TextClassification", "QuestionAnswering","TextSummarization",
+ "TokenClassification", "TextTranslation", "ImageClassification", "ImageInstanceSegmentation",
+ "ImageObjectDetection","VideoMultiObjectTracking".
+
+ :rtype: str
+ """
+ return self._task
+
+ @task.setter
+ def task(self, task: str) -> None:
+ """Set finetuning task.
+
+ :param task: The type of task to run. Possible values include: "ChatCompletion"
+ "TextCompletion", "TextClassification", "QuestionAnswering","TextSummarization",
+ "TokenClassification", "TextTranslation", "ImageClassification", "ImageInstanceSegmentation",
+ "ImageObjectDetection","VideoMultiObjectTracking",.
+ :type task: str
+
+ :return: None
+ """
+ self._task = task
+
+ @property
+ def model(self) -> Optional[Input]:
+ """The model to be fine-tuned.
+ :return: Input object representing the mlflow model to be fine-tuned.
+ :rtype: Input
+ """
+ return self._model
+
+ @model.setter
+ def model(self, value: Input) -> None:
+ """Set the model to be fine-tuned.
+
+ :param value: Input object representing the mlflow model to be fine-tuned.
+ :type value: Input
+ :raises ValidationException: Expected a mlflow model input.
+ """
+ if isinstance(value, Input) and (cast(Input, value).type in ("mlflow_model", "custom_model")):
+ self._model = value
+ else:
+ msg = "Expected a mlflow model input or custom model input."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.FINETUNING,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ @property
+ def model_provider(self) -> Optional[str]:
+ """The model provider.
+ :return: The model provider.
+ :rtype: str
+ """
+ return self._model_provider
+
+ @model_provider.setter
+ def model_provider(self, value: str) -> None:
+ """Set the model provider.
+
+ :param value: The model provider.
+ :type value: str
+ """
+ self._model_provider = RestModelProvider[camel_to_snake(value).upper()] if value else None
+
+ @property
+ def training_data(self) -> Input:
+ """Get training data.
+
+ :return: Training data input
+ :rtype: Input
+ """
+ return self._training_data
+
+ @training_data.setter
+ def training_data(self, training_data: Input) -> None:
+ """Set training data.
+
+ :param training_data: Training data input
+ :type training_data: Input
+ """
+ self._training_data = training_data
+
+ @property
+ def validation_data(self) -> Optional[Input]:
+ """Get validation data.
+
+ :return: Validation data input
+ :rtype: Input
+ """
+ return self._validation_data
+
+ @validation_data.setter
+ def validation_data(self, validation_data: Input) -> None:
+ """Set validation data.
+
+ :param validation_data: Validation data input
+ :type validation_data: Input
+ """
+ self._validation_data = validation_data
+
+ def _resolve_inputs(self, rest_job: RestFineTuningVertical) -> None:
+ """Resolve JobInputs to UriFileJobInput within data_settings.
+
+ :param rest_job: The rest job object.
+ :type rest_job: RestFineTuningVertical
+ """
+ if isinstance(rest_job.training_data, Input):
+ rest_job.training_data = UriFileJobInput(uri=rest_job.training_data.path)
+ if isinstance(rest_job.validation_data, Input):
+ rest_job.validation_data = UriFileJobInput(uri=rest_job.validation_data.path)
+ if isinstance(rest_job.model, Input):
+ rest_job.model = MLFlowModelJobInput(uri=rest_job.model.path)
+
+ def _restore_inputs(self) -> None:
+ """Restore UriFileJobInputs to JobInputs within data_settings."""
+ if isinstance(self.training_data, UriFileJobInput):
+ self.training_data = Input(type=AssetTypes.URI_FILE, path=self.training_data.uri)
+ if isinstance(self.validation_data, UriFileJobInput):
+ self.validation_data = Input(type=AssetTypes.URI_FILE, path=self.validation_data.uri)
+ if isinstance(self.model, MLFlowModelJobInput):
+ self.model = Input(type=AssetTypes.MLFLOW_MODEL, path=self.model.uri)
+
+ def __eq__(self, other: object) -> bool:
+ """Returns True if both instances have the same values.
+
+ This method check instances equality and returns True if both of
+ the instances have the same attributes with the same values.
+
+ :param other: Any object
+ :type other: object
+ :return: True or False
+ :rtype: bool
+ """
+ if not isinstance(other, FineTuningVertical):
+ return NotImplemented
+
+ return (
+ # TODO: Equality from base class does not work, no current precedence for this
+ super().__eq__(other)
+ and self.task == other.task
+ and self.model == other.model
+ and self.model_provider == other.model_provider
+ and self.training_data == other.training_data
+ and self.validation_data == other.validation_data
+ )
+
+ def __ne__(self, other: object) -> bool:
+ """Check inequality between two FineTuningJob objects.
+
+ :param other: Any object
+ :type other: object
+ :return: True or False
+ :rtype: bool
+ """
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/import_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/import_job.py
new file mode 100644
index 00000000..24d4ec90
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/import_job.py
@@ -0,0 +1,285 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import logging
+from abc import ABC, abstractmethod
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Dict, Optional
+
+from azure.ai.ml._restclient.v2022_02_01_preview.models import CommandJob as RestCommandJob
+from azure.ai.ml._restclient.v2022_02_01_preview.models import JobBaseData
+from azure.ai.ml._schema.job.import_job import ImportJobSchema
+from azure.ai.ml._utils.utils import is_private_preview_enabled
+from azure.ai.ml.constants import JobType
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE
+from azure.ai.ml.entities._inputs_outputs import Output
+from azure.ai.ml.entities._job._input_output_helpers import (
+ from_rest_data_outputs,
+ from_rest_inputs_to_dataset_literal,
+ to_rest_data_outputs,
+ to_rest_dataset_literal_inputs,
+)
+from azure.ai.ml.entities._job.job_io_mixin import JobIOMixin
+from azure.ai.ml.entities._util import load_from_dict
+from azure.ai.ml.exceptions import MlException
+
+from .job import Job
+
+# avoid circular import error
+if TYPE_CHECKING:
+ from azure.ai.ml.entities._builders import Import
+ from azure.ai.ml.entities._component.import_component import ImportComponent
+
+module_logger = logging.getLogger(__name__)
+
+
+class ImportSource(ABC):
+ def __init__(
+ self,
+ *,
+ type: Optional[str] = None, # pylint: disable=redefined-builtin
+ connection: Optional[str] = None,
+ ):
+ self.type = type
+ self.connection = connection
+
+ @abstractmethod
+ def _to_job_inputs(self) -> Dict[str, Optional[str]]:
+ pass
+
+ @classmethod
+ def _from_job_inputs(cls, job_inputs: Dict[str, str]) -> "ImportSource":
+ """Translate job inputs to import source.
+
+ :param job_inputs: The job inputs
+ :type job_inputs: Dict[str, str]
+ :return: The import source
+ :rtype: ImportSource
+ """
+ type = job_inputs.get("type") # pylint: disable=redefined-builtin
+ connection = job_inputs.get("connection")
+ query = job_inputs.get("query")
+ path = job_inputs.get("path")
+
+ import_source = (
+ DatabaseImportSource(type=type, connection=connection, query=query)
+ if query is not None
+ else FileImportSource(type=type, connection=connection, path=path)
+ )
+ return import_source
+
+
+class DatabaseImportSource(ImportSource):
+ def __init__(
+ self,
+ *,
+ type: Optional[str] = None, # pylint: disable=redefined-builtin
+ connection: Optional[str] = None,
+ query: Optional[str] = None,
+ ):
+ ImportSource.__init__(
+ self,
+ type=type,
+ connection=connection,
+ )
+ self.query = query
+
+ def _to_job_inputs(self) -> Dict[str, Optional[str]]:
+ """Translate source to command Inputs.
+
+ :return: The job inputs dict
+ :rtype: Dict[str, str]
+ """
+ inputs = {
+ "type": self.type,
+ "connection": self.connection,
+ "query": self.query,
+ }
+ return inputs
+
+
+class FileImportSource(ImportSource):
+ def __init__(
+ self,
+ *,
+ type: Optional[str] = None, # pylint: disable=redefined-builtin
+ connection: Optional[str] = None,
+ path: Optional[str] = None,
+ ):
+ ImportSource.__init__(
+ self,
+ type=type,
+ connection=connection,
+ )
+ self.path = path
+
+ def _to_job_inputs(self) -> Dict[str, Optional[str]]:
+ """Translate source to command Inputs.
+
+ :return: The job inputs dict
+ :rtype: Dict[str, str]
+ """
+ inputs = {
+ "type": self.type,
+ "connection": self.connection,
+ "path": self.path,
+ }
+ return inputs
+
+
+class ImportJob(Job, JobIOMixin):
+ """Import job.
+
+ :param name: Name of the job.
+ :type name: str
+ :param description: Description of the job.
+ :type description: str
+ :param display_name: Display name of the job.
+ :type display_name: str
+ :param experiment_name: Name of the experiment the job will be created under.
+ If None is provided, default will be set to current directory name.
+ :type experiment_name: str
+ :param source: Input source parameters to the import job.
+ :type source: azure.ai.ml.entities.DatabaseImportSource or FileImportSource
+ :param output: output data binding used in the job.
+ :type output: azure.ai.ml.Output
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: dict
+ """
+
+ def __init__(
+ self,
+ *,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ display_name: Optional[str] = None,
+ experiment_name: Optional[str] = None,
+ source: Optional[ImportSource] = None,
+ output: Optional[Output] = None,
+ **kwargs: Any,
+ ):
+ kwargs[TYPE] = JobType.IMPORT
+
+ Job.__init__(
+ self,
+ name=name,
+ display_name=display_name,
+ description=description,
+ experiment_name=experiment_name,
+ **kwargs,
+ )
+
+ self.source = source
+ self.output = output
+
+ def _to_dict(self) -> Dict:
+ res: dict = ImportJobSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
+
+ def _to_rest_object(self) -> JobBaseData:
+ # TODO: Remove in PuP
+ if not is_private_preview_enabled():
+ msg = JobType.IMPORT + " job not supported."
+ raise MlException(message=msg, no_personal_data_message=msg)
+
+ _inputs = self.source._to_job_inputs() if self.source is not None else None # pylint: disable=protected-access
+ if self.compute is None:
+ msg = "compute cannot be None."
+ raise MlException(message=msg, no_personal_data_message=msg)
+
+ properties = RestCommandJob(
+ display_name=self.display_name,
+ description=self.description,
+ compute_id=self.compute,
+ experiment_name=self.experiment_name,
+ inputs=to_rest_dataset_literal_inputs(_inputs, job_type=self.type),
+ outputs=to_rest_data_outputs({"output": self.output}),
+ # TODO: Remove in PuP with native import job/component type support in MFE/Designer
+ # No longer applicable once new import job type is ready on MFE in PuP
+ # command and environment are required as we use command type for import
+ # command can be random string and the particular environment name here is defined as default in MFE
+ # public const string DefaultEnvironmentName = "AzureML-sklearn-0.24-ubuntu18.04-py37-cpu";
+ # which is considered valid environment in MFE unless MFE changes current default logic
+ # but chance should be very low in PrP
+ command="import",
+ environment_id=self.compute.replace(
+ "/computes/DataFactory", "/environments/AzureML-sklearn-0.24-ubuntu18.04-py37-cpu"
+ ),
+ )
+ result = JobBaseData(properties=properties)
+ result.name = self.name
+ return result
+
+ @classmethod
+ def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "ImportJob":
+ loaded_data = load_from_dict(ImportJobSchema, data, context, additional_message, **kwargs)
+ return ImportJob(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_data)
+
+ @classmethod
+ def _load_from_rest(cls, obj: JobBaseData) -> "ImportJob":
+ rest_command_job: RestCommandJob = obj.properties
+ outputs = from_rest_data_outputs(rest_command_job.outputs)
+ inputs = from_rest_inputs_to_dataset_literal(rest_command_job.inputs)
+
+ import_job = ImportJob(
+ name=obj.name,
+ id=obj.id,
+ display_name=rest_command_job.display_name,
+ description=rest_command_job.description,
+ experiment_name=rest_command_job.experiment_name,
+ status=rest_command_job.status,
+ creation_context=obj.system_data,
+ source=ImportSource._from_job_inputs(inputs), # pylint: disable=protected-access
+ output=outputs["output"] if "output" in outputs else None,
+ )
+ return import_job
+
+ def _to_component(self, context: Optional[Dict] = None, **kwargs: Any) -> "ImportComponent":
+ """Translate a import job to component.
+
+ :param context: Context of import job YAML file.
+ :type context: dict
+ :return: Translated import component.
+ :rtype: ImportComponent
+ """
+ from azure.ai.ml.entities._component.import_component import ImportComponent
+
+ pipeline_job_dict = kwargs.get("pipeline_job_dict", {})
+ context = context or {BASE_PATH_CONTEXT_KEY: Path("import/")}
+
+ _inputs = self.source._to_job_inputs() if self.source is not None else None # pylint: disable=protected-access
+
+ # Create anonymous command component with default version as 1
+ return ImportComponent(
+ is_anonymous=True,
+ base_path=context[BASE_PATH_CONTEXT_KEY],
+ description=self.description,
+ source=self._to_inputs(
+ inputs=_inputs,
+ pipeline_job_dict=pipeline_job_dict,
+ ),
+ output=self._to_outputs(outputs={"output": self.output}, pipeline_job_dict=pipeline_job_dict)["output"],
+ )
+
+ def _to_node(self, context: Optional[Dict] = None, **kwargs: Any) -> "Import":
+ """Translate a import job to a pipeline node.
+
+ :param context: Context of import job YAML file.
+ :type context: dict
+ :return: Translated import node.
+ :rtype: Import
+ """
+ from azure.ai.ml.entities._builders import Import
+
+ component = self._to_component(context, **kwargs)
+ _inputs = self.source._to_job_inputs() if self.source is not None else None # pylint: disable=protected-access
+ return Import(
+ component=component,
+ compute=self.compute,
+ inputs=_inputs,
+ outputs={"output": self.output},
+ description=self.description,
+ display_name=self.display_name,
+ properties=self.properties,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/input_output_entry.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/input_output_entry.py
new file mode 100644
index 00000000..aa0e73b1
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/input_output_entry.py
@@ -0,0 +1,27 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import collections.abc
+import logging
+from typing import Any, Optional, Union
+
+from azure.ai.ml.constants import InputOutputModes
+from azure.ai.ml.entities._assets import Data
+from azure.ai.ml.entities._mixins import DictMixin
+
+module_logger = logging.getLogger(__name__)
+
+
+class InputOutputEntry(DictMixin):
+ def __init__(
+ self, # pylint: disable=unused-argument
+ data: Optional[Union[str, "Data"]] = None,
+ mode: Optional[str] = InputOutputModes.MOUNT,
+ **kwargs: Any,
+ ):
+ # Data will be either a dataset id, inline dataset definition
+ self.data = data
+ self.mode = mode
+ if isinstance(self.data, collections.abc.Mapping) and not isinstance(self.data, Data):
+ self.data = Data(**self.data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/input_port.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/input_port.py
new file mode 100644
index 00000000..7953bbde
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/input_port.py
@@ -0,0 +1,18 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import logging
+from typing import Optional, Union
+
+module_logger = logging.getLogger(__name__)
+
+
+class InputPort:
+ def __init__(self, *, type_string: str, default: Optional[str] = None, optional: Optional[bool] = False):
+ self.type_string = type_string
+ self.optional = optional
+ if self.type_string == "number" and default is not None:
+ self.default: Union[float, Optional[str]] = float(default)
+ else:
+ self.default = default
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job.py
new file mode 100644
index 00000000..b181636e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job.py
@@ -0,0 +1,363 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import json
+import logging
+import traceback
+from abc import abstractmethod
+from collections import OrderedDict
+from os import PathLike
+from pathlib import Path
+from typing import IO, Any, AnyStr, Dict, List, Optional, Tuple, Type, Union
+
+from azure.ai.ml._restclient.runhistory.models import Run
+from azure.ai.ml._restclient.v2023_04_01_preview.models import JobBase, JobService
+from azure.ai.ml._restclient.v2023_04_01_preview.models import JobType as RestJobType
+from azure.ai.ml._restclient.v2024_01_01_preview.models import JobBase as JobBase_2401
+from azure.ai.ml._restclient.v2024_01_01_preview.models import JobType as RestJobType_20240101Preview
+from azure.ai.ml._utils._html_utils import make_link, to_html
+from azure.ai.ml._utils.utils import dump_yaml_to_file
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY, CommonYamlFields
+from azure.ai.ml.constants._compute import ComputeType
+from azure.ai.ml.constants._job.job import JobServices, JobType
+from azure.ai.ml.entities._mixins import TelemetryMixin
+from azure.ai.ml.entities._resource import Resource
+from azure.ai.ml.entities._util import find_type_in_override
+from azure.ai.ml.exceptions import (
+ ErrorCategory,
+ ErrorTarget,
+ JobException,
+ JobParsingError,
+ PipelineChildJobError,
+ ValidationErrorType,
+ ValidationException,
+)
+
+from ._studio_url_from_job_id import studio_url_from_job_id
+from .pipeline._component_translatable import ComponentTranslatableMixin
+
+module_logger = logging.getLogger(__name__)
+
+
+def _is_pipeline_child_job(job: JobBase) -> bool:
+ # pipeline child job has no properties, so we can check through testing job.properties
+ # if backend has spec changes, this method need to be updated
+ return job.properties is None
+
+
+class Job(Resource, ComponentTranslatableMixin, TelemetryMixin):
+ """Base class for jobs.
+
+ This class should not be instantiated directly. Instead, use one of its subclasses.
+
+ :param name: The name of the job.
+ :type name: Optional[str]
+ :param display_name: The display name of the job.
+ :type display_name: Optional[str]
+ :param description: The description of the job.
+ :type description: Optional[str]
+ :param tags: Tag dictionary. Tags can be added, removed, and updated.
+ :type tags: Optional[dict[str, str]]
+ :param properties: The job property dictionary.
+ :type properties: Optional[dict[str, str]]
+ :param experiment_name: The name of the experiment the job will be created under. Defaults to the name of the
+ current directory.
+ :type experiment_name: Optional[str]
+ :param services: Information on services associated with the job.
+ :type services: Optional[dict[str, ~azure.ai.ml.entities.JobService]]
+ :param compute: Information about the compute resources associated with the job.
+ :type compute: Optional[str]
+ """
+
+ def __init__(
+ self,
+ name: Optional[str] = None,
+ display_name: Optional[str] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ properties: Optional[Dict] = None,
+ experiment_name: Optional[str] = None,
+ compute: Optional[str] = None,
+ services: Optional[Dict[str, JobService]] = None,
+ **kwargs: Any,
+ ) -> None:
+ self._type: Optional[str] = kwargs.pop("type", JobType.COMMAND)
+ self._status: Optional[str] = kwargs.pop("status", None)
+ self._log_files: Optional[Dict] = kwargs.pop("log_files", None)
+
+ super().__init__(
+ name=name,
+ description=description,
+ tags=tags,
+ properties=properties,
+ **kwargs,
+ )
+
+ self.display_name = display_name
+ self.experiment_name = experiment_name
+ self.compute: Any = compute
+ self.services = services
+
+ @property
+ def type(self) -> Optional[str]:
+ """The type of the job.
+
+ :return: The type of the job.
+ :rtype: Optional[str]
+ """
+ return self._type
+
+ @property
+ def status(self) -> Optional[str]:
+ """The status of the job.
+
+ Common values returned include "Running", "Completed", and "Failed". All possible values are:
+
+ * NotStarted - This is a temporary state that client-side Run objects are in before cloud submission.
+ * Starting - The Run has started being processed in the cloud. The caller has a run ID at this point.
+ * Provisioning - On-demand compute is being created for a given job submission.
+ * Preparing - The run environment is being prepared and is in one of two stages:
+ * Docker image build
+ * conda environment setup
+ * Queued - The job is queued on the compute target. For example, in BatchAI, the job is in a queued state
+ while waiting for all the requested nodes to be ready.
+ * Running - The job has started to run on the compute target.
+ * Finalizing - User code execution has completed, and the run is in post-processing stages.
+ * CancelRequested - Cancellation has been requested for the job.
+ * Completed - The run has completed successfully. This includes both the user code execution and run
+ post-processing stages.
+ * Failed - The run failed. Usually the Error property on a run will provide details as to why.
+ * Canceled - Follows a cancellation request and indicates that the run is now successfully cancelled.
+ * NotResponding - For runs that have Heartbeats enabled, no heartbeat has been recently sent.
+
+ :return: Status of the job.
+ :rtype: Optional[str]
+ """
+ return self._status
+
+ @property
+ def log_files(self) -> Optional[Dict[str, str]]:
+ """Job output files.
+
+ :return: The dictionary of log names and URLs.
+ :rtype: Optional[Dict[str, str]]
+ """
+ return self._log_files
+
+ @property
+ def studio_url(self) -> Optional[str]:
+ """Azure ML studio endpoint.
+
+ :return: The URL to the job details page.
+ :rtype: Optional[str]
+ """
+ if self.services and (JobServices.STUDIO in self.services.keys()):
+ res: Optional[str] = self.services[JobServices.STUDIO].endpoint
+ return res
+
+ return studio_url_from_job_id(self.id) if self.id else None
+
+ def dump(self, dest: Union[str, PathLike, IO[AnyStr]], **kwargs: Any) -> None:
+ """Dumps the job content into a file in YAML format.
+
+ :param dest: The local path or file stream to write the YAML content to.
+ If dest is a file path, a new file will be created.
+ If dest is an open file, the file will be written to directly.
+ :type dest: Union[PathLike, str, IO[AnyStr]]
+ :raises FileExistsError: Raised if dest is a file path and the file already exists.
+ :raises IOError: Raised if dest is an open file and the file is not writable.
+ """
+ path = kwargs.pop("path", None)
+ yaml_serialized = self._to_dict()
+ dump_yaml_to_file(dest, yaml_serialized, default_flow_style=False, path=path, **kwargs)
+
+ def _get_base_info_dict(self) -> OrderedDict:
+ return OrderedDict(
+ [
+ ("Experiment", self.experiment_name),
+ ("Name", self.name),
+ ("Type", self._type),
+ ("Status", self._status),
+ ]
+ )
+
+ def _repr_html_(self) -> str:
+ info = self._get_base_info_dict()
+ if self.studio_url:
+ info.update(
+ [
+ (
+ "Details Page",
+ make_link(self.studio_url, "Link to Azure Machine Learning studio"),
+ ),
+ ]
+ )
+ res: str = to_html(info)
+ return res
+
+ @abstractmethod
+ def _to_dict(self) -> Dict:
+ pass
+
+ @classmethod
+ def _resolve_cls_and_type(cls, data: Dict, params_override: Optional[List[Dict]] = None) -> Tuple:
+ from azure.ai.ml.entities._builders.command import Command
+ from azure.ai.ml.entities._builders.spark import Spark
+ from azure.ai.ml.entities._job.automl.automl_job import AutoMLJob
+ from azure.ai.ml.entities._job.distillation.distillation_job import DistillationJob
+ from azure.ai.ml.entities._job.finetuning.finetuning_job import FineTuningJob
+ from azure.ai.ml.entities._job.import_job import ImportJob
+ from azure.ai.ml.entities._job.pipeline.pipeline_job import PipelineJob
+ from azure.ai.ml.entities._job.sweep.sweep_job import SweepJob
+
+ job_type: Optional[Type["Job"]] = None
+ type_in_override = find_type_in_override(params_override)
+ type_str = type_in_override or data.get(CommonYamlFields.TYPE, JobType.COMMAND) # override takes the priority
+ if type_str == JobType.COMMAND:
+ job_type = Command
+ elif type_str == JobType.SPARK:
+ job_type = Spark
+ elif type_str == JobType.IMPORT:
+ job_type = ImportJob
+ elif type_str == JobType.SWEEP:
+ job_type = SweepJob
+ elif type_str == JobType.AUTOML:
+ job_type = AutoMLJob
+ elif type_str == JobType.PIPELINE:
+ job_type = PipelineJob
+ elif type_str == JobType.FINE_TUNING:
+ job_type = FineTuningJob
+ elif type_str == JobType.DISTILLATION:
+ job_type = DistillationJob
+ else:
+ msg = f"Unsupported job type: {type_str}."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ return job_type, type_str
+
+ @classmethod
+ def _load(
+ cls,
+ data: Optional[Dict] = None,
+ yaml_path: Optional[Union[PathLike, str]] = None,
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "Job":
+ """Load a job object from a yaml file.
+
+ :param cls: Indicates that this is a class method.
+ :type cls: class
+ :param data: Data Dictionary, defaults to None
+ :type data: Dict
+ :param yaml_path: YAML Path, defaults to None
+ :type yaml_path: Union[PathLike, str]
+ :param params_override: Fields to overwrite on top of the yaml file.
+ Format is [{"field1": "value1"}, {"field2": "value2"}], defaults to None
+ :type params_override: List[Dict]
+ :raises Exception: An exception
+ :return: Loaded job object.
+ :rtype: Job
+ """
+ data = data or {}
+ params_override = params_override or []
+ context = {
+ BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"),
+ PARAMS_OVERRIDE_KEY: params_override,
+ }
+ job_type, type_str = cls._resolve_cls_and_type(data, params_override)
+ job: Job = job_type._load_from_dict(
+ data=data,
+ context=context,
+ additional_message=f"If you are trying to configure a job that is not of type {type_str}, please specify "
+ f"the correct job type in the 'type' property.",
+ **kwargs,
+ )
+ if yaml_path:
+ job._source_path = yaml_path
+ return job
+
+ @classmethod
+ def _from_rest_object( # pylint: disable=too-many-return-statements
+ cls, obj: Union[JobBase, JobBase_2401, Run]
+ ) -> "Job":
+ from azure.ai.ml.entities import PipelineJob
+ from azure.ai.ml.entities._builders.command import Command
+ from azure.ai.ml.entities._builders.spark import Spark
+ from azure.ai.ml.entities._job.automl.automl_job import AutoMLJob
+ from azure.ai.ml.entities._job.base_job import _BaseJob
+ from azure.ai.ml.entities._job.distillation.distillation_job import DistillationJob
+ from azure.ai.ml.entities._job.finetuning.finetuning_job import FineTuningJob
+ from azure.ai.ml.entities._job.import_job import ImportJob
+ from azure.ai.ml.entities._job.sweep.sweep_job import SweepJob
+
+ try:
+ if isinstance(obj, Run):
+ # special handling for child jobs
+ return _BaseJob._load_from_rest(obj)
+ if _is_pipeline_child_job(obj):
+ raise PipelineChildJobError(job_id=obj.id)
+ if obj.properties.job_type == RestJobType.COMMAND:
+ # PrP only until new import job type is ready on MFE in PuP
+ # compute type 'DataFactory' is reserved compute name for 'clusterless' ADF jobs
+ if obj.properties.compute_id and obj.properties.compute_id.endswith("/" + ComputeType.ADF):
+ return ImportJob._load_from_rest(obj)
+
+ res_command: Job = Command._load_from_rest_job(obj)
+ if hasattr(obj, "name"):
+ res_command._name = obj.name # type: ignore[attr-defined]
+ return res_command
+ if obj.properties.job_type == RestJobType.SPARK:
+ res_spark: Job = Spark._load_from_rest_job(obj)
+ if hasattr(obj, "name"):
+ res_spark._name = obj.name # type: ignore[attr-defined]
+ return res_spark
+ if obj.properties.job_type == RestJobType.SWEEP:
+ return SweepJob._load_from_rest(obj)
+ if obj.properties.job_type == RestJobType.AUTO_ML:
+ return AutoMLJob._load_from_rest(obj)
+ if obj.properties.job_type == RestJobType_20240101Preview.FINE_TUNING:
+ if obj.properties.properties.get("azureml.enable_distillation", False):
+ return DistillationJob._load_from_rest(obj)
+ return FineTuningJob._load_from_rest(obj)
+ if obj.properties.job_type == RestJobType.PIPELINE:
+ res_pipeline: Job = PipelineJob._load_from_rest(obj)
+ return res_pipeline
+ except PipelineChildJobError as ex:
+ raise ex
+ except Exception as ex:
+ error_message = json.dumps(obj.as_dict(), indent=2) if obj else None
+ module_logger.info(
+ "Exception: %s.\n%s\nUnable to parse the job resource: %s.\n",
+ ex,
+ traceback.format_exc(),
+ error_message,
+ )
+ raise JobParsingError(
+ message=str(ex),
+ no_personal_data_message=f"Unable to parse a job resource of type:{type(obj).__name__}",
+ error_category=ErrorCategory.SYSTEM_ERROR,
+ ) from ex
+ msg = f"Unsupported job type {obj.properties.job_type}"
+ raise JobException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.JOB,
+ error_category=ErrorCategory.SYSTEM_ERROR,
+ )
+
+ def _get_telemetry_values(self) -> Dict: # pylint: disable=arguments-differ
+ telemetry_values = {"type": self.type}
+ return telemetry_values
+
+ @classmethod
+ @abstractmethod
+ def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "Job":
+ pass
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_io_mixin.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_io_mixin.py
new file mode 100644
index 00000000..21db73ba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_io_mixin.py
@@ -0,0 +1,37 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+
+from typing import Dict, Union
+
+from azure.ai.ml.entities._inputs_outputs import Input, Output
+from azure.ai.ml.entities._job._input_output_helpers import build_input_output
+
+
+class JobIOMixin:
+ @property
+ def inputs(self) -> Dict[str, Union[Input, str, bool, int, float]]:
+ return self._inputs
+
+ @inputs.setter
+ def inputs(self, value: Dict[str, Union[Input, str, bool, int, float]]) -> None:
+ self._inputs: Dict = {}
+ if not value:
+ return
+
+ for input_name, input_value in value.items():
+ self._inputs[input_name] = build_input_output(input_value)
+
+ @property
+ def outputs(self) -> Dict[str, Output]:
+ return self._outputs
+
+ @outputs.setter
+ def outputs(self, value: Dict[str, Output]) -> None:
+ self._outputs: Dict = {}
+ if not value:
+ return
+
+ for output_name, output_value in value.items():
+ self._outputs[output_name] = build_input_output(output_value, inputs=False)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_limits.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_limits.py
new file mode 100644
index 00000000..7aae9263
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_limits.py
@@ -0,0 +1,201 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import logging
+from abc import ABC
+from typing import Any, Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import CommandJobLimits as RestCommandJobLimits
+from azure.ai.ml._restclient.v2023_08_01_preview.models import SweepJobLimits as RestSweepJobLimits
+from azure.ai.ml._utils.utils import from_iso_duration_format, is_data_binding_expression, to_iso_duration_format
+from azure.ai.ml.constants import JobType
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+module_logger = logging.getLogger(__name__)
+
+
+class JobLimits(RestTranslatableMixin, ABC):
+ """Base class for Job limits.
+
+ This class should not be instantiated directly. Instead, one of its child classes should be used.
+ """
+
+ def __init__(
+ self,
+ ) -> None:
+ self.type: Any = None
+
+ def __eq__(self, other: Any) -> bool:
+ if not isinstance(other, JobLimits):
+ return NotImplemented
+ res: bool = self._to_rest_object() == other._to_rest_object()
+ return res
+
+
+class CommandJobLimits(JobLimits):
+ """Limits for Command Jobs.
+
+ :keyword timeout: The maximum run duration, in seconds, after which the job will be cancelled.
+ :paramtype timeout: Optional[Union[int, str]]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_command_configurations.py
+ :start-after: [START command_job_definition]
+ :end-before: [END command_job_definition]
+ :language: python
+ :dedent: 8
+ :caption: Configuring a CommandJob with CommandJobLimits.
+ """
+
+ def __init__(self, *, timeout: Optional[Union[int, str]] = None) -> None:
+ super().__init__()
+ self.type = JobType.COMMAND
+ self.timeout = timeout
+
+ def _to_rest_object(self) -> RestCommandJobLimits:
+ if is_data_binding_expression(self.timeout):
+ return RestCommandJobLimits(timeout=self.timeout)
+ return RestCommandJobLimits(timeout=to_iso_duration_format(self.timeout))
+
+ @classmethod
+ def _from_rest_object(cls, obj: Union[RestCommandJobLimits, dict]) -> Optional["CommandJobLimits"]:
+ if not obj:
+ return None
+ if isinstance(obj, dict):
+ timeout_value = obj.get("timeout", None)
+ # if timeout value is a binding string
+ if is_data_binding_expression(timeout_value):
+ return cls(timeout=timeout_value)
+ # if response timeout is a normal iso date string
+ obj = RestCommandJobLimits.from_dict(obj)
+ return cls(timeout=from_iso_duration_format(obj.timeout))
+
+
+class SweepJobLimits(JobLimits):
+ """Limits for Sweep Jobs.
+
+ :keyword max_concurrent_trials: The maximum number of concurrent trials for the Sweep Job.
+ :paramtype max_concurrent_trials: Optional[int]
+ :keyword max_total_trials: The maximum number of total trials for the Sweep Job.
+ :paramtype max_total_trials: Optional[int]
+ :keyword timeout: The maximum run duration, in seconds, after which the job will be cancelled.
+ :paramtype timeout: Optional[int]
+ :keyword trial_timeout: The timeout value, in seconds, for each Sweep Job trial.
+ :paramtype trial_timeout: Optional[int]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_sweep_configurations.py
+ :start-after: [START configure_sweep_job_bayesian_sampling_algorithm]
+ :end-before: [END configure_sweep_job_bayesian_sampling_algorithm]
+ :language: python
+ :dedent: 8
+ :caption: Assigning limits to a SweepJob
+ """
+
+ def __init__(
+ self,
+ *,
+ max_concurrent_trials: Optional[int] = None,
+ max_total_trials: Optional[int] = None,
+ timeout: Optional[int] = None,
+ trial_timeout: Optional[Union[int, str]] = None,
+ ) -> None:
+ super().__init__()
+ self.type = JobType.SWEEP
+ self.max_concurrent_trials = max_concurrent_trials
+ self.max_total_trials = max_total_trials
+ self._timeout = _get_floored_timeout(timeout)
+ self._trial_timeout = _get_floored_timeout(trial_timeout)
+
+ @property
+ def timeout(self) -> Optional[Union[int, str]]:
+ """The maximum run duration, in seconds, after which the job will be cancelled.
+
+ :return: The maximum run duration, in seconds, after which the job will be cancelled.
+ :rtype: int
+ """
+ return self._timeout
+
+ @timeout.setter
+ def timeout(self, value: int) -> None:
+ """Sets the maximum run duration.
+
+ :param value: The maximum run duration, in seconds, after which the job will be cancelled.
+ :type value: int
+ """
+ self._timeout = _get_floored_timeout(value)
+
+ @property
+ def trial_timeout(self) -> Optional[Union[int, str]]:
+ """The timeout value, in seconds, for each Sweep Job trial.
+
+ :return: The timeout value, in seconds, for each Sweep Job trial.
+ :rtype: int
+ """
+ return self._trial_timeout
+
+ @trial_timeout.setter
+ def trial_timeout(self, value: int) -> None:
+ """Sets the timeout value for each Sweep Job trial.
+
+ :param value: The timeout value, in seconds, for each Sweep Job trial.
+ :type value: int
+ """
+ self._trial_timeout = _get_floored_timeout(value)
+
+ def _to_rest_object(self) -> RestSweepJobLimits:
+ return RestSweepJobLimits(
+ max_concurrent_trials=self.max_concurrent_trials,
+ max_total_trials=self.max_total_trials,
+ timeout=to_iso_duration_format(self.timeout),
+ trial_timeout=to_iso_duration_format(self.trial_timeout),
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestSweepJobLimits) -> Optional["SweepJobLimits"]:
+ if not obj:
+ return None
+
+ return cls(
+ max_concurrent_trials=obj.max_concurrent_trials,
+ max_total_trials=obj.max_total_trials,
+ timeout=from_iso_duration_format(obj.timeout),
+ trial_timeout=from_iso_duration_format(obj.trial_timeout),
+ )
+
+
+def _get_floored_timeout(value: Optional[Union[int, str]]) -> Optional[Union[int, str]]:
+ # Bug 1335978: Service rounds durations less than 60 seconds to 60 days.
+ # If duration is non-0 and less than 60, set to 60.
+ if isinstance(value, int):
+ return value if not value or value > 60 else 60
+
+ return None
+
+
+class DoWhileJobLimits(JobLimits):
+ """DoWhile Job limit class.
+
+ :keyword max_iteration_count: The maximum number of iterations for the DoWhile Job.
+ :paramtype max_iteration_count: Optional[int]
+ """
+
+ def __init__(
+ self, # pylint: disable=unused-argument
+ *,
+ max_iteration_count: Optional[int] = None,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__()
+ self._max_iteration_count = max_iteration_count
+
+ @property
+ def max_iteration_count(self) -> Optional[int]:
+ """The maximum number of iterations for the DoWhile Job.
+
+ :rtype: int
+ """
+ return self._max_iteration_count
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_name_generator.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_name_generator.py
new file mode 100644
index 00000000..e4f62d3d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_name_generator.py
@@ -0,0 +1,487 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import random
+
+SUFFIX_LENGTH = 10
+ALLOWED_CHARS = "bcdfghjklmnpqrstvwxyz0123456789"
+
+ALLOWED_ADJECTIVES = [
+ "affable",
+ "amiable",
+ "amusing",
+ "ashy",
+ "blue",
+ "bold",
+ "boring",
+ "brave",
+ "bright",
+ "bubbly",
+ "busy",
+ "calm",
+ "careful",
+ "clever",
+ "cool",
+ "coral",
+ "crimson",
+ "cyan",
+ "dreamy",
+ "dynamic",
+ "eager",
+ "elated",
+ "epic",
+ "frank",
+ "frosty",
+ "funny",
+ "gentle",
+ "gifted",
+ "good",
+ "goofy",
+ "gray",
+ "great",
+ "green",
+ "happy",
+ "helpful",
+ "heroic",
+ "honest",
+ "hungry",
+ "icy",
+ "ivory",
+ "jolly",
+ "jovial",
+ "joyful",
+ "keen",
+ "khaki",
+ "kind",
+ "lemon",
+ "lime",
+ "loving",
+ "loyal",
+ "lucid",
+ "magenta",
+ "mango",
+ "maroon",
+ "mighty",
+ "modest",
+ "musing",
+ "neat",
+ "nice",
+ "nifty",
+ "olden",
+ "olive",
+ "orange",
+ "patient",
+ "placid",
+ "plucky",
+ "plum",
+ "polite",
+ "purple",
+ "quiet",
+ "quirky",
+ "red",
+ "sad",
+ "salmon",
+ "serene",
+ "sharp",
+ "shy",
+ "silly",
+ "silver",
+ "sincere",
+ "sleepy",
+ "stoic",
+ "strong",
+ "sweet",
+ "teal",
+ "tender",
+ "tidy",
+ "tough",
+ "upbeat",
+ "wheat",
+ "willing",
+ "witty",
+ "yellow",
+ "zen",
+]
+
+ALLOWED_NOUNS = [
+ "actor",
+ "airport",
+ "angle",
+ "animal",
+ "answer",
+ "ant",
+ "apple",
+ "apricot",
+ "arch",
+ "arm",
+ "atemoya",
+ "avocado",
+ "bag",
+ "ball",
+ "balloon",
+ "band",
+ "basil",
+ "basin",
+ "basket",
+ "battery",
+ "beach",
+ "bean",
+ "bear",
+ "beard",
+ "bee",
+ "beet",
+ "bell",
+ "berry",
+ "bird",
+ "board",
+ "boat",
+ "bone",
+ "boniato",
+ "book",
+ "boot",
+ "bottle",
+ "box",
+ "brain",
+ "brake",
+ "branch",
+ "bread",
+ "brick",
+ "bridge",
+ "brush",
+ "bucket",
+ "bulb",
+ "button",
+ "cabbage",
+ "cake",
+ "calypso",
+ "camel",
+ "camera",
+ "candle",
+ "car",
+ "caravan",
+ "card",
+ "carnival",
+ "carpet",
+ "carrot",
+ "cart",
+ "cartoon",
+ "cassava",
+ "cat",
+ "celery",
+ "chaconia",
+ "chain",
+ "chayote",
+ "cheese",
+ "cheetah",
+ "cherry",
+ "chicken",
+ "chin",
+ "circle",
+ "clock",
+ "cloud",
+ "coat",
+ "coconut",
+ "collar",
+ "comb",
+ "cord",
+ "corn",
+ "cow",
+ "crayon",
+ "crowd",
+ "cumin",
+ "cup",
+ "curtain",
+ "cushion",
+ "date",
+ "deer",
+ "diamond",
+ "dinner",
+ "dog",
+ "dolphin",
+ "door",
+ "double",
+ "drain",
+ "drawer",
+ "dream",
+ "dress",
+ "drop",
+ "duck",
+ "eagle",
+ "ear",
+ "egg",
+ "endive",
+ "energy",
+ "engine",
+ "evening",
+ "eye",
+ "farm",
+ "feast",
+ "feather",
+ "feijoa",
+ "fennel",
+ "fig",
+ "fish",
+ "flag",
+ "floor",
+ "flower",
+ "fly",
+ "foot",
+ "forest",
+ "fork",
+ "fowl",
+ "fox",
+ "frame",
+ "frog",
+ "garage",
+ "garden",
+ "garlic",
+ "gas",
+ "ghost",
+ "giraffe",
+ "glass",
+ "glove",
+ "goat",
+ "gold",
+ "grape",
+ "grass",
+ "guava",
+ "guitar",
+ "gyro",
+ "hair",
+ "hamster",
+ "hand",
+ "hat",
+ "head",
+ "heart",
+ "helmet",
+ "holiday",
+ "hominy",
+ "honey",
+ "hook",
+ "horse",
+ "house",
+ "ice",
+ "insect",
+ "iron",
+ "island",
+ "jackal",
+ "jelly",
+ "jewel",
+ "jicama",
+ "juice",
+ "kale",
+ "kettle",
+ "key",
+ "king",
+ "kitchen",
+ "kite",
+ "kitten",
+ "kiwi",
+ "knee",
+ "knot",
+ "kumquat",
+ "lamp",
+ "leaf",
+ "leather",
+ "leek",
+ "leg",
+ "lemon",
+ "lettuce",
+ "library",
+ "lime",
+ "line",
+ "lion",
+ "lizard",
+ "lobster",
+ "lock",
+ "longan",
+ "loquat",
+ "lunch",
+ "lychee",
+ "machine",
+ "malanga",
+ "mango",
+ "mangos",
+ "map",
+ "market",
+ "match",
+ "melon",
+ "milk",
+ "monkey",
+ "moon",
+ "morning",
+ "muscle",
+ "music",
+ "nail",
+ "napa",
+ "napkin",
+ "neck",
+ "needle",
+ "nerve",
+ "nest",
+ "net",
+ "night",
+ "nose",
+ "nut",
+ "nutmeg",
+ "ocean",
+ "octopus",
+ "office",
+ "oil",
+ "okra",
+ "onion",
+ "orange",
+ "oregano",
+ "oven",
+ "owl",
+ "oxygen",
+ "oyster",
+ "panda",
+ "papaya",
+ "parang",
+ "parcel",
+ "parrot",
+ "parsnip",
+ "pasta",
+ "pea",
+ "peach",
+ "pear",
+ "pen",
+ "pencil",
+ "pepper",
+ "piano",
+ "picture",
+ "pig",
+ "pillow",
+ "pin",
+ "pipe",
+ "pizza",
+ "plane",
+ "planet",
+ "plastic",
+ "plate",
+ "plow",
+ "plum",
+ "pocket",
+ "pot",
+ "potato",
+ "prune",
+ "pummelo",
+ "pump",
+ "pumpkin",
+ "puppy",
+ "queen",
+ "quill",
+ "quince",
+ "rabbit",
+ "rail",
+ "rain",
+ "rainbow",
+ "raisin",
+ "rat",
+ "receipt",
+ "reggae",
+ "rhubarb",
+ "rhythm",
+ "rice",
+ "ring",
+ "river",
+ "rocket",
+ "rod",
+ "roof",
+ "room",
+ "root",
+ "rose",
+ "roti",
+ "sail",
+ "salt",
+ "sand",
+ "school",
+ "scooter",
+ "screw",
+ "seal",
+ "seed",
+ "shampoo",
+ "shark",
+ "sheep",
+ "shelf",
+ "ship",
+ "shirt",
+ "shoe",
+ "skin",
+ "snail",
+ "snake",
+ "soca",
+ "soccer",
+ "sock",
+ "soursop",
+ "spade",
+ "spider",
+ "spinach",
+ "sponge",
+ "spoon",
+ "spring",
+ "sprout",
+ "square",
+ "squash",
+ "stamp",
+ "star",
+ "station",
+ "steelpan",
+ "stem",
+ "stick",
+ "stomach",
+ "stone",
+ "store",
+ "street",
+ "sugar",
+ "sun",
+ "table",
+ "tail",
+ "tangelo",
+ "tent",
+ "thread",
+ "ticket",
+ "tiger",
+ "toe",
+ "tomato",
+ "tongue",
+ "tooth",
+ "town",
+ "train",
+ "tray",
+ "tree",
+ "truck",
+ "turnip",
+ "turtle",
+ "van",
+ "vase",
+ "vinegar",
+ "vulture",
+ "wall",
+ "watch",
+ "whale",
+ "wheel",
+ "whistle",
+ "window",
+ "wing",
+ "wire",
+ "wolf",
+ "worm",
+ "yacht",
+ "yak",
+ "yam",
+ "yogurt",
+ "yuca",
+ "zebra",
+ "zoo",
+]
+
+
+def generate_job_name() -> str:
+ adj = random.choice(ALLOWED_ADJECTIVES)
+ noun = random.choice(ALLOWED_NOUNS)
+ suffix = "".join(random.choices(ALLOWED_CHARS, k=SUFFIX_LENGTH))
+
+ return "_".join([adj, noun, suffix])
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_resource_configuration.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_resource_configuration.py
new file mode 100644
index 00000000..a27b5ba1
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_resource_configuration.py
@@ -0,0 +1,239 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import json
+import logging
+from typing import Any, Dict, List, Optional, Union, cast
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import JobResourceConfiguration as RestJobResourceConfiguration
+from azure.ai.ml._restclient.v2025_01_01_preview.models import (
+ JobResourceConfiguration as RestJobResourceConfiguration202501,
+)
+from azure.ai.ml.constants._job.job import JobComputePropertyFields
+from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin
+from azure.ai.ml.entities._util import convert_ordered_dict_to_dict
+
+module_logger = logging.getLogger(__name__)
+
+
+class BaseProperty(dict):
+ """Base class for entity classes to be used as value of JobResourceConfiguration.properties."""
+
+ def __init__(self, **kwargs: Any) -> None:
+ super().__init__()
+ for key, value in kwargs.items():
+ setattr(self, key, value)
+
+ def __setattr__(self, key: str, value: Any) -> None:
+ if key.startswith("_"):
+ super().__setattr__(key, value)
+ else:
+ self[key] = value
+
+ def __getattr__(self, key: str) -> Any:
+ if key.startswith("_"):
+ super().__getattribute__(key)
+ return None
+
+ return self[key]
+
+ def __repr__(self) -> str:
+ return json.dumps(self.as_dict())
+
+ def __eq__(self, other: Any) -> bool:
+ if isinstance(other, dict):
+ return self.as_dict() == other
+ if isinstance(other, BaseProperty):
+ return self.as_dict() == other.as_dict()
+ return False
+
+ def as_dict(self) -> Dict[str, Any]:
+ res: dict = self._to_dict(self)
+ return res
+
+ @classmethod
+ def _to_dict(cls, obj: Any) -> Any:
+ if isinstance(obj, dict):
+ result = {}
+ for key, value in obj.items():
+ if value is None:
+ continue
+ if isinstance(value, dict):
+ result[key] = cls._to_dict(value)
+ else:
+ result[key] = value
+ return result
+ return obj
+
+
+class Properties(BaseProperty):
+ # pre-defined properties are case-insensitive
+ # Map Singularity -> AISupercomputer in SDK until MFE does mapping
+ _KEY_MAPPING = {
+ JobComputePropertyFields.AISUPERCOMPUTER.lower(): JobComputePropertyFields.AISUPERCOMPUTER,
+ JobComputePropertyFields.SINGULARITY.lower(): JobComputePropertyFields.AISUPERCOMPUTER,
+ JobComputePropertyFields.ITP.lower(): JobComputePropertyFields.ITP,
+ JobComputePropertyFields.TARGET_SELECTOR.lower(): JobComputePropertyFields.TARGET_SELECTOR,
+ }
+
+ def as_dict(self) -> Dict[str, Any]:
+ result = {}
+ for key, value in super().as_dict().items():
+ if key.lower() in self._KEY_MAPPING:
+ key = self._KEY_MAPPING[key.lower()]
+ result[key] = value
+ # recursively convert Ordered Dict to dictionary
+ return cast(dict, convert_ordered_dict_to_dict(result))
+
+
+class JobResourceConfiguration(RestTranslatableMixin, DictMixin):
+ """Job resource configuration class, inherited and extended functionalities from ResourceConfiguration.
+
+ :keyword locations: A list of locations where the job can run.
+ :paramtype locations: Optional[List[str]]
+ :keyword instance_count: The number of instances or nodes used by the compute target.
+ :paramtype instance_count: Optional[int]
+ :keyword instance_type: The type of VM to be used, as supported by the compute target.
+ :paramtype instance_type: Optional[str]
+ :keyword properties: A dictionary of properties for the job.
+ :paramtype properties: Optional[dict[str, Any]]
+ :keyword docker_args: Extra arguments to pass to the Docker run command. This would override any
+ parameters that have already been set by the system, or in this section. This parameter is only
+ supported for Azure ML compute types.
+ :paramtype docker_args: Optional[Union[str, List[str]]]
+ :keyword shm_size: The size of the docker container's shared memory block. This should be in the
+ format of (number)(unit) where the number has to be greater than 0 and the unit can be one of
+ b(bytes), k(kilobytes), m(megabytes), or g(gigabytes).
+ :paramtype shm_size: Optional[str]
+ :keyword max_instance_count: The maximum number of instances or nodes used by the compute target.
+ :paramtype max_instance_count: Optional[int]
+ :keyword kwargs: A dictionary of additional configuration parameters.
+ :paramtype kwargs: dict
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_command_configurations.py
+ :start-after: [START command_job_resource_configuration]
+ :end-before: [END command_job_resource_configuration]
+ :language: python
+ :dedent: 8
+ :caption: Configuring a CommandJob with a JobResourceConfiguration.
+ """
+
+ def __init__(
+ self, # pylint: disable=unused-argument
+ *,
+ locations: Optional[List[str]] = None,
+ instance_count: Optional[int] = None,
+ instance_type: Optional[Union[str, List]] = None,
+ properties: Optional[Union[Properties, Dict]] = None,
+ docker_args: Optional[Union[str, List[str]]] = None,
+ shm_size: Optional[str] = None,
+ max_instance_count: Optional[int] = None,
+ **kwargs: Any
+ ) -> None:
+ self.locations = locations
+ self.instance_count = instance_count
+ self.instance_type = instance_type
+ self.shm_size = shm_size
+ self.max_instance_count = max_instance_count
+ self.docker_args = docker_args
+ self._properties = None
+ self.properties = properties
+
+ @property
+ def properties(self) -> Optional[Union[Properties, Dict]]:
+ """The properties of the job.
+
+ :rtype: ~azure.ai.ml.entities._job.job_resource_configuration.Properties
+ """
+ return self._properties
+
+ @properties.setter
+ def properties(self, properties: Dict[str, Any]) -> None:
+ """Sets the properties of the job.
+
+ :param properties: A dictionary of properties for the job.
+ :type properties: Dict[str, Any]
+ :raises TypeError: Raised if properties is not a dictionary type.
+ """
+ if properties is None:
+ self._properties = Properties()
+ elif isinstance(properties, dict):
+ self._properties = Properties(**properties)
+ else:
+ raise TypeError("properties must be a dict.")
+
+ def _to_rest_object(self) -> Union[RestJobResourceConfiguration, RestJobResourceConfiguration202501]:
+ if self.docker_args and isinstance(self.docker_args, list):
+ return RestJobResourceConfiguration202501(
+ instance_count=self.instance_count,
+ instance_type=self.instance_type,
+ max_instance_count=self.max_instance_count,
+ properties=self.properties.as_dict() if isinstance(self.properties, Properties) else None,
+ docker_args_list=self.docker_args,
+ shm_size=self.shm_size,
+ )
+ return RestJobResourceConfiguration(
+ locations=self.locations,
+ instance_count=self.instance_count,
+ instance_type=self.instance_type,
+ max_instance_count=self.max_instance_count,
+ properties=self.properties.as_dict() if isinstance(self.properties, Properties) else None,
+ docker_args=self.docker_args,
+ shm_size=self.shm_size,
+ )
+
+ @classmethod
+ def _from_rest_object(
+ cls, obj: Optional[Union[RestJobResourceConfiguration, RestJobResourceConfiguration202501]]
+ ) -> Optional["JobResourceConfiguration"]:
+ if obj is None:
+ return None
+ if isinstance(obj, dict):
+ return cls(**obj)
+ return JobResourceConfiguration(
+ locations=obj.locations if hasattr(obj, "locations") else None,
+ instance_count=obj.instance_count,
+ instance_type=obj.instance_type,
+ max_instance_count=obj.max_instance_count if hasattr(obj, "max_instance_count") else None,
+ properties=obj.properties,
+ docker_args=obj.docker_args_list if hasattr(obj, "docker_args_list") else obj.docker_args,
+ shm_size=obj.shm_size,
+ deserialize_properties=True,
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, JobResourceConfiguration):
+ return NotImplemented
+ return (
+ self.locations == other.locations
+ and self.instance_count == other.instance_count
+ and self.instance_type == other.instance_type
+ and self.max_instance_count == other.max_instance_count
+ and self.docker_args == other.docker_args
+ and self.shm_size == other.shm_size
+ )
+
+ def __ne__(self, other: object) -> bool:
+ if not isinstance(other, JobResourceConfiguration):
+ return NotImplemented
+ return not self.__eq__(other)
+
+ def _merge_with(self, other: "JobResourceConfiguration") -> None:
+ if other:
+ if other.locations:
+ self.locations = other.locations
+ if other.instance_count:
+ self.instance_count = other.instance_count
+ if other.instance_type:
+ self.instance_type = other.instance_type
+ if other.max_instance_count:
+ self.max_instance_count = other.max_instance_count
+ if other.properties:
+ self.properties = other.properties
+ if other.docker_args:
+ self.docker_args = other.docker_args
+ if other.shm_size:
+ self.shm_size = other.shm_size
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_resources.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_resources.py
new file mode 100644
index 00000000..bd1cdad5
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_resources.py
@@ -0,0 +1,33 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Any, List
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+from azure.ai.ml._restclient.v2024_10_01_preview.models import JobResources as RestJobResources
+
+
+class JobResources(RestTranslatableMixin):
+ """Resource configuration for a job.
+
+ This class should not be instantiated directly. Instead, use its subclasses.
+ """
+
+ def __init__(self, *, instance_types: List[str]) -> None:
+ self.instance_types = instance_types
+
+ def _to_rest_object(self) -> Any:
+ return RestJobResources(instance_types=self.instance_types)
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestJobResources) -> "JobResources":
+ job_resources = cls(instance_types=obj.instance_types)
+ return job_resources
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, JobResources):
+ return NotImplemented
+ return self.instance_types == other.instance_types
+
+ def __ne__(self, other: object) -> bool:
+ return not self.__eq__(other)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_service.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_service.py
new file mode 100644
index 00000000..a97048fc
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/job_service.py
@@ -0,0 +1,424 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import logging
+from typing import Any, Dict, Optional, cast
+
+from typing_extensions import Literal
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import AllNodes
+from azure.ai.ml._restclient.v2023_04_01_preview.models import JobService as RestJobService
+from azure.ai.ml.constants._job.job import JobServiceTypeNames
+from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+
+module_logger = logging.getLogger(__name__)
+
+
+class JobServiceBase(RestTranslatableMixin, DictMixin):
+ """Base class for job service configuration.
+
+ This class should not be instantiated directly. Instead, use one of its subclasses.
+
+ :keyword endpoint: The endpoint URL.
+ :paramtype endpoint: Optional[str]
+ :keyword type: The endpoint type. Accepted values are "jupyter_lab", "ssh", "tensor_board", and "vs_code".
+ :paramtype type: Optional[Literal["jupyter_lab", "ssh", "tensor_board", "vs_code"]]
+ :keyword port: The port for the endpoint.
+ :paramtype port: Optional[int]
+ :keyword nodes: Indicates whether the service has to run in all nodes.
+ :paramtype nodes: Optional[Literal["all"]]
+ :keyword properties: Additional properties to set on the endpoint.
+ :paramtype properties: Optional[dict[str, str]]
+ :keyword status: The status of the endpoint.
+ :paramtype status: Optional[str]
+ :keyword kwargs: A dictionary of additional configuration parameters.
+ :paramtype kwargs: dict
+ """
+
+ def __init__( # pylint: disable=unused-argument
+ self,
+ *,
+ endpoint: Optional[str] = None,
+ type: Optional[ # pylint: disable=redefined-builtin
+ Literal["jupyter_lab", "ssh", "tensor_board", "vs_code"]
+ ] = None,
+ nodes: Optional[Literal["all"]] = None,
+ status: Optional[str] = None,
+ port: Optional[int] = None,
+ properties: Optional[Dict[str, str]] = None,
+ **kwargs: Dict,
+ ) -> None:
+ self.endpoint = endpoint
+ self.type: Any = type
+ self.nodes = nodes
+ self.status = status
+ self.port = port
+ self.properties = properties
+ self._validate_nodes()
+ self._validate_type_name()
+
+ def _validate_nodes(self) -> None:
+ if not self.nodes in ["all", None]:
+ msg = f"nodes should be either 'all' or None, but received '{self.nodes}'."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+ def _validate_type_name(self) -> None:
+ if self.type and not self.type in JobServiceTypeNames.ENTITY_TO_REST:
+ msg = (
+ f"type should be one of " f"{JobServiceTypeNames.NAMES_ALLOWED_FOR_PUBLIC}, but received '{self.type}'."
+ )
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+ def _to_rest_job_service(self, updated_properties: Optional[Dict[str, str]] = None) -> RestJobService:
+ return RestJobService(
+ endpoint=self.endpoint,
+ job_service_type=JobServiceTypeNames.ENTITY_TO_REST.get(self.type, None) if self.type else None,
+ nodes=AllNodes() if self.nodes else None,
+ status=self.status,
+ port=self.port,
+ properties=updated_properties if updated_properties else self.properties,
+ )
+
+ @classmethod
+ def _to_rest_job_services(
+ cls,
+ services: Optional[Dict],
+ ) -> Optional[Dict[str, RestJobService]]:
+ if services is None:
+ return None
+
+ return {name: service._to_rest_object() for name, service in services.items()}
+
+ @classmethod
+ def _from_rest_job_service_object(cls, obj: RestJobService) -> "JobServiceBase":
+ return cls(
+ endpoint=obj.endpoint,
+ type=(
+ JobServiceTypeNames.REST_TO_ENTITY.get(obj.job_service_type, None) # type: ignore[arg-type]
+ if obj.job_service_type
+ else None
+ ),
+ nodes="all" if obj.nodes else None,
+ status=obj.status,
+ port=obj.port,
+ # ssh_public_keys=_get_property(obj.properties, "sshPublicKeys"),
+ properties=obj.properties,
+ )
+
+ @classmethod
+ def _from_rest_job_services(cls, services: Dict[str, RestJobService]) -> Dict:
+ # """Resolve Dict[str, RestJobService] to Dict[str, Specific JobService]"""
+ if services is None:
+ return None
+
+ result: dict = {}
+ for name, service in services.items():
+ if service.job_service_type == JobServiceTypeNames.RestNames.JUPYTER_LAB:
+ result[name] = JupyterLabJobService._from_rest_object(service)
+ elif service.job_service_type == JobServiceTypeNames.RestNames.SSH:
+ result[name] = SshJobService._from_rest_object(service)
+ elif service.job_service_type == JobServiceTypeNames.RestNames.TENSOR_BOARD:
+ result[name] = TensorBoardJobService._from_rest_object(service)
+ elif service.job_service_type == JobServiceTypeNames.RestNames.VS_CODE:
+ result[name] = VsCodeJobService._from_rest_object(service)
+ else:
+ result[name] = JobService._from_rest_object(service)
+ return result
+
+
+class JobService(JobServiceBase):
+ """Basic job service configuration for backward compatibility.
+
+ This class is not intended to be used directly. Instead, use one of its subclasses specific to your job type.
+
+ :keyword endpoint: The endpoint URL.
+ :paramtype endpoint: Optional[str]
+ :keyword type: The endpoint type. Accepted values are "jupyter_lab", "ssh", "tensor_board", and "vs_code".
+ :paramtype type: Optional[Literal["jupyter_lab", "ssh", "tensor_board", "vs_code"]]
+ :keyword port: The port for the endpoint.
+ :paramtype port: Optional[int]
+ :keyword nodes: Indicates whether the service has to run in all nodes.
+ :paramtype nodes: Optional[Literal["all"]]
+ :keyword properties: Additional properties to set on the endpoint.
+ :paramtype properties: Optional[dict[str, str]]
+ :keyword status: The status of the endpoint.
+ :paramtype status: Optional[str]
+ :keyword kwargs: A dictionary of additional configuration parameters.
+ :paramtype kwargs: dict
+ """
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestJobService) -> "JobService":
+ return cast(JobService, cls._from_rest_job_service_object(obj))
+
+ def _to_rest_object(self) -> RestJobService:
+ return self._to_rest_job_service()
+
+
+class SshJobService(JobServiceBase):
+ """SSH job service configuration.
+
+ :ivar type: Specifies the type of job service. Set automatically to "ssh" for this class.
+ :vartype type: str
+ :keyword endpoint: The endpoint URL.
+ :paramtype endpoint: Optional[str]
+ :keyword port: The port for the endpoint.
+ :paramtype port: Optional[int]
+ :keyword nodes: Indicates whether the service has to run in all nodes.
+ :paramtype nodes: Optional[Literal["all"]]
+ :keyword properties: Additional properties to set on the endpoint.
+ :paramtype properties: Optional[dict[str, str]]
+ :keyword status: The status of the endpoint.
+ :paramtype status: Optional[str]
+ :keyword ssh_public_keys: The SSH Public Key to access the job container.
+ :paramtype ssh_public_keys: Optional[str]
+ :keyword kwargs: A dictionary of additional configuration parameters.
+ :paramtype kwargs: dict
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START ssh_job_service_configuration]
+ :end-before: [END ssh_job_service_configuration]
+ :language: python
+ :dedent: 8
+ :caption: Configuring a SshJobService configuration on a command job.
+ """
+
+ def __init__(
+ self,
+ *,
+ endpoint: Optional[str] = None,
+ nodes: Optional[Literal["all"]] = None,
+ status: Optional[str] = None,
+ port: Optional[int] = None,
+ ssh_public_keys: Optional[str] = None,
+ properties: Optional[Dict[str, str]] = None,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ endpoint=endpoint,
+ nodes=nodes,
+ status=status,
+ port=port,
+ properties=properties,
+ **kwargs,
+ )
+ self.type = JobServiceTypeNames.EntityNames.SSH
+ self.ssh_public_keys = ssh_public_keys
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestJobService) -> "SshJobService":
+ ssh_job_service = cast(SshJobService, cls._from_rest_job_service_object(obj))
+ ssh_job_service.ssh_public_keys = _get_property(obj.properties, "sshPublicKeys")
+ return ssh_job_service
+
+ def _to_rest_object(self) -> RestJobService:
+ updated_properties = _append_or_update_properties(self.properties, "sshPublicKeys", self.ssh_public_keys)
+ return self._to_rest_job_service(updated_properties)
+
+
+class TensorBoardJobService(JobServiceBase):
+ """TensorBoard job service configuration.
+
+ :ivar type: Specifies the type of job service. Set automatically to "tensor_board" for this class.
+ :vartype type: str
+ :keyword endpoint: The endpoint URL.
+ :paramtype endpoint: Optional[str]
+ :keyword port: The port for the endpoint.
+ :paramtype port: Optional[int]
+ :keyword nodes: Indicates whether the service has to run in all nodes.
+ :paramtype nodes: Optional[Literal["all"]]
+ :keyword properties: Additional properties to set on the endpoint.
+ :paramtype properties: Optional[dict[str, str]]
+ :keyword status: The status of the endpoint.
+ :paramtype status: Optional[str]
+ :keyword log_dir: The directory path for the log file.
+ :paramtype log_dir: Optional[str]
+ :keyword kwargs: A dictionary of additional configuration parameters.
+ :paramtype kwargs: dict
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START ssh_job_service_configuration]
+ :end-before: [END ssh_job_service_configuration]
+ :language: python
+ :dedent: 8
+ :caption: Configuring TensorBoardJobService configuration on a command job.
+ """
+
+ def __init__(
+ self,
+ *,
+ endpoint: Optional[str] = None,
+ nodes: Optional[Literal["all"]] = None,
+ status: Optional[str] = None,
+ port: Optional[int] = None,
+ log_dir: Optional[str] = None,
+ properties: Optional[Dict[str, str]] = None,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ endpoint=endpoint,
+ nodes=nodes,
+ status=status,
+ port=port,
+ properties=properties,
+ **kwargs,
+ )
+ self.type = JobServiceTypeNames.EntityNames.TENSOR_BOARD
+ self.log_dir = log_dir
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestJobService) -> "TensorBoardJobService":
+ tensorboard_job_Service = cast(TensorBoardJobService, cls._from_rest_job_service_object(obj))
+ tensorboard_job_Service.log_dir = _get_property(obj.properties, "logDir")
+ return tensorboard_job_Service
+
+ def _to_rest_object(self) -> RestJobService:
+ updated_properties = _append_or_update_properties(self.properties, "logDir", self.log_dir)
+ return self._to_rest_job_service(updated_properties)
+
+
+class JupyterLabJobService(JobServiceBase):
+ """JupyterLab job service configuration.
+
+ :ivar type: Specifies the type of job service. Set automatically to "jupyter_lab" for this class.
+ :vartype type: str
+ :keyword endpoint: The endpoint URL.
+ :paramtype endpoint: Optional[str]
+ :keyword port: The port for the endpoint.
+ :paramtype port: Optional[int]
+ :keyword nodes: Indicates whether the service has to run in all nodes.
+ :paramtype nodes: Optional[Literal["all"]]
+ :keyword properties: Additional properties to set on the endpoint.
+ :paramtype properties: Optional[dict[str, str]]
+ :keyword status: The status of the endpoint.
+ :paramtype status: Optional[str]
+ :keyword kwargs: A dictionary of additional configuration parameters.
+ :paramtype kwargs: dict
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START ssh_job_service_configuration]
+ :end-before: [END ssh_job_service_configuration]
+ :language: python
+ :dedent: 8
+ :caption: Configuring JupyterLabJobService configuration on a command job.
+ """
+
+ def __init__(
+ self,
+ *,
+ endpoint: Optional[str] = None,
+ nodes: Optional[Literal["all"]] = None,
+ status: Optional[str] = None,
+ port: Optional[int] = None,
+ properties: Optional[Dict[str, str]] = None,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ endpoint=endpoint,
+ nodes=nodes,
+ status=status,
+ port=port,
+ properties=properties,
+ **kwargs,
+ )
+ self.type = JobServiceTypeNames.EntityNames.JUPYTER_LAB
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestJobService) -> "JupyterLabJobService":
+ return cast(JupyterLabJobService, cls._from_rest_job_service_object(obj))
+
+ def _to_rest_object(self) -> RestJobService:
+ return self._to_rest_job_service()
+
+
+class VsCodeJobService(JobServiceBase):
+ """VS Code job service configuration.
+
+ :ivar type: Specifies the type of job service. Set automatically to "vs_code" for this class.
+ :vartype type: str
+ :keyword endpoint: The endpoint URL.
+ :paramtype endpoint: Optional[str]
+ :keyword port: The port for the endpoint.
+ :paramtype port: Optional[int]
+ :keyword nodes: Indicates whether the service has to run in all nodes.
+ :paramtype nodes: Optional[Literal["all"]]
+ :keyword properties: Additional properties to set on the endpoint.
+ :paramtype properties: Optional[dict[str, str]]
+ :keyword status: The status of the endpoint.
+ :paramtype status: Optional[str]
+ :keyword kwargs: A dictionary of additional configuration parameters.
+ :paramtype kwargs: dict
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START ssh_job_service_configuration]
+ :end-before: [END ssh_job_service_configuration]
+ :language: python
+ :dedent: 8
+ :caption: Configuring a VsCodeJobService configuration on a command job.
+ """
+
+ def __init__(
+ self,
+ *,
+ endpoint: Optional[str] = None,
+ nodes: Optional[Literal["all"]] = None,
+ status: Optional[str] = None,
+ port: Optional[int] = None,
+ properties: Optional[Dict[str, str]] = None,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ endpoint=endpoint,
+ nodes=nodes,
+ status=status,
+ port=port,
+ properties=properties,
+ **kwargs,
+ )
+ self.type = JobServiceTypeNames.EntityNames.VS_CODE
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestJobService) -> "VsCodeJobService":
+ return cast(VsCodeJobService, cls._from_rest_job_service_object(obj))
+
+ def _to_rest_object(self) -> RestJobService:
+ return self._to_rest_job_service()
+
+
+def _append_or_update_properties(
+ properties: Optional[Dict[str, str]], key: str, value: Optional[str]
+) -> Dict[str, str]:
+ if value and not properties:
+ properties = {key: value}
+
+ if value and properties:
+ properties.update({key: value})
+ return properties if properties is not None else {}
+
+
+def _get_property(properties: Dict[str, str], key: str) -> Optional[str]:
+ return properties.get(key, None) if properties else None
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/__init__.py
new file mode 100644
index 00000000..fdf8caba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/parallel_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/parallel_job.py
new file mode 100644
index 00000000..49b2c992
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/parallel_job.py
@@ -0,0 +1,244 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import logging
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
+
+from azure.ai.ml._restclient.v2022_02_01_preview.models import JobBaseData
+from azure.ai.ml._schema.job.parallel_job import ParallelJobSchema
+from azure.ai.ml._utils.utils import is_data_binding_expression
+from azure.ai.ml.constants import JobType
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE
+from azure.ai.ml.entities._credentials import (
+ AmlTokenConfiguration,
+ ManagedIdentityConfiguration,
+ UserIdentityConfiguration,
+)
+from azure.ai.ml.entities._inputs_outputs import Input, Output
+from azure.ai.ml.entities._util import load_from_dict
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+
+from ..job import Job
+from ..job_io_mixin import JobIOMixin
+from .parameterized_parallel import ParameterizedParallel
+
+# avoid circular import error
+if TYPE_CHECKING:
+ from azure.ai.ml.entities._builders import Parallel
+ from azure.ai.ml.entities._component.parallel_component import ParallelComponent
+
+module_logger = logging.getLogger(__name__)
+
+
+class ParallelJob(Job, ParameterizedParallel, JobIOMixin):
+ """Parallel job.
+
+ :param name: Name of the job.
+ :type name: str
+ :param version: Version of the job.
+ :type version: str
+ :param id: Global id of the resource, Azure Resource Manager ID.
+ :type id: str
+ :param type: Type of the job, supported is 'parallel'.
+ :type type: str
+ :param description: Description of the job.
+ :type description: str
+ :param tags: Internal use only.
+ :type tags: dict
+ :param properties: Internal use only.
+ :type properties: dict
+ :param display_name: Display name of the job.
+ :type display_name: str
+ :param retry_settings: parallel job run failed retry
+ :type retry_settings: BatchRetrySettings
+ :param logging_level: A string of the logging level name
+ :type logging_level: str
+ :param max_concurrency_per_instance: The max parallellism that each compute instance has.
+ :type max_concurrency_per_instance: int
+ :param error_threshold: The number of item processing failures should be ignored.
+ :type error_threshold: int
+ :param mini_batch_error_threshold: The number of mini batch processing failures should be ignored.
+ :type mini_batch_error_threshold: int
+ :keyword identity: The identity that the job will use while running on compute.
+ :paramtype identity: Optional[Union[~azure.ai.ml.ManagedIdentityConfiguration, ~azure.ai.ml.AmlTokenConfiguration,
+ ~azure.ai.ml.UserIdentityConfiguration]]
+ :param task: The parallel task.
+ :type task: ParallelTask
+ :param mini_batch_size: The mini batch size.
+ :type mini_batch_size: str
+ :param partition_keys: The partition keys.
+ :type partition_keys: list
+ :param input_data: The input data.
+ :type input_data: str
+ :param inputs: Inputs of the job.
+ :type inputs: dict
+ :param outputs: Outputs of the job.
+ :type outputs: dict
+ """
+
+ def __init__(
+ self,
+ *,
+ inputs: Optional[Dict[str, Union[Input, str, bool, int, float]]] = None,
+ outputs: Optional[Dict[str, Output]] = None,
+ identity: Optional[
+ Union[ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration, Dict]
+ ] = None,
+ **kwargs: Any,
+ ):
+ kwargs[TYPE] = JobType.PARALLEL
+
+ super().__init__(**kwargs)
+
+ self.inputs = inputs # type: ignore[assignment]
+ self.outputs = outputs # type: ignore[assignment]
+ self.identity = identity
+
+ def _to_dict(self) -> Dict:
+ res: dict = ParallelJobSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
+
+ def _to_rest_object(self) -> None:
+ pass
+
+ @classmethod
+ def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "ParallelJob":
+ loaded_data = load_from_dict(ParallelJobSchema, data, context, additional_message, **kwargs)
+ return ParallelJob(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_data)
+
+ @classmethod
+ def _load_from_rest(cls, obj: JobBaseData) -> None:
+ pass
+
+ def _to_component(self, context: Optional[Dict] = None, **kwargs: Any) -> "ParallelComponent":
+ """Translate a parallel job to component job.
+
+ :param context: Context of parallel job YAML file.
+ :type context: dict
+ :return: Translated parallel component.
+ :rtype: ParallelComponent
+ """
+ from azure.ai.ml.entities._component.parallel_component import ParallelComponent
+
+ pipeline_job_dict = kwargs.get("pipeline_job_dict", {})
+ context = context or {BASE_PATH_CONTEXT_KEY: Path("./")}
+
+ # Create anonymous parallel component with default version as 1
+ init_kwargs = {}
+ for key in [
+ "mini_batch_size",
+ "partition_keys",
+ "logging_level",
+ "max_concurrency_per_instance",
+ "error_threshold",
+ "mini_batch_error_threshold",
+ "retry_settings",
+ "resources",
+ ]:
+ value = getattr(self, key)
+ from azure.ai.ml.entities import BatchRetrySettings, JobResourceConfiguration
+
+ values_to_check: List = []
+ if key == "retry_settings" and isinstance(value, BatchRetrySettings):
+ values_to_check = [value.max_retries, value.timeout]
+ elif key == "resources" and isinstance(value, JobResourceConfiguration):
+ values_to_check = [
+ value.locations,
+ value.instance_count,
+ value.instance_type,
+ value.shm_size,
+ value.max_instance_count,
+ value.docker_args,
+ ]
+ else:
+ values_to_check = [value]
+
+ # note that component level attributes can not be data binding expressions
+ # so filter out data binding expression properties here;
+ # they will still take effect at node level according to _to_node
+ if any(
+ map(
+ lambda x: is_data_binding_expression(x, binding_prefix=["parent", "inputs"], is_singular=False)
+ or is_data_binding_expression(x, binding_prefix=["inputs"], is_singular=False),
+ values_to_check,
+ )
+ ):
+ continue
+
+ init_kwargs[key] = getattr(self, key)
+
+ return ParallelComponent(
+ base_path=context[BASE_PATH_CONTEXT_KEY],
+ # for parallel_job.task, all attributes for this are string for now so data binding expression is allowed
+ # in SDK level naturally, but not sure if such component is valid. leave the validation to service side.
+ task=self.task,
+ inputs=self._to_inputs(inputs=self.inputs, pipeline_job_dict=pipeline_job_dict),
+ outputs=self._to_outputs(outputs=self.outputs, pipeline_job_dict=pipeline_job_dict),
+ input_data=self.input_data,
+ # keep them if no data binding expression detected to keep the behavior of to_component
+ **init_kwargs,
+ )
+
+ def _to_node(self, context: Optional[Dict] = None, **kwargs: Any) -> "Parallel":
+ """Translate a parallel job to a pipeline node.
+
+ :param context: Context of parallel job YAML file.
+ :type context: dict
+ :return: Translated parallel component.
+ :rtype: Parallel
+ """
+ from azure.ai.ml.entities._builders import Parallel
+
+ component = self._to_component(context, **kwargs)
+
+ return Parallel(
+ component=component,
+ compute=self.compute,
+ # Need to supply the inputs with double curly.
+ inputs=self.inputs, # type: ignore[arg-type]
+ outputs=self.outputs, # type: ignore[arg-type]
+ mini_batch_size=self.mini_batch_size,
+ partition_keys=self.partition_keys,
+ input_data=self.input_data,
+ # task will be inherited from component & base_path will be set correctly.
+ retry_settings=self.retry_settings,
+ logging_level=self.logging_level,
+ max_concurrency_per_instance=self.max_concurrency_per_instance,
+ error_threshold=self.error_threshold,
+ mini_batch_error_threshold=self.mini_batch_error_threshold,
+ environment_variables=self.environment_variables,
+ properties=self.properties,
+ identity=self.identity,
+ resources=self.resources if self.resources and not isinstance(self.resources, dict) else None,
+ )
+
+ def _validate(self) -> None:
+ if self.name is None:
+ msg = "Job name is required"
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.MISSING_FIELD,
+ )
+ if self.compute is None:
+ msg = "compute is required"
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.MISSING_FIELD,
+ )
+ if self.task is None:
+ msg = "task is required"
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.MISSING_FIELD,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/parallel_task.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/parallel_task.py
new file mode 100644
index 00000000..7325aed3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/parallel_task.py
@@ -0,0 +1,119 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from os import PathLike
+from pathlib import Path
+from typing import Any, Dict, Optional, Union
+
+# from azure.ai.ml.entities._deployment.code_configuration import CodeConfiguration
+from azure.ai.ml._schema.component.parallel_task import ComponentParallelTaskSchema
+from azure.ai.ml._utils.utils import load_yaml
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY
+from azure.ai.ml.entities._assets.environment import Environment
+from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin
+from azure.ai.ml.entities._util import load_from_dict
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
+
+
+class ParallelTask(RestTranslatableMixin, DictMixin):
+ """Parallel task.
+
+ :param type: The type of the parallel task.
+ Possible values are 'run_function'and 'model'.
+ :type type: str
+ :param code: A local or remote path pointing at source code.
+ :type code: str
+ :param entry_script: User script which will be run in parallel on multiple nodes. This is
+ specified as a local file path.
+ The entry_script should contain two functions:
+ ``init()``: this function should be used for any costly or common preparation for subsequent inferences,
+ e.g., deserializing and loading the model into a global object.
+ ``run(mini_batch)``: The method to be parallelized. Each invocation will have one mini-batch.
+ 'mini_batch': Batch inference will invoke run method and pass either a list or a Pandas DataFrame as an
+ argument to the method. Each entry in min_batch will be a filepath if input is a FileDataset,
+ a Pandas DataFrame if input is a TabularDataset.
+ run() method should return a Pandas DataFrame or an array.
+ For append_row output_action, these returned elements are appended into the common output file.
+ For summary_only, the contents of the elements are ignored. For all output actions,
+ each returned output element indicates one successful inference of input element in the input mini-batch.
+ Each parallel worker process will call `init` once and then loop over `run` function until all mini-batches
+ are processed.
+ :type entry_script: str
+ :param program_arguments: The arguments of the parallel task.
+ :type program_arguments: str
+ :param model: The model of the parallel task.
+ :type model: str
+ :param append_row_to: All values output by run() method invocations will be aggregated into
+ one unique file which is created in the output location.
+ if it is not set, 'summary_only' would invoked, which means user script is expected to store the output itself.
+ :type append_row_to: str
+ :param environment: Environment that training job will run in.
+ :type environment: Union[Environment, str]
+ """
+
+ def __init__(
+ self, # pylint: disable=unused-argument
+ *,
+ type: Optional[str] = None, # pylint: disable=redefined-builtin
+ code: Optional[str] = None,
+ entry_script: Optional[str] = None,
+ program_arguments: Optional[str] = None,
+ model: Optional[str] = None,
+ append_row_to: Optional[str] = None,
+ environment: Optional[Union[Environment, str]] = None,
+ **kwargs: Any,
+ ):
+ self.type = type
+ self.code = code
+ self.entry_script = entry_script
+ self.program_arguments = program_arguments
+ self.model = model
+ self.append_row_to = append_row_to
+ self.environment: Any = environment
+
+ def _to_dict(self) -> Dict:
+ # pylint: disable=no-member
+ res: dict = ComponentParallelTaskSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
+
+ @classmethod
+ def _load(
+ cls, # pylint: disable=unused-argument
+ path: Optional[Union[PathLike, str]] = None,
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "ParallelTask":
+ params_override = params_override or []
+ data = load_yaml(path)
+ return ParallelTask._load_from_dict(data=data, path=path, params_override=params_override)
+
+ @classmethod
+ def _load_from_dict(
+ cls,
+ data: dict,
+ path: Optional[Union[PathLike, str]] = None,
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "ParallelTask":
+ params_override = params_override or []
+ context = {
+ BASE_PATH_CONTEXT_KEY: Path(path).parent if path else Path.cwd(),
+ PARAMS_OVERRIDE_KEY: params_override,
+ }
+ res: ParallelTask = load_from_dict(ComponentParallelTaskSchema, data, context, **kwargs)
+ return res
+
+ @classmethod
+ def _from_dict(cls, dct: dict) -> "ParallelTask":
+ obj = cls(**dict(dct.items()))
+ return obj
+
+ def _validate(self) -> None:
+ if self.type is None:
+ msg = "'type' is required for ParallelTask {}."
+ raise ValidationException(
+ message=msg.format(self.type),
+ target=ErrorTarget.COMPONENT,
+ no_personal_data_message=msg.format(""),
+ error_category=ErrorCategory.USER_ERROR,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/parameterized_parallel.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/parameterized_parallel.py
new file mode 100644
index 00000000..6b5dbced
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/parameterized_parallel.py
@@ -0,0 +1,96 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import logging
+from typing import Any, Dict, List, Optional, Union
+
+from ..job_resource_configuration import JobResourceConfiguration
+from .parallel_task import ParallelTask
+from .retry_settings import RetrySettings
+
+module_logger = logging.getLogger(__name__)
+
+
+class ParameterizedParallel:
+ """Parallel component that contains the traning parallel and supporting parameters for the parallel.
+
+ :param retry_settings: parallel component run failed retry
+ :type retry_settings: BatchRetrySettings
+ :param logging_level: A string of the logging level name
+ :type logging_level: str
+ :param max_concurrency_per_instance: The max parallellism that each compute instance has.
+ :type max_concurrency_per_instance: int
+ :param error_threshold: The number of item processing failures should be ignored.
+ :type error_threshold: int
+ :param mini_batch_error_threshold: The number of mini batch processing failures should be ignored.
+ :type mini_batch_error_threshold: int
+ :param task: The parallel task.
+ :type task: ParallelTask
+ :param mini_batch_size: The mini batch size.
+ :type mini_batch_size: str
+ :param input_data: The input data.
+ :type input_data: str
+ :param resources: Compute Resource configuration for the job.
+ :type resources: Union[Dict, ~azure.ai.ml.entities.JobResourceConfiguration]
+ """
+
+ # pylint: disable=too-many-instance-attributes
+ def __init__(
+ self,
+ retry_settings: Optional[RetrySettings] = None,
+ logging_level: Optional[str] = None,
+ max_concurrency_per_instance: Optional[int] = None,
+ error_threshold: Optional[int] = None,
+ mini_batch_error_threshold: Optional[int] = None,
+ input_data: Optional[str] = None,
+ task: Optional[ParallelTask] = None,
+ mini_batch_size: Optional[int] = None,
+ partition_keys: Optional[List] = None,
+ resources: Optional[Union[dict, JobResourceConfiguration]] = None,
+ environment_variables: Optional[Dict] = None,
+ ):
+ self.mini_batch_size = mini_batch_size
+ self.partition_keys = partition_keys
+ self.task = task
+ self.retry_settings = retry_settings
+ self.input_data = input_data
+ self.logging_level = logging_level
+ self.max_concurrency_per_instance = max_concurrency_per_instance
+ self.error_threshold = error_threshold
+ self.mini_batch_error_threshold = mini_batch_error_threshold
+ self.resources = resources
+ self.environment_variables = dict(environment_variables) if environment_variables else {}
+
+ @property
+ def task(self) -> Optional[ParallelTask]:
+ res: Optional[ParallelTask] = self._task
+ return res
+
+ @task.setter
+ def task(self, value: Any) -> None:
+ if isinstance(value, dict):
+ value = ParallelTask(**value)
+ self._task = value
+
+ @property
+ def resources(self) -> Optional[Union[dict, JobResourceConfiguration]]:
+ res: Optional[Union[dict, JobResourceConfiguration]] = self._resources
+ return res
+
+ @resources.setter
+ def resources(self, value: Any) -> None:
+ if isinstance(value, dict):
+ value = JobResourceConfiguration(**value)
+ self._resources = value
+
+ @property
+ def retry_settings(self) -> Optional[RetrySettings]:
+ res: Optional[RetrySettings] = self._retry_settings
+ return res
+
+ @retry_settings.setter
+ def retry_settings(self, value: Any) -> None:
+ if isinstance(value, dict):
+ value = RetrySettings(**value)
+ self._retry_settings = value
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/retry_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/retry_settings.py
new file mode 100644
index 00000000..2fb19ba1
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/retry_settings.py
@@ -0,0 +1,78 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from os import PathLike
+from pathlib import Path
+from typing import Any, Dict, Optional, Union
+
+from azure.ai.ml._schema.component.retry_settings import RetrySettingsSchema
+from azure.ai.ml._utils.utils import load_yaml
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY
+from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin
+from azure.ai.ml.entities._util import load_from_dict
+
+
+class RetrySettings(RestTranslatableMixin, DictMixin):
+ """Parallel RetrySettings.
+
+ :param timeout: Timeout in seconds for each invocation of the run() method.
+ (optional) This value could be set through PipelineParameter.
+ :type timeout: int
+ :param max_retries: The number of maximum tries for a failed or timeout mini batch.
+ The range is [1, int.max]. This value could be set through PipelineParameter.
+ A mini batch with dequeue count greater than this won't be processed again and will be deleted directly.
+ :type max_retries: int
+ """
+
+ def __init__(
+ self, # pylint: disable=unused-argument
+ *,
+ timeout: Optional[Union[int, str]] = None,
+ max_retries: Optional[Union[int, str]] = None,
+ **kwargs: Any,
+ ):
+ self.timeout = timeout
+ self.max_retries = max_retries
+
+ def _to_dict(self) -> Dict:
+ res: dict = RetrySettingsSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) # pylint: disable=no-member
+ return res
+
+ @classmethod
+ def _load(
+ cls, # pylint: disable=unused-argument
+ path: Optional[Union[PathLike, str]] = None,
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "RetrySettings":
+ params_override = params_override or []
+ data = load_yaml(path)
+ return RetrySettings._load_from_dict(data=data, path=path, params_override=params_override)
+
+ @classmethod
+ def _load_from_dict(
+ cls,
+ data: dict,
+ path: Optional[Union[PathLike, str]] = None,
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "RetrySettings":
+ params_override = params_override or []
+ context = {
+ BASE_PATH_CONTEXT_KEY: Path(path).parent if path else Path.cwd(),
+ PARAMS_OVERRIDE_KEY: params_override,
+ }
+ res: RetrySettings = load_from_dict(RetrySettingsSchema, data, context, **kwargs)
+ return res
+
+ @classmethod
+ def _from_dict(cls, dct: dict) -> "RetrySettings":
+ obj = cls(**dict(dct.items()))
+ return obj
+
+ def _to_rest_object(self) -> Dict:
+ return self._to_dict()
+
+ @classmethod
+ def _from_rest_object(cls, obj: dict) -> "RetrySettings":
+ return cls._from_dict(obj)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/run_function.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/run_function.py
new file mode 100644
index 00000000..180cee76
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/run_function.py
@@ -0,0 +1,66 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+
+from typing import Any, Optional, Union
+
+from azure.ai.ml.constants import ParallelTaskType
+from azure.ai.ml.entities._assets.environment import Environment
+
+from .parallel_task import ParallelTask
+
+
+class RunFunction(ParallelTask):
+ """Run Function.
+
+ :param code: A local or remote path pointing at source code.
+ :type code: str
+ :param entry_script: User script which will be run in parallel on multiple nodes. This is
+ specified as a local file path.
+ The entry_script should contain two functions:
+ ``init()``: this function should be used for any costly or common preparation for subsequent inferences,
+ e.g., deserializing and loading the model into a global object.
+ ``run(mini_batch)``: The method to be parallelized. Each invocation will have one mini-batch.
+ 'mini_batch': Batch inference will invoke run method and pass either a list or a Pandas DataFrame as an
+ argument to the method. Each entry in min_batch will be a filepath if input is a FileDataset,
+ a Pandas DataFrame if input is a TabularDataset.
+ run() method should return a Pandas DataFrame or an array.
+ For append_row output_action, these returned elements are appended into the common output file.
+ For summary_only, the contents of the elements are ignored. For all output actions,
+ each returned output element indicates one successful inference of input element in the input mini-batch.
+ Each parallel worker process will call `init` once and then loop over `run` function until all mini-batches
+ are processed.
+ :type entry_script: str
+ :param program_arguments: The arguments of the parallel task.
+ :type args: str
+ :param model: The model of the parallel task.
+ :type model: str
+ :param append_row_to: All values output by run() method invocations will be aggregated into
+ one unique file which is created in the output location.
+ if it is not set, 'summary_only' would invoked, which means user script is expected to store the output itself.
+ :type append_row_to: str
+ :param environment: Environment that training job will run in.
+ :type environment: Union[Environment, str]
+ """
+
+ def __init__(
+ self,
+ *,
+ code: Optional[str] = None,
+ entry_script: Optional[str] = None,
+ program_arguments: Optional[str] = None,
+ model: Optional[str] = None,
+ append_row_to: Optional[str] = None,
+ environment: Optional[Union[Environment, str]] = None,
+ **kwargs: Any,
+ ):
+ super().__init__(
+ code=code,
+ entry_script=entry_script,
+ program_arguments=program_arguments,
+ model=model,
+ append_row_to=append_row_to,
+ environment=environment,
+ type=ParallelTaskType.RUN_FUNCTION,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parameterized_command.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parameterized_command.py
new file mode 100644
index 00000000..57604b38
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parameterized_command.py
@@ -0,0 +1,170 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import logging
+import os
+from typing import Dict, Optional, Union
+
+from marshmallow import INCLUDE
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import SweepJob
+from azure.ai.ml._schema.core.fields import ExperimentalField
+from azure.ai.ml.entities._assets import Environment
+
+from ..._schema import NestedField, UnionField
+from ..._schema.job.distribution import (
+ MPIDistributionSchema,
+ PyTorchDistributionSchema,
+ RayDistributionSchema,
+ TensorFlowDistributionSchema,
+)
+from .distribution import (
+ DistributionConfiguration,
+ MpiDistribution,
+ PyTorchDistribution,
+ RayDistribution,
+ TensorFlowDistribution,
+)
+from .job_resource_configuration import JobResourceConfiguration
+from .queue_settings import QueueSettings
+
+module_logger = logging.getLogger(__name__)
+
+# no reference found. leave it for future use.
+INPUT_BINDING_PREFIX = "AZURE_ML_INPUT_"
+OLD_INPUT_BINDING_PREFIX = "AZURE_ML_INPUT"
+
+
+class ParameterizedCommand:
+ """Command component version that contains the command and supporting parameters for a Command component
+ or job.
+
+ This class should not be instantiated directly. Instead, use the child class
+ ~azure.ai.ml.entities.CommandComponent.
+
+ :param command: The command to be executed. Defaults to "".
+ :type command: str
+ :param resources: The compute resource configuration for the command.
+ :type resources: Optional[Union[dict, ~azure.ai.ml.entities.JobResourceConfiguration]]
+ :param code: The source code to run the job. Can be a local path or "http:", "https:", or "azureml:" url pointing
+ to a remote location.
+ :type code: Optional[str]
+ :param environment_variables: A dictionary of environment variable names and values.
+ These environment variables are set on the process where user script is being executed.
+ :type environment_variables: Optional[dict[str, str]]
+ :param distribution: The distribution configuration for distributed jobs.
+ :type distribution: Optional[Union[dict, ~azure.ai.ml.PyTorchDistribution, ~azure.ai.ml.MpiDistribution,
+ ~azure.ai.ml.TensorFlowDistribution, ~azure.ai.ml.RayDistribution]]
+ :param environment: The environment that the job will run in.
+ :type environment: Optional[Union[str, ~azure.ai.ml.entities.Environment]]
+ :param queue_settings: The queue settings for the job.
+ :type queue_settings: Optional[~azure.ai.ml.entities.QueueSettings]
+ :keyword kwargs: A dictionary of additional configuration parameters.
+ :paramtype kwargs: dict
+ """
+
+ def __init__(
+ self,
+ command: Optional[str] = "",
+ resources: Optional[Union[dict, JobResourceConfiguration]] = None,
+ code: Optional[Union[str, os.PathLike]] = None,
+ environment_variables: Optional[Dict] = None,
+ distribution: Optional[
+ Union[
+ Dict,
+ MpiDistribution,
+ TensorFlowDistribution,
+ PyTorchDistribution,
+ RayDistribution,
+ DistributionConfiguration,
+ ]
+ ] = None,
+ environment: Optional[Union[Environment, str]] = None,
+ queue_settings: Optional[QueueSettings] = None,
+ **kwargs: Dict,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.command = command
+ self.code = code
+ self.environment_variables = dict(environment_variables) if environment_variables else {}
+ self.environment = environment
+ self.distribution = distribution
+ self.resources = resources # type: ignore[assignment]
+ self.queue_settings = queue_settings
+
+ @property
+ def distribution(
+ self,
+ ) -> Optional[
+ Union[
+ dict,
+ MpiDistribution,
+ TensorFlowDistribution,
+ PyTorchDistribution,
+ RayDistribution,
+ DistributionConfiguration,
+ ]
+ ]:
+ """The configuration for the distributed command component or job.
+
+ :return: The distribution configuration.
+ :rtype: Union[~azure.ai.ml.PyTorchDistribution, ~azure.ai.ml.MpiDistribution,
+ ~azure.ai.ml.TensorFlowDistribution, ~azure.ai.ml.RayDistribution]
+ """
+ return self._distribution
+
+ @distribution.setter
+ def distribution(self, value: Union[dict, PyTorchDistribution, MpiDistribution]) -> None:
+ """Sets the configuration for the distributed command component or job.
+
+ :param value: The distribution configuration for distributed jobs.
+ :type value: Union[dict, ~azure.ai.ml.PyTorchDistribution, ~azure.ai.ml.MpiDistribution,
+ ~azure.ai.ml.TensorFlowDistribution, ~azure.ai.ml.RayDistribution]
+ """
+ if isinstance(value, dict):
+ dist_schema = UnionField(
+ [
+ NestedField(PyTorchDistributionSchema, unknown=INCLUDE),
+ NestedField(TensorFlowDistributionSchema, unknown=INCLUDE),
+ NestedField(MPIDistributionSchema, unknown=INCLUDE),
+ ExperimentalField(NestedField(RayDistributionSchema, unknown=INCLUDE)),
+ ]
+ )
+ value = dist_schema._deserialize(value=value, attr=None, data=None)
+ self._distribution = value
+
+ @property
+ def resources(self) -> JobResourceConfiguration:
+ """The compute resource configuration for the command component or job.
+
+ :return: The compute resource configuration for the command component or job.
+ :rtype: ~azure.ai.ml.entities.JobResourceConfiguration
+ """
+ return self._resources
+
+ @resources.setter
+ def resources(self, value: Union[dict, JobResourceConfiguration]) -> None:
+ """Sets the compute resource configuration for the command component or job.
+
+ :param value: The compute resource configuration for the command component or job.
+ :type value: Union[dict, ~azure.ai.ml.entities.JobResourceConfiguration]
+ """
+ if isinstance(value, dict):
+ value = JobResourceConfiguration(**value)
+ self._resources = value
+
+ @classmethod
+ def _load_from_sweep_job(cls, sweep_job: SweepJob) -> "ParameterizedCommand":
+ parameterized_command = cls(
+ command=sweep_job.trial.command,
+ code=sweep_job.trial.code_id,
+ environment_variables=sweep_job.trial.environment_variables,
+ environment=sweep_job.trial.environment_id,
+ distribution=DistributionConfiguration._from_rest_object(sweep_job.trial.distribution),
+ resources=JobResourceConfiguration._from_rest_object(sweep_job.trial.resources),
+ queue_settings=QueueSettings._from_rest_object(sweep_job.queue_settings),
+ )
+ return parameterized_command
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parameterized_spark.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parameterized_spark.py
new file mode 100644
index 00000000..c8a9a0c0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parameterized_spark.py
@@ -0,0 +1,88 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+import os
+from typing import Any, Dict, List, Optional, Union
+
+from azure.ai.ml.entities._assets import Environment
+from azure.ai.ml.entities._job.spark_job_entry import SparkJobEntry
+
+from .._job.spark_job_entry_mixin import SparkJobEntryMixin
+
+DUMMY_IMAGE = "conda/miniconda3"
+
+
+class ParameterizedSpark(SparkJobEntryMixin):
+ """
+ This class should not be instantiated directly. Instead, use the child class ~azure.ai.ml.entities.SparkComponent.
+
+ Spark component that contains supporting parameters.
+
+ :param code: The source code to run the job. Can be a local path or "http:", "https:", or "azureml:" url pointing
+ to a remote location.
+ :type code: Optional[Union[str, os.PathLike]]
+ :param entry: The file or class entry point.
+ :type entry: dict[str, str]
+ :param py_files: The list of .zip, .egg or .py files to place on the PYTHONPATH for Python apps.
+ :type py_files: Optional[list[str]]
+ :param jars: The list of .JAR files to include on the driver and executor classpaths.
+ :type jars: Optional[list[str]]
+ :param files: The list of files to be placed in the working directory of each executor.
+ :type files: Optional[list[str]]
+ :param archives: The list of archives to be extracted into the working directory of each executor.
+ :type archives: Optional[list[str]]
+ :param conf: A dictionary with pre-defined Spark configurations key and values.
+ :type conf: Optional[dict[str, str]]
+ :param environment: The Azure ML environment to run the job in.
+ :type environment: Optional[Union[str, ~azure.ai.ml.entities.Environment]]
+ :param args: The arguments for the job.
+ :type args: Optional[str]
+ :keyword kwargs: A dictionary of additional configuration parameters.
+ :paramtype kwargs: dict
+ """
+
+ def __init__(
+ self,
+ code: Optional[Union[str, os.PathLike]] = ".",
+ entry: Optional[Union[Dict[str, str], SparkJobEntry]] = None,
+ py_files: Optional[List[str]] = None,
+ jars: Optional[List[str]] = None,
+ files: Optional[List[str]] = None,
+ archives: Optional[List[str]] = None,
+ conf: Optional[Dict[str, str]] = None,
+ environment: Optional[Union[str, Environment]] = None,
+ args: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ self.args = None
+
+ super().__init__(**kwargs)
+ self.code = code
+ self.entry = entry
+ self.py_files = py_files
+ self.jars = jars
+ self.files = files
+ self.archives = archives
+ self.conf = conf
+ self.environment = environment
+ self.args = args
+
+ @property
+ def environment(self) -> Optional[Union[str, Environment]]:
+ """The Azure ML environment to run the Spark component or job in.
+
+ :return: The Azure ML environment to run the Spark component or job in.
+ :rtype: Optional[Union[str, ~azure.ai.ml.entities.Environment]]
+ """
+ if isinstance(self._environment, Environment) and self._environment.image is None:
+ return Environment(conda_file=self._environment.conda_file, image=DUMMY_IMAGE)
+ return self._environment
+
+ @environment.setter
+ def environment(self, value: Optional[Union[str, Environment]]) -> None:
+ """Sets the Azure ML environment to run the Spark component or job in.
+
+ :param value: The Azure ML environment to run the Spark component or job in.
+ :type value: Optional[Union[str, ~azure.ai.ml.entities.Environment]]
+ """
+ self._environment = value
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/__init__.py
new file mode 100644
index 00000000..fdf8caba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_attr_dict.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_attr_dict.py
new file mode 100644
index 00000000..cf8d92be
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_attr_dict.py
@@ -0,0 +1,161 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import logging
+from abc import ABC
+from typing import Any, Dict, Generic, List, Optional, TypeVar
+
+K = TypeVar("K")
+V = TypeVar("V")
+
+
+class _AttrDict(Generic[K, V], Dict, ABC):
+ """This class is used for accessing values with instance.some_key. It supports the following scenarios:
+
+ 1. Setting arbitrary attribute, eg: obj.resource_layout.node_count = 2
+ 1.1 Setting same nested filed twice will return same object, eg:
+ obj.resource_layout.node_count = 2
+ obj.resource_layout.process_count_per_node = 2
+ obj.resource_layout will be {"node_count": 2, "process_count_per_node": 2}
+ 1.2 Only public attribute is supported, eg: obj._resource_layout._node_count = 2 will raise AttributeError
+ 1.3 All set attribute can be recorded, eg:
+ obj.target = "aml"
+ obj.resource_layout.process_count_per_node = 2
+ obj.get_attr() will return {"target": "aml", "resource_layout": {"process_count_per_node": 2}}
+ 2. Getting arbitrary attribute, getting non-exist attribute will return an empty dict.
+ 3. Calling arbitrary methods is not allowed, eg: obj.resource_layout() should raise AttributeError
+ """
+
+ def __init__(self, allowed_keys: Optional[Dict] = None, **kwargs: Any):
+ """Initialize a attribute dictionary.
+
+ :param allowed_keys: A dictionary of keys that allowed to set as arbitrary attributes. None means all keys can
+ be set as arbitrary attributes.
+
+ :type dict
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: dict
+ """
+ super(_AttrDict, self).__init__(**kwargs)
+ if allowed_keys is None:
+ # None allowed_keys means no restriction on keys can be set for _AttrDict
+ self._allowed_keys = {}
+ self._key_restriction = False
+ else:
+ # Otherwise use allowed_keys to restrict keys can be set for _AttrDict
+ self._allowed_keys = dict(allowed_keys)
+ self._key_restriction = True
+ self._logger = logging.getLogger("attr_dict")
+
+ def _initializing(self) -> bool:
+ # use this to indicate ongoing init process, sub class need to make sure this return True during init process.
+ return False
+
+ def _get_attrs(self) -> dict:
+ """Get all arbitrary attributes which has been set, empty values are excluded.
+
+ :return: A dict which contains all arbitrary attributes set by user.
+ :rtype: dict
+ """
+
+ # TODO: check this
+ def remove_empty_values(data: Dict) -> Dict:
+ if not isinstance(data, dict):
+ return data
+ # skip empty dicts as default value of _AttrDict is empty dict
+ return {k: remove_empty_values(v) for k, v in data.items() if v or not isinstance(v, dict)}
+
+ return remove_empty_values(self)
+
+ def _is_arbitrary_attr(self, attr_name: str) -> bool:
+ """Checks if a given attribute name should be treat as arbitrary attribute.
+
+ Attributes inside _AttrDict can be non-arbitrary attribute or arbitrary attribute.
+ Non-arbitrary attributes are normal attributes like other object which stores in self.__dict__.
+ Arbitrary attributes are attributes stored in the dictionary it self, what makes it special it it's value
+ can be an instance of _AttrDict
+ Take `obj = _AttrDict(allowed_keys={"resource_layout": {"node_count": None}})` as an example.
+ `obj.some_key` is accessing non-arbitrary attribute.
+ `obj.resource_layout` is accessing arbitrary attribute, user can use `obj.resource_layout.node_count = 1` to
+ assign value to it.
+
+ :param attr_name: Attribute name
+ :type attr_name: str
+ :return: If the given attribute name should be treated as arbitrary attribute.
+ :rtype: bool
+ """
+ # Internal attribute won't be set as arbitrary attribute.
+ if attr_name.startswith("_"):
+ return False
+ # All attributes set in __init__ won't be set as arbitrary attribute
+ if self._initializing():
+ return False
+ # If there's key restriction, only keys in it can be set as arbitrary attribute.
+ if self._key_restriction and attr_name not in self._allowed_keys:
+ return False
+ # Attributes already in attribute dict will not be set as arbitrary attribute.
+ try:
+ self.__getattribute__(attr_name)
+ except AttributeError:
+ return True
+ return False
+
+ def __getattr__(self, key: Any) -> Any:
+ if not self._is_arbitrary_attr(key):
+ return super().__getattribute__(key)
+ self._logger.debug("getting %s", key)
+ try:
+ return super().__getitem__(key)
+ except KeyError:
+ allowed_keys = self._allowed_keys.get(key, None) if self._key_restriction else None
+ result: Any = _AttrDict(allowed_keys=allowed_keys)
+ self.__setattr__(key, result)
+ return result
+
+ def __setattr__(self, key: Any, value: V) -> None:
+ if not self._is_arbitrary_attr(key):
+ super().__setattr__(key, value)
+ else:
+ self._logger.debug("setting %s to %s", key, value)
+ super().__setitem__(key, value)
+
+ def __setitem__(self, key: Any, value: V) -> None:
+ self.__setattr__(key, value)
+
+ def __getitem__(self, item: V) -> Any:
+ # support attr_dict[item] since dumping it in marshmallow requires this.
+ return self.__getattr__(item)
+
+ def __dir__(self) -> List:
+ # For Jupyter Notebook auto-completion
+ return list(super().__dir__()) + list(self.keys())
+
+
+def has_attr_safe(obj: Any, attr: Any) -> bool:
+ if isinstance(obj, _AttrDict):
+ has_attr = not obj._is_arbitrary_attr(attr)
+ elif isinstance(obj, dict):
+ return attr in obj
+ else:
+ has_attr = hasattr(obj, attr)
+ return has_attr
+
+
+def try_get_non_arbitrary_attr(obj: Any, attr: str) -> Optional[Any]:
+ """Try to get non-arbitrary attribute for potential attribute dict.
+
+ Will not create target attribute if it is an arbitrary attribute in _AttrDict.
+
+ :param obj: The obj
+ :type obj: Any
+ :param attr: The attribute name
+ :type attr: str
+ :return: obj.attr
+ :rtype: Any
+ """
+ if has_attr_safe(obj, attr):
+ return obj[attr] if isinstance(obj, dict) else getattr(obj, attr)
+ return None
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_component_translatable.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_component_translatable.py
new file mode 100644
index 00000000..22be939d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_component_translatable.py
@@ -0,0 +1,412 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+# pylint: disable=protected-access, redefined-builtin
+# disable redefined-builtin to use input as argument name
+import re
+from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
+
+from pydash import get
+
+from azure.ai.ml._utils.utils import is_data_binding_expression
+from azure.ai.ml.constants._common import AssetTypes
+from azure.ai.ml.constants._component import ComponentJobConstants
+from azure.ai.ml.entities._inputs_outputs import Input, Output
+from azure.ai.ml.entities._job.pipeline._io import PipelineInput, PipelineOutput
+from azure.ai.ml.entities._job.sweep.search_space import Choice, Randint, SweepDistribution
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, JobException
+
+# avoid circular import error
+if TYPE_CHECKING:
+ from azure.ai.ml.entities._builders import BaseNode
+ from azure.ai.ml.entities._component.component import Component
+
+
+class ComponentTranslatableMixin:
+ _PYTHON_SDK_TYPE_MAPPING = {
+ float: "number",
+ int: "integer",
+ bool: "boolean",
+ str: "string",
+ }
+
+ @classmethod
+ def _find_source_from_parent_inputs(cls, input: str, pipeline_job_inputs: dict) -> Tuple[str, Optional[str]]:
+ """Find source type and mode of input/output from parent input.
+
+ :param input: The input name
+ :type input: str
+ :param pipeline_job_inputs: The pipeline job inputs
+ :type pipeline_job_inputs: dict
+ :return: A 2-tuple of the type and the mode
+ :rtype: Tuple[str, Optional[str]]
+ """
+ _input_name = input.split(".")[2][:-2]
+ if _input_name not in pipeline_job_inputs.keys():
+ msg = "Failed to find top level definition for input binding {}."
+ raise JobException(
+ message=msg.format(input),
+ no_personal_data_message=msg.format("[input]"),
+ target=ErrorTarget.PIPELINE,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ input_data = pipeline_job_inputs[_input_name]
+ input_type = type(input_data)
+ if input_type in cls._PYTHON_SDK_TYPE_MAPPING:
+ return cls._PYTHON_SDK_TYPE_MAPPING[input_type], None
+ return getattr(input_data, "type", AssetTypes.URI_FOLDER), getattr(input_data, "mode", None)
+
+ @classmethod
+ def _find_source_from_parent_outputs(cls, input: str, pipeline_job_outputs: dict) -> Tuple[str, Optional[str]]:
+ """Find source type and mode of input/output from parent output.
+
+ :param input: The input name
+ :type input: str
+ :param pipeline_job_outputs: The pipeline job outputs
+ :type pipeline_job_outputs: dict
+ :return: A 2-tuple of the type and the mode
+ :rtype: Tuple[str, Optional[str]]
+ """
+ _output_name = input.split(".")[2][:-2]
+ if _output_name not in pipeline_job_outputs.keys():
+ msg = "Failed to find top level definition for output binding {}."
+ raise JobException(
+ message=msg.format(input),
+ no_personal_data_message=msg.format("[input]"),
+ target=ErrorTarget.PIPELINE,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ output_data = pipeline_job_outputs[_output_name]
+ output_type = type(output_data)
+ if output_type in cls._PYTHON_SDK_TYPE_MAPPING:
+ return cls._PYTHON_SDK_TYPE_MAPPING[output_type], None
+ if isinstance(output_data, dict):
+ if "type" in output_data:
+ output_data_type = output_data["type"]
+ else:
+ output_data_type = AssetTypes.URI_FOLDER
+ if "mode" in output_data:
+ output_data_mode = output_data["mode"]
+ else:
+ output_data_mode = None
+ return output_data_type, output_data_mode
+ return getattr(output_data, "type", AssetTypes.URI_FOLDER), getattr(output_data, "mode", None)
+
+ @classmethod
+ def _find_source_from_other_jobs(
+ cls, input: str, jobs_dict: dict, pipeline_job_dict: dict
+ ) -> Tuple[str, Optional[str]]:
+ """Find source type and mode of input/output from other job.
+
+ :param input: The input name
+ :type input: str
+ :param jobs_dict: The job dict
+ :type jobs_dict:
+ :param pipeline_job_dict: The pipeline job dict
+ :type pipeline_job_dict: dict
+ :return: A 2-tuple of the type and the mode
+ :rtype: Tuple[str, Optional[str]]
+ """
+ from azure.ai.ml.entities import CommandJob
+ from azure.ai.ml.entities._builders import BaseNode
+ from azure.ai.ml.entities._job.automl.automl_job import AutoMLJob
+ from azure.ai.ml.parallel import ParallelJob
+
+ _input_regex = r"\${{parent.jobs.([^.]+).([^.]+).([^.]+)}}"
+ m = re.match(_input_regex, input)
+ if m is None:
+ msg = "Failed to find top level definition for job binding {}."
+ raise JobException(
+ message=msg.format(input),
+ no_personal_data_message=msg.format("[input]"),
+ target=ErrorTarget.PIPELINE,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ _input_job_name, _io_type, _name = m.groups()
+ _input_job = jobs_dict[_input_job_name]
+
+ # we only support input of one job is from output of another output, but input mode should be decoupled with
+ # output mode, so we always return None source_mode
+ source_mode = None
+ if isinstance(_input_job, BaseNode):
+ # If source is base node, get type from io builder
+ _source = _input_job[_io_type][_name]
+ try:
+ source_type = _source.type
+ # Todo: get component type for registered component, and no need following codes
+ # source_type is None means _input_job's component is registered component which results in its
+ # input/output type is None.
+ if source_type is None:
+ if _source._data is None:
+ # return default type if _input_job's output data is None
+ source_type = AssetTypes.URI_FOLDER
+ elif isinstance(_source._data, Output):
+ # if _input_job data is a Output object and we return its type.
+ source_type = _source._data.type
+ else:
+ # otherwise _input_job's input/output is bound to pipeline input/output, we continue
+ # infer the type according to _source._data. Will return corresponding pipeline
+ # input/output type because we didn't get the component.
+ source_type, _ = cls._find_source_input_output_type(_source._data, pipeline_job_dict)
+ return source_type, source_mode
+ except AttributeError as e:
+ msg = "Failed to get referenced component type {}."
+ raise JobException(
+ message=msg.format(_input_regex),
+ no_personal_data_message=msg.format("[_input_regex]"),
+ target=ErrorTarget.PIPELINE,
+ error_category=ErrorCategory.USER_ERROR,
+ ) from e
+ if isinstance(_input_job, (CommandJob, ParallelJob)):
+ # If source has not parsed to Command yet, infer type
+ _source = get(_input_job, f"{_io_type}.{_name}")
+ if isinstance(_source, str):
+ source_type, _ = cls._find_source_input_output_type(_source, pipeline_job_dict)
+ return source_type, source_mode
+ return getattr(_source, "type", AssetTypes.URI_FOLDER), source_mode
+ if isinstance(_input_job, AutoMLJob):
+ # If source is AutoMLJob, only outputs is supported
+ if _io_type != "outputs":
+ msg = f"Only binding to AutoMLJob output is supported, currently got {_io_type}"
+ raise JobException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.PIPELINE,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ # AutoMLJob's output type can only be MLTABLE
+ return AssetTypes.MLTABLE, source_mode
+ msg = f"Unknown referenced source job type: {type(_input_job)}."
+ raise JobException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.PIPELINE,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ @classmethod
+ def _find_source_input_output_type(cls, input: str, pipeline_job_dict: dict) -> Tuple[str, Optional[str]]:
+ """Find source type and mode of input/output.
+
+ :param input: The input binding
+ :type input: str
+ :param pipeline_job_dict: The pipeline job dict
+ :type pipeline_job_dict: dict
+ :return: A 2-tuple of the type and the mode
+ :rtype: Tuple[str, Optional[str]]
+ """
+ pipeline_job_inputs = pipeline_job_dict.get("inputs", {})
+ pipeline_job_outputs = pipeline_job_dict.get("outputs", {})
+ jobs_dict = pipeline_job_dict.get("jobs", {})
+ if is_data_binding_expression(input, ["parent", "inputs"]):
+ return cls._find_source_from_parent_inputs(input, pipeline_job_inputs)
+ if is_data_binding_expression(input, ["parent", "outputs"]):
+ return cls._find_source_from_parent_outputs(input, pipeline_job_outputs)
+ if is_data_binding_expression(input, ["parent", "jobs"]):
+ try:
+ return cls._find_source_from_other_jobs(input, jobs_dict, pipeline_job_dict)
+ except JobException as e:
+ raise e
+ except Exception as e:
+ msg = "Failed to find referenced source for input binding {}"
+ raise JobException(
+ message=msg.format(input),
+ no_personal_data_message=msg.format("[input]"),
+ target=ErrorTarget.PIPELINE,
+ error_category=ErrorCategory.SYSTEM_ERROR,
+ ) from e
+ else:
+ msg = "Job input in a pipeline can bind only to a job output or a pipeline input"
+ raise JobException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.PIPELINE,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ @classmethod
+ def _to_input(
+ cls, # pylint: disable=unused-argument
+ input: Union[Input, str, bool, int, float],
+ pipeline_job_dict: Optional[dict] = None,
+ **kwargs: Any,
+ ) -> Input:
+ """Convert a single job input value to component input.
+
+ :param input: The input
+ :type input: Union[Input, str, bool, int, float]
+ :param pipeline_job_dict: The pipeline job dict
+ :type pipeline_job_dict: Optional[dict]
+ :return: The Component Input
+ :rtype: Input
+ """
+ pipeline_job_dict = pipeline_job_dict or {}
+ input_variable: Dict = {}
+
+ if isinstance(input, str) and bool(re.search(ComponentJobConstants.INPUT_PATTERN, input)):
+ # handle input bindings
+ input_variable["type"], input_variable["mode"] = cls._find_source_input_output_type(
+ input, pipeline_job_dict
+ )
+
+ elif isinstance(input, Input):
+ input_variable = input._to_dict()
+ elif isinstance(input, SweepDistribution):
+ if isinstance(input, Choice):
+ if input.values is not None:
+ input_variable["type"] = cls._PYTHON_SDK_TYPE_MAPPING[type(input.values[0])]
+ elif isinstance(input, Randint):
+ input_variable["type"] = cls._PYTHON_SDK_TYPE_MAPPING[int]
+ else:
+ input_variable["type"] = cls._PYTHON_SDK_TYPE_MAPPING[float]
+
+ input_variable["optional"] = False
+ elif type(input) in cls._PYTHON_SDK_TYPE_MAPPING:
+ input_variable["type"] = cls._PYTHON_SDK_TYPE_MAPPING[type(input)]
+ input_variable["default"] = input
+ elif isinstance(input, PipelineInput):
+ # Infer input type from input data
+ input_variable = input._to_input()._to_dict()
+ else:
+ msg = "'{}' is not supported as component input, supported types are '{}'.".format(
+ type(input), cls._PYTHON_SDK_TYPE_MAPPING.keys()
+ )
+ raise JobException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.PIPELINE,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ return Input(**input_variable)
+
+ @classmethod
+ def _to_input_builder_function(cls, input: Union[Dict, SweepDistribution, Input, str, bool, int, float]) -> Input:
+ input_variable = {}
+
+ if isinstance(input, Input):
+ input_variable = input._to_dict()
+ elif isinstance(input, SweepDistribution):
+ if isinstance(input, Choice):
+ if input.values is not None:
+ input_variable["type"] = cls._PYTHON_SDK_TYPE_MAPPING[type(input.values[0])]
+ elif isinstance(input, Randint):
+ input_variable["type"] = cls._PYTHON_SDK_TYPE_MAPPING[int]
+ else:
+ input_variable["type"] = cls._PYTHON_SDK_TYPE_MAPPING[float]
+
+ input_variable["optional"] = False
+ else:
+ input_variable["type"] = cls._PYTHON_SDK_TYPE_MAPPING[type(input)]
+ input_variable["default"] = input
+ return Input(**input_variable)
+
+ @classmethod
+ def _to_output(
+ cls, # pylint: disable=unused-argument
+ output: Optional[Union[Output, Dict, str, bool, int, float]],
+ pipeline_job_dict: Optional[dict] = None,
+ **kwargs: Any,
+ ) -> Output:
+ """Translate output value to Output and infer component output type
+ from linked pipeline output, its original type or default type.
+
+ :param output: The output
+ :type output: Union[Output, str, bool, int, float]
+ :param pipeline_job_dict: The pipeline job dict
+ :type pipeline_job_dict: Optional[dict]
+ :return: The output object
+ :rtype: Output
+ """
+ pipeline_job_dict = pipeline_job_dict or {}
+ output_type = None
+ if not pipeline_job_dict or output is None:
+ try:
+ output_type = output.type # type: ignore
+ except AttributeError:
+ # default to url_folder if failed to get type
+ output_type = AssetTypes.URI_FOLDER
+ output_variable = {"type": output_type}
+ return Output(**output_variable)
+ output_variable = {}
+
+ if isinstance(output, str) and bool(re.search(ComponentJobConstants.OUTPUT_PATTERN, output)):
+ # handle output bindings
+ output_variable["type"], output_variable["mode"] = cls._find_source_input_output_type(
+ output, pipeline_job_dict
+ )
+
+ elif isinstance(output, Output):
+ output_variable = output._to_dict()
+
+ elif isinstance(output, PipelineOutput):
+ output_variable = output._to_output()._to_dict()
+
+ elif type(output) in cls._PYTHON_SDK_TYPE_MAPPING:
+ output_variable["type"] = cls._PYTHON_SDK_TYPE_MAPPING[type(output)]
+ output_variable["default"] = output
+ else:
+ msg = "'{}' is not supported as component output, supported types are '{}'.".format(
+ type(output), cls._PYTHON_SDK_TYPE_MAPPING.keys()
+ )
+ raise JobException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.PIPELINE,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ return Output(**output_variable)
+
+ def _to_inputs(self, inputs: Optional[Dict], **kwargs: Any) -> Dict:
+ """Translate inputs to Inputs.
+
+ :param inputs: mapping from input name to input object.
+ :type inputs: Dict[str, Union[Input, str, bool, int, float]]
+ :return: mapping from input name to translated component input.
+ :rtype: Dict[str, Input]
+ """
+ pipeline_job_dict = kwargs.get("pipeline_job_dict", {})
+ translated_component_inputs = {}
+ if inputs is not None:
+ for io_name, io_value in inputs.items():
+ translated_component_inputs[io_name] = self._to_input(io_value, pipeline_job_dict)
+ return translated_component_inputs
+
+ def _to_outputs(self, outputs: Optional[Dict], **kwargs: Any) -> Dict:
+ """Translate outputs to Outputs.
+
+ :param outputs: mapping from output name to output object.
+ :type outputs: Dict[str, Output]
+ :return: mapping from output name to translated component output.
+ :rtype: Dict[str, Output]
+ """
+ # Translate outputs to Outputs.
+ pipeline_job_dict = kwargs.get("pipeline_job_dict", {})
+ translated_component_outputs = {}
+ if outputs is not None:
+ for output_name, output_value in outputs.items():
+ translated_component_outputs[output_name] = self._to_output(output_value, pipeline_job_dict)
+ return translated_component_outputs
+
+ def _to_component(self, context: Optional[Dict] = None, **kwargs: Any) -> Union["Component", str]:
+ """Translate to Component.
+
+ :param context: The context
+ :type context: Optional[context]
+ :return: Translated Component.
+ :rtype: Component
+ """
+ # Note: Source of translated component should be same with Job
+ # And should be set after called _to_component/_to_node as job has no _source now.
+ raise NotImplementedError()
+
+ def _to_node(self, context: Optional[Dict] = None, **kwargs: Any) -> "BaseNode":
+ """Translate to pipeline node.
+
+ :param context: The context
+ :type context: Optional[context]
+ :return: Translated node.
+ :rtype: BaseNode
+ """
+ # Note: Source of translated component should be same with Job
+ # And should be set after called _to_component/_to_node as job has no _source now.
+ raise NotImplementedError()
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_io/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_io/__init__.py
new file mode 100644
index 00000000..3ccde947
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_io/__init__.py
@@ -0,0 +1,21 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+"""Classes in this package converts input & output set by user to pipeline job input & output."""
+
+from .attr_dict import OutputsAttrDict, _GroupAttrDict
+from .base import InputOutputBase, NodeInput, NodeOutput, PipelineInput, PipelineOutput
+from .mixin import AutoMLNodeIOMixin, NodeWithGroupInputMixin, PipelineJobIOMixin
+
+__all__ = [
+ "PipelineOutput",
+ "PipelineInput",
+ "NodeOutput",
+ "NodeInput",
+ "InputOutputBase",
+ "OutputsAttrDict",
+ "_GroupAttrDict",
+ "NodeWithGroupInputMixin",
+ "AutoMLNodeIOMixin",
+ "PipelineJobIOMixin",
+]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_io/attr_dict.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_io/attr_dict.py
new file mode 100644
index 00000000..0ae08bcd
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_io/attr_dict.py
@@ -0,0 +1,170 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from enum import Enum
+from typing import Any, Dict, List, Optional, Union
+
+from azure.ai.ml.entities._assets import Data
+from azure.ai.ml.entities._inputs_outputs import GroupInput, Input, Output
+from azure.ai.ml.entities._job.pipeline._attr_dict import K
+from azure.ai.ml.entities._job.pipeline._io.base import NodeInput, NodeOutput, PipelineInput
+from azure.ai.ml.exceptions import (
+ ErrorCategory,
+ ErrorTarget,
+ UnexpectedAttributeError,
+ UnexpectedKeywordError,
+ ValidationException,
+)
+
+
+class InputsAttrDict(dict):
+ def __init__(self, inputs: dict, **kwargs: Any):
+ self._validate_inputs(inputs)
+ super(InputsAttrDict, self).__init__(**inputs, **kwargs)
+
+ @classmethod
+ def _validate_inputs(cls, inputs: Any) -> None:
+ msg = "Pipeline/component input should be a \
+ azure.ai.ml.entities._job.pipeline._io.NodeInput with owner, got {}."
+ for val in inputs.values():
+ if isinstance(val, NodeInput) and val._owner is not None: # pylint: disable=protected-access
+ continue
+ if isinstance(val, _GroupAttrDict):
+ continue
+ raise ValidationException(
+ message=msg.format(val),
+ no_personal_data_message=msg.format("[val]"),
+ target=ErrorTarget.PIPELINE,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ def __setattr__(
+ self,
+ key: str,
+ value: Union[int, bool, float, str, NodeOutput, PipelineInput, Input],
+ ) -> None:
+ # Extract enum value.
+ value = value.value if isinstance(value, Enum) else value
+ original_input = self.__getattr__(key) # Note that an exception will be raised if the keyword is invalid.
+ if isinstance(original_input, _GroupAttrDict) or isinstance(value, _GroupAttrDict):
+ # Set the value directly if is parameter group.
+ self._set_group_with_type_check(key, GroupInput.custom_class_value_to_attr_dict(value))
+ return
+ original_input._data = original_input._build_data(value)
+
+ def _set_group_with_type_check(self, key: Any, value: Any) -> None:
+ msg = "{!r} is expected to be a parameter group, but got {}."
+ if not isinstance(value, _GroupAttrDict):
+ raise ValidationException(
+ message=msg.format(key, type(value)),
+ no_personal_data_message=msg.format("[key]", "[value_type]"),
+ target=ErrorTarget.PIPELINE,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ self.__setitem__(key, GroupInput.custom_class_value_to_attr_dict(value))
+
+ def __getattr__(self, item: Any) -> NodeInput:
+ res: NodeInput = self.__getitem__(item)
+ return res
+
+
+class _GroupAttrDict(InputsAttrDict):
+ """This class is used for accessing values with instance.some_key."""
+
+ @classmethod
+ def _validate_inputs(cls, inputs: Any) -> None:
+ msg = "Pipeline/component input should be a azure.ai.ml.entities._job.pipeline._io.NodeInput, got {}."
+ for val in inputs.values():
+ if isinstance(val, NodeInput) and val._owner is not None: # pylint: disable=protected-access
+ continue
+ if isinstance(val, _GroupAttrDict):
+ continue
+ # Allow PipelineInput as Group may appear at top level pipeline input.
+ if isinstance(val, PipelineInput):
+ continue
+ raise ValidationException(
+ message=msg.format(val),
+ no_personal_data_message=msg.format("[val]"),
+ target=ErrorTarget.PIPELINE,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ def __getattr__(self, name: K) -> Any:
+ if name not in self:
+ raise UnexpectedAttributeError(keyword=name, keywords=list(self))
+ return super().__getitem__(name)
+
+ def __getitem__(self, item: K) -> Any:
+ # We raise this exception instead of KeyError
+ if item not in self:
+ raise UnexpectedKeywordError(func_name="ParameterGroup", keyword=item, keywords=list(self))
+ return super().__getitem__(item)
+
+ # For Jupyter Notebook auto-completion
+ def __dir__(self) -> List:
+ return list(super().__dir__()) + list(self.keys())
+
+ def flatten(self, group_parameter_name: Optional[str]) -> Dict:
+ # Return the flattened result of self
+
+ group_parameter_name = group_parameter_name if group_parameter_name else ""
+ flattened_parameters = {}
+ msg = "'%s' in parameter group should be a azure.ai.ml.entities._job._io.NodeInput, got '%s'."
+ for k, v in self.items():
+ flattened_name = ".".join([group_parameter_name, k])
+ if isinstance(v, _GroupAttrDict):
+ flattened_parameters.update(v.flatten(flattened_name))
+ elif isinstance(v, NodeInput):
+ flattened_parameters[flattened_name] = v._to_job_input() # pylint: disable=protected-access
+ else:
+ raise ValidationException(
+ message=msg % (flattened_name, type(v)),
+ no_personal_data_message=msg % ("name", "type"),
+ target=ErrorTarget.PIPELINE,
+ )
+ return flattened_parameters
+
+ def insert_group_name_for_items(self, group_name: Any) -> None:
+ # Insert one group name for all items.
+ for v in self.values():
+ if isinstance(v, _GroupAttrDict):
+ v.insert_group_name_for_items(group_name)
+ elif isinstance(v, PipelineInput):
+ # Insert group names for pipeline input
+ v._group_names = [group_name] + v._group_names # pylint: disable=protected-access
+
+
+class OutputsAttrDict(dict):
+ def __init__(self, outputs: dict, **kwargs: Any):
+ for val in outputs.values():
+ if not isinstance(val, NodeOutput) or val._owner is None:
+ msg = "Pipeline/component output should be a azure.ai.ml.dsl.Output with owner, got {}."
+ raise ValidationException(
+ message=msg.format(val),
+ no_personal_data_message=msg.format("[val]"),
+ target=ErrorTarget.PIPELINE,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ super(OutputsAttrDict, self).__init__(**outputs, **kwargs)
+
+ def __getattr__(self, item: Any) -> NodeOutput:
+ return self.__getitem__(item)
+
+ def __getitem__(self, item: Any) -> NodeOutput:
+ if item not in self:
+ # We raise this exception instead of KeyError as OutputsAttrDict doesn't support add new item after
+ # __init__.
+ raise UnexpectedAttributeError(keyword=item, keywords=list(self))
+ res: NodeOutput = super().__getitem__(item)
+ return res
+
+ def __setattr__(self, key: str, value: Union[Data, Output]) -> None:
+ if isinstance(value, Output):
+ mode = value.mode
+ value = Output(type=value.type, path=value.path, mode=mode, name=value.name, version=value.version)
+ original_output = self.__getattr__(key) # Note that an exception will be raised if the keyword is invalid.
+ original_output._data = original_output._build_data(value)
+
+ def __setitem__(self, key: str, value: Output) -> None:
+ return self.__setattr__(key, value)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_io/base.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_io/base.py
new file mode 100644
index 00000000..b17972ae
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_io/base.py
@@ -0,0 +1,848 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import copy
+import re
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, TypeVar, Union, cast, overload
+
+from azure.ai.ml._utils.utils import is_data_binding_expression
+from azure.ai.ml.constants import AssetTypes
+from azure.ai.ml.constants._component import IOConstants
+from azure.ai.ml.entities._assets._artifacts.data import Data
+from azure.ai.ml.entities._assets._artifacts.model import Model
+from azure.ai.ml.entities._inputs_outputs import Input, Output
+from azure.ai.ml.entities._job.pipeline._pipeline_expression import PipelineExpressionMixin
+from azure.ai.ml.entities._util import resolve_pipeline_parameter
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, UserErrorException, ValidationException
+
+# avoid circular import error
+if TYPE_CHECKING:
+ from azure.ai.ml.entities import PipelineJob
+ from azure.ai.ml.entities._builders import BaseNode
+
+T = TypeVar("T")
+
+
+def _build_data_binding(data: Union[str, "PipelineInput", "Output"]) -> Union[str, Output]:
+ """Build input builders to data bindings.
+
+ :param data: The data to build a data binding from
+ :type data: Union[str, PipelineInput, Output]
+ :return: A data binding string if data isn't a str, otherwise data
+ :rtype: str
+ """
+ result: Union[str, Output] = ""
+
+ if isinstance(data, (InputOutputBase)):
+ # Build data binding when data is PipelineInput, Output
+ result = data._data_binding()
+ else:
+ # Otherwise just return the data
+ result = data
+ return result
+
+
+def _resolve_builders_2_data_bindings(
+ data: Union[list, dict, str, "PipelineInput", "Output"]
+) -> Union[dict, list, str, Output]:
+ """Traverse data and build input builders inside it to data bindings.
+
+ :param data: The bindings to resolve
+ :type data: Union[list, dict, str, "PipelineInput", "Output"]
+ :return:
+ * A dict if data was a dict
+ * A list if data was a list
+ * A str otherwise
+ :rtype: Union[list, dict, str]
+ """
+ if isinstance(data, dict):
+ for key, val in data.items():
+ if isinstance(val, (dict, list)):
+ data[key] = _resolve_builders_2_data_bindings(val)
+ else:
+ data[key] = _build_data_binding(val)
+ return data
+ if isinstance(data, list):
+ resolved_data = []
+ for val in data:
+ resolved_data.append(_resolve_builders_2_data_bindings(val))
+ return resolved_data
+ return _build_data_binding(data)
+
+
+def _data_to_input(data: Union[Data, Model]) -> Input:
+ """Convert a Data object to an Input object.
+
+ :param data: The data to convert
+ :type data: Data
+ :return: The Input object
+ :rtype: Input
+ """
+ if data.id:
+ return Input(type=data.type, path=data.id)
+ return Input(type=data.type, path=f"{data.name}:{data.version}")
+
+
+class InputOutputBase(ABC):
+ # TODO: refine this code, always use _data to store builder level settings and use _meta to store definition
+ # TODO: when _data missing, return value from _meta
+
+ def __init__(
+ self,
+ meta: Optional[Union[Input, Output]],
+ data: Optional[Union[int, bool, float, str, Input, Output, "PipelineInput"]],
+ default_data: Optional[Union[int, bool, float, str, Input, Output]] = None,
+ **kwargs: Any,
+ ):
+ """Base class of input & output.
+
+ :param meta: Metadata of this input/output, eg: type, min, max, etc.
+ :type meta: Union[Input, Output]
+ :param data: Actual value of input/output, None means un-configured data.
+ :type data: Union[None, int, bool, float, str,
+ azure.ai.ml.Input,
+ azure.ai.ml.Output]
+ :param default_data: default value of input/output, None means un-configured data.
+ :type default_data: Union[None, int, bool, float, str,
+ azure.ai.ml.Input,
+ azure.ai.ml.Output]
+ """
+ self._meta = meta
+ self._original_data = data
+ self._data: Any = self._build_data(data)
+ self._default_data = default_data
+ self._type: str = meta.type if meta is not None else kwargs.pop("type", None)
+ self._mode = self._get_mode(original_data=data, data=self._data, kwargs=kwargs)
+ self._description = (
+ self._data.description
+ if self._data is not None and hasattr(self._data, "description") and self._data.description
+ else kwargs.pop("description", None)
+ )
+ # TODO: remove this
+ self._attribute_map: Dict = {}
+ self._name: Optional[str] = ""
+ self._version: Optional[str] = ""
+ super(InputOutputBase, self).__init__(**kwargs)
+
+ @abstractmethod
+ def _build_data(self, data: T) -> Union[T, str, Input, "InputOutputBase"]:
+ """Validate if data matches type and translate it to Input/Output acceptable type.
+
+ :param data: The data
+ :type data: T
+ :return: The built data
+ :rtype: Union[T, str, Input, InputOutputBase]
+ """
+
+ @abstractmethod
+ def _build_default_data(self) -> None:
+ """Build default data when data not configured."""
+
+ @property
+ def type(self) -> str:
+ """Type of input/output.
+
+ :return: The type
+ :rtype: str
+ """
+ return self._type
+
+ @type.setter
+ def type(self, type: Any) -> None: # pylint: disable=redefined-builtin
+ # For un-configured input/output, we build a default data entry for them.
+ self._build_default_data()
+ self._type = type
+ if isinstance(self._data, (Input, Output)):
+ self._data.type = type
+ elif self._data is not None and not isinstance(
+ self._data, (int, float, str)
+ ): # when type of self._data is InputOutputBase or its child class
+ self._data._type = type
+
+ @property
+ def mode(self) -> Optional[str]:
+ return self._mode
+
+ @mode.setter
+ def mode(self, mode: Optional[str]) -> None:
+ # For un-configured input/output, we build a default data entry for them.
+ self._build_default_data()
+ self._mode = mode
+ if isinstance(self._data, (Input, Output)):
+ self._data.mode = mode
+ elif self._data is not None and not isinstance(self._data, (int, float, str)):
+ self._data._mode = mode
+
+ @property
+ def description(self) -> Any:
+ return self._description
+
+ @description.setter
+ def description(self, description: str) -> None:
+ # For un-configured input/output, we build a default data entry for them.
+ self._build_default_data()
+ self._description = description
+ if isinstance(self._data, (Input, Output)):
+ self._data.description = description
+ elif self._data is not None and not isinstance(self._data, (int, float, str)):
+ self._data._description = description
+
+ @property
+ def path(self) -> Optional[str]:
+ # This property is introduced for static intellisense.
+ if hasattr(self._data, "path"):
+ if self._data is not None and not isinstance(self._data, (int, float, str)):
+ res: Optional[str] = self._data.path
+ return res
+ msg = f"{type(self._data)} does not have path."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.PIPELINE,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ @path.setter
+ def path(self, path: str) -> None:
+ # For un-configured input/output, we build a default data entry for them.
+ self._build_default_data()
+ if hasattr(self._data, "path"):
+ if self._data is not None and not isinstance(self._data, (int, float, str)):
+ self._data.path = path
+ else:
+ msg = f"{type(self._data)} does not support setting path."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.PIPELINE,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ def _data_binding(self) -> str:
+ """Return data binding string representation for this input/output.
+
+ :return: The data binding string
+ :rtype: str
+ """
+ raise NotImplementedError()
+
+ # Why did we have this function? It prevents the DictMixin from being applied.
+ # Unclear if we explicitly do NOT want the mapping protocol to be applied to this, or it this was just
+ # confirmation that it didn't at the time.
+ def keys(self) -> None:
+ # This property is introduced to raise catchable Exception in marshmallow mapping validation trial.
+ raise TypeError(f"'{type(self).__name__}' object is not a mapping")
+
+ def __str__(self) -> str:
+ try:
+ return self._data_binding()
+ except AttributeError:
+ return super(InputOutputBase, self).__str__()
+
+ def __hash__(self) -> int:
+ return id(self)
+
+ @classmethod
+ def _get_mode(
+ cls,
+ original_data: Optional[Union[int, bool, float, str, Input, Output, "PipelineInput"]],
+ data: Optional[Union[int, bool, float, str, Input, Output]],
+ kwargs: dict,
+ ) -> Optional[str]:
+ """Get mode of this input/output builder.
+
+ :param original_data: Original value of input/output.
+ :type original_data: Union[None, int, bool, float, str
+ azure.ai.ml.Input,
+ azure.ai.ml.Output,
+ azure.ai.ml.entities._job.pipeline._io.PipelineInput]
+ :param data: Built input/output data.
+ :type data: Union[None, int, bool, float, str
+ azure.ai.ml.Input,
+ azure.ai.ml.Output]
+ :param kwargs: The kwargs
+ :type kwargs: Dict
+ :return: The mode
+ :rtype: Optional[str]
+ """
+ # pipeline level inputs won't pass mode to bound node level inputs
+ if isinstance(original_data, PipelineInput):
+ return None
+ return data.mode if data is not None and hasattr(data, "mode") else kwargs.pop("mode", None)
+
+ @property
+ def _is_primitive_type(self) -> bool:
+ return self.type in IOConstants.PRIMITIVE_STR_2_TYPE
+
+
+class NodeInput(InputOutputBase):
+ """Define one input of a Component."""
+
+ def __init__(
+ self,
+ port_name: str,
+ meta: Optional[Input],
+ *,
+ data: Optional[Union[int, bool, float, str, Output, "PipelineInput", Input]] = None,
+ # TODO: Bug Item number: 2883405
+ owner: Optional[Union["BaseComponent", "PipelineJob"]] = None, # type: ignore
+ **kwargs: Any,
+ ):
+ """Initialize an input of a component.
+
+ :param name: The name of the input.
+ :type name: str
+ :param meta: Metadata of this input, eg: type, min, max, etc.
+ :type meta: Input
+ :param data: The input data. Valid types include int, bool, float, str,
+ Output of another component or pipeline input and Input.
+ Note that the output of another component or pipeline input associated should be reachable in the scope
+ of current pipeline. Input is introduced to support case like
+ TODO: new examples
+ component.inputs.xxx = Input(path="arm_id")
+ :type data: Union[int, bool, float, str
+ azure.ai.ml.Output,
+ azure.ai.ml.Input]
+ :param owner: The owner component of the input, used to calculate binding.
+ :type owner: Union[azure.ai.ml.entities.BaseNode, azure.ai.ml.entities.PipelineJob]
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: dict
+ """
+ # TODO: validate data matches type in meta
+ # TODO: validate supported data
+ self._port_name = port_name
+ self._owner = owner
+ super().__init__(meta=meta, data=data, **kwargs)
+
+ def _build_default_data(self) -> None:
+ """Build default data when input not configured."""
+ if self._data is None:
+ self._data = Input()
+
+ def _build_data(self, data: T) -> Union[T, str, Input, InputOutputBase]:
+ """Build input data according to assigned input
+
+ eg: node.inputs.key = data
+
+ :param data: The data
+ :type data: T
+ :return: The built data
+ :rtype: Union[T, str, Input, "PipelineInput", "NodeOutput"]
+ """
+ _data: Union[T, str, NodeOutput] = resolve_pipeline_parameter(data)
+ if _data is None:
+ return _data
+ # Unidiomatic typecheck: Checks that data is _exactly_ this type, and not potentially a subtype
+ if type(_data) is NodeInput: # pylint: disable=unidiomatic-typecheck
+ msg = "Can not bind input to another component's input."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.PIPELINE,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ if isinstance(_data, (PipelineInput, NodeOutput)):
+ # If value is input or output, it's a data binding, we require it have a owner so we can convert it to
+ # a data binding, eg: ${{inputs.xxx}}
+ if isinstance(_data, NodeOutput) and _data._owner is None:
+ msg = "Setting input binding {} to output without owner is not allowed."
+ raise ValidationException(
+ message=msg.format(_data),
+ no_personal_data_message=msg.format("[_data]"),
+ target=ErrorTarget.PIPELINE,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ return _data
+ # for data binding case, set is_singular=False for case like "${{parent.inputs.job_in_folder}}/sample1.csv"
+ if isinstance(_data, Input) or is_data_binding_expression(_data, is_singular=False):
+ return _data
+ if isinstance(_data, (Data, Model)):
+ return _data_to_input(_data)
+ # self._meta.type could be None when sub pipeline has no annotation
+ if isinstance(self._meta, Input) and self._meta.type and not self._meta._is_primitive_type:
+ if isinstance(_data, str):
+ return Input(type=self._meta.type, path=_data)
+ msg = "only path input is supported now but get {}: {}."
+ raise UserErrorException(
+ message=msg.format(type(_data), _data),
+ no_personal_data_message=msg.format(type(_data), "[_data]"),
+ )
+ return _data
+
+ def _to_job_input(self) -> Optional[Union[Input, str, Output]]:
+ """convert the input to Input, this logic will change if backend contract changes."""
+ result: Optional[Union[Input, str, Output]] = None
+
+ if self._data is None:
+ # None data means this input is not configured.
+ result = None
+ elif isinstance(self._data, (PipelineInput, NodeOutput)):
+ # Build data binding when data is PipelineInput, Output
+ result = Input(path=self._data._data_binding(), mode=self.mode)
+ elif is_data_binding_expression(self._data):
+ result = Input(path=self._data, mode=self.mode)
+ else:
+ data_binding = _build_data_binding(self._data)
+ if is_data_binding_expression(self._data):
+ result = Input(path=data_binding, mode=self.mode)
+ else:
+ result = data_binding
+ # TODO: validate is self._data is supported
+
+ return result
+
+ def _data_binding(self) -> str:
+ msg = "Input binding {} can only come from a pipeline, currently got {}"
+ # call type(self._owner) to avoid circular import
+ raise ValidationException(
+ message=msg.format(self._port_name, type(self._owner)),
+ target=ErrorTarget.PIPELINE,
+ no_personal_data_message=msg.format("[port_name]", "[owner]"),
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ def _copy(self, owner: Any) -> "NodeInput":
+ return NodeInput(
+ port_name=self._port_name,
+ data=self._data,
+ owner=owner,
+ meta=cast(Input, self._meta),
+ )
+
+ def _deepcopy(self) -> "NodeInput":
+ return NodeInput(
+ port_name=self._port_name,
+ data=copy.copy(self._data),
+ owner=self._owner,
+ meta=cast(Input, self._meta),
+ )
+
+ def _get_data_owner(self) -> Optional["BaseNode"]:
+ """Gets the data owner of the node
+
+ Note: This only works for @pipeline, not for YAML pipeline.
+
+ Note: Inner step will be returned as the owner when node's input is from sub pipeline's output.
+ @pipeline
+ def sub_pipeline():
+ inner_node = component_func()
+ return inner_node.outputs
+
+ @pipeline
+ def root_pipeline():
+ pipeline_node = sub_pipeline()
+ node = copy_files_component_func(input_dir=pipeline_node.outputs.output_dir)
+ owner = node.inputs.input_dir._get_data_owner()
+ assert owner == pipeline_node.nodes[0]
+
+ :return: The node if Input is from another node's output. Returns None for literal value.
+ :rtype: Optional[BaseNode]
+ """
+ from azure.ai.ml.entities import Pipeline
+ from azure.ai.ml.entities._builders import BaseNode
+
+ def _resolve_data_owner(data: Any) -> Optional["BaseNode"]:
+ if isinstance(data, BaseNode) and not isinstance(data, Pipeline):
+ return data
+ while isinstance(data, PipelineInput):
+ # for pipeline input, it's original value(can be literal value or another node's output)
+ # is stored in _original_data
+ return _resolve_data_owner(data._original_data)
+ if isinstance(data, NodeOutput):
+ if isinstance(data._owner, Pipeline):
+ # for input from subgraph's output, trace back to inner node
+ return _resolve_data_owner(data._binding_output)
+ # for input from another node's output, return the node
+ return _resolve_data_owner(data._owner)
+ return None
+
+ return _resolve_data_owner(self._data)
+
+
+class NodeOutput(InputOutputBase, PipelineExpressionMixin):
+ """Define one output of a Component."""
+
+ def __init__(
+ self,
+ port_name: str,
+ meta: Optional[Union[Input, Output]],
+ *,
+ data: Optional[Union[Output, str]] = None,
+ # TODO: Bug Item number: 2883405
+ owner: Optional[Union["BaseComponent", "PipelineJob"]] = None, # type: ignore
+ binding_output: Optional["NodeOutput"] = None,
+ **kwargs: Any,
+ ):
+ """Initialize an Output of a component.
+
+ :param port_name: The port_name of the output.
+ :type port_name: str
+ :param name: The name used to register NodeOutput/PipelineOutput data.
+ :type name: str
+ :param version: The version used to register NodeOutput/PipelineOutput data.
+ :ype version: str
+ :param data: The output data. Valid types include str, Output
+ :type data: Union[str
+ azure.ai.ml.entities.Output]
+ :param mode: The mode of the output.
+ :type mode: str
+ :param owner: The owner component of the output, used to calculate binding.
+ :type owner: Union[azure.ai.ml.entities.BaseNode, azure.ai.ml.entities.PipelineJob]
+ :param binding_output: The node output bound to pipeline output, only available for pipeline.
+ :type binding_output: azure.ai.ml.entities.NodeOutput
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: dict
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if object cannot be successfully validated.
+ Details will be provided in the error message.
+ """
+ # Allow inline output binding with string, eg: "component_out_path_1": "${{parents.outputs.job_out_data_1}}"
+ if data is not None and not isinstance(data, (Output, str)):
+ msg = "Got unexpected type for output: {}."
+ raise ValidationException(
+ message=msg.format(data),
+ target=ErrorTarget.PIPELINE,
+ no_personal_data_message=msg.format("[data]"),
+ )
+ super().__init__(meta=meta, data=data, **kwargs)
+ self._port_name = port_name
+ self._owner = owner
+ self._name: Optional[str] = self._data.name if isinstance(self._data, Output) else None
+ self._version: Optional[str] = self._data.version if isinstance(self._data, Output) else None
+
+ self._assert_name_and_version()
+
+ # store original node output to be able to trace back to inner node from a pipeline output builder.
+ self._binding_output = binding_output
+
+ @property
+ def port_name(self) -> str:
+ """The output port name, eg: node.outputs.port_name.
+
+ :return: The port name
+ :rtype: str
+ """
+ return self._port_name
+
+ @property
+ def name(self) -> Optional[str]:
+ """Used in registering output data.
+
+ :return: The output name
+ :rtype: str
+ """
+ return self._name
+
+ @name.setter
+ def name(self, name: str) -> None:
+ """Assigns the name to NodeOutput/PipelineOutput and builds data according to the name.
+
+ :param name: The new name
+ :type name: str
+ """
+ self._build_default_data()
+ self._name = name
+ if isinstance(self._data, Output):
+ self._data.name = name
+ elif isinstance(self._data, InputOutputBase):
+ self._data._name = name
+ else:
+ raise UserErrorException(
+ f"We support self._data of Input, Output, InputOutputBase, NodeOutput and NodeInput,"
+ f"but got type: {type(self._data)}."
+ )
+
+ @property
+ def version(self) -> Optional[str]:
+ """Used in registering output data.
+
+ :return: The output data
+ :rtype: str
+ """
+ return self._version
+
+ @version.setter
+ def version(self, version: str) -> None:
+ """Assigns the version to NodeOutput/PipelineOutput and builds data according to the version.
+
+ :param version: The new version
+ :type version: str
+ """
+ self._build_default_data()
+ self._version = version
+ if isinstance(self._data, Output):
+ self._data.version = version
+ elif isinstance(self._data, InputOutputBase):
+ self._data._version = version
+ else:
+ raise UserErrorException(
+ f"We support self._data of Input, Output, InputOutputBase, NodeOutput and NodeInput,"
+ f"but got type: {type(self._data)}."
+ )
+
+ @property
+ def path(self) -> Any:
+ # For node output path,
+ if self._data is not None and hasattr(self._data, "path"):
+ return self._data.path
+ return None
+
+ @path.setter
+ def path(self, path: Optional[str]) -> None:
+ # For un-configured output, we build a default data entry for them.
+ self._build_default_data()
+ if self._data is not None and hasattr(self._data, "path"):
+ self._data.path = path
+ else:
+ # YAML job will have string output binding and do not support setting path for it.
+ msg = f"{type(self._data)} does not support setting path."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.PIPELINE,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ def _assert_name_and_version(self) -> None:
+ if self.name and not (re.match("^[A-Za-z0-9_-]*$", self.name) and len(self.name) <= 255):
+ raise UserErrorException(
+ f"The output name {self.name} can only contain alphanumeric characters, dashes and underscores, "
+ f"with a limit of 255 characters."
+ )
+ if self.version and not self.name:
+ raise UserErrorException("Output name is required when output version is specified.")
+
+ def _build_default_data(self) -> None:
+ """Build default data when output not configured."""
+ if self._data is None:
+ # _meta will be None when node._component is not a Component object
+ # so we just leave the type inference work to backend
+ self._data = Output(type=None) # type: ignore[call-overload]
+
+ def _build_data(self, data: T) -> Any:
+ """Build output data according to assigned input, eg: node.outputs.key = data
+
+ :param data: The data
+ :type data: T
+ :return: `data`
+ :rtype: T
+ """
+ if data is None:
+ return data
+ if not isinstance(data, (Output, str)):
+ msg = f"{self.__class__.__name__} only allow set {Output.__name__} object, {type(data)} is not supported."
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.PIPELINE,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ res: T = cast(T, data)
+ return res
+
+ def _to_job_output(self) -> Optional[Output]:
+ """Convert the output to Output, this logic will change if backend contract changes."""
+ if self._data is None:
+ # None data means this output is not configured.
+ result = None
+ elif isinstance(self._data, str):
+ result = Output(
+ type=AssetTypes.URI_FOLDER, path=self._data, mode=self.mode, name=self.name, version=self.version
+ )
+ elif isinstance(self._data, Output):
+ result = self._data
+ elif isinstance(self._data, PipelineOutput):
+ result = Output(
+ type=AssetTypes.URI_FOLDER,
+ path=self._data._data_binding(),
+ mode=self.mode,
+ name=self._data.name,
+ version=self._data.version,
+ description=self.description,
+ )
+ else:
+ msg = "Got unexpected type for output: {}."
+ raise ValidationException(
+ message=msg.format(self._data),
+ target=ErrorTarget.PIPELINE,
+ no_personal_data_message=msg.format("[data]"),
+ )
+ return result
+
+ def _data_binding(self) -> str:
+ if self._owner is not None:
+ return f"${{{{parent.jobs.{self._owner.name}.outputs.{self._port_name}}}}}"
+
+ return ""
+
+ def _copy(self, owner: Any) -> "NodeOutput":
+ return NodeOutput(
+ port_name=self._port_name,
+ data=cast(Output, self._data),
+ owner=owner,
+ meta=self._meta,
+ )
+
+ def _deepcopy(self) -> "NodeOutput":
+ return NodeOutput(
+ port_name=self._port_name,
+ data=cast(Output, copy.copy(self._data)),
+ owner=self._owner,
+ meta=self._meta,
+ binding_output=self._binding_output,
+ )
+
+
+class PipelineInput(NodeInput, PipelineExpressionMixin):
+ """Define one input of a Pipeline."""
+
+ def __init__(self, name: str, meta: Optional[Input], group_names: Optional[List[str]] = None, **kwargs: Any):
+ """Initialize a PipelineInput.
+
+ :param name: The name of the input.
+ :type name: str
+ :param meta: Metadata of this input, eg: type, min, max, etc.
+ :type meta: Input
+ :param group_names: The input parameter's group names.
+ :type group_names: List[str]
+ """
+ super(PipelineInput, self).__init__(port_name=name, meta=meta, **kwargs)
+ self._group_names = group_names if group_names else []
+
+ def result(self) -> Any:
+ """Return original value of pipeline input.
+
+ :return: The original value of pipeline input
+ :rtype: Any
+
+ Example:
+
+ .. code-block:: python
+
+ @pipeline
+ def pipeline_func(param1):
+ # node1's param1 will get actual value of param1 instead of a input binding.
+ node1 = component_func(param1=param1.result())
+ """
+
+ # use this to break self loop
+ original_data_cache: Set = set()
+ original_data = self._original_data
+ while isinstance(original_data, PipelineInput) and original_data not in original_data_cache:
+ original_data_cache.add(original_data)
+ original_data = original_data._original_data
+ return original_data
+
+ def __str__(self) -> str:
+ return self._data_binding()
+
+ @overload
+ def _build_data(self, data: Union[Model, Data]) -> Input: ...
+
+ @overload
+ def _build_data(self, data: T) -> Any: ...
+
+ def _build_data(self, data: Union[Model, Data, T]) -> Any:
+ """Build data according to input type.
+
+ :param data: The data
+ :type data: Union[Model, Data, T]
+ :return:
+ * Input if data is a Model or Data
+ * data otherwise
+ :rtype: Union[Input, T]
+ """
+ if data is None:
+ return data
+ # Unidiomatic typecheck: Checks that data is _exactly_ this type, and not potentially a subtype
+ if type(data) is NodeInput: # pylint: disable=unidiomatic-typecheck
+ msg = "Can not bind input to another component's input."
+ raise ValidationException(message=msg, no_personal_data_message=msg, target=ErrorTarget.PIPELINE)
+ if isinstance(data, (PipelineInput, NodeOutput)):
+ # If value is input or output, it's a data binding, owner is required to convert it to
+ # a data binding, eg: ${{parent.inputs.xxx}}
+ if isinstance(data, NodeOutput) and data._owner is None:
+ msg = "Setting input binding {} to output without owner is not allowed."
+ raise ValidationException(
+ message=msg.format(data),
+ no_personal_data_message=msg.format("[data]"),
+ target=ErrorTarget.PIPELINE,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ return data
+ if isinstance(data, (Data, Model)):
+ # If value is Data, we convert it to an corresponding Input
+ return _data_to_input(data)
+ return data
+
+ def _data_binding(self) -> str:
+ full_name = "%s.%s" % (".".join(self._group_names), self._port_name) if self._group_names else self._port_name
+ return f"${{{{parent.inputs.{full_name}}}}}"
+
+ def _to_input(self) -> Optional[Union[Input, Output]]:
+ """Convert pipeline input to component input for pipeline component.
+
+ :return: The component input
+ :rtype: Input
+ """
+ if self._data is None:
+ # None data means this input is not configured.
+ return self._meta
+ data_type = self._data.type if isinstance(self._data, Input) else None
+ # If type is asset type, return data type without default.
+ # Else infer type from data and set it as default.
+ if data_type and data_type.lower() in AssetTypes.__dict__.values():
+ if not isinstance(self._data, (int, float, str)):
+ result = Input(type=data_type, mode=self._data.mode)
+ elif type(self._data) in IOConstants.PRIMITIVE_TYPE_2_STR:
+ result = Input(
+ type=IOConstants.PRIMITIVE_TYPE_2_STR[type(self._data)],
+ default=self._data,
+ )
+ else:
+ msg = f"Unsupported Input type {type(self._data)} detected when translate job to component."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.PIPELINE,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ return result # pylint: disable=possibly-used-before-assignment
+
+
+class PipelineOutput(NodeOutput):
+ """Define one output of a Pipeline."""
+
+ def _to_job_output(self) -> Optional[Output]:
+ result: Optional[Output] = None
+ if isinstance(self._data, Output):
+ # For pipeline output with type Output, always pass to backend.
+ result = self._data
+ elif self._data is None and self._meta and self._meta.type:
+ # For un-configured pipeline output with meta, we need to return Output with accurate type,
+ # so it won't default to uri_folder.
+ result = Output(type=self._meta.type, mode=self._meta.mode, description=self._meta.description)
+ else:
+ result = super(PipelineOutput, self)._to_job_output()
+ # Copy meta type to avoid built output's None type default to uri_folder.
+ if self.type and result is not None and not result.type:
+ result.type = self.type
+ return result
+
+ def _data_binding(self) -> str:
+ return f"${{{{parent.outputs.{self._port_name}}}}}"
+
+ def _to_output(self) -> Optional[Output]:
+ """Convert pipeline output to component output for pipeline component."""
+ if self._data is None:
+ # None data means this input is not configured.
+ return None
+ if isinstance(self._meta, Output):
+ return self._meta
+ # Assign type directly as we didn't have primitive output type for now.
+ if not isinstance(self._data, (int, float, str)):
+ return Output(type=self._data.type, mode=self._data.mode)
+ return Output()
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_io/mixin.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_io/mixin.py
new file mode 100644
index 00000000..6c3d9357
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_io/mixin.py
@@ -0,0 +1,623 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import copy
+from typing import Any, Dict, List, Optional, Tuple, Type, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import JobInput as RestJobInput
+from azure.ai.ml._restclient.v2023_04_01_preview.models import JobOutput as RestJobOutput
+from azure.ai.ml.constants._component import ComponentJobConstants
+from azure.ai.ml.entities._inputs_outputs import GroupInput, Input, Output
+from azure.ai.ml.entities._util import copy_output_setting
+from azure.ai.ml.exceptions import ErrorTarget, ValidationErrorType, ValidationException
+
+from ..._input_output_helpers import (
+ from_rest_data_outputs,
+ from_rest_inputs_to_dataset_literal,
+ to_rest_data_outputs,
+ to_rest_dataset_literal_inputs,
+)
+from .._pipeline_job_helpers import from_dict_to_rest_io, process_sdk_component_job_io
+from .attr_dict import InputsAttrDict, OutputsAttrDict, _GroupAttrDict
+from .base import NodeInput, NodeOutput, PipelineInput, PipelineOutput
+
+
+class NodeIOMixin:
+ """Provides ability to wrap node inputs/outputs and build data bindings
+ dynamically."""
+
+ @classmethod
+ def _get_supported_inputs_types(cls) -> Optional[Any]:
+ return None
+
+ @classmethod
+ def _get_supported_outputs_types(cls) -> Optional[Any]:
+ return None
+
+ @classmethod
+ def _validate_io(cls, value: Any, allowed_types: Optional[tuple], *, key: Optional[str] = None) -> None:
+ if allowed_types is None:
+ return
+
+ if value is None or isinstance(value, allowed_types):
+ pass
+ else:
+ msg = "Expecting {} for input/output {}, got {} instead."
+ raise ValidationException(
+ message=msg.format(allowed_types, key, type(value)),
+ no_personal_data_message=msg.format(allowed_types, "[key]", type(value)),
+ target=ErrorTarget.PIPELINE,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+ def _build_input(
+ self,
+ name: str,
+ meta: Optional[Input],
+ data: Optional[Union[dict, int, bool, float, str, Output, "PipelineInput", Input]],
+ ) -> NodeInput:
+ # output mode of last node should not affect input mode of next node
+ if isinstance(data, NodeOutput):
+ # Decoupled input and output
+ # value = copy.deepcopy(value)
+ data = data._deepcopy() # pylint: disable=protected-access
+ data.mode = None
+ elif isinstance(data, dict):
+ # Use type comparison instead of is_instance to skip _GroupAttrDict
+ # when loading from yaml io will be a dict,
+ # like {'job_data_path': '${{parent.inputs.pipeline_job_data_path}}'}
+ # parse dict to allowed type
+ data = Input(**data)
+
+ # parameter group can be of custom type, so we don't check it here
+ if meta is not None and not isinstance(meta, GroupInput):
+ self._validate_io(data, self._get_supported_inputs_types(), key=name)
+ return NodeInput(port_name=name, meta=meta, data=data, owner=self)
+
+ def _build_output(self, name: str, meta: Optional[Output], data: Optional[Union[Output, str]]) -> NodeOutput:
+ if isinstance(data, dict):
+ data = Output(**data)
+
+ self._validate_io(data, self._get_supported_outputs_types(), key=name)
+ # For un-configured outputs, settings it to None, so we won't pass extra fields(eg: default mode)
+ return NodeOutput(port_name=name, meta=meta, data=data, owner=self)
+
+ # pylint: disable=unused-argument
+ def _get_default_input_val(self, val: Any): # type: ignore
+ # use None value as data placeholder for unfilled inputs.
+ # server side will fill the default value
+ return None
+
+ def _build_inputs_dict(
+ self,
+ inputs: Dict[str, Union[Input, str, bool, int, float]],
+ *,
+ input_definition_dict: Optional[dict] = None,
+ ) -> InputsAttrDict:
+ """Build an input attribute dict so user can get/set inputs by
+ accessing attribute, eg: node1.inputs.xxx.
+
+ :param inputs: Provided kwargs when parameterizing component func.
+ :type inputs: Dict[str, Union[Input, str, bool, int, float]]
+ :keyword input_definition_dict: Static input definition dict. If not provided, will build inputs without meta.
+ :paramtype input_definition_dict: dict
+ :return: Built dynamic input attribute dict.
+ :rtype: InputsAttrDict
+ """
+ if input_definition_dict is not None:
+ # TODO: validate inputs.keys() in input_definitions.keys()
+ input_dict = {}
+ for key, val in input_definition_dict.items():
+ if key in inputs.keys():
+ # If input is set through component functions' kwargs, create an input object with real value.
+ data = inputs[key]
+ else:
+ data = self._get_default_input_val(val) # pylint: disable=assignment-from-none
+
+ val = self._build_input(name=key, meta=val, data=data)
+ input_dict[key] = val
+ else:
+ input_dict = {key: self._build_input(name=key, meta=None, data=val) for key, val in inputs.items()}
+ return InputsAttrDict(input_dict)
+
+ def _build_outputs_dict(
+ self, outputs: Dict, *, output_definition_dict: Optional[dict] = None, none_data: bool = False
+ ) -> OutputsAttrDict:
+ """Build an output attribute dict so user can get/set outputs by
+ accessing attribute, eg: node1.outputs.xxx.
+
+ :param outputs: Provided kwargs when parameterizing component func.
+ :type outputs: Dict[str, Output]
+ :keyword output_definition_dict: Static output definition dict.
+ :paramtype output_definition_dict: Dict
+ :keyword none_data: If True, will set output data to None.
+ :paramtype none_data: bool
+ :return: Built dynamic output attribute dict.
+ :rtype: OutputsAttrDict
+ """
+ if output_definition_dict is not None:
+ # TODO: check if we need another way to mark a un-configured output instead of just set None.
+ # Create None as data placeholder for all outputs.
+ output_dict = {}
+ for key, val in output_definition_dict.items():
+ if key in outputs.keys():
+ # If output has given value, create an output object with real value.
+ val = self._build_output(name=key, meta=val, data=outputs[key])
+ else:
+ val = self._build_output(name=key, meta=val, data=None)
+ output_dict[key] = val
+ else:
+ output_dict = {}
+ for key, val in outputs.items():
+ output_val = self._build_output(name=key, meta=None, data=val if not none_data else None)
+ output_dict[key] = output_val
+ return OutputsAttrDict(output_dict)
+
+ def _build_inputs(self) -> Dict:
+ """Build inputs of this component to a dict dict which maps output to
+ actual value.
+
+ The built input dict will have same input format as other jobs, eg:
+ {
+ "input_data": Input(path="path/to/input/data", mode="Mount"),
+ "input_value": 10,
+ "learning_rate": "${{jobs.step1.inputs.learning_rate}}"
+ }
+
+ :return: The input dict
+ :rtype: Dict[str, Union[Input, str, bool, int, float]]
+ """
+ inputs = {}
+ # pylint: disable=redefined-builtin
+ for name, input in self.inputs.items(): # type: ignore
+ if isinstance(input, _GroupAttrDict):
+ # Flatten group inputs into inputs dict
+ inputs.update(input.flatten(group_parameter_name=name))
+ continue
+ inputs[name] = input._to_job_input() # pylint: disable=protected-access
+ return inputs
+
+ def _build_outputs(self) -> Dict[str, Output]:
+ """Build outputs of this component to a dict which maps output to
+ actual value.
+
+ The built output dict will have same output format as other jobs, eg:
+ {
+ "eval_output": "${{jobs.eval.outputs.eval_output}}"
+ }
+
+ :return: The output dict
+ :rtype: Dict[str, Output]
+ """
+ outputs = {}
+ for name, output in self.outputs.items(): # type: ignore
+ if isinstance(output, NodeOutput):
+ output = output._to_job_output() # pylint: disable=protected-access
+ outputs[name] = output
+ # Remove non-configured output
+ return {k: v for k, v in outputs.items() if v is not None}
+
+ def _to_rest_inputs(self) -> Dict[str, Dict]:
+ """Translate input builders to rest input dicts.
+
+ The built dictionary's format aligns with component job's input yaml, eg:
+ {
+ "input_data": {"data": {"path": "path/to/input/data"}, "mode"="Mount"},
+ "input_value": 10,
+ "learning_rate": "${{jobs.step1.inputs.learning_rate}}"
+ }
+
+ :return: The REST inputs
+ :rtype: Dict[str, Dict]
+ """
+ built_inputs = self._build_inputs()
+ return self._input_entity_to_rest_inputs(input_entity=built_inputs)
+
+ @classmethod
+ def _input_entity_to_rest_inputs(cls, input_entity: Dict[str, Input]) -> Dict[str, Dict]:
+ # Convert io entity to rest io objects
+ input_bindings, dataset_literal_inputs = process_sdk_component_job_io(
+ input_entity, [ComponentJobConstants.INPUT_PATTERN]
+ )
+
+ # parse input_bindings to InputLiteral(value=str(binding))
+ rest_inputs = {**input_bindings, **dataset_literal_inputs}
+ # Note: The function will only be called from BaseNode,
+ # and job_type is used to enable dot in pipeline job input keys,
+ # so pass job_type as None directly here.
+ rest_inputs = to_rest_dataset_literal_inputs(rest_inputs, job_type=None)
+
+ # convert rest io to dict
+ rest_dataset_literal_inputs = {}
+ for name, val in rest_inputs.items():
+ rest_dataset_literal_inputs[name] = val.as_dict()
+ if hasattr(val, "mode") and val.mode:
+ rest_dataset_literal_inputs[name].update({"mode": val.mode.value})
+ return rest_dataset_literal_inputs
+
+ def _to_rest_outputs(self) -> Dict[str, Dict]:
+ """Translate output builders to rest output dicts.
+
+ The built dictionary's format aligns with component job's output yaml, eg:
+ {"eval_output": "${{jobs.eval.outputs.eval_output}}"}
+
+ :return: The REST outputs
+ :rtype: Dict[str, Dict]
+ """
+ built_outputs = self._build_outputs()
+
+ # Convert io entity to rest io objects
+ output_bindings, data_outputs = process_sdk_component_job_io(
+ built_outputs, [ComponentJobConstants.OUTPUT_PATTERN]
+ )
+ rest_data_outputs = to_rest_data_outputs(data_outputs)
+
+ # convert rest io to dict
+ # parse output_bindings to {"value": binding, "type": "literal"} since there's no mode for it
+ rest_output_bindings = {}
+ for key, binding in output_bindings.items():
+ rest_output_bindings[key] = {"value": binding["value"], "type": "literal"}
+ if "mode" in binding:
+ rest_output_bindings[key].update({"mode": binding["mode"].value})
+ if "name" in binding:
+ rest_output_bindings[key].update({"name": binding["name"]})
+ if "version" in binding:
+ rest_output_bindings[key].update({"version": binding["version"]})
+
+ def _rename_name_and_version(output_dict: Dict) -> Dict:
+ # NodeOutput can only be registered with name and version, therefore we rename here
+ if "asset_name" in output_dict.keys():
+ output_dict["name"] = output_dict.pop("asset_name")
+ if "asset_version" in output_dict.keys():
+ output_dict["version"] = output_dict.pop("asset_version")
+ return output_dict
+
+ rest_data_outputs = {name: _rename_name_and_version(val.as_dict()) for name, val in rest_data_outputs.items()}
+ self._update_output_types(rest_data_outputs)
+ rest_data_outputs.update(rest_output_bindings)
+ return rest_data_outputs
+
+ @classmethod
+ def _from_rest_inputs(cls, inputs: Dict) -> Dict[str, Union[Input, str, bool, int, float]]:
+ """Load inputs from rest inputs.
+
+ :param inputs: The REST inputs
+ :type inputs: Dict[str, Union[str, dict]]
+ :return: Input dict
+ :rtype: Dict[str, Union[Input, str, bool, int, float]]
+ """
+
+ # JObject -> RestJobInput/RestJobOutput
+ input_bindings, rest_inputs = from_dict_to_rest_io(inputs, RestJobInput, [ComponentJobConstants.INPUT_PATTERN])
+
+ # RestJobInput/RestJobOutput -> Input/Output
+ dataset_literal_inputs = from_rest_inputs_to_dataset_literal(rest_inputs)
+
+ return {**dataset_literal_inputs, **input_bindings}
+
+ @classmethod
+ def _from_rest_outputs(cls, outputs: Dict[str, Union[str, dict]]) -> Dict:
+ """Load outputs from rest outputs.
+
+ :param outputs: The REST outputs
+ :type outputs: Dict[str, Union[str, dict]]
+ :return: Output dict
+ :rtype: Dict[str, Output]
+ """
+
+ # JObject -> RestJobInput/RestJobOutput
+ output_bindings, rest_outputs = from_dict_to_rest_io(
+ outputs, RestJobOutput, [ComponentJobConstants.OUTPUT_PATTERN]
+ )
+
+ # RestJobInput/RestJobOutput -> Input/Output
+ data_outputs = from_rest_data_outputs(rest_outputs)
+
+ return {**data_outputs, **output_bindings}
+
+ def _update_output_types(self, rest_data_outputs: dict) -> None:
+ """Update output types in rest_data_outputs according to meta level output.
+
+ :param rest_data_outputs: The REST data outputs
+ :type rest_data_outputs: Dict
+ """
+
+ for name, rest_output in rest_data_outputs.items():
+ original_output = self.outputs[name] # type: ignore
+ # for configured output with meta, "correct" the output type to file to avoid the uri_folder default value
+ if original_output and original_output.type:
+ if original_output.type in ["AnyFile", "uri_file"]:
+ rest_output["job_output_type"] = "uri_file"
+
+
+def flatten_dict(
+ dct: Optional[Dict],
+ _type: Union[Type["_GroupAttrDict"], Type[GroupInput]],
+ *,
+ allow_dict_fields: Optional[List[str]] = None,
+) -> Dict:
+ """Flatten inputs/input_definitions dict for inputs dict build.
+
+ :param dct: The dictionary to flatten
+ :type dct: Dict
+ :param _type: Either _GroupAttrDict or GroupInput (both have the method `flatten`)
+ :type _type: Union[Type["_GroupAttrDict"], Type[GroupInput]]
+ :keyword allow_dict_fields: A list of keys for dictionary values that will be included in flattened output
+ :paramtype allow_dict_fields: Optional[List[str]]
+ :return: The flattened dict
+ :rtype: Dict
+ """
+ _result = {}
+ if dct is not None:
+ for key, val in dct.items():
+ # to support passing dict value as parameter group
+ if allow_dict_fields and key in allow_dict_fields and isinstance(val, dict):
+ # for child dict, all values are allowed to be dict
+ for flattened_key, flattened_val in flatten_dict(
+ val, _type, allow_dict_fields=list(val.keys())
+ ).items():
+ _result[key + "." + flattened_key] = flattened_val
+ continue
+ val = GroupInput.custom_class_value_to_attr_dict(val)
+ if isinstance(val, _type):
+ _result.update(val.flatten(group_parameter_name=key))
+ continue
+ _result[key] = val
+ return _result
+
+
+class NodeWithGroupInputMixin(NodeIOMixin):
+ """This class provide build_inputs_dict for a node to use ParameterGroup as an input."""
+
+ @classmethod
+ def _validate_group_input_type(
+ cls,
+ input_definition_dict: dict,
+ inputs: Dict[str, Union[Input, str, bool, int, float]],
+ ) -> None:
+ """Raise error when group input receive a value not group type.
+
+ :param input_definition_dict: The input definition dict
+ :type input_definition_dict: dict
+ :param inputs: The inputs
+ :type inputs: Dict[str, Union[Input, str, bool, int, float]]
+ """
+ # Note: We put and extra validation here instead of doing it in pipeline._validate()
+ # due to group input will be discarded silently if assign it to a non-group parameter.
+ group_msg = "'%s' is defined as a parameter group but got input '%s' with type '%s'."
+ non_group_msg = "'%s' is defined as a parameter but got a parameter group as input."
+ for key, val in inputs.items():
+ definition = input_definition_dict.get(key)
+ val = GroupInput.custom_class_value_to_attr_dict(val)
+ if val is None:
+ continue
+ # 1. inputs.group = 'a string'
+ if isinstance(definition, GroupInput) and not isinstance(val, (_GroupAttrDict, dict)):
+ raise ValidationException(
+ message=group_msg % (key, val, type(val)),
+ no_personal_data_message=group_msg % ("[key]", "[val]", "[type(val)]"),
+ target=ErrorTarget.PIPELINE,
+ type=ValidationErrorType.INVALID_VALUE,
+ )
+ # 2. inputs.str_param = group
+ if not isinstance(definition, GroupInput) and isinstance(val, _GroupAttrDict):
+ raise ValidationException(
+ message=non_group_msg % key,
+ no_personal_data_message=non_group_msg % "[key]",
+ target=ErrorTarget.PIPELINE,
+ type=ValidationErrorType.INVALID_VALUE,
+ )
+
+ @classmethod
+ def _flatten_inputs_and_definition(
+ cls,
+ inputs: Dict[str, Union[Input, str, bool, int, float]],
+ input_definition_dict: dict,
+ ) -> Tuple[Dict, Dict]:
+ """
+ Flatten all GroupInput(definition) and GroupAttrDict recursively and build input dict.
+ For example:
+ input_definition_dict = {
+ "group1": GroupInput(
+ values={
+ "param1": GroupInput(
+ values={
+ "param1_1": Input(type="str"),
+ }
+ ),
+ "param2": Input(type="int"),
+ }
+ ),
+ "group2": GroupInput(
+ values={
+ "param3": Input(type="str"),
+ }
+ ),
+ } => {
+ "group1.param1.param1_1": Input(type="str"),
+ "group1.param2": Input(type="int"),
+ "group2.param3": Input(type="str"),
+ }
+ inputs = {
+ "group1": {
+ "param1": {
+ "param1_1": "value1",
+ },
+ "param2": 2,
+ },
+ "group2": {
+ "param3": "value3",
+ },
+ } => {
+ "group1.param1.param1_1": "value1",
+ "group1.param2": 2,
+ "group2.param3": "value3",
+ }
+ :param inputs: The inputs
+ :type inputs: Dict[str, Union[Input, str, bool, int, float]]
+ :param input_definition_dict: The input definition dict
+ :type input_definition_dict: dict
+ :return: The flattened inputs and definition
+ :rtype: Tuple[Dict, Dict]
+ """
+ group_input_names = [key for key, val in input_definition_dict.items() if isinstance(val, GroupInput)]
+ flattened_inputs = flatten_dict(inputs, _GroupAttrDict, allow_dict_fields=group_input_names)
+ flattened_definition_dict = flatten_dict(input_definition_dict, GroupInput)
+ return flattened_inputs, flattened_definition_dict
+
+ def _build_inputs_dict(
+ self,
+ inputs: Dict[str, Union[Input, str, bool, int, float]],
+ *,
+ input_definition_dict: Optional[dict] = None,
+ ) -> InputsAttrDict:
+ """Build an input attribute dict so user can get/set inputs by
+ accessing attribute, eg: node1.inputs.xxx.
+
+ :param inputs: Provided kwargs when parameterizing component func.
+ :type inputs: Dict[str, Union[Input, str, bool, int, float]]
+ :keyword input_definition_dict: Input definition dict from component entity.
+ :paramtype input_definition_dict: dict
+ :return: Built input attribute dict.
+ :rtype: InputsAttrDict
+ """
+
+ # TODO: should we support group input when there is no local input definition?
+ if input_definition_dict is not None:
+ # Validate group mismatch
+ self._validate_group_input_type(input_definition_dict, inputs)
+
+ # Flatten inputs and definition
+ flattened_inputs, flattened_definition_dict = self._flatten_inputs_and_definition(
+ inputs, input_definition_dict
+ )
+ # Build: zip all flattened parameter with definition
+ inputs = super()._build_inputs_dict(flattened_inputs, input_definition_dict=flattened_definition_dict)
+ return InputsAttrDict(GroupInput.restore_flattened_inputs(inputs))
+ return super()._build_inputs_dict(inputs)
+
+
+class PipelineJobIOMixin(NodeWithGroupInputMixin):
+ """Provides ability to wrap pipeline job inputs/outputs and build data bindings
+ dynamically."""
+
+ def _build_input(self, name: str, meta: Optional[Input], data: Any) -> "PipelineInput":
+ return PipelineInput(name=name, meta=meta, data=data, owner=self)
+
+ def _build_output(
+ self, name: str, meta: Optional[Union[Input, Output]], data: Optional[Union[Output, str]]
+ ) -> "PipelineOutput":
+ # TODO: settings data to None for un-configured outputs so we won't passing extra fields(eg: default mode)
+ result = PipelineOutput(port_name=name, meta=meta, data=data, owner=self)
+ return result
+
+ def _build_inputs_dict(
+ self,
+ inputs: Dict[str, Union[Input, str, bool, int, float]],
+ *,
+ input_definition_dict: Optional[dict] = None,
+ ) -> InputsAttrDict:
+ """Build an input attribute dict so user can get/set inputs by
+ accessing attribute, eg: node1.inputs.xxx.
+
+ :param inputs: Provided kwargs when parameterizing component func.
+ :type inputs: Dict[str, Union[Input, str, bool, int, float]]
+ :keyword input_definition_dict: Input definition dict from component entity.
+ :return: Built input attribute dict.
+ :rtype: InputsAttrDict
+ """
+ input_dict = super()._build_inputs_dict(inputs, input_definition_dict=input_definition_dict)
+ # TODO: should we do this when input_definition_dict is not None?
+ # TODO: should we put this in super()._build_inputs_dict?
+ if input_definition_dict is None:
+ return InputsAttrDict(GroupInput.restore_flattened_inputs(input_dict))
+ return input_dict
+
+ def _build_output_for_pipeline(self, name: str, data: Optional[Union[Output, NodeOutput]]) -> "PipelineOutput":
+ """Build an output object for pipeline and copy settings from source output.
+
+ :param name: Output name.
+ :type name: str
+ :param data: Output data.
+ :type data: Optional[Union[Output, NodeOutput]]
+ :return: Built output object.
+ :rtype: PipelineOutput
+ """
+ # pylint: disable=protected-access
+ if data is None:
+ # For None output, build an empty output builder
+ output_val = self._build_output(name=name, meta=None, data=None)
+ elif isinstance(data, Output):
+ # For output entity, build an output builder with data points to it
+ output_val = self._build_output(name=name, meta=data, data=data)
+ elif isinstance(data, NodeOutput):
+ # For output builder, build a new output builder and copy settings from it
+ output_val = self._build_output(name=name, meta=data._meta, data=None)
+ copy_output_setting(source=data, target=output_val)
+ else:
+ message = "Unsupported output type: {} for pipeline output: {}: {}"
+ raise ValidationException(
+ message=message.format(type(data), name, data),
+ no_personal_data_message=message,
+ target=ErrorTarget.PIPELINE,
+ )
+ return output_val
+
+ def _build_pipeline_outputs_dict(self, outputs: Dict) -> OutputsAttrDict:
+ """Build an output attribute dict without output definition metadata.
+ For pipeline outputs, its setting should be copied from node level outputs.
+
+ :param outputs: Node output dict or pipeline component's outputs.
+ :type outputs: Dict[str, Union[Output, NodeOutput]]
+ :return: Built dynamic output attribute dict.
+ :rtype: OutputsAttrDict
+ """
+ output_dict = {}
+ for key, val in outputs.items():
+ output_dict[key] = self._build_output_for_pipeline(name=key, data=val)
+ return OutputsAttrDict(output_dict)
+
+ def _build_outputs(self) -> Dict[str, Output]:
+ """Build outputs of this pipeline to a dict which maps output to actual
+ value.
+
+ The built dictionary's format aligns with component job's output yaml,
+ un-configured outputs will be None, eg:
+ {"eval_output": "${{jobs.eval.outputs.eval_output}}", "un_configured": None}
+
+ :return: The output dict
+ :rtype: Dict[str, Output]
+ """
+ outputs = {}
+ for name, output in self.outputs.items(): # type: ignore
+ if isinstance(output, NodeOutput):
+ output = output._to_job_output() # pylint: disable=protected-access
+ outputs[name] = output
+ return outputs
+
+ def _get_default_input_val(self, val: Any): # type: ignore
+ # use Default value as data placeholder for unfilled inputs.
+ # client side need to fill the default value for dsl.pipeline
+ if isinstance(val, GroupInput):
+ # Copy default value dict for group
+ return copy.deepcopy(val.default)
+ return val.default
+
+ def _update_output_types(self, rest_data_outputs: Dict) -> None:
+ """Won't clear output type for pipeline level outputs since it's required in rest object.
+
+ :param rest_data_outputs: The REST data outputs
+ :type rest_data_outputs: Dict
+ """
+
+
+class AutoMLNodeIOMixin(NodeIOMixin):
+ """Wrap outputs of automl node and build data bindings dynamically."""
+
+ def __init__(self, **kwargs): # type: ignore
+ # add a inputs field to align with other nodes
+ self.inputs = {}
+ super(AutoMLNodeIOMixin, self).__init__(**kwargs)
+ if getattr(self, "outputs", None):
+ self._outputs = self._build_outputs_dict(self.outputs or {})
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_load_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_load_component.py
new file mode 100644
index 00000000..60c4cbe7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_load_component.py
@@ -0,0 +1,313 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+from typing import Any, Callable, Dict, List, Mapping, Optional, Union, cast
+
+from marshmallow import INCLUDE
+
+from azure.ai.ml import Output
+from azure.ai.ml._schema import NestedField
+from azure.ai.ml._schema.pipeline.component_job import SweepSchema
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, SOURCE_PATH_CONTEXT_KEY, CommonYamlFields
+from azure.ai.ml.constants._component import ControlFlowType, DataTransferTaskType, NodeType
+from azure.ai.ml.constants._compute import ComputeType
+from azure.ai.ml.dsl._component_func import to_component_func
+from azure.ai.ml.dsl._overrides_definition import OverrideDefinition
+from azure.ai.ml.entities._builders import (
+ BaseNode,
+ Command,
+ DataTransferCopy,
+ DataTransferExport,
+ DataTransferImport,
+ Import,
+ Parallel,
+ Spark,
+ Sweep,
+)
+from azure.ai.ml.entities._builders.condition_node import ConditionNode
+from azure.ai.ml.entities._builders.control_flow_node import ControlFlowNode
+from azure.ai.ml.entities._builders.do_while import DoWhile
+from azure.ai.ml.entities._builders.parallel_for import ParallelFor
+from azure.ai.ml.entities._builders.pipeline import Pipeline
+from azure.ai.ml.entities._component.component import Component
+from azure.ai.ml.entities._job.automl.automl_job import AutoMLJob
+from azure.ai.ml.entities._util import get_type_from_spec
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
+
+
+class _PipelineNodeFactory:
+ """A class to create pipeline node instances from yaml dict or rest objects without hard-coded type check."""
+
+ def __init__(self) -> None:
+ self._create_instance_funcs: dict = {}
+ self._load_from_rest_object_funcs: dict = {}
+
+ self.register_type(
+ _type=NodeType.COMMAND,
+ create_instance_func=lambda: Command.__new__(Command),
+ load_from_rest_object_func=Command._from_rest_object,
+ nested_schema=None,
+ )
+ self.register_type(
+ _type=NodeType.IMPORT,
+ create_instance_func=lambda: Import.__new__(Import),
+ load_from_rest_object_func=Import._from_rest_object,
+ nested_schema=None,
+ )
+ self.register_type(
+ _type=NodeType.PARALLEL,
+ create_instance_func=lambda: Parallel.__new__(Parallel),
+ load_from_rest_object_func=Parallel._from_rest_object,
+ nested_schema=None,
+ )
+ self.register_type(
+ _type=NodeType.PIPELINE,
+ create_instance_func=lambda: Pipeline.__new__(Pipeline),
+ load_from_rest_object_func=Pipeline._from_rest_object,
+ nested_schema=None,
+ )
+ self.register_type(
+ _type=NodeType.SWEEP,
+ create_instance_func=lambda: Sweep.__new__(Sweep),
+ load_from_rest_object_func=Sweep._from_rest_object,
+ nested_schema=NestedField(SweepSchema, unknown=INCLUDE),
+ )
+ self.register_type(
+ _type=NodeType.AUTOML,
+ create_instance_func=None,
+ load_from_rest_object_func=self._automl_from_rest_object,
+ nested_schema=None,
+ )
+ self.register_type(
+ _type=NodeType.SPARK,
+ create_instance_func=lambda: Spark.__new__(Spark),
+ load_from_rest_object_func=Spark._from_rest_object,
+ nested_schema=None,
+ )
+ self.register_type(
+ _type=ControlFlowType.DO_WHILE,
+ create_instance_func=None,
+ load_from_rest_object_func=DoWhile._from_rest_object,
+ nested_schema=None,
+ )
+ self.register_type(
+ _type=ControlFlowType.IF_ELSE,
+ create_instance_func=None,
+ load_from_rest_object_func=ConditionNode._from_rest_object,
+ nested_schema=None,
+ )
+ self.register_type(
+ _type=ControlFlowType.PARALLEL_FOR,
+ create_instance_func=None,
+ load_from_rest_object_func=ParallelFor._from_rest_object,
+ nested_schema=None,
+ )
+ self.register_type(
+ _type="_".join([NodeType.DATA_TRANSFER, DataTransferTaskType.COPY_DATA]),
+ create_instance_func=lambda: DataTransferCopy.__new__(DataTransferCopy),
+ load_from_rest_object_func=DataTransferCopy._from_rest_object,
+ nested_schema=None,
+ )
+ self.register_type(
+ _type="_".join([NodeType.DATA_TRANSFER, DataTransferTaskType.IMPORT_DATA]),
+ create_instance_func=lambda: DataTransferImport.__new__(DataTransferImport),
+ load_from_rest_object_func=DataTransferImport._from_rest_object,
+ nested_schema=None,
+ )
+ self.register_type(
+ _type="_".join([NodeType.DATA_TRANSFER, DataTransferTaskType.EXPORT_DATA]),
+ create_instance_func=lambda: DataTransferExport.__new__(DataTransferExport),
+ load_from_rest_object_func=DataTransferExport._from_rest_object,
+ nested_schema=None,
+ )
+ self.register_type(
+ _type=NodeType.FLOW_PARALLEL,
+ create_instance_func=lambda: Parallel.__new__(Parallel),
+ load_from_rest_object_func=None,
+ nested_schema=None,
+ )
+
+ @classmethod
+ def _get_func(cls, _type: str, funcs: Dict[str, Callable]) -> Callable:
+ if _type == NodeType._CONTAINER:
+ msg = (
+ "Component returned by 'list' is abbreviated and can not be used directly, "
+ "please use result from 'get'."
+ )
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.COMPONENT,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ _type = get_type_from_spec({CommonYamlFields.TYPE: _type}, valid_keys=funcs)
+ return funcs[_type]
+
+ def get_create_instance_func(self, _type: str) -> Callable[..., BaseNode]:
+ """Get the function to create a new instance of the node.
+
+ :param _type: The type of the node.
+ :type _type: str
+ :return: The create instance function
+ :rtype: Callable[..., BaseNode]
+ """
+ return self._get_func(_type, self._create_instance_funcs)
+
+ def get_load_from_rest_object_func(self, _type: str) -> Callable:
+ """Get the function to load a node from a rest object.
+
+ :param _type: The type of the node.
+ :type _type: str
+ :return: The `_load_from_rest_object` function
+ :rtype: Callable[[Any], Union[BaseNode, AutoMLJob, ControlFlowNode]]
+ """
+ return self._get_func(_type, self._load_from_rest_object_funcs)
+
+ def register_type(
+ self,
+ _type: str,
+ *,
+ create_instance_func: Optional[Callable[..., Union[BaseNode, AutoMLJob]]] = None,
+ load_from_rest_object_func: Optional[Callable] = None,
+ nested_schema: Optional[Union[NestedField, List[NestedField]]] = None,
+ ) -> None:
+ """Register a type of node.
+
+ :param _type: The type of the node.
+ :type _type: str
+ :keyword create_instance_func: A function to create a new instance of the node
+ :paramtype create_instance_func: typing.Optional[typing.Callable[..., typing.Union[BaseNode, AutoMLJob]]]
+ :keyword load_from_rest_object_func: A function to load a node from a rest object
+ :paramtype load_from_rest_object_func: typing.Optional[typing.Callable[[Any], typing.Union[BaseNode, AutoMLJob\
+ , ControlFlowNode]]]
+ :keyword nested_schema: schema/schemas of corresponding nested field, will be used in \
+ PipelineJobSchema.jobs.value
+ :paramtype nested_schema: typing.Optional[typing.Union[NestedField, List[NestedField]]]
+ """
+ if create_instance_func is not None:
+ self._create_instance_funcs[_type] = create_instance_func
+ if load_from_rest_object_func is not None:
+ self._load_from_rest_object_funcs[_type] = load_from_rest_object_func
+ if nested_schema is not None:
+ from azure.ai.ml._schema.core.fields import TypeSensitiveUnionField
+ from azure.ai.ml._schema.pipeline.pipeline_component import PipelineComponentSchema
+ from azure.ai.ml._schema.pipeline.pipeline_job import PipelineJobSchema
+
+ for declared_fields in [
+ PipelineJobSchema._declared_fields,
+ PipelineComponentSchema._declared_fields,
+ ]:
+ jobs_value_field: TypeSensitiveUnionField = declared_fields["jobs"].value_field
+ if not isinstance(nested_schema, list):
+ nested_schema = [nested_schema]
+ for nested_field in nested_schema:
+ jobs_value_field.insert_type_sensitive_field(type_name=_type, field=nested_field)
+
+ def load_from_dict(self, *, data: dict, _type: Optional[str] = None) -> Union[BaseNode, AutoMLJob]:
+ """Load a node from a dict.
+
+ :keyword data: A dict containing the node's data.
+ :paramtype data: dict
+ :keyword _type: The type of the node. If not specified, it will be inferred from the data.
+ :paramtype _type: str
+ :return: The node
+ :rtype: Union[BaseNode, AutoMLJob]
+ """
+ if _type is None:
+ _type = data[CommonYamlFields.TYPE] if CommonYamlFields.TYPE in data else NodeType.COMMAND
+ # todo: refine Hard code for now to support different task type for DataTransfer node
+ if _type == NodeType.DATA_TRANSFER:
+ _type = "_".join([NodeType.DATA_TRANSFER, data.get("task", " ")])
+ else:
+ data[CommonYamlFields.TYPE] = _type
+
+ new_instance: Union[BaseNode, AutoMLJob] = self.get_create_instance_func(_type)()
+
+ if isinstance(new_instance, BaseNode):
+ # parse component
+ component_key = new_instance._get_component_attr_name()
+ if component_key in data and isinstance(data[component_key], dict):
+ data[component_key] = Component._load(
+ data=data[component_key],
+ yaml_path=data[component_key].pop(SOURCE_PATH_CONTEXT_KEY, None),
+ )
+ # TODO: Bug Item number: 2883415
+ new_instance.__init__(**data) # type: ignore
+ return new_instance
+
+ def load_from_rest_object(
+ self, *, obj: dict, _type: Optional[str] = None, **kwargs: Any
+ ) -> Union[BaseNode, AutoMLJob, ControlFlowNode]:
+ """Load a node from a rest object.
+
+ :keyword obj: A rest object containing the node's data.
+ :paramtype obj: dict
+ :keyword _type: The type of the node. If not specified, it will be inferred from the data.
+ :paramtype _type: str
+ :return: The node
+ :rtype: Union[BaseNode, AutoMLJob, ControlFlowNode]
+ """
+
+ # TODO: Remove in PuP with native import job/component type support in MFE/Designer
+ if "computeId" in obj and obj["computeId"] and obj["computeId"].endswith("/" + ComputeType.ADF):
+ _type = NodeType.IMPORT
+
+ if _type is None:
+ _type = obj[CommonYamlFields.TYPE] if CommonYamlFields.TYPE in obj else NodeType.COMMAND
+ # todo: refine Hard code for now to support different task type for DataTransfer node
+ if _type == NodeType.DATA_TRANSFER:
+ _type = "_".join([NodeType.DATA_TRANSFER, obj.get("task", " ")])
+ else:
+ obj[CommonYamlFields.TYPE] = _type
+
+ res: Union[BaseNode, AutoMLJob, ControlFlowNode] = self.get_load_from_rest_object_func(_type)(obj, **kwargs)
+ return res
+
+ @classmethod
+ def _automl_from_rest_object(cls, node: Dict) -> AutoMLJob:
+ _outputs = cast(Dict[str, Union[str, dict]], node.get("outputs"))
+ # rest dict outputs -> Output objects
+ outputs = AutoMLJob._from_rest_outputs(_outputs)
+ # Output objects -> yaml dict outputs
+ parsed_outputs = {}
+ for key, val in outputs.items():
+ if isinstance(val, Output):
+ val = val._to_dict()
+ parsed_outputs[key] = val
+ node["outputs"] = parsed_outputs
+ return AutoMLJob._load_from_dict(
+ node,
+ context={BASE_PATH_CONTEXT_KEY: "./"},
+ additional_message="Failed to load automl task from backend.",
+ inside_pipeline=True,
+ )
+
+
+def _generate_component_function(
+ component_entity: Component,
+ override_definitions: Optional[Mapping[str, OverrideDefinition]] = None, # pylint: disable=unused-argument
+) -> Callable[..., Union[Command, Parallel]]:
+ # Generate a function which returns a component node.
+ def create_component_func(**kwargs: Any) -> Union[BaseNode, AutoMLJob]:
+ # todo: refine Hard code for now to support different task type for DataTransfer node
+ _type = component_entity.type
+ if _type == NodeType.DATA_TRANSFER:
+ # TODO: Bug Item number: 2883431
+ _type = "_".join([NodeType.DATA_TRANSFER, component_entity.task]) # type: ignore
+ if component_entity.task == DataTransferTaskType.IMPORT_DATA: # type: ignore
+ return pipeline_node_factory.load_from_dict(
+ data={"component": component_entity, "_from_component_func": True, **kwargs},
+ _type=_type,
+ )
+ return pipeline_node_factory.load_from_dict(
+ data={"component": component_entity, "inputs": kwargs, "_from_component_func": True},
+ _type=_type,
+ )
+
+ res: Callable = to_component_func(component_entity, create_component_func)
+ return res
+
+
+pipeline_node_factory = _PipelineNodeFactory()
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_pipeline_expression.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_pipeline_expression.py
new file mode 100644
index 00000000..49bb8a61
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_pipeline_expression.py
@@ -0,0 +1,662 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import re
+import tempfile
+from collections import namedtuple
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
+
+from azure.ai.ml._utils.utils import dump_yaml_to_file, get_all_data_binding_expressions, load_yaml
+from azure.ai.ml.constants._common import AZUREML_PRIVATE_FEATURES_ENV_VAR, DefaultOpenEncoding
+from azure.ai.ml.constants._component import ComponentParameterTypes, IOConstants
+from azure.ai.ml.exceptions import UserErrorException
+
+if TYPE_CHECKING:
+ from azure.ai.ml.entities._builders import BaseNode
+
+ExpressionInput = namedtuple("ExpressionInput", ["name", "type", "value"])
+NONE_PARAMETER_TYPE = "None"
+
+
+class PipelineExpressionOperator:
+ """Support operator in native Python experience."""
+
+ ADD = "+"
+ SUB = "-"
+ MUL = "*"
+ DIV = "/"
+ MOD = "%"
+ POW = "**"
+ FLOORDIV = "//"
+ LT = "<"
+ GT = ">"
+ LTE = "<="
+ GTE = ">="
+ EQ = "=="
+ NE = "!="
+ AND = "&"
+ OR = "|"
+ XOR = "^"
+
+
+_SUPPORTED_OPERATORS = {
+ getattr(PipelineExpressionOperator, attr)
+ for attr in PipelineExpressionOperator.__dict__
+ if not attr.startswith("__")
+}
+
+
+def _enumerate_operation_combination() -> Dict[str, Union[str, Exception]]:
+ """Enumerate the result type of binary operations on types
+
+ Leverages `eval` to validate operation and get its result type.
+
+ :return: A dictionary that maps an operation to either:
+ * A result type
+ * An Exception
+ :rtype: Dict[str, Union[str, Exception]]
+ """
+ res: Dict = {}
+ primitive_types_values = {
+ NONE_PARAMETER_TYPE: repr(None),
+ ComponentParameterTypes.BOOLEAN: repr(True),
+ ComponentParameterTypes.INTEGER: repr(1),
+ ComponentParameterTypes.NUMBER: repr(1.0),
+ ComponentParameterTypes.STRING: repr("1"),
+ }
+ for type1, operand1 in primitive_types_values.items():
+ for type2, operand2 in primitive_types_values.items():
+ for operator in _SUPPORTED_OPERATORS:
+ k = f"{type1} {operator} {type2}"
+ try:
+ eval_result = eval(f"{operand1} {operator} {operand2}") # pylint: disable=eval-used # nosec
+ res[k] = IOConstants.PRIMITIVE_TYPE_2_STR[type(eval_result)]
+ except TypeError:
+ error_message = (
+ f"Operator '{operator}' is not supported between instances of '{type1}' and '{type2}'."
+ )
+ res[k] = UserErrorException(message=error_message, no_personal_data_message=error_message)
+ return res
+
+
+# enumerate and store as a lookup table:
+# key format is "<operand1_type> <operator> <operand2_type>"
+# value can be either result type as str and UserErrorException for invalid operation
+_OPERATION_RESULT_TYPE_LOOKUP = _enumerate_operation_combination()
+
+
+class PipelineExpressionMixin:
+ _SUPPORTED_PRIMITIVE_TYPES = (bool, int, float, str)
+ _SUPPORTED_PIPELINE_INPUT_TYPES = (
+ ComponentParameterTypes.BOOLEAN,
+ ComponentParameterTypes.INTEGER,
+ ComponentParameterTypes.NUMBER,
+ ComponentParameterTypes.STRING,
+ )
+
+ def _validate_binary_operation(self, other: Any, operator: str) -> None:
+ from azure.ai.ml.entities._job.pipeline._io import NodeOutput, PipelineInput
+
+ if (
+ other is not None
+ and not isinstance(other, self._SUPPORTED_PRIMITIVE_TYPES)
+ and not isinstance(other, (PipelineInput, NodeOutput, PipelineExpression))
+ ):
+ error_message = (
+ f"Operator '{operator}' is not supported with {type(other)}; "
+ "currently only support primitive types (None, bool, int, float and str), "
+ "pipeline input, component output and expression."
+ )
+ raise UserErrorException(message=error_message, no_personal_data_message=error_message)
+
+ def __add__(self, other: Any) -> "PipelineExpression":
+ self._validate_binary_operation(other, PipelineExpressionOperator.ADD)
+ return PipelineExpression._from_operation(self, other, PipelineExpressionOperator.ADD)
+
+ def __radd__(self, other: Any) -> "PipelineExpression":
+ self._validate_binary_operation(other, PipelineExpressionOperator.ADD)
+ return PipelineExpression._from_operation(other, self, PipelineExpressionOperator.ADD)
+
+ def __sub__(self, other: Any) -> "PipelineExpression":
+ self._validate_binary_operation(other, PipelineExpressionOperator.SUB)
+ return PipelineExpression._from_operation(self, other, PipelineExpressionOperator.SUB)
+
+ def __rsub__(self, other: Any) -> "PipelineExpression":
+ self._validate_binary_operation(other, PipelineExpressionOperator.SUB)
+ return PipelineExpression._from_operation(other, self, PipelineExpressionOperator.SUB)
+
+ def __mul__(self, other: Any) -> "PipelineExpression":
+ self._validate_binary_operation(other, PipelineExpressionOperator.MUL)
+ return PipelineExpression._from_operation(self, other, PipelineExpressionOperator.MUL)
+
+ def __rmul__(self, other: Any) -> "PipelineExpression":
+ self._validate_binary_operation(other, PipelineExpressionOperator.MUL)
+ return PipelineExpression._from_operation(other, self, PipelineExpressionOperator.MUL)
+
+ def __truediv__(self, other: Any) -> "PipelineExpression":
+ self._validate_binary_operation(other, PipelineExpressionOperator.DIV)
+ return PipelineExpression._from_operation(self, other, PipelineExpressionOperator.DIV)
+
+ def __rtruediv__(self, other: Any) -> "PipelineExpression":
+ self._validate_binary_operation(other, PipelineExpressionOperator.DIV)
+ return PipelineExpression._from_operation(other, self, PipelineExpressionOperator.DIV)
+
+ def __mod__(self, other: Any) -> "PipelineExpression":
+ self._validate_binary_operation(other, PipelineExpressionOperator.MOD)
+ return PipelineExpression._from_operation(self, other, PipelineExpressionOperator.MOD)
+
+ def __rmod__(self, other: Any) -> "PipelineExpression":
+ self._validate_binary_operation(other, PipelineExpressionOperator.MOD)
+ return PipelineExpression._from_operation(other, self, PipelineExpressionOperator.MOD)
+
+ def __pow__(self, other: Any) -> "PipelineExpression":
+ self._validate_binary_operation(other, PipelineExpressionOperator.POW)
+ return PipelineExpression._from_operation(self, other, PipelineExpressionOperator.POW)
+
+ def __rpow__(self, other: Any) -> "PipelineExpression":
+ self._validate_binary_operation(other, PipelineExpressionOperator.POW)
+ return PipelineExpression._from_operation(other, self, PipelineExpressionOperator.POW)
+
+ def __floordiv__(self, other: Any) -> "PipelineExpression":
+ self._validate_binary_operation(other, PipelineExpressionOperator.FLOORDIV)
+ return PipelineExpression._from_operation(self, other, PipelineExpressionOperator.FLOORDIV)
+
+ def __rfloordiv__(self, other: Any) -> "PipelineExpression":
+ self._validate_binary_operation(other, PipelineExpressionOperator.FLOORDIV)
+ return PipelineExpression._from_operation(other, self, PipelineExpressionOperator.FLOORDIV)
+
+ def __lt__(self, other: Any) -> "PipelineExpression":
+ self._validate_binary_operation(other, PipelineExpressionOperator.LT)
+ return PipelineExpression._from_operation(self, other, PipelineExpressionOperator.LT)
+
+ def __gt__(self, other: Any) -> "PipelineExpression":
+ self._validate_binary_operation(other, PipelineExpressionOperator.GT)
+ return PipelineExpression._from_operation(self, other, PipelineExpressionOperator.GT)
+
+ def __le__(self, other: Any) -> "PipelineExpression":
+ self._validate_binary_operation(other, PipelineExpressionOperator.LTE)
+ return PipelineExpression._from_operation(self, other, PipelineExpressionOperator.LTE)
+
+ def __ge__(self, other: Any) -> "PipelineExpression":
+ self._validate_binary_operation(other, PipelineExpressionOperator.GTE)
+ return PipelineExpression._from_operation(self, other, PipelineExpressionOperator.GTE)
+
+ # TODO: Bug Item number: 2883354
+ def __eq__(self, other: Any) -> "PipelineExpression": # type: ignore
+ self._validate_binary_operation(other, PipelineExpressionOperator.EQ)
+ return PipelineExpression._from_operation(self, other, PipelineExpressionOperator.EQ)
+
+ # TODO: Bug Item number: 2883354
+ def __ne__(self, other: Any) -> "PipelineExpression": # type: ignore
+ self._validate_binary_operation(other, PipelineExpressionOperator.NE)
+ return PipelineExpression._from_operation(self, other, PipelineExpressionOperator.NE)
+
+ def __and__(self, other: Any) -> "PipelineExpression":
+ self._validate_binary_operation(other, PipelineExpressionOperator.AND)
+ return PipelineExpression._from_operation(self, other, PipelineExpressionOperator.AND)
+
+ def __or__(self, other: Any) -> "PipelineExpression":
+ self._validate_binary_operation(other, PipelineExpressionOperator.OR)
+ return PipelineExpression._from_operation(self, other, PipelineExpressionOperator.OR)
+
+ def __xor__(self, other: Any) -> "PipelineExpression":
+ self._validate_binary_operation(other, PipelineExpressionOperator.XOR)
+ return PipelineExpression._from_operation(self, None, PipelineExpressionOperator.XOR)
+
+ def __bool__(self) -> bool:
+ """Python method that is used to implement truth value testing and the built-in operation bool().
+
+ This method is not supported as PipelineExpressionMixin is designed to record operation history,
+ while this method can only return False or True, leading to history breaks here.
+ As overloadable boolean operators PEP (refer to: https://www.python.org/dev/peps/pep-0335/)
+ was rejected, logical operations are also not supported.
+
+ :return: True if not inside dsl pipeline func, raises otherwise
+ :rtype: bool
+ """
+ from azure.ai.ml.dsl._pipeline_component_builder import _is_inside_dsl_pipeline_func
+
+ # note: unexpected bool test always be checking if the object is None;
+ # so for non-pipeline scenarios, directly return True to avoid unexpected breaking,
+ # and for pipeline scenarios, will use is not None to replace bool test.
+ if not _is_inside_dsl_pipeline_func():
+ return True
+
+ error_message = f"Type {type(self)} is not supported for operation bool()."
+ raise UserErrorException(message=error_message, no_personal_data_message=error_message)
+
+
+class PipelineExpression(PipelineExpressionMixin):
+ """Pipeline expression entity.
+
+ Use PipelineExpression to support simple and trivial parameter transformation tasks with constants
+ or other parameters. Operations are recorded in this class during executions, and expected result
+ will be generated for corresponding scenario.
+ """
+
+ _PIPELINE_INPUT_PREFIX = ["parent", "inputs"]
+ _PIPELINE_INPUT_PATTERN = re.compile(pattern=r"parent.inputs.(?P<pipeline_input_name>[^.]+)")
+ _PIPELINE_INPUT_NAME_GROUP = "pipeline_input_name"
+ # AML type to Python type, for generated Python code
+ _TO_PYTHON_TYPE = {
+ ComponentParameterTypes.BOOLEAN: bool.__name__,
+ ComponentParameterTypes.INTEGER: int.__name__,
+ ComponentParameterTypes.NUMBER: float.__name__,
+ ComponentParameterTypes.STRING: str.__name__,
+ }
+
+ _INDENTATION = " "
+ _IMPORT_MLDESIGNER_LINE = "from mldesigner import command_component, Output"
+ _DECORATOR_LINE = "@command_component(@@decorator_parameters@@)"
+ _COMPONENT_FUNC_NAME = "expression_func"
+ _COMPONENT_FUNC_DECLARATION_LINE = (
+ f"def {_COMPONENT_FUNC_NAME}(@@component_parameters@@)" " -> Output(type=@@return_type@@):"
+ )
+ _PYTHON_CACHE_FOLDER_NAME = "__pycache__"
+
+ def __init__(self, postfix: List[str], inputs: Dict[str, ExpressionInput]):
+ self._postfix = postfix
+ self._inputs = inputs.copy() # including PiplineInput and Output, extra stored name and type
+ self._result_type: Optional[str] = None
+ self._created_component = None
+
+ @property
+ def expression(self) -> str:
+ """Infix expression string, wrapped with parentheses.
+
+ :return: The infix expression
+ :rtype: str
+ """
+ return self._to_infix()
+
+ def __str__(self) -> str:
+ return self._to_data_binding()
+
+ def _data_binding(self) -> str:
+ return self._to_data_binding()
+
+ def _to_infix(self) -> str:
+ stack = []
+ for token in self._postfix:
+ if token not in _SUPPORTED_OPERATORS:
+ stack.append(token)
+ continue
+ operand2, operand1 = stack.pop(), stack.pop()
+ stack.append(f"({operand1} {token} {operand2})")
+ return stack.pop()
+
+ # pylint: disable=too-many-statements
+ @staticmethod
+ def _handle_operand(
+ operand: "PipelineExpression",
+ postfix: List[str],
+ expression_inputs: Dict[str, ExpressionInput],
+ pipeline_inputs: dict,
+ ) -> Tuple[List[str], Dict[str, ExpressionInput]]:
+ """Handle operand in expression, update postfix expression and expression inputs.
+
+ :param operand: The operand
+ :type operand: "PipelineExpression"
+ :param postfix:
+ :type postfix: List[str]
+ :param expression_inputs: The expression inputs
+ :type expression_inputs: Dict[str, ExpressionInput]
+ :param pipeline_inputs: The pipeline inputs
+ :type pipeline_inputs: dict
+ :return: A 2-tuple of the updated postfix expression and expression inputs
+ :rtype: Tuple[List[str], Dict[str, ExpressionInput]]
+ """
+ from azure.ai.ml.entities._job.pipeline._io import NodeOutput, PipelineInput
+
+ def _update_postfix(_postfix: List[str], _old_name: str, _new_name: str) -> List[str]:
+ return list(map(lambda _x: _new_name if _x == _old_name else _x, _postfix))
+
+ def _get_or_create_input_name(
+ _original_name: str,
+ _operand: Union[PipelineInput, NodeOutput],
+ _expression_inputs: Dict[str, ExpressionInput],
+ ) -> str:
+ """Get or create expression input name as current operand may have appeared in expression.
+
+ :param _original_name: The original name
+ :type _original_name: str
+ :param _operand: The expression operand
+ :type _operand: Union[PipelineInput, NodeOutput]
+ :param _expression_inputs: The expression inputs
+ :type _expression_inputs: Dict[str, ExpressionInput]
+ :return: The input name
+ :rtype: str
+ """
+ _existing_id_to_name = {id(_v.value): _k for _k, _v in _expression_inputs.items()}
+ if id(_operand) in _existing_id_to_name:
+ return _existing_id_to_name[id(_operand)]
+ # use a counter to generate a unique name for current operand
+ _name, _counter = _original_name, 0
+ while _name in _expression_inputs:
+ _name = f"{_original_name}_{_counter}"
+ _counter += 1
+ return _name
+
+ def _handle_pipeline_input(
+ _pipeline_input: PipelineInput,
+ _postfix: List[str],
+ _expression_inputs: Dict[str, ExpressionInput],
+ ) -> Tuple[List[str], dict]:
+ _name = _pipeline_input._port_name
+ # 1. use name with counter for pipeline input; 2. add component's name to component output
+ if _name in _expression_inputs:
+ _seen_input = _expression_inputs[_name]
+ if isinstance(_seen_input.value, PipelineInput):
+ _name = _get_or_create_input_name(_name, _pipeline_input, _expression_inputs)
+ else:
+ _expression_inputs.pop(_name)
+ _new_name = f"{_seen_input.value._owner.component.name}__{_seen_input.value._port_name}"
+ _postfix = _update_postfix(_postfix, _name, _new_name)
+ _expression_inputs[_new_name] = ExpressionInput(_new_name, _seen_input.type, _seen_input)
+ _postfix.append(_name)
+
+ param_input = pipeline_inputs
+ for group_name in _pipeline_input._group_names:
+ param_input = param_input[group_name].values
+ _expression_inputs[_name] = ExpressionInput(
+ _name, param_input[_pipeline_input._port_name].type, _pipeline_input
+ )
+ return _postfix, _expression_inputs
+
+ def _handle_component_output(
+ _component_output: NodeOutput,
+ _postfix: List[str],
+ _expression_inputs: Dict[str, ExpressionInput],
+ ) -> Tuple[List[str], dict]:
+ if _component_output._meta is not None and not _component_output._meta._is_primitive_type:
+ error_message = (
+ f"Component output {_component_output._port_name} in expression must "
+ f"be a primitive type with value {True!r}, "
+ f"got {_component_output._meta._is_primitive_type!r}"
+ )
+ raise UserErrorException(message=error_message, no_personal_data_message=error_message)
+ _name = _component_output._port_name
+ _has_prefix = False
+ # "output" is the default output name for command component, add component's name as prefix
+ if _name == "output":
+ if _component_output._owner is not None and not isinstance(_component_output._owner.component, str):
+ _name = f"{_component_output._owner.component.name}__output"
+ _has_prefix = True
+ # following loop is expected to execute at most twice:
+ # 1. add component's name to output(s)
+ # 2. use name with counter
+ while _name in _expression_inputs:
+ _seen_input = _expression_inputs[_name]
+ if isinstance(_seen_input.value, PipelineInput):
+ if not _has_prefix:
+ if _component_output._owner is not None and not isinstance(
+ _component_output._owner.component, str
+ ):
+ _name = f"{_component_output._owner.component.name}__{_component_output._port_name}"
+ _has_prefix = True
+ continue
+ _name = _get_or_create_input_name(_name, _component_output, _expression_inputs)
+ else:
+ if not _has_prefix:
+ _expression_inputs.pop(_name)
+ _new_name = f"{_seen_input.value._owner.component.name}__{_seen_input.value._port_name}"
+ _postfix = _update_postfix(_postfix, _name, _new_name)
+ _expression_inputs[_new_name] = ExpressionInput(_new_name, _seen_input.type, _seen_input)
+ if _component_output._owner is not None and not isinstance(
+ _component_output._owner.component, str
+ ):
+ _name = f"{_component_output._owner.component.name}__{_component_output._port_name}"
+ _has_prefix = True
+ _name = _get_or_create_input_name(_name, _component_output, _expression_inputs)
+ _postfix.append(_name)
+ _expression_inputs[_name] = ExpressionInput(_name, _component_output.type, _component_output)
+ return _postfix, _expression_inputs
+
+ if operand is None or isinstance(operand, PipelineExpression._SUPPORTED_PRIMITIVE_TYPES):
+ postfix.append(repr(operand))
+ elif isinstance(operand, PipelineInput):
+ postfix, expression_inputs = _handle_pipeline_input(operand, postfix, expression_inputs)
+ elif isinstance(operand, NodeOutput):
+ postfix, expression_inputs = _handle_component_output(operand, postfix, expression_inputs)
+ elif isinstance(operand, PipelineExpression):
+ postfix.extend(operand._postfix.copy())
+ expression_inputs.update(operand._inputs.copy())
+ return postfix, expression_inputs
+
+ @staticmethod
+ def _from_operation(operand1: Any, operand2: Any, operator: str) -> "PipelineExpression":
+ if operator not in _SUPPORTED_OPERATORS:
+ error_message = (
+ f"Operator '{operator}' is not supported operator, "
+ f"currently supported operators are {','.join(_SUPPORTED_OPERATORS)}."
+ )
+ raise UserErrorException(message=error_message, no_personal_data_message=error_message)
+
+ # get all pipeline input types from builder stack
+ # TODO: check if there is pipeline input we cannot know its type (missing in `PipelineComponentBuilder.inputs`)?
+ from azure.ai.ml.dsl._pipeline_component_builder import _definition_builder_stack
+
+ res = _definition_builder_stack.top()
+ pipeline_inputs = res.inputs if res is not None else {}
+ postfix: List[str] = []
+ inputs: Dict[str, ExpressionInput] = {}
+ postfix, inputs = PipelineExpression._handle_operand(operand1, postfix, inputs, pipeline_inputs)
+ postfix, inputs = PipelineExpression._handle_operand(operand2, postfix, inputs, pipeline_inputs)
+ postfix.append(operator)
+ return PipelineExpression(postfix, inputs)
+
+ @property
+ def _string_concatenation(self) -> bool:
+ """If all operands are string and operations are addition, it is a string concatenation expression.
+
+ :return: Whether this represents string concatenation
+ :rtype: bool
+ """
+ for token in self._postfix:
+ # operator can only be "+" for string concatenation
+ if token in _SUPPORTED_OPERATORS:
+ if token != PipelineExpressionOperator.ADD:
+ return False
+ continue
+ # constant and PiplineInput should be type string
+ if token in self._inputs:
+ if self._inputs[token].type != ComponentParameterTypes.STRING:
+ return False
+ else:
+ if not isinstance(eval(token), str): # pylint: disable=eval-used # nosec
+ return False
+ return True
+
+ def _to_data_binding(self) -> str:
+ """Convert operands to data binding and concatenate them in the order of postfix expression.
+
+ :return: The data binding
+ :rtype: str
+ """
+ if not self._string_concatenation:
+ error_message = (
+ "Only string concatenation expression is supported to convert to data binding, "
+ f"current expression is '{self.expression}'."
+ )
+ raise UserErrorException(message=error_message, no_personal_data_message=error_message)
+
+ stack = []
+ for token in self._postfix:
+ if token != PipelineExpressionOperator.ADD:
+ if token in self._inputs:
+ stack.append(self._inputs[token].value._data_binding())
+ else:
+ stack.append(eval(token)) # pylint: disable=eval-used # nosec
+ continue
+ operand2, operand1 = stack.pop(), stack.pop()
+ stack.append(operand1 + operand2)
+ res: str = stack.pop()
+ return res
+
+ def resolve(self) -> Union[str, "BaseNode"]:
+ """Resolve expression to data binding or component, depend on the operations.
+
+ :return: The data binding string or the component
+ :rtype: Union[str, BaseNode]
+ """
+ if self._string_concatenation:
+ return self._to_data_binding()
+ return cast(Union[str, "BaseNode"], self._create_component())
+
+ @staticmethod
+ def parse_pipeline_inputs_from_data_binding(data_binding: str) -> List[str]:
+ """Parse all PipelineInputs name from data binding expression.
+
+ :param data_binding: Data binding expression
+ :type data_binding: str
+ :return: List of PipelineInput's name from given data binding expression
+ :rtype: List[str]
+ """
+ pipeline_input_names = []
+ for single_data_binding in get_all_data_binding_expressions(
+ value=data_binding,
+ binding_prefix=PipelineExpression._PIPELINE_INPUT_PREFIX,
+ is_singular=False,
+ ):
+ m = PipelineExpression._PIPELINE_INPUT_PATTERN.match(single_data_binding)
+ # `get_all_data_binding_expressions` should work as pre-filter, so no need to concern `m` is None
+ if m is not None:
+ pipeline_input_names.append(m.group(PipelineExpression._PIPELINE_INPUT_NAME_GROUP))
+ return pipeline_input_names
+
+ @staticmethod
+ def _get_operation_result_type(type1: str, operator: str, type2: str) -> str:
+ def _validate_operand_type(_type: str) -> None:
+ if _type != NONE_PARAMETER_TYPE and _type not in PipelineExpression._SUPPORTED_PIPELINE_INPUT_TYPES:
+ error_message = (
+ f"Pipeline input type {_type!r} is not supported in expression; "
+ f"currently only support None, "
+ + ", ".join(PipelineExpression._SUPPORTED_PIPELINE_INPUT_TYPES)
+ + "."
+ )
+ raise UserErrorException(message=error_message, no_personal_data_message=error_message)
+
+ _validate_operand_type(type1)
+ _validate_operand_type(type2)
+ operation = f"{type1} {operator} {type2}"
+ lookup_value = _OPERATION_RESULT_TYPE_LOOKUP.get(operation)
+ if isinstance(lookup_value, str):
+ return lookup_value # valid operation, return result type
+ _user_exception: UserErrorException = lookup_value
+ raise _user_exception # invalid operation, raise UserErrorException
+
+ def _get_operand_type(self, operand: str) -> str:
+ if operand in self._inputs:
+ res: str = self._inputs[operand].type
+ return res
+ primitive_type = type(eval(operand)) # pylint: disable=eval-used # nosec
+ res_type: str = IOConstants.PRIMITIVE_TYPE_2_STR.get(primitive_type, NONE_PARAMETER_TYPE)
+ return res_type
+
+ @property
+ def _component_code(self) -> str:
+ def _generate_function_code_lines() -> Tuple[List[str], str]:
+ """Return lines of code and return type.
+
+ :return: A 2-tuple of (function body, return type name)
+ :rtype: Tuple[List[str], str]
+ """
+ _inter_id, _code, _stack = 0, [], []
+ _line_recorder: Dict = {}
+ for _token in self._postfix:
+ if _token not in _SUPPORTED_OPERATORS:
+ _type = self._get_operand_type(_token)
+ _stack.append((_token, _type))
+ continue
+ _operand2, _type2 = _stack.pop()
+ _operand1, _type1 = _stack.pop()
+ _current_line = f"{_operand1} {_token} {_operand2}"
+ if _current_line in _line_recorder:
+ _inter_var, _inter_var_type = _line_recorder[_current_line]
+ else:
+ _inter_var = f"inter_var_{_inter_id}"
+ _inter_id += 1
+ _inter_var_type = self._get_operation_result_type(_type1, _token, _type2)
+ _code.append(f"{self._INDENTATION}{_inter_var} = {_current_line}")
+ _line_recorder[_current_line] = (_inter_var, _inter_var_type)
+ _stack.append((_inter_var, _inter_var_type))
+ _return_var, _result_type = _stack.pop()
+ _code.append(f"{self._INDENTATION}return {_return_var}")
+ return _code, _result_type
+
+ def _generate_function_decorator_and_declaration_lines(_return_type: str) -> List[str]:
+ # decorator parameters
+ _display_name = f'{self._INDENTATION}display_name="Expression: {self.expression}",'
+ _decorator_parameters = "\n" + "\n".join([_display_name]) + "\n"
+ # component parameters
+ _component_parameters = []
+ for _name in sorted(self._inputs):
+ _type = self._TO_PYTHON_TYPE[self._inputs[_name].type]
+ _component_parameters.append(f"{_name}: {_type}")
+ _component_parameters_str = (
+ "\n"
+ + "\n".join(
+ [f"{self._INDENTATION}{_component_parameter}," for _component_parameter in _component_parameters]
+ )
+ + "\n"
+ )
+ return [
+ self._IMPORT_MLDESIGNER_LINE + "\n\n",
+ self._DECORATOR_LINE.replace("@@decorator_parameters@@", _decorator_parameters),
+ self._COMPONENT_FUNC_DECLARATION_LINE.replace(
+ "@@component_parameters@@", _component_parameters_str
+ ).replace("@@return_type@@", f'"{_return_type}"'),
+ ]
+
+ lines, result_type = _generate_function_code_lines()
+ self._result_type = result_type
+ code = _generate_function_decorator_and_declaration_lines(result_type) + lines
+ return "\n".join(code) + "\n"
+
+ def _create_component(self) -> Any:
+ def _generate_python_file(_folder: Path) -> None:
+ _folder.mkdir()
+ with open(_folder / "expression_component.py", "w", encoding=DefaultOpenEncoding.WRITE) as _f:
+ _f.write(self._component_code)
+
+ def _generate_yaml_file(_path: Path) -> None:
+ _data_folder = Path(__file__).parent / "data"
+ # update YAML content from template and dump
+ with open(_data_folder / "expression_component_template.yml", "r", encoding=DefaultOpenEncoding.READ) as _f:
+ _data = load_yaml(_f)
+ _data["display_name"] = f"Expression: {self.expression}"
+ _data["inputs"] = {}
+ _data["outputs"]["output"]["type"] = self._result_type
+ _command_inputs_items = []
+ for _name in sorted(self._inputs):
+ _type = self._inputs[_name].type
+ _data["inputs"][_name] = {"type": _type}
+ _command_inputs_items.append(_name + '="${{inputs.' + _name + '}}"')
+ _command_inputs_string = " ".join(_command_inputs_items)
+ _command_output_string = 'output="${{outputs.output}}"'
+ _command = (
+ "mldesigner execute --source expression_component.py --name expression_func"
+ " --inputs " + _command_inputs_string + " --outputs " + _command_output_string
+ )
+ _data["command"] = _data["command"].format(command_placeholder=_command)
+ dump_yaml_to_file(_path, _data)
+
+ if self._created_component is None:
+ tmp_folder = Path(tempfile.mkdtemp())
+ code_folder = tmp_folder / "src"
+ yaml_path = tmp_folder / "component_spec.yml"
+ _generate_python_file(code_folder)
+ _generate_yaml_file(yaml_path)
+
+ from azure.ai.ml import load_component
+
+ component_func = load_component(yaml_path)
+ component_kwargs = {k: v.value for k, v in self._inputs.items()}
+ self._created_component = component_func(**component_kwargs)
+ if self._created_component is not None:
+ self._created_component.environment_variables = {AZUREML_PRIVATE_FEATURES_ENV_VAR: "true"}
+ return self._created_component
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_pipeline_job_helpers.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_pipeline_job_helpers.py
new file mode 100644
index 00000000..3a7d89e7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/_pipeline_job_helpers.py
@@ -0,0 +1,182 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+import re
+from typing import Dict, List, Tuple, Type, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import InputDeliveryMode
+from azure.ai.ml._restclient.v2023_04_01_preview.models import JobInput as RestJobInput
+from azure.ai.ml._restclient.v2023_04_01_preview.models import JobOutput as RestJobOutput
+from azure.ai.ml._restclient.v2023_04_01_preview.models import Mpi, PyTorch, Ray, TensorFlow
+from azure.ai.ml.constants._component import ComponentJobConstants
+from azure.ai.ml.entities._inputs_outputs import Input, Output
+from azure.ai.ml.entities._job._input_output_helpers import (
+ INPUT_MOUNT_MAPPING_FROM_REST,
+ INPUT_MOUNT_MAPPING_TO_REST,
+ OUTPUT_MOUNT_MAPPING_FROM_REST,
+ OUTPUT_MOUNT_MAPPING_TO_REST,
+)
+from azure.ai.ml.entities._util import normalize_job_input_output_type
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
+
+
+def process_sdk_component_job_io(
+ io: Dict,
+ io_binding_regex_list: List[str],
+) -> Tuple:
+ """Separates SDK ComponentJob inputs that are data bindings (i.e. string inputs prefixed with 'inputs.' or
+ 'outputs.') and dataset and literal inputs/outputs.
+
+ :param io: Input or output dictionary of an SDK ComponentJob
+ :type io: Dict[str, Union[str, float, bool, Input]]
+ :param io_binding_regex_list: A list of regexes for io bindings
+ :type io_binding_regex_list: List[str]
+ :return: A tuple of dictionaries:
+ * One mapping inputs to REST formatted ComponentJobInput/ComponentJobOutput for data binding io.
+ * The other dictionary contains any IO that is not a databinding that is yet to be turned into REST form
+ :rtype: Tuple[Dict[str, str], Dict[str, Union[str, float, bool, Input]]]
+ """
+ io_bindings: Dict = {}
+ dataset_literal_io: Dict = {}
+ legacy_io_binding_regex_list = [
+ ComponentJobConstants.LEGACY_INPUT_PATTERN,
+ ComponentJobConstants.LEGACY_OUTPUT_PATTERN,
+ ]
+ for io_name, io_value in io.items():
+ if isinstance(io_value, (Input, Output)) and isinstance(io_value.path, str):
+ mode = io_value.mode
+ path = io_value.path
+ name = io_value.name if hasattr(io_value, "name") else None
+ version = io_value.version if hasattr(io_value, "version") else None
+ if any(re.match(item, path) for item in io_binding_regex_list):
+ # Yaml syntax requires using ${{}} to enclose inputs and outputs bindings
+ # io_bindings[io_name] = io_value
+ io_bindings.update({io_name: {"value": path}})
+ # add mode to literal value for binding input
+ if mode:
+ if isinstance(io_value, Input):
+ io_bindings[io_name].update({"mode": INPUT_MOUNT_MAPPING_TO_REST[mode]})
+ else:
+ io_bindings[io_name].update({"mode": OUTPUT_MOUNT_MAPPING_TO_REST[mode]})
+ if name or version:
+ assert isinstance(io_value, Output)
+ if name:
+ io_bindings[io_name].update({"name": name})
+ if version:
+ io_bindings[io_name].update({"version": version})
+ if isinstance(io_value, Output) and io_value.name is not None:
+ # when the output should be registered,
+ # we add io_value to dataset_literal_io for further to_rest_data_outputs
+ dataset_literal_io[io_name] = io_value
+ elif any(re.match(item, path) for item in legacy_io_binding_regex_list):
+ new_format = path.replace("{{", "{{parent.")
+ msg = "{} has changed to {}, please change to use new format."
+ raise ValidationException(
+ message=msg.format(path, new_format),
+ no_personal_data_message=msg.format("[io_value]", "[io_value_new_format]"),
+ target=ErrorTarget.PIPELINE,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ else:
+ dataset_literal_io[io_name] = io_value
+ else:
+ # Collect non-input data inputs
+ dataset_literal_io[io_name] = io_value
+ return io_bindings, dataset_literal_io
+
+
+def from_dict_to_rest_io(
+ io: Dict[str, Union[str, dict]],
+ rest_object_class: Union[Type[RestJobInput], Type[RestJobOutput]],
+ io_binding_regex_list: List[str],
+) -> Tuple[Dict[str, str], Dict[str, Union[RestJobInput, RestJobOutput]]]:
+ """Translate rest JObject dictionary to rest inputs/outputs and bindings.
+
+ :param io: Input or output dictionary.
+ :type io: Dict[str, Union[str, dict]]
+ :param rest_object_class: RestJobInput or RestJobOutput
+ :type rest_object_class: Union[Type[RestJobInput], Type[RestJobOutput]]
+ :param io_binding_regex_list: A list of regexes for io bindings
+ :type io_binding_regex_list: List[str]
+ :return: Map from IO name to IO bindings and Map from IO name to IO objects.
+ :rtype: Tuple[Dict[str, str], Dict[str, Union[RestJobInput, RestJobOutput]]]
+ """
+ io_bindings: dict = {}
+ rest_io_objects = {}
+ DIRTY_MODE_MAPPING = {
+ "Mount": InputDeliveryMode.READ_ONLY_MOUNT,
+ "RoMount": InputDeliveryMode.READ_ONLY_MOUNT,
+ "RwMount": InputDeliveryMode.READ_WRITE_MOUNT,
+ }
+ for key, val in io.items():
+ if isinstance(val, dict):
+ # convert the input of camel to snake to be compatible with the Jun api
+ # todo: backend help convert node level input/output type
+ normalize_job_input_output_type(val)
+
+ # Add casting as sometimes we got value like 1(int)
+ io_value = str(val.get("value", ""))
+ io_mode = val.get("mode", None)
+ io_name = val.get("name", None)
+ io_version = val.get("version", None)
+ if any(re.match(item, io_value) for item in io_binding_regex_list):
+ io_bindings.update({key: {"path": io_value}})
+ # add mode to literal value for binding input
+ if io_mode:
+ # deal with dirty mode data submitted before
+ if io_mode in DIRTY_MODE_MAPPING:
+ io_mode = DIRTY_MODE_MAPPING[io_mode]
+ val["mode"] = io_mode
+ if io_mode in OUTPUT_MOUNT_MAPPING_FROM_REST:
+ io_bindings[key].update({"mode": OUTPUT_MOUNT_MAPPING_FROM_REST[io_mode]})
+ else:
+ io_bindings[key].update({"mode": INPUT_MOUNT_MAPPING_FROM_REST[io_mode]})
+ # add name and version for binding input
+ if io_name or io_version:
+ assert rest_object_class.__name__ == "JobOutput"
+ # current code only support dump name and version for JobOutput
+ # this assert can be deleted if we need to dump name/version for JobInput
+ if io_name:
+ io_bindings[key].update({"name": io_name})
+ if io_version:
+ io_bindings[key].update({"version": io_version})
+ if not io_mode and not io_name and not io_version:
+ io_bindings[key] = io_value
+ else:
+ if rest_object_class.__name__ == "JobOutput":
+ # current code only support dump name and version for JobOutput
+ # this condition can be deleted if we need to dump name/version for JobInput
+ if "name" in val.keys():
+ val["asset_name"] = val.pop("name")
+ if "version" in val.keys():
+ val["asset_version"] = val.pop("version")
+ rest_obj = rest_object_class.from_dict(val)
+ rest_io_objects[key] = rest_obj
+ else:
+ msg = "Got unsupported type of input/output: {}:" + f"{type(val)}"
+ raise ValidationException(
+ message=msg.format(val),
+ no_personal_data_message=msg.format("[val]"),
+ target=ErrorTarget.PIPELINE,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ return io_bindings, rest_io_objects
+
+
+def from_dict_to_rest_distribution(distribution_dict: Dict) -> Union[PyTorch, Mpi, TensorFlow, Ray]:
+ target_type = distribution_dict["distribution_type"].lower()
+ if target_type == "pytorch":
+ return PyTorch(**distribution_dict)
+ if target_type == "mpi":
+ return Mpi(**distribution_dict)
+ if target_type == "tensorflow":
+ return TensorFlow(**distribution_dict)
+ if target_type == "ray":
+ return Ray(**distribution_dict)
+ msg = "Distribution type must be pytorch, mpi, tensorflow or ray: {}".format(target_type)
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.PIPELINE,
+ error_category=ErrorCategory.USER_ERROR,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/data/expression_component_template.yml b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/data/expression_component_template.yml
new file mode 100644
index 00000000..10d391aa
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/data/expression_component_template.yml
@@ -0,0 +1,16 @@
+$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json
+type: command
+
+name: expression_component
+version: 1
+
+outputs:
+ output:
+ is_control: true
+
+code: ./src
+
+environment: azureml://registries/azureml/environments/mldesigner/labels/latest
+
+command: >-
+ {command_placeholder}
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/pipeline_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/pipeline_job.py
new file mode 100644
index 00000000..7ddbbc46
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/pipeline_job.py
@@ -0,0 +1,711 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+import itertools
+import logging
+import typing
+from functools import partial
+from pathlib import Path
+from typing import Any, Dict, Generator, List, Optional, Union, cast
+
+from typing_extensions import Literal
+
+from azure.ai.ml._restclient.v2024_01_01_preview.models import JobBase
+from azure.ai.ml._restclient.v2024_01_01_preview.models import PipelineJob as RestPipelineJob
+from azure.ai.ml._schema import PathAwareSchema
+from azure.ai.ml._schema.pipeline.pipeline_job import PipelineJobSchema
+from azure.ai.ml._utils._arm_id_utils import get_resource_name_from_arm_id_safe
+from azure.ai.ml._utils.utils import (
+ camel_to_snake,
+ is_data_binding_expression,
+ is_private_preview_enabled,
+ transform_dict_keys,
+)
+from azure.ai.ml.constants import JobType
+from azure.ai.ml.constants._common import AZUREML_PRIVATE_FEATURES_ENV_VAR, BASE_PATH_CONTEXT_KEY
+from azure.ai.ml.constants._component import ComponentSource
+from azure.ai.ml.constants._job.pipeline import ValidationErrorCode
+from azure.ai.ml.entities._builders import BaseNode
+from azure.ai.ml.entities._builders.condition_node import ConditionNode
+from azure.ai.ml.entities._builders.control_flow_node import LoopNode
+from azure.ai.ml.entities._builders.import_node import Import
+from azure.ai.ml.entities._builders.parallel import Parallel
+from azure.ai.ml.entities._builders.pipeline import Pipeline
+from azure.ai.ml.entities._component.component import Component
+from azure.ai.ml.entities._component.pipeline_component import PipelineComponent
+
+# from azure.ai.ml.entities._job.identity import AmlToken, Identity, ManagedIdentity, UserIdentity
+from azure.ai.ml.entities._credentials import (
+ AmlTokenConfiguration,
+ ManagedIdentityConfiguration,
+ UserIdentityConfiguration,
+ _BaseJobIdentityConfiguration,
+)
+from azure.ai.ml.entities._inputs_outputs import Input, Output
+from azure.ai.ml.entities._inputs_outputs.group_input import GroupInput
+from azure.ai.ml.entities._job._input_output_helpers import (
+ from_rest_data_outputs,
+ from_rest_inputs_to_dataset_literal,
+ to_rest_data_outputs,
+ to_rest_dataset_literal_inputs,
+)
+from azure.ai.ml.entities._job.import_job import ImportJob
+from azure.ai.ml.entities._job.job import Job
+from azure.ai.ml.entities._job.job_service import JobServiceBase
+from azure.ai.ml.entities._job.pipeline._io import PipelineInput, PipelineJobIOMixin
+from azure.ai.ml.entities._job.pipeline.pipeline_job_settings import PipelineJobSettings
+from azure.ai.ml.entities._mixins import YamlTranslatableMixin
+from azure.ai.ml.entities._system_data import SystemData
+from azure.ai.ml.entities._validation import MutableValidationResult, PathAwareSchemaValidatableMixin
+from azure.ai.ml.exceptions import ErrorTarget, UserErrorException, ValidationException
+
+module_logger = logging.getLogger(__name__)
+
+
+class PipelineJob(Job, YamlTranslatableMixin, PipelineJobIOMixin, PathAwareSchemaValidatableMixin):
+ """Pipeline job.
+
+ You should not instantiate this class directly. Instead, you should
+ use the `@pipeline` decorator to create a `PipelineJob`.
+
+ :param component: Pipeline component version. The field is mutually exclusive with 'jobs'.
+ :type component: Union[str, ~azure.ai.ml.entities._component.pipeline_component.PipelineComponent]
+ :param inputs: Inputs to the pipeline job.
+ :type inputs: dict[str, Union[~azure.ai.ml.entities.Input, str, bool, int, float]]
+ :param outputs: Outputs of the pipeline job.
+ :type outputs: dict[str, ~azure.ai.ml.entities.Output]
+ :param name: Name of the PipelineJob. Defaults to None.
+ :type name: str
+ :param description: Description of the pipeline job. Defaults to None
+ :type description: str
+ :param display_name: Display name of the pipeline job. Defaults to None
+ :type display_name: str
+ :param experiment_name: Name of the experiment the job will be created under.
+ If None is provided, the experiment will be set to the current directory. Defaults to None
+ :type experiment_name: str
+ :param jobs: Pipeline component node name to component object. Defaults to None
+ :type jobs: dict[str, ~azure.ai.ml.entities._builders.BaseNode]
+ :param settings: Setting of the pipeline job. Defaults to None
+ :type settings: ~azure.ai.ml.entities.PipelineJobSettings
+ :param identity: Identity that the training job will use while running on compute. Defaults to None
+ :type identity: Union[
+ ~azure.ai.ml.entities._credentials.ManagedIdentityConfiguration,
+ ~azure.ai.ml.entities._credentials.AmlTokenConfiguration,
+ ~azure.ai.ml.entities._credentials.UserIdentityConfiguration
+
+ ]
+ :param compute: Compute target name of the built pipeline. Defaults to None
+ :type compute: str
+ :param tags: Tag dictionary. Tags can be added, removed, and updated. Defaults to None
+ :type tags: dict[str, str]
+ :param kwargs: A dictionary of additional configuration parameters. Defaults to None
+ :type kwargs: dict
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_pipeline_job_configurations.py
+ :start-after: [START configure_pipeline_job_and_settings]
+ :end-before: [END configure_pipeline_job_and_settings]
+ :language: python
+ :dedent: 8
+ :caption: Shows how to create a pipeline using this class.
+ """
+
+ def __init__(
+ self,
+ *,
+ component: Optional[Union[str, PipelineComponent, Component]] = None,
+ inputs: Optional[Dict[str, Union[Input, str, bool, int, float]]] = None,
+ outputs: Optional[Dict[str, Output]] = None,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ display_name: Optional[str] = None,
+ experiment_name: Optional[str] = None,
+ jobs: Optional[Dict[str, BaseNode]] = None,
+ settings: Optional[PipelineJobSettings] = None,
+ identity: Optional[
+ Union[ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration]
+ ] = None,
+ compute: Optional[str] = None,
+ tags: Optional[Dict[str, str]] = None,
+ **kwargs: Any,
+ ) -> None:
+ # initialize io
+ inputs, outputs = inputs or {}, outputs or {}
+ if isinstance(component, PipelineComponent) and component._source in [
+ ComponentSource.DSL,
+ ComponentSource.YAML_COMPONENT,
+ ]:
+ self._inputs = self._build_inputs_dict(inputs, input_definition_dict=component.inputs)
+ # for pipeline component created pipeline jobs,
+ # it's output should have same value with the component outputs,
+ # then override it with given outputs (filter out None value)
+ pipeline_outputs = {k: v for k, v in (outputs or {}).items() if v}
+ self._outputs = self._build_pipeline_outputs_dict({**component.outputs, **pipeline_outputs})
+ else:
+ # Build inputs/outputs dict without meta when definition not available
+ self._inputs = self._build_inputs_dict(inputs)
+ # for node created pipeline jobs,
+ # it's output should have same value with the given outputs
+ self._outputs = self._build_pipeline_outputs_dict(outputs=outputs)
+ source = kwargs.pop("_source", ComponentSource.CLASS)
+ if component is None:
+ component = PipelineComponent(
+ jobs=jobs,
+ description=description,
+ display_name=display_name,
+ base_path=kwargs.get(BASE_PATH_CONTEXT_KEY),
+ _source=source,
+ )
+
+ # If component is Pipeline component, jobs will be component.jobs
+ self._jobs = (jobs or {}) if isinstance(component, str) else {}
+
+ self.component: Union[PipelineComponent, str] = cast(Union[PipelineComponent, str], component)
+ if "type" not in kwargs:
+ kwargs["type"] = JobType.PIPELINE
+ if isinstance(component, PipelineComponent):
+ description = component.description if description is None else description
+ display_name = component.display_name if display_name is None else display_name
+ super(PipelineJob, self).__init__(
+ name=name,
+ description=description,
+ tags=tags,
+ display_name=display_name,
+ experiment_name=experiment_name,
+ compute=compute,
+ **kwargs,
+ )
+
+ self._remove_pipeline_input()
+ self.compute = compute
+ self._settings: Any = None
+ self.settings = settings
+ self.identity = identity
+ # TODO: remove default code & environment?
+ self._default_code = None
+ self._default_environment = None
+
+ @property
+ def inputs(self) -> Dict:
+ """Inputs of the pipeline job.
+
+ :return: Inputs of the pipeline job.
+ :rtype: dict[str, Union[~azure.ai.ml.entities.Input, str, bool, int, float]]
+ """
+ return self._inputs
+
+ @property
+ def outputs(self) -> Dict[str, Union[str, Output]]:
+ """Outputs of the pipeline job.
+
+ :return: Outputs of the pipeline job.
+ :rtype: dict[str, Union[str, ~azure.ai.ml.entities.Output]]
+ """
+ return self._outputs
+
+ @property
+ def jobs(self) -> Dict:
+ """Return jobs of pipeline job.
+
+ :return: Jobs of pipeline job.
+ :rtype: dict
+ """
+ res: dict = self.component.jobs if isinstance(self.component, PipelineComponent) else self._jobs
+ return res
+
+ @property
+ def settings(self) -> Optional[PipelineJobSettings]:
+ """Settings of the pipeline job.
+
+ :return: Settings of the pipeline job.
+ :rtype: ~azure.ai.ml.entities.PipelineJobSettings
+ """
+ if self._settings is None:
+ self._settings = PipelineJobSettings()
+ res: Optional[PipelineJobSettings] = self._settings
+ return res
+
+ @settings.setter
+ def settings(self, value: Union[Dict, PipelineJobSettings]) -> None:
+ """Set the pipeline job settings.
+
+ :param value: The pipeline job settings.
+ :type value: Union[dict, ~azure.ai.ml.entities.PipelineJobSettings]
+ """
+ if value is not None:
+ if isinstance(value, PipelineJobSettings):
+ # since PipelineJobSettings inherit _AttrDict, we need add this branch to distinguish with dict
+ pass
+ elif isinstance(value, dict):
+ value = PipelineJobSettings(**value)
+ else:
+ raise TypeError("settings must be PipelineJobSettings or dict but got {}".format(type(value)))
+ self._settings = value
+
+ @classmethod
+ def _create_validation_error(cls, message: str, no_personal_data_message: str) -> ValidationException:
+ return ValidationException(
+ message=message,
+ no_personal_data_message=no_personal_data_message,
+ target=ErrorTarget.PIPELINE,
+ )
+
+ @classmethod
+ def _create_schema_for_validation(cls, context: Any) -> PathAwareSchema:
+ # import this to ensure that nodes are registered before schema is created.
+
+ return PipelineJobSchema(context=context)
+
+ @classmethod
+ def _get_skip_fields_in_schema_validation(cls) -> typing.List[str]:
+ # jobs validations are done in _customized_validate()
+ return ["component", "jobs"]
+
+ @property
+ def _skip_required_compute_missing_validation(self) -> Literal[True]:
+ return True
+
+ def _validate_compute_is_set(self) -> MutableValidationResult:
+ validation_result = self._create_empty_validation_result()
+ if self.compute is not None:
+ return validation_result
+ if self.settings is not None and self.settings.default_compute is not None:
+ return validation_result
+
+ if not isinstance(self.component, str):
+ validation_result.merge_with(self.component._validate_compute_is_set())
+ return validation_result
+
+ def _customized_validate(self) -> MutableValidationResult:
+ """Validate that all provided inputs and parameters are valid for current pipeline and components in it.
+
+ :return: The validation result
+ :rtype: MutableValidationResult
+ """
+ validation_result = super(PipelineJob, self)._customized_validate()
+
+ if isinstance(self.component, PipelineComponent):
+ # Merge with pipeline component validate result for structure validation.
+ # Skip top level parameter missing type error
+ validation_result.merge_with(
+ self.component._customized_validate(),
+ condition_skip=lambda x: x.error_code == ValidationErrorCode.PARAMETER_TYPE_UNKNOWN
+ and x.yaml_path.startswith("inputs"),
+ )
+ # Validate compute
+ validation_result.merge_with(self._validate_compute_is_set())
+ # Validate Input
+ validation_result.merge_with(self._validate_input())
+ # Validate initialization & finalization jobs
+ validation_result.merge_with(self._validate_init_finalize_job())
+
+ return validation_result
+
+ def _validate_input(self) -> MutableValidationResult:
+ validation_result = self._create_empty_validation_result()
+ if not isinstance(self.component, str):
+ # TODO(1979547): refine this logic: not all nodes have `_get_input_binding_dict` method
+ used_pipeline_inputs = set(
+ itertools.chain(
+ *[
+ self.component._get_input_binding_dict(node if not isinstance(node, LoopNode) else node.body)[0]
+ for node in self.jobs.values()
+ if not isinstance(node, ConditionNode)
+ # condition node has no inputs
+ ]
+ )
+ )
+ # validate inputs
+ if not isinstance(self.component, Component):
+ return validation_result
+ for key, meta in self.component.inputs.items():
+ if key not in used_pipeline_inputs: # pylint: disable=possibly-used-before-assignment
+ # Only validate inputs certainly used.
+ continue
+ # raise error when required input with no default value not set
+ if (
+ self.inputs.get(key, None) is None # input not provided
+ and meta.optional is not True # and it's required
+ and meta.default is None # and it does not have default
+ ):
+ name = self.name or self.display_name
+ name = f"{name!r} " if name else ""
+ validation_result.append_error(
+ yaml_path=f"inputs.{key}",
+ message=f"Required input {key!r} for pipeline {name}not provided.",
+ )
+ return validation_result
+
+ def _validate_init_finalize_job(self) -> MutableValidationResult: # pylint: disable=too-many-statements
+ from azure.ai.ml.entities._job.pipeline._io import InputOutputBase, _GroupAttrDict
+
+ validation_result = self._create_empty_validation_result()
+ # subgraph (PipelineComponent) should not have on_init/on_finalize set
+ for job_name, job in self.jobs.items():
+ if job.type != "pipeline":
+ continue
+ if job.settings.on_init:
+ validation_result.append_error(
+ yaml_path=f"jobs.{job_name}.settings.on_init",
+ message="On_init is not supported for pipeline component.",
+ )
+ if job.settings.on_finalize:
+ validation_result.append_error(
+ yaml_path=f"jobs.{job_name}.settings.on_finalize",
+ message="On_finalize is not supported for pipeline component.",
+ )
+
+ on_init = None
+ on_finalize = None
+
+ if self.settings is not None:
+ # quick return if neither on_init nor on_finalize is set
+ if self.settings.on_init is None and self.settings.on_finalize is None:
+ return validation_result
+
+ on_init, on_finalize = self.settings.on_init, self.settings.on_finalize
+
+ append_on_init_error = partial(validation_result.append_error, "settings.on_init")
+ append_on_finalize_error = partial(validation_result.append_error, "settings.on_finalize")
+ # on_init and on_finalize cannot be same
+ if on_init == on_finalize:
+ append_on_init_error(f"Invalid on_init job {on_init}, it should be different from on_finalize.")
+ append_on_finalize_error(f"Invalid on_finalize job {on_finalize}, it should be different from on_init.")
+ # pipeline should have at least one normal node
+ if len(set(self.jobs.keys()) - {on_init, on_finalize}) == 0:
+ validation_result.append_error(yaml_path="jobs", message="No other job except for on_init/on_finalize job.")
+
+ def _is_control_flow_node(_validate_job_name: str) -> bool:
+ from azure.ai.ml.entities._builders.control_flow_node import ControlFlowNode
+
+ _validate_job = self.jobs[_validate_job_name]
+ return issubclass(type(_validate_job), ControlFlowNode)
+
+ def _is_isolated_job(_validate_job_name: str) -> bool:
+ def _try_get_data_bindings(
+ _name: str, _input_output_data: Union["_GroupAttrDict", "InputOutputBase"]
+ ) -> Optional[List]:
+ """Try to get data bindings from input/output data, return None if not found.
+ :param _name: The name to use when flattening GroupAttrDict
+ :type _name: str
+ :param _input_output_data: The input/output data
+ :type _input_output_data: Union[_GroupAttrDict, str, InputOutputBase]
+ :return: A list of data bindings, or None if not found
+ :rtype: Optional[List[str]]
+ """
+ # handle group input
+ if GroupInput._is_group_attr_dict(_input_output_data):
+ _new_input_output_data: _GroupAttrDict = cast(_GroupAttrDict, _input_output_data)
+ # flatten to avoid nested cases
+ flattened_values: List[Input] = list(_new_input_output_data.flatten(_name).values())
+ # handle invalid empty group
+ if len(flattened_values) == 0:
+ return None
+ return [_value.path for _value in flattened_values]
+ _input_output_data = _input_output_data._data
+ if isinstance(_input_output_data, str):
+ return [_input_output_data]
+ if not hasattr(_input_output_data, "_data_binding"):
+ return None
+ return [_input_output_data._data_binding()]
+
+ _validate_job = self.jobs[_validate_job_name]
+ # no input to validate job
+ for _input_name in _validate_job.inputs:
+ _data_bindings = _try_get_data_bindings(_input_name, _validate_job.inputs[_input_name])
+ if _data_bindings is None:
+ continue
+ for _data_binding in _data_bindings:
+ if is_data_binding_expression(_data_binding, ["parent", "jobs"]):
+ return False
+ # no output from validate job - iterate other jobs input(s) to validate
+ for _job_name, _job in self.jobs.items():
+ # exclude control flow node as it does not have inputs
+ if _is_control_flow_node(_job_name):
+ continue
+ for _input_name in _job.inputs:
+ _data_bindings = _try_get_data_bindings(_input_name, _job.inputs[_input_name])
+ if _data_bindings is None:
+ continue
+ for _data_binding in _data_bindings:
+ if is_data_binding_expression(_data_binding, ["parent", "jobs", _validate_job_name]):
+ return False
+ return True
+
+ # validate on_init
+ if on_init is not None:
+ if on_init not in self.jobs:
+ append_on_init_error(f"On_init job name {on_init} not exists in jobs.")
+ else:
+ if _is_control_flow_node(on_init):
+ append_on_init_error("On_init job should not be a control flow node.")
+ elif not _is_isolated_job(on_init):
+ append_on_init_error("On_init job should not have connection to other execution node.")
+ # validate on_finalize
+ if on_finalize is not None:
+ if on_finalize not in self.jobs:
+ append_on_finalize_error(f"On_finalize job name {on_finalize} not exists in jobs.")
+ else:
+ if _is_control_flow_node(on_finalize):
+ append_on_finalize_error("On_finalize job should not be a control flow node.")
+ elif not _is_isolated_job(on_finalize):
+ append_on_finalize_error("On_finalize job should not have connection to other execution node.")
+ return validation_result
+
+ def _remove_pipeline_input(self) -> None:
+ """Remove None pipeline input.If not remove, it will pass "None" to backend."""
+ redundant_pipeline_inputs = []
+ for pipeline_input_name, pipeline_input in self._inputs.items():
+ if isinstance(pipeline_input, PipelineInput) and pipeline_input._data is None:
+ redundant_pipeline_inputs.append(pipeline_input_name)
+ for redundant_pipeline_input in redundant_pipeline_inputs:
+ self._inputs.pop(redundant_pipeline_input)
+
+ def _check_private_preview_features(self) -> None:
+ """Checks is private preview features included in pipeline.
+
+ If private preview environment not set, raise exception.
+ """
+ if not is_private_preview_enabled():
+ error_msg = (
+ "{} is a private preview feature, "
+ f"please set environment variable {AZUREML_PRIVATE_FEATURES_ENV_VAR} to true to use it."
+ )
+ # check has not supported nodes
+ for _, node in self.jobs.items():
+ # TODO: Remove in PuP
+ if isinstance(node, (ImportJob, Import)):
+ msg = error_msg.format("Import job in pipeline")
+ raise UserErrorException(message=msg, no_personal_data_message=msg)
+
+ def _to_node(self, context: Optional[Dict] = None, **kwargs: Any) -> "Pipeline":
+ """Translate a command job to a pipeline node when load schema.
+
+ (Write a pipeline job as node in yaml is not supported presently.)
+
+ :param context: Context of command job YAML file.
+ :type context: dict
+ :return: Translated command component.
+ :rtype: Pipeline
+ """
+ component = self._to_component(context, **kwargs)
+
+ return Pipeline(
+ component=component,
+ compute=self.compute,
+ # Need to supply the inputs with double curly.
+ inputs=self.inputs,
+ outputs=self.outputs,
+ description=self.description,
+ tags=self.tags,
+ display_name=self.display_name,
+ properties=self.properties,
+ )
+
+ def _to_rest_object(self) -> JobBase:
+ """Build current parameterized pipeline instance to a pipeline job object before submission.
+
+ :return: Rest pipeline job.
+ :rtype: JobBase
+ """
+ # Check if there are private preview features in it
+ self._check_private_preview_features()
+
+ # Build the inputs to dict. Handle both value & binding assignment.
+ # Example: {
+ # "input_data": {"data": {"path": "path/to/input/data"}, "mode"="Mount"},
+ # "input_value": 10,
+ # "learning_rate": "${{jobs.step1.inputs.learning_rate}}"
+ # }
+ built_inputs = self._build_inputs()
+
+ # Build the outputs to dict
+ # example: {"eval_output": "${{jobs.eval.outputs.eval_output}}"}
+ built_outputs = self._build_outputs()
+
+ if self.settings is not None:
+ settings_dict = self.settings._to_dict()
+
+ if isinstance(self.component, PipelineComponent):
+ source = self.component._source
+ # Build the jobs to dict
+ rest_component_jobs = self.component._build_rest_component_jobs()
+ else:
+ source = ComponentSource.REMOTE_WORKSPACE_JOB
+ rest_component_jobs = {}
+ # add _source on pipeline job.settings
+ if "_source" not in settings_dict: # pylint: disable=possibly-used-before-assignment
+ settings_dict.update({"_source": source})
+
+ # TODO: Revisit this logic when multiple types of component jobs are supported
+ rest_compute = self.compute
+ # This will be resolved in job_operations _resolve_arm_id_or_upload_dependencies.
+ component_id = self.component if isinstance(self.component, str) else self.component.id
+
+ # TODO remove it in the future.
+ # MFE not support pass None or empty input value. Remove the empty inputs in pipeline job.
+ built_inputs = {k: v for k, v in built_inputs.items() if v is not None and v != ""}
+
+ pipeline_job = RestPipelineJob(
+ compute_id=rest_compute,
+ component_id=component_id,
+ display_name=self.display_name,
+ tags=self.tags,
+ description=self.description,
+ properties=self.properties,
+ experiment_name=self.experiment_name,
+ jobs=rest_component_jobs,
+ inputs=to_rest_dataset_literal_inputs(built_inputs, job_type=self.type),
+ outputs=to_rest_data_outputs(built_outputs),
+ settings=settings_dict,
+ services={k: v._to_rest_object() for k, v in self.services.items()} if self.services else None,
+ identity=self.identity._to_job_rest_object() if self.identity else None,
+ )
+
+ rest_job = JobBase(properties=pipeline_job)
+ rest_job.name = self.name
+ return rest_job
+
+ @classmethod
+ def _load_from_rest(cls, obj: JobBase) -> "PipelineJob":
+ """Build a pipeline instance from rest pipeline object.
+
+ :param obj: The REST Pipeline Object
+ :type obj: JobBase
+ :return: pipeline job.
+ :rtype: PipelineJob
+ """
+ properties: RestPipelineJob = obj.properties
+ # Workaround for BatchEndpoint as these fields are not filled in
+ # Unpack the inputs
+ from_rest_inputs = from_rest_inputs_to_dataset_literal(properties.inputs) or {}
+ from_rest_outputs = from_rest_data_outputs(properties.outputs) or {}
+ # Unpack the component jobs
+ sub_nodes = PipelineComponent._resolve_sub_nodes(properties.jobs) if properties.jobs else {}
+ # backend may still store Camel settings, eg: DefaultDatastore, translate them to snake when load back
+ settings_dict = transform_dict_keys(properties.settings, camel_to_snake) if properties.settings else None
+ settings_sdk = PipelineJobSettings(**settings_dict) if settings_dict else PipelineJobSettings()
+ # Create component or use component id
+ if getattr(properties, "component_id", None):
+ component = properties.component_id
+ else:
+ component = PipelineComponent._load_from_rest_pipeline_job(
+ {
+ "inputs": from_rest_inputs,
+ "outputs": from_rest_outputs,
+ "display_name": properties.display_name,
+ "description": properties.description,
+ "jobs": sub_nodes,
+ }
+ )
+
+ job = PipelineJob(
+ component=component,
+ inputs=from_rest_inputs,
+ outputs=from_rest_outputs,
+ name=obj.name,
+ id=obj.id,
+ jobs=sub_nodes,
+ display_name=properties.display_name,
+ tags=properties.tags,
+ properties=properties.properties,
+ experiment_name=properties.experiment_name,
+ status=properties.status,
+ creation_context=SystemData._from_rest_object(obj.system_data) if obj.system_data else None,
+ services=JobServiceBase._from_rest_job_services(properties.services) if properties.services else None,
+ compute=get_resource_name_from_arm_id_safe(properties.compute_id),
+ settings=settings_sdk,
+ identity=(
+ _BaseJobIdentityConfiguration._from_rest_object(properties.identity) if properties.identity else None
+ ),
+ )
+
+ return job
+
+ def _to_dict(self) -> Dict:
+ res: dict = self._dump_for_validation()
+ return res
+
+ @classmethod
+ def _component_items_from_path(cls, data: Dict) -> Generator:
+ if "jobs" in data:
+ for node_name, job_instance in data["jobs"].items():
+ potential_component_path = job_instance["component"] if "component" in job_instance else None
+ if isinstance(potential_component_path, str) and potential_component_path.startswith("file:"):
+ yield node_name, potential_component_path
+
+ @classmethod
+ def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "PipelineJob":
+ path_first_occurrence: dict = {}
+ component_first_occurrence = {}
+ for node_name, component_path in cls._component_items_from_path(data):
+ if component_path in path_first_occurrence:
+ component_first_occurrence[node_name] = path_first_occurrence[component_path]
+ # set components to be replaced here may break the validation logic
+ else:
+ path_first_occurrence[component_path] = node_name
+
+ # use this instead of azure.ai.ml.entities._util.load_from_dict to avoid parsing
+ loaded_schema = cls._create_schema_for_validation(context=context).load(data, **kwargs)
+
+ # replace repeat component with first occurrence to reduce arm id resolution
+ # current load yaml file logic is in azure.ai.ml._schema.core.schema.YamlFileSchema.load_from_file
+ # is it possible to load the same yaml file only once in 1 pipeline loading?
+ for node_name, first_occurrence in component_first_occurrence.items():
+ job = loaded_schema["jobs"][node_name]
+ job._component = loaded_schema["jobs"][first_occurrence].component
+ # For Parallel job, should also align task attribute which is usually from component.task
+ if isinstance(job, Parallel):
+ job.task = job._component.task
+ # parallel.task.code is based on parallel._component.base_path, so need to update it
+ job._base_path = job._component.base_path
+ return PipelineJob(
+ base_path=context[BASE_PATH_CONTEXT_KEY],
+ _source=ComponentSource.YAML_JOB,
+ **loaded_schema,
+ )
+
+ def __str__(self) -> str:
+ try:
+ res_to_yaml: str = self._to_yaml()
+ return res_to_yaml
+ except BaseException: # pylint: disable=W0718
+ res: str = super(PipelineJob, self).__str__()
+ return res
+
+ def _get_telemetry_values(self) -> Dict:
+ telemetry_values: dict = super()._get_telemetry_values()
+ if isinstance(self.component, PipelineComponent):
+ telemetry_values.update(self.component._get_telemetry_values())
+ else:
+ telemetry_values.update({"source": ComponentSource.REMOTE_WORKSPACE_JOB})
+ telemetry_values.pop("is_anonymous")
+ return telemetry_values
+
+ def _to_component(self, context: Optional[Dict] = None, **kwargs: Any) -> "PipelineComponent":
+ """Translate a pipeline job to pipeline component.
+
+ :param context: Context of pipeline job YAML file.
+ :type context: dict
+ :return: Translated pipeline component.
+ :rtype: PipelineComponent
+ """
+ ignored_keys = PipelineComponent._check_ignored_keys(self)
+ if ignored_keys:
+ name = self.name or self.display_name
+ name = f"{name!r} " if name else ""
+ module_logger.warning("%s ignored when translating PipelineJob %sto PipelineComponent.", ignored_keys, name)
+ pipeline_job_dict = kwargs.get("pipeline_job_dict", {})
+ context = context or {BASE_PATH_CONTEXT_KEY: Path("./")}
+
+ # Create anonymous pipeline component with default version as 1
+ return PipelineComponent(
+ base_path=context[BASE_PATH_CONTEXT_KEY],
+ display_name=self.display_name,
+ inputs=self._to_inputs(inputs=self.inputs, pipeline_job_dict=pipeline_job_dict),
+ outputs=self._to_outputs(outputs=self.outputs, pipeline_job_dict=pipeline_job_dict),
+ jobs=self.jobs,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/pipeline_job_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/pipeline_job_settings.py
new file mode 100644
index 00000000..0fe41e2e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/pipeline/pipeline_job_settings.py
@@ -0,0 +1,75 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Any, Dict, Generator, Optional
+
+from azure.ai.ml.entities._job.pipeline._attr_dict import _AttrDict
+
+
+class PipelineJobSettings(_AttrDict):
+ """Settings of PipelineJob.
+
+ :param default_datastore: The default datastore of the pipeline.
+ :type default_datastore: str
+ :param default_compute: The default compute target of the pipeline.
+ :type default_compute: str
+ :param continue_on_step_failure: Flag indicating whether to continue pipeline execution if a step fails.
+ :type continue_on_step_failure: bool
+ :param force_rerun: Flag indicating whether to force rerun pipeline execution.
+ :type force_rerun: bool
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_pipeline_job_configurations.py
+ :start-after: [START configure_pipeline_job_and_settings]
+ :end-before: [END configure_pipeline_job_and_settings]
+ :language: python
+ :dedent: 8
+ :caption: Shows how to set pipeline properties using this class.
+ """
+
+ def __init__(
+ self,
+ default_datastore: Optional[str] = None,
+ default_compute: Optional[str] = None,
+ continue_on_step_failure: Optional[bool] = None,
+ force_rerun: Optional[bool] = None,
+ **kwargs: Any
+ ) -> None:
+ self._init = True
+ super().__init__()
+ self.default_compute: Any = default_compute
+ self.default_datastore: Any = default_datastore
+ self.continue_on_step_failure = continue_on_step_failure
+ self.force_rerun = force_rerun
+ self.on_init = kwargs.get("on_init", None)
+ self.on_finalize = kwargs.get("on_finalize", None)
+ for k, v in kwargs.items():
+ setattr(self, k, v)
+ self._init = False
+
+ def _get_valid_keys(self) -> Generator[str, Any, None]:
+ for k, v in self.__dict__.items():
+ if v is None:
+ continue
+ # skip private attributes inherited from _AttrDict
+ if k in ["_logger", "_allowed_keys", "_init", "_key_restriction"]:
+ continue
+ yield k
+
+ def _to_dict(self) -> Dict:
+ result = {}
+ for k in self._get_valid_keys():
+ result[k] = self.__dict__[k]
+ result.update(self._get_attrs())
+ return result
+
+ def _initializing(self) -> bool:
+ return self._init
+
+ def __bool__(self) -> bool:
+ for _ in self._get_valid_keys():
+ return True
+ # _attr_dict will return False if no extra attributes are set
+ return self.__len__() > 0
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/queue_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/queue_settings.py
new file mode 100644
index 00000000..5b51fb6e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/queue_settings.py
@@ -0,0 +1,87 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import logging
+from typing import Any, Dict, Optional, Union
+
+from ..._restclient.v2023_04_01_preview.models import QueueSettings as RestQueueSettings
+from ..._utils._experimental import experimental
+from ..._utils.utils import is_data_binding_expression
+from ...constants._job.job import JobPriorityValues, JobTierNames
+from ...entities._mixins import DictMixin, RestTranslatableMixin
+from ...exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+
+module_logger = logging.getLogger(__name__)
+
+
+@experimental
+class QueueSettings(RestTranslatableMixin, DictMixin):
+ """Queue settings for a pipeline job.
+
+ :ivar job_tier: Enum to determine the job tier. Possible values include: "Spot", "Basic",
+ "Standard", "Premium", "Null".
+ :vartype job_tier: str or ~azure.mgmt.machinelearningservices.models.JobTier
+ :ivar priority: Controls the priority of the job on a compute.
+ :vartype priority: str
+ :keyword job_tier: The job tier. Accepted values are "Spot", "Basic", "Standard", and "Premium".
+ :paramtype job_tier: Optional[Literal]]
+ :keyword priority: The priority of the job on a compute. Accepted values are "low", "medium", and "high".
+ Defaults to "medium".
+ :paramtype priority: Optional[Literal]
+ :keyword kwargs: Additional properties for QueueSettings.
+ :paramtype kwargs: Optional[dict]
+ """
+
+ def __init__(
+ self, # pylint: disable=unused-argument
+ *,
+ job_tier: Optional[str] = None,
+ priority: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ self.job_tier = job_tier
+ self.priority = priority
+
+ def _to_rest_object(self) -> RestQueueSettings:
+ self._validate()
+ job_tier = JobTierNames.ENTITY_TO_REST.get(self.job_tier.lower(), None) if self.job_tier else None
+ priority = JobPriorityValues.ENTITY_TO_REST.get(self.priority.lower(), None) if self.priority else None
+ return RestQueueSettings(job_tier=job_tier, priority=priority)
+
+ @classmethod
+ def _from_rest_object(cls, obj: Union[Dict[str, Any], RestQueueSettings, None]) -> Optional["QueueSettings"]:
+ if obj is None:
+ return None
+ if isinstance(obj, dict):
+ queue_settings = RestQueueSettings.from_dict(obj)
+ return cls._from_rest_object(queue_settings)
+ job_tier = JobTierNames.REST_TO_ENTITY.get(obj.job_tier, None) if obj.job_tier else None
+ priority = JobPriorityValues.REST_TO_ENTITY.get(obj.priority, None) if hasattr(obj, "priority") else None
+ return cls(job_tier=job_tier, priority=priority)
+
+ def _validate(self) -> None:
+ for key, enum_class in [("job_tier", JobTierNames), ("priority", JobPriorityValues)]:
+ value = getattr(self, key)
+ if is_data_binding_expression(value):
+ msg = (
+ f"do not support data binding expression on {key} as it involves value mapping "
+ f"when transformed to rest object, but received '{value}'."
+ )
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ valid_keys = list(enum_class.ENTITY_TO_REST.keys()) # type: ignore[attr-defined]
+ if value and value.lower() not in valid_keys:
+ msg = f"{key} should be one of {valid_keys}, but received '{value}'."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/resource_configuration.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/resource_configuration.py
new file mode 100644
index 00000000..a10d4a66
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/resource_configuration.py
@@ -0,0 +1,98 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import json
+import logging
+from typing import Any, Dict, Optional
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import ResourceConfiguration as RestResourceConfiguration
+from azure.ai.ml.constants._job.job import JobComputePropertyFields
+from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin
+
+module_logger = logging.getLogger(__name__)
+
+
+class ResourceConfiguration(RestTranslatableMixin, DictMixin):
+ """Resource configuration for a job.
+
+ This class should not be instantiated directly. Instead, use its subclasses.
+
+ :keyword instance_count: The number of instances to use for the job.
+ :paramtype instance_count: Optional[int]
+ :keyword instance_type: The type of instance to use for the job.
+ :paramtype instance_type: Optional[str]
+ :keyword properties: The resource's property dictionary.
+ :paramtype properties: Optional[dict[str, Any]]
+ """
+
+ def __init__(
+ self, # pylint: disable=unused-argument
+ *,
+ instance_count: Optional[int] = None,
+ instance_type: Optional[str] = None,
+ properties: Optional[Dict[str, Any]] = None,
+ **kwargs: Any
+ ) -> None:
+ self.instance_count = instance_count
+ self.instance_type = instance_type
+ self.properties = {}
+ if properties is not None:
+ for key, value in properties.items():
+ if key == JobComputePropertyFields.AISUPERCOMPUTER:
+ self.properties[JobComputePropertyFields.SINGULARITY.lower()] = value
+ else:
+ self.properties[key] = value
+
+ def _to_rest_object(self) -> RestResourceConfiguration:
+ serialized_properties = {}
+ if self.properties:
+ for key, value in self.properties.items():
+ try:
+ if (
+ key.lower() == JobComputePropertyFields.SINGULARITY.lower()
+ or key.lower() == JobComputePropertyFields.AISUPERCOMPUTER.lower()
+ ):
+ # Map Singularity -> AISupercomputer in SDK until MFE does mapping
+ key = JobComputePropertyFields.AISUPERCOMPUTER
+ # recursively convert Ordered Dict to dictionary
+ serialized_properties[key] = json.loads(json.dumps(value))
+ except Exception: # pylint: disable=W0718
+ pass
+ return RestResourceConfiguration(
+ instance_count=self.instance_count,
+ instance_type=self.instance_type,
+ properties=serialized_properties,
+ )
+
+ @classmethod
+ def _from_rest_object( # pylint: disable=arguments-renamed
+ cls, rest_obj: Optional[RestResourceConfiguration]
+ ) -> Optional["ResourceConfiguration"]:
+ if rest_obj is None:
+ return None
+ return ResourceConfiguration(
+ instance_count=rest_obj.instance_count,
+ instance_type=rest_obj.instance_type,
+ properties=rest_obj.properties,
+ deserialize_properties=True,
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, ResourceConfiguration):
+ return NotImplemented
+ return self.instance_count == other.instance_count and self.instance_type == other.instance_type
+
+ def __ne__(self, other: object) -> bool:
+ if not isinstance(other, ResourceConfiguration):
+ return NotImplemented
+ return not self.__eq__(other)
+
+ def _merge_with(self, other: "ResourceConfiguration") -> None:
+ if other:
+ if other.instance_count:
+ self.instance_count = other.instance_count
+ if other.instance_type:
+ self.instance_type = other.instance_type
+ if other.properties:
+ self.properties = other.properties
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/service_instance.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/service_instance.py
new file mode 100644
index 00000000..0e5ba6c6
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/service_instance.py
@@ -0,0 +1,59 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import logging
+from typing import Any, Dict, Optional
+
+from azure.ai.ml._restclient.runhistory.models import ServiceInstanceResult
+from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin
+
+module_logger = logging.getLogger(__name__)
+
+
+class ServiceInstance(RestTranslatableMixin, DictMixin):
+ """Service Instance Result.
+
+ :keyword type: The type of service.
+ :paramtype type: Optional[str]
+ :keyword port: The port used by the service.
+ :paramtype port: Optional[int]
+ :keyword status: The status of the service.
+ :paramtype status: Optional[str]
+ :keyword error: The error message.
+ :paramtype error: Optional[str]
+ :keyword endpoint: The service endpoint.
+ :paramtype endpoint: Optional[str]
+ :keyword properties: The service instance's properties.
+ :paramtype properties: Optional[dict[str, str]]
+ """
+
+ def __init__(
+ self, # pylint: disable=unused-argument
+ *,
+ type: Optional[str] = None, # pylint: disable=redefined-builtin
+ port: Optional[int] = None,
+ status: Optional[str] = None,
+ error: Optional[str] = None,
+ endpoint: Optional[str] = None,
+ properties: Optional[Dict[str, str]] = None,
+ **kwargs: Any
+ ) -> None:
+ self.type = type
+ self.port = port
+ self.status = status
+ self.error = error
+ self.endpoint = endpoint
+ self.properties = properties
+
+ @classmethod
+ # pylint: disable=arguments-differ
+ def _from_rest_object(cls, obj: ServiceInstanceResult, node_index: int) -> "ServiceInstance": # type: ignore
+ return cls(
+ type=obj.type,
+ port=obj.port,
+ status=obj.status,
+ error=obj.error.error.message if obj.error and obj.error.error else None,
+ endpoint=obj.endpoint.replace("<nodeIndex>", str(node_index)) if obj.endpoint else obj.endpoint,
+ properties=obj.properties,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/spark_helpers.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/spark_helpers.py
new file mode 100644
index 00000000..d3fdf9dc
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/spark_helpers.py
@@ -0,0 +1,210 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+# pylint: disable=protected-access
+import re
+from typing import Any
+
+from azure.ai.ml.constants import InputOutputModes
+from azure.ai.ml.constants._component import ComponentJobConstants
+from azure.ai.ml.entities._inputs_outputs import Input, Output
+from azure.ai.ml.entities._job.pipeline._io import NodeInput, NodeOutput
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
+
+
+def _validate_spark_configurations(obj: Any) -> None:
+ # skip validation when component of node is from remote
+ if hasattr(obj, "component") and isinstance(obj.component, str):
+ return
+ if obj.dynamic_allocation_enabled in ["True", "true", True]:
+ if (
+ obj.driver_cores is None
+ or obj.driver_memory is None
+ or obj.executor_cores is None
+ or obj.executor_memory is None
+ ):
+ msg = (
+ "spark.driver.cores, spark.driver.memory, spark.executor.cores and spark.executor.memory are "
+ "mandatory fields."
+ )
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.SPARK_JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ if obj.dynamic_allocation_min_executors is None or obj.dynamic_allocation_max_executors is None:
+ msg = (
+ "spark.dynamicAllocation.minExecutors and spark.dynamicAllocation.maxExecutors are required "
+ "when dynamic allocation is enabled."
+ )
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.SPARK_JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ if not (
+ obj.dynamic_allocation_min_executors > 0
+ and obj.dynamic_allocation_min_executors <= obj.dynamic_allocation_max_executors
+ ):
+ msg = (
+ "Dynamic min executors should be bigger than 0 and min executors should be equal or less than "
+ "max executors."
+ )
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.SPARK_JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ if obj.executor_instances and (
+ obj.executor_instances > obj.dynamic_allocation_max_executors
+ or obj.executor_instances < obj.dynamic_allocation_min_executors
+ ):
+ msg = (
+ "Executor instances must be a valid non-negative integer and must be between "
+ "spark.dynamicAllocation.minExecutors and spark.dynamicAllocation.maxExecutors"
+ )
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.SPARK_JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ else:
+ if (
+ obj.driver_cores is None
+ or obj.driver_memory is None
+ or obj.executor_cores is None
+ or obj.executor_memory is None
+ or obj.executor_instances is None
+ ):
+ msg = (
+ "spark.driver.cores, spark.driver.memory, spark.executor.cores, spark.executor.memory and "
+ "spark.executor.instances are mandatory fields."
+ )
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.SPARK_JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ if obj.dynamic_allocation_min_executors is not None or obj.dynamic_allocation_max_executors is not None:
+ msg = "Should not specify min or max executors when dynamic allocation is disabled."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.SPARK_JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+
+def _validate_compute_or_resources(compute: Any, resources: Any) -> None:
+ # if resources is set, then ensure it is valid before
+ # checking mutual exclusiveness against compute existence
+ if compute is None and resources is None:
+ msg = "One of either compute or resources must be specified for Spark job"
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.SPARK_JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ if compute and resources:
+ msg = "Only one of either compute or resources may be specified for Spark job"
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.SPARK_JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+
+# Only "direct" mode is supported for spark job inputs and outputs
+# pylint: disable=no-else-raise, too-many-boolean-expressions
+def _validate_input_output_mode(inputs: Any, outputs: Any) -> None:
+ for input_name, input_value in inputs.items():
+ if isinstance(input_value, Input) and input_value.mode != InputOutputModes.DIRECT:
+ # For standalone job input
+ msg = "Input '{}' is using '{}' mode, only '{}' is supported for Spark job"
+ raise ValidationException(
+ message=msg.format(input_name, input_value.mode, InputOutputModes.DIRECT),
+ no_personal_data_message=msg.format("[input_name]", "[input_value.mode]", "direct"),
+ target=ErrorTarget.SPARK_JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ elif (
+ isinstance(input_value, NodeInput)
+ and (
+ isinstance(input_value._data, Input)
+ and not (
+ isinstance(input_value._data.path, str)
+ and bool(re.search(ComponentJobConstants.INPUT_PATTERN, input_value._data.path))
+ )
+ and input_value._data.mode != InputOutputModes.DIRECT
+ )
+ and (isinstance(input_value._meta, Input) and input_value._meta.mode != InputOutputModes.DIRECT)
+ ):
+ # For node input in pipeline job, client side can only validate node input which isn't bound to pipeline
+ # input or node output.
+ # 1. If node input is bound to pipeline input, we can't get pipeline level input mode in node level
+ # validate. Even if we can judge through component input mode (_meta), we should note that pipeline level
+ # input mode has higher priority than component level. so component input can be set "Mount", but it can
+ # run successfully when pipeline input is "Direct".
+ # 2. If node input is bound to last node output, input mode should be decoupled with output mode, so we
+ # always get None mode in node level. In this case, if we define correct "Direct" mode in component yaml,
+ # component level mode will take effect and run successfully. Otherwise, it need to set mode in node level
+ # like input1: path: ${{parent.jobs.sample_word.outputs.output1}} mode: direct.
+ msg = "Input '{}' is using '{}' mode, only '{}' is supported for Spark job"
+ raise ValidationException(
+ message=msg.format(
+ input_name, input_value._data.mode or input_value._meta.mode, InputOutputModes.DIRECT
+ ),
+ no_personal_data_message=msg.format("[input_name]", "[input_value.mode]", "direct"),
+ target=ErrorTarget.SPARK_JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ for output_name, output_value in outputs.items():
+ if (
+ isinstance(output_value, Output)
+ and output_name != "default"
+ and output_value.mode != InputOutputModes.DIRECT
+ ):
+ # For standalone job output
+ msg = "Output '{}' is using '{}' mode, only '{}' is supported for Spark job"
+ raise ValidationException(
+ message=msg.format(output_name, output_value.mode, InputOutputModes.DIRECT),
+ no_personal_data_message=msg.format("[output_name]", "[output_value.mode]", "direct"),
+ target=ErrorTarget.SPARK_JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ elif (
+ isinstance(output_value, NodeOutput)
+ and output_name != "default"
+ and (
+ isinstance(output_value._data, Output)
+ and not (
+ isinstance(output_value._data.path, str)
+ and bool(re.search(ComponentJobConstants.OUTPUT_PATTERN, output_value._data.path))
+ )
+ and output_value._data.mode != InputOutputModes.DIRECT
+ )
+ and (isinstance(output_value._meta, Output) and output_value._meta.mode != InputOutputModes.DIRECT)
+ ):
+ # For node output in pipeline job, client side can only validate node output which isn't bound to pipeline
+ # output.
+ # 1. If node output is bound to pipeline output, we can't get pipeline level output mode in node level
+ # validate. Even if we can judge through component output mode (_meta), we should note that pipeline level
+ # output mode has higher priority than component level. so component output can be set "upload", but it
+ # can run successfully when pipeline output is "Direct".
+ msg = "Output '{}' is using '{}' mode, only '{}' is supported for Spark job"
+ raise ValidationException(
+ message=msg.format(
+ output_name, output_value._data.mode or output_value._meta.mode, InputOutputModes.DIRECT
+ ),
+ no_personal_data_message=msg.format("[output_name]", "[output_value.mode]", "direct"),
+ target=ErrorTarget.SPARK_JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/spark_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/spark_job.py
new file mode 100644
index 00000000..10930fb4
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/spark_job.py
@@ -0,0 +1,393 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+# pylint: disable=protected-access, too-many-instance-attributes
+
+import copy
+import logging
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Dict, Optional, Union
+
+from marshmallow import INCLUDE
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import JobBase
+from azure.ai.ml._restclient.v2023_04_01_preview.models import SparkJob as RestSparkJob
+from azure.ai.ml._schema.job.identity import AMLTokenIdentitySchema, ManagedIdentitySchema, UserIdentitySchema
+from azure.ai.ml._schema.job.parameterized_spark import CONF_KEY_MAP
+from azure.ai.ml._schema.job.spark_job import SparkJobSchema
+from azure.ai.ml.constants import JobType
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE
+from azure.ai.ml.constants._job.job import SparkConfKey
+from azure.ai.ml.entities._credentials import (
+ AmlTokenConfiguration,
+ ManagedIdentityConfiguration,
+ UserIdentityConfiguration,
+ _BaseJobIdentityConfiguration,
+)
+from azure.ai.ml.entities._inputs_outputs import Input, Output
+from azure.ai.ml.entities._job._input_output_helpers import (
+ from_rest_data_outputs,
+ from_rest_inputs_to_dataset_literal,
+ to_rest_data_outputs,
+ to_rest_dataset_literal_inputs,
+ validate_inputs_for_args,
+)
+from azure.ai.ml.entities._job.parameterized_spark import ParameterizedSpark
+from azure.ai.ml.entities._util import load_from_dict
+
+from ..._schema import NestedField, UnionField
+from .job import Job
+from .job_io_mixin import JobIOMixin
+from .spark_helpers import _validate_compute_or_resources, _validate_input_output_mode, _validate_spark_configurations
+from .spark_job_entry import SparkJobEntry
+from .spark_job_entry_mixin import SparkJobEntryMixin
+from .spark_resource_configuration import SparkResourceConfiguration
+
+# avoid circular import error
+if TYPE_CHECKING:
+ from azure.ai.ml.entities import SparkComponent
+ from azure.ai.ml.entities._builders import Spark
+
+module_logger = logging.getLogger(__name__)
+
+
+class SparkJob(Job, ParameterizedSpark, JobIOMixin, SparkJobEntryMixin):
+ """A standalone Spark job.
+
+ :keyword driver_cores: The number of cores to use for the driver process, only in cluster mode.
+ :paramtype driver_cores: Optional[int]
+ :keyword driver_memory: The amount of memory to use for the driver process, formatted as strings with a size unit
+ suffix ("k", "m", "g" or "t") (e.g. "512m", "2g").
+ :paramtype driver_memory: Optional[str]
+ :keyword executor_cores: The number of cores to use on each executor.
+ :paramtype executor_cores: Optional[int]
+ :keyword executor_memory: The amount of memory to use per executor process, formatted as strings with a size unit
+ suffix ("k", "m", "g" or "t") (e.g. "512m", "2g").
+ :paramtype executor_memory: Optional[str]
+ :keyword executor_instances: The initial number of executors.
+ :paramtype executor_instances: Optional[int]
+ :keyword dynamic_allocation_enabled: Whether to use dynamic resource allocation, which scales the number of
+ executors registered with this application up and down based on the workload.
+ :paramtype dynamic_allocation_enabled: Optional[bool]
+ :keyword dynamic_allocation_min_executors: The lower bound for the number of executors if dynamic allocation is
+ enabled.
+ :paramtype dynamic_allocation_min_executors: Optional[int]
+ :keyword dynamic_allocation_max_executors: The upper bound for the number of executors if dynamic allocation is
+ enabled.
+ :paramtype dynamic_allocation_max_executors: Optional[int]
+ :keyword inputs: The mapping of input data bindings used in the job.
+ :paramtype inputs: Optional[dict[str, ~azure.ai.ml.Input]]
+ :keyword outputs: The mapping of output data bindings used in the job.
+ :paramtype outputs: Optional[dict[str, ~azure.ai.ml.Output]]
+ :keyword compute: The compute resource the job runs on.
+ :paramtype compute: Optional[str]
+ :keyword identity: The identity that the Spark job will use while running on compute.
+ :paramtype identity: Optional[Union[dict[str, str], ~azure.ai.ml.ManagedIdentityConfiguration,
+ ~azure.ai.ml.AmlTokenConfiguration, ~azure.ai.ml.UserIdentityConfiguration]]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_spark_configurations.py
+ :start-after: [START spark_job_configuration]
+ :end-before: [END spark_job_configuration]
+ :language: python
+ :dedent: 8
+ :caption: Configuring a SparkJob.
+ """
+
+ def __init__(
+ self,
+ *,
+ driver_cores: Optional[Union[int, str]] = None,
+ driver_memory: Optional[str] = None,
+ executor_cores: Optional[Union[int, str]] = None,
+ executor_memory: Optional[str] = None,
+ executor_instances: Optional[Union[int, str]] = None,
+ dynamic_allocation_enabled: Optional[Union[bool, str]] = None,
+ dynamic_allocation_min_executors: Optional[Union[int, str]] = None,
+ dynamic_allocation_max_executors: Optional[Union[int, str]] = None,
+ inputs: Optional[Dict[str, Union[Input, str, bool, int, float]]] = None,
+ outputs: Optional[Dict[str, Output]] = None,
+ compute: Optional[str] = None,
+ identity: Optional[
+ Union[Dict[str, str], ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration]
+ ] = None,
+ resources: Optional[Union[Dict, SparkResourceConfiguration]] = None,
+ **kwargs: Any,
+ ) -> None:
+ kwargs[TYPE] = JobType.SPARK
+
+ super().__init__(**kwargs)
+ self.conf: Dict = self.conf or {}
+ self.properties_sparkJob = self.properties or {}
+ self.driver_cores = driver_cores
+ self.driver_memory = driver_memory
+ self.executor_cores = executor_cores
+ self.executor_memory = executor_memory
+ self.executor_instances = executor_instances
+ self.dynamic_allocation_enabled = dynamic_allocation_enabled
+ self.dynamic_allocation_min_executors = dynamic_allocation_min_executors
+ self.dynamic_allocation_max_executors = dynamic_allocation_max_executors
+ self.inputs = inputs # type: ignore[assignment]
+ self.outputs = outputs # type: ignore[assignment]
+ self.compute = compute
+ self.resources = resources
+ self.identity = identity
+ if self.executor_instances is None and str(self.dynamic_allocation_enabled).lower() == "true":
+ self.executor_instances = self.dynamic_allocation_min_executors
+
+ @property
+ def resources(self) -> Optional[Union[Dict, SparkResourceConfiguration]]:
+ """The compute resource configuration for the job.
+
+ :return: The compute resource configuration for the job.
+ :rtype: Optional[~azure.ai.ml.entities.SparkResourceConfiguration]
+ """
+ return self._resources
+
+ @resources.setter
+ def resources(self, value: Optional[Union[Dict[str, str], SparkResourceConfiguration]]) -> None:
+ """Sets the compute resource configuration for the job.
+
+ :param value: The compute resource configuration for the job.
+ :type value: Optional[Union[dict[str, str], ~azure.ai.ml.entities.SparkResourceConfiguration]]
+ """
+ if isinstance(value, dict):
+ value = SparkResourceConfiguration(**value)
+ self._resources = value
+
+ @property
+ def identity(
+ self,
+ ) -> Optional[Union[Dict, ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration]]:
+ """The identity that the Spark job will use while running on compute.
+
+ :return: The identity that the Spark job will use while running on compute.
+ :rtype: Optional[Union[~azure.ai.ml.ManagedIdentityConfiguration, ~azure.ai.ml.AmlTokenConfiguration,
+ ~azure.ai.ml.UserIdentityConfiguration]]
+ """
+ return self._identity
+
+ @identity.setter
+ def identity(
+ self,
+ value: Optional[
+ Union[Dict[str, str], ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration]
+ ],
+ ) -> None:
+ """Sets the identity that the Spark job will use while running on compute.
+
+ :param value: The identity that the Spark job will use while running on compute.
+ :type value: Optional[Union[dict[str, str], ~azure.ai.ml.ManagedIdentityConfiguration,
+ ~azure.ai.ml.AmlTokenConfiguration, ~azure.ai.ml.UserIdentityConfiguration]]
+ """
+ if isinstance(value, dict):
+ identify_schema = UnionField(
+ [
+ NestedField(ManagedIdentitySchema, unknown=INCLUDE),
+ NestedField(AMLTokenIdentitySchema, unknown=INCLUDE),
+ NestedField(UserIdentitySchema, unknown=INCLUDE),
+ ]
+ )
+ value = identify_schema._deserialize(value=value, attr=None, data=None)
+ self._identity = value
+
+ def _to_dict(self) -> Dict:
+ res: dict = SparkJobSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
+
+ def filter_conf_fields(self) -> Dict[str, str]:
+ """Filters out the fields of the conf attribute that are not among the Spark configuration fields
+ listed in ~azure.ai.ml._schema.job.parameterized_spark.CONF_KEY_MAP and returns them in their own dictionary.
+
+ :return: A dictionary of the conf fields that are not Spark configuration fields.
+ :rtype: dict[str, str]
+ """
+ if self.conf is None:
+ return {}
+ data_conf = {}
+ for conf_key, conf_val in self.conf.items():
+ if not conf_key in CONF_KEY_MAP:
+ data_conf[conf_key] = conf_val
+ return data_conf
+
+ def _to_rest_object(self) -> JobBase:
+ self._validate()
+ conf = {
+ **(self.filter_conf_fields()),
+ "spark.driver.cores": self.driver_cores,
+ "spark.driver.memory": self.driver_memory,
+ "spark.executor.cores": self.executor_cores,
+ "spark.executor.memory": self.executor_memory,
+ }
+ if self.dynamic_allocation_enabled in ["True", "true", True]:
+ conf["spark.dynamicAllocation.enabled"] = True
+ conf["spark.dynamicAllocation.minExecutors"] = self.dynamic_allocation_min_executors
+ conf["spark.dynamicAllocation.maxExecutors"] = self.dynamic_allocation_max_executors
+ if self.executor_instances is not None:
+ conf["spark.executor.instances"] = self.executor_instances
+
+ properties = RestSparkJob(
+ experiment_name=self.experiment_name,
+ display_name=self.display_name,
+ description=self.description,
+ tags=self.tags,
+ code_id=self.code,
+ entry=self.entry._to_rest_object() if self.entry is not None and not isinstance(self.entry, dict) else None,
+ py_files=self.py_files,
+ jars=self.jars,
+ files=self.files,
+ archives=self.archives,
+ identity=(
+ self.identity._to_job_rest_object() if self.identity and not isinstance(self.identity, dict) else None
+ ),
+ conf=conf,
+ properties=self.properties_sparkJob,
+ environment_id=self.environment,
+ inputs=to_rest_dataset_literal_inputs(self.inputs, job_type=self.type),
+ outputs=to_rest_data_outputs(self.outputs),
+ args=self.args,
+ compute_id=self.compute,
+ resources=(
+ self.resources._to_rest_object() if self.resources and not isinstance(self.resources, Dict) else None
+ ),
+ )
+ result = JobBase(properties=properties)
+ result.name = self.name
+ return result
+
+ @classmethod
+ def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "SparkJob":
+ loaded_data = load_from_dict(SparkJobSchema, data, context, additional_message, **kwargs)
+ return SparkJob(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_data)
+
+ @classmethod
+ def _load_from_rest(cls, obj: JobBase) -> "SparkJob":
+ rest_spark_job: RestSparkJob = obj.properties
+ rest_spark_conf = copy.copy(rest_spark_job.conf) or {}
+ spark_job = SparkJob(
+ name=obj.name,
+ entry=SparkJobEntry._from_rest_object(rest_spark_job.entry),
+ experiment_name=rest_spark_job.experiment_name,
+ id=obj.id,
+ display_name=rest_spark_job.display_name,
+ description=rest_spark_job.description,
+ tags=rest_spark_job.tags,
+ properties=rest_spark_job.properties,
+ services=rest_spark_job.services,
+ status=rest_spark_job.status,
+ creation_context=obj.system_data,
+ code=rest_spark_job.code_id,
+ compute=rest_spark_job.compute_id,
+ environment=rest_spark_job.environment_id,
+ identity=(
+ _BaseJobIdentityConfiguration._from_rest_object(rest_spark_job.identity)
+ if rest_spark_job.identity
+ else None
+ ),
+ args=rest_spark_job.args,
+ conf=rest_spark_conf,
+ driver_cores=rest_spark_conf.get(
+ SparkConfKey.DRIVER_CORES, None
+ ), # copy fields from conf into the promote attribute in spark
+ driver_memory=rest_spark_conf.get(SparkConfKey.DRIVER_MEMORY, None),
+ executor_cores=rest_spark_conf.get(SparkConfKey.EXECUTOR_CORES, None),
+ executor_memory=rest_spark_conf.get(SparkConfKey.EXECUTOR_MEMORY, None),
+ executor_instances=rest_spark_conf.get(SparkConfKey.EXECUTOR_INSTANCES, None),
+ dynamic_allocation_enabled=rest_spark_conf.get(SparkConfKey.DYNAMIC_ALLOCATION_ENABLED, None),
+ dynamic_allocation_min_executors=rest_spark_conf.get(SparkConfKey.DYNAMIC_ALLOCATION_MIN_EXECUTORS, None),
+ dynamic_allocation_max_executors=rest_spark_conf.get(SparkConfKey.DYNAMIC_ALLOCATION_MAX_EXECUTORS, None),
+ resources=SparkResourceConfiguration._from_rest_object(rest_spark_job.resources),
+ inputs=from_rest_inputs_to_dataset_literal(rest_spark_job.inputs),
+ outputs=from_rest_data_outputs(rest_spark_job.outputs),
+ )
+ return spark_job
+
+ def _to_component(self, context: Optional[Dict] = None, **kwargs: Any) -> "SparkComponent":
+ """Translate a spark job to component.
+
+ :param context: Context of spark job YAML file.
+ :type context: dict
+ :return: Translated spark component.
+ :rtype: SparkComponent
+ """
+ from azure.ai.ml.entities import SparkComponent
+
+ pipeline_job_dict = kwargs.get("pipeline_job_dict", {})
+ context = context or {BASE_PATH_CONTEXT_KEY: Path("./")}
+
+ # Create anonymous spark component with default version as 1
+ return SparkComponent(
+ tags=self.tags,
+ is_anonymous=True,
+ base_path=context[BASE_PATH_CONTEXT_KEY],
+ description=self.description,
+ code=self.code,
+ entry=self.entry,
+ py_files=self.py_files,
+ jars=self.jars,
+ files=self.files,
+ archives=self.archives,
+ driver_cores=self.driver_cores,
+ driver_memory=self.driver_memory,
+ executor_cores=self.executor_cores,
+ executor_memory=self.executor_memory,
+ executor_instances=self.executor_instances,
+ dynamic_allocation_enabled=self.dynamic_allocation_enabled,
+ dynamic_allocation_min_executors=self.dynamic_allocation_min_executors,
+ dynamic_allocation_max_executors=self.dynamic_allocation_max_executors,
+ conf=self.conf,
+ properties=self.properties_sparkJob,
+ environment=self.environment,
+ inputs=self._to_inputs(inputs=self.inputs, pipeline_job_dict=pipeline_job_dict),
+ outputs=self._to_outputs(outputs=self.outputs, pipeline_job_dict=pipeline_job_dict),
+ args=self.args,
+ )
+
+ def _to_node(self, context: Optional[Dict] = None, **kwargs: Any) -> "Spark":
+ """Translate a spark job to a pipeline node.
+
+ :param context: Context of spark job YAML file.
+ :type context: dict
+ :return: Translated spark component.
+ :rtype: Spark
+ """
+ from azure.ai.ml.entities._builders import Spark
+
+ component = self._to_component(context, **kwargs)
+
+ return Spark(
+ display_name=self.display_name,
+ description=self.description,
+ tags=self.tags,
+ # code, entry, py_files, jars, files, archives, environment and args are static and not allowed to be
+ # overwritten. And we will always get them from component.
+ component=component,
+ identity=self.identity,
+ driver_cores=self.driver_cores,
+ driver_memory=self.driver_memory,
+ executor_cores=self.executor_cores,
+ executor_memory=self.executor_memory,
+ executor_instances=self.executor_instances,
+ dynamic_allocation_enabled=self.dynamic_allocation_enabled,
+ dynamic_allocation_min_executors=self.dynamic_allocation_min_executors,
+ dynamic_allocation_max_executors=self.dynamic_allocation_max_executors,
+ conf=self.conf,
+ inputs=self.inputs, # type: ignore[arg-type]
+ outputs=self.outputs, # type: ignore[arg-type]
+ compute=self.compute,
+ resources=self.resources,
+ properties=self.properties_sparkJob,
+ )
+
+ def _validate(self) -> None:
+ # TODO: make spark job schema validatable?
+ if self.resources and not isinstance(self.resources, Dict):
+ self.resources._validate()
+ _validate_compute_or_resources(self.compute, self.resources)
+ _validate_input_output_mode(self.inputs, self.outputs)
+ _validate_spark_configurations(self)
+ self._validate_entry()
+
+ if self.args:
+ validate_inputs_for_args(self.args, self.inputs)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/spark_job_entry.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/spark_job_entry.py
new file mode 100644
index 00000000..ed8d3ca7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/spark_job_entry.py
@@ -0,0 +1,59 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+# pylint: disable=redefined-builtin
+
+from typing import Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import SparkJobEntry as RestSparkJobEntry
+from azure.ai.ml._restclient.v2023_04_01_preview.models import SparkJobPythonEntry, SparkJobScalaEntry
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+
+class SparkJobEntryType:
+ """Type of Spark job entry. Possibilities are Python file entry or Scala class entry."""
+
+ SPARK_JOB_FILE_ENTRY = "SparkJobPythonEntry"
+ SPARK_JOB_CLASS_ENTRY = "SparkJobScalaEntry"
+
+
+class SparkJobEntry(RestTranslatableMixin):
+ """Entry for Spark job.
+
+ :keyword entry: The file or class entry point.
+ :paramtype entry: str
+ :keyword type: The entry type. Accepted values are SparkJobEntryType.SPARK_JOB_FILE_ENTRY or
+ SparkJobEntryType.SPARK_JOB_CLASS_ENTRY. Defaults to SparkJobEntryType.SPARK_JOB_FILE_ENTRY.
+ :paramtype type: ~azure.ai.ml.entities.SparkJobEntryType
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_spark_configurations.py
+ :start-after: [START spark_component_definition]
+ :end-before: [END spark_component_definition]
+ :language: python
+ :dedent: 8
+ :caption: Creating SparkComponent.
+ """
+
+ def __init__(self, *, entry: str, type: str = SparkJobEntryType.SPARK_JOB_FILE_ENTRY) -> None:
+ self.entry_type = type
+ self.entry = entry
+
+ @classmethod
+ def _from_rest_object(cls, obj: Union[SparkJobPythonEntry, SparkJobScalaEntry]) -> Optional["SparkJobEntry"]:
+ if obj is None:
+ return None
+ if isinstance(obj, dict):
+ obj = RestSparkJobEntry.from_dict(obj)
+ if obj.spark_job_entry_type == SparkJobEntryType.SPARK_JOB_FILE_ENTRY:
+ return SparkJobEntry(
+ entry=obj.__dict__.get("file", None),
+ type=SparkJobEntryType.SPARK_JOB_FILE_ENTRY,
+ )
+ return SparkJobEntry(entry=obj.class_name, type=SparkJobEntryType.SPARK_JOB_CLASS_ENTRY)
+
+ def _to_rest_object(self) -> Union[SparkJobPythonEntry, SparkJobScalaEntry]:
+ if self.entry_type == SparkJobEntryType.SPARK_JOB_FILE_ENTRY:
+ return SparkJobPythonEntry(file=self.entry)
+ return SparkJobScalaEntry(class_name=self.entry)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/spark_job_entry_mixin.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/spark_job_entry_mixin.py
new file mode 100644
index 00000000..2a1ff549
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/spark_job_entry_mixin.py
@@ -0,0 +1,64 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import re
+from typing import Any, Dict, Optional, Union, cast
+
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
+
+from .spark_job_entry import SparkJobEntry, SparkJobEntryType
+
+
+class SparkJobEntryMixin:
+ CODE_ID_RE_PATTERN = re.compile(
+ (
+ r"\/subscriptions\/(?P<subscription>[\w,-]+)\/resourceGroups\/(?P<resource_group>[\w,-]+)"
+ r"\/providers\/Microsoft\.MachineLearningServices\/workspaces\/(?P<workspace>[\w,-]+)"
+ r"\/codes\/(?P<code_id>[\w,-]+)" # fmt: skip
+ )
+ )
+
+ def __init__(self, **kwargs: Any):
+ self._entry = None
+ self.entry = kwargs.get("entry", None)
+
+ @property
+ def entry(self) -> Optional[Union[Dict[str, str], SparkJobEntry]]:
+ return self._entry
+
+ @entry.setter
+ def entry(self, value: Optional[Union[Dict[str, str], SparkJobEntry]]) -> None:
+ if isinstance(value, dict):
+ if value.get("file", None):
+ _entry = cast(str, value.get("file"))
+ self._entry = SparkJobEntry(entry=_entry, type=SparkJobEntryType.SPARK_JOB_FILE_ENTRY)
+ return
+ if value.get("class_name", None):
+ _entry = cast(str, value.get("class_name"))
+ self._entry = SparkJobEntry(entry=_entry, type=SparkJobEntryType.SPARK_JOB_CLASS_ENTRY)
+ return
+ self._entry = value
+
+ def _validate_entry(self) -> None:
+ if self.entry is None:
+ # Entry is a required field for local component and when we load a remote job, component now is an arm_id,
+ # entry is from node level returned from service. Entry is only None when we reference an existing
+ # component with a function and the referenced component is in remote with name and version.
+ return
+ if not isinstance(self.entry, SparkJobEntry):
+ msg = f"Unsupported type {type(self.entry)} detected when validate entry, entry should be SparkJobEntry."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.SPARK_JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ if self.entry.entry_type == SparkJobEntryType.SPARK_JOB_CLASS_ENTRY:
+ msg = "Classpath is not supported, please use 'file' to define the entry file."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.SPARK_JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/spark_resource_configuration.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/spark_resource_configuration.py
new file mode 100644
index 00000000..138fc7ed
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/spark_resource_configuration.py
@@ -0,0 +1,91 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ SparkResourceConfiguration as RestSparkResourceConfiguration,
+)
+from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
+
+
+class SparkResourceConfiguration(RestTranslatableMixin, DictMixin):
+ """Compute resource configuration for Spark component or job.
+
+ :keyword instance_type: The type of VM to be used by the compute target.
+ :paramtype instance_type: Optional[str]
+ :keyword runtime_version: The Spark runtime version.
+ :paramtype runtime_version: Optional[str]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_spark_configurations.py
+ :start-after: [START spark_resource_configuration]
+ :end-before: [END spark_resource_configuration]
+ :language: python
+ :dedent: 8
+ :caption: Configuring a SparkJob with SparkResourceConfiguration.
+ """
+
+ instance_type_list = [
+ "standard_e4s_v3",
+ "standard_e8s_v3",
+ "standard_e16s_v3",
+ "standard_e32s_v3",
+ "standard_e64s_v3",
+ ]
+
+ def __init__(self, *, instance_type: Optional[str] = None, runtime_version: Optional[str] = None) -> None:
+ self.instance_type = instance_type
+ self.runtime_version = runtime_version
+
+ def _to_rest_object(self) -> RestSparkResourceConfiguration:
+ return RestSparkResourceConfiguration(instance_type=self.instance_type, runtime_version=self.runtime_version)
+
+ @classmethod
+ def _from_rest_object(
+ cls, obj: Union[dict, None, RestSparkResourceConfiguration]
+ ) -> Optional["SparkResourceConfiguration"]:
+ if obj is None:
+ return None
+ if isinstance(obj, dict):
+ return SparkResourceConfiguration(**obj)
+ return SparkResourceConfiguration(instance_type=obj.instance_type, runtime_version=obj.runtime_version)
+
+ def _validate(self) -> None:
+ # TODO: below logic is duplicated to SparkResourceConfigurationSchema, maybe make SparkJob schema validatable
+ if self.instance_type is None or self.instance_type == "":
+ msg = "Instance type must be specified for SparkResourceConfiguration"
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.SPARK_JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ if self.instance_type.lower() not in self.instance_type_list:
+ msg = "Instance type must be specified for the list of {}".format(",".join(self.instance_type_list))
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.SPARK_JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, SparkResourceConfiguration):
+ return NotImplemented
+ return self.instance_type == other.instance_type and self.runtime_version == other.runtime_version
+
+ def __ne__(self, other: object) -> bool:
+ if not isinstance(other, SparkResourceConfiguration):
+ return NotImplemented
+ return not self.__eq__(other)
+
+ def _merge_with(self, other: "SparkResourceConfiguration") -> None:
+ if other:
+ if other.instance_type:
+ self.instance_type = other.instance_type
+ if other.runtime_version:
+ self.runtime_version = other.runtime_version
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/__init__.py
new file mode 100644
index 00000000..fdf8caba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/early_termination_policy.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/early_termination_policy.py
new file mode 100644
index 00000000..b1b928fc
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/early_termination_policy.py
@@ -0,0 +1,191 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from abc import ABC
+from typing import Any, Optional, cast
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import BanditPolicy as RestBanditPolicy
+from azure.ai.ml._restclient.v2023_04_01_preview.models import EarlyTerminationPolicy as RestEarlyTerminationPolicy
+from azure.ai.ml._restclient.v2023_04_01_preview.models import EarlyTerminationPolicyType
+from azure.ai.ml._restclient.v2023_04_01_preview.models import MedianStoppingPolicy as RestMedianStoppingPolicy
+from azure.ai.ml._restclient.v2023_04_01_preview.models import (
+ TruncationSelectionPolicy as RestTruncationSelectionPolicy,
+)
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+
+class EarlyTerminationPolicy(ABC, RestTranslatableMixin):
+ def __init__(
+ self,
+ *,
+ delay_evaluation: int,
+ evaluation_interval: int,
+ ):
+ self.type = None
+ self.delay_evaluation = delay_evaluation
+ self.evaluation_interval = evaluation_interval
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestEarlyTerminationPolicy) -> Optional["EarlyTerminationPolicy"]:
+ if not obj:
+ return None
+
+ policy: Any = None
+ if obj.policy_type == EarlyTerminationPolicyType.BANDIT:
+ policy = BanditPolicy._from_rest_object(obj) # pylint: disable=protected-access
+
+ if obj.policy_type == EarlyTerminationPolicyType.MEDIAN_STOPPING:
+ policy = MedianStoppingPolicy._from_rest_object(obj) # pylint: disable=protected-access
+
+ if obj.policy_type == EarlyTerminationPolicyType.TRUNCATION_SELECTION:
+ policy = TruncationSelectionPolicy._from_rest_object(obj) # pylint: disable=protected-access
+
+ return cast(Optional["EarlyTerminationPolicy"], policy)
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, EarlyTerminationPolicy):
+ raise NotImplementedError
+ res: bool = self._to_rest_object() == other._to_rest_object()
+ return res
+
+
+class BanditPolicy(EarlyTerminationPolicy):
+ """Defines an early termination policy based on slack criteria and a frequency and delay interval for evaluation.
+
+ :keyword delay_evaluation: Number of intervals by which to delay the first evaluation. Defaults to 0.
+ :paramtype delay_evaluation: int
+ :keyword evaluation_interval: Interval (number of runs) between policy evaluations. Defaults to 0.
+ :paramtype evaluation_interval: int
+ :keyword slack_amount: Absolute distance allowed from the best performing run. Defaults to 0.
+ :paramtype slack_amount: float
+ :keyword slack_factor: Ratio of the allowed distance from the best performing run. Defaults to 0.
+ :paramtype slack_factor: float
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_sweep_configurations.py
+ :start-after: [START configure_sweep_job_bandit_policy]
+ :end-before: [END configure_sweep_job_bandit_policy]
+ :language: python
+ :dedent: 8
+ :caption: Configuring BanditPolicy early termination of a hyperparameter sweep on a Command job.
+ """
+
+ def __init__(
+ self,
+ *,
+ delay_evaluation: int = 0,
+ evaluation_interval: int = 0,
+ slack_amount: float = 0,
+ slack_factor: float = 0,
+ ) -> None:
+ super().__init__(delay_evaluation=delay_evaluation, evaluation_interval=evaluation_interval)
+ self.type = EarlyTerminationPolicyType.BANDIT.lower()
+ self.slack_factor = slack_factor
+ self.slack_amount = slack_amount
+
+ def _to_rest_object(self) -> RestBanditPolicy:
+ return RestBanditPolicy(
+ delay_evaluation=self.delay_evaluation,
+ evaluation_interval=self.evaluation_interval,
+ slack_factor=self.slack_factor,
+ slack_amount=self.slack_amount,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestBanditPolicy) -> "BanditPolicy":
+ return cls(
+ delay_evaluation=obj.delay_evaluation,
+ evaluation_interval=obj.evaluation_interval,
+ slack_factor=obj.slack_factor,
+ slack_amount=obj.slack_amount,
+ )
+
+
+class MedianStoppingPolicy(EarlyTerminationPolicy):
+ """Defines an early termination policy based on a running average of the primary metric of all runs.
+
+ :keyword delay_evaluation: Number of intervals by which to delay the first evaluation. Defaults to 0.
+ :paramtype delay_evaluation: int
+ :keyword evaluation_interval: Interval (number of runs) between policy evaluations. Defaults to 1.
+ :paramtype evaluation_interval: int
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_sweep_configurations.py
+ :start-after: [START configure_sweep_job_median_stopping_policy]
+ :end-before: [END configure_sweep_job_median_stopping_policy]
+ :language: python
+ :dedent: 8
+ :caption: Configuring an early termination policy for a hyperparameter sweep job using MedianStoppingPolicy
+ """
+
+ def __init__(
+ self,
+ *,
+ delay_evaluation: int = 0,
+ evaluation_interval: int = 1,
+ ) -> None:
+ super().__init__(delay_evaluation=delay_evaluation, evaluation_interval=evaluation_interval)
+ self.type = camel_to_snake(EarlyTerminationPolicyType.MEDIAN_STOPPING)
+
+ def _to_rest_object(self) -> RestMedianStoppingPolicy:
+ return RestMedianStoppingPolicy(
+ delay_evaluation=self.delay_evaluation, evaluation_interval=self.evaluation_interval
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestMedianStoppingPolicy) -> "MedianStoppingPolicy":
+ return cls(
+ delay_evaluation=obj.delay_evaluation,
+ evaluation_interval=obj.evaluation_interval,
+ )
+
+
+class TruncationSelectionPolicy(EarlyTerminationPolicy):
+ """Defines an early termination policy that cancels a given percentage of runs at each evaluation interval.
+
+ :keyword delay_evaluation: Number of intervals by which to delay the first evaluation. Defaults to 0.
+ :paramtype delay_evaluation: int
+ :keyword evaluation_interval: Interval (number of runs) between policy evaluations. Defaults to 0.
+ :paramtype evaluation_interval: int
+ :keyword truncation_percentage: The percentage of runs to cancel at each evaluation interval. Defaults to 0.
+ :paramtype truncation_percentage: int
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_sweep_configurations.py
+ :start-after: [START configure_sweep_job_truncation_selection_policy]
+ :end-before: [END configure_sweep_job_truncation_selection_policy]
+ :language: python
+ :dedent: 8
+ :caption: Configuring an early termination policy for a hyperparameter sweep job
+ using TruncationStoppingPolicy
+ """
+
+ def __init__(
+ self,
+ *,
+ delay_evaluation: int = 0,
+ evaluation_interval: int = 0,
+ truncation_percentage: int = 0,
+ ) -> None:
+ super().__init__(delay_evaluation=delay_evaluation, evaluation_interval=evaluation_interval)
+ self.type = camel_to_snake(EarlyTerminationPolicyType.TRUNCATION_SELECTION)
+ self.truncation_percentage = truncation_percentage
+
+ def _to_rest_object(self) -> RestTruncationSelectionPolicy:
+ return RestTruncationSelectionPolicy(
+ delay_evaluation=self.delay_evaluation,
+ evaluation_interval=self.evaluation_interval,
+ truncation_percentage=self.truncation_percentage,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestTruncationSelectionPolicy) -> "TruncationSelectionPolicy":
+ return cls(
+ delay_evaluation=obj.delay_evaluation,
+ evaluation_interval=obj.evaluation_interval,
+ truncation_percentage=obj.truncation_percentage,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/objective.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/objective.py
new file mode 100644
index 00000000..45e13332
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/objective.py
@@ -0,0 +1,53 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from typing import Optional
+
+from azure.ai.ml._restclient.v2023_08_01_preview.models import Objective as RestObjective
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+
+class Objective(RestTranslatableMixin):
+ """Optimization objective.
+
+ :param goal: Defines supported metric goals for hyperparameter tuning. Accepted values
+ are: "minimize", "maximize".
+ :type goal: str
+ :param primary_metric: The name of the metric to optimize.
+ :type primary_metric: str
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_sweep_configurations.py
+ :start-after: [START configure_sweep_job_bayesian_sampling_algorithm]
+ :end-before: [END configure_sweep_job_bayesian_sampling_algorithm]
+ :language: python
+ :dedent: 8
+ :caption: Assigning an objective to a SweepJob.
+ """
+
+ def __init__(self, goal: Optional[str], primary_metric: Optional[str] = None) -> None:
+ """Optimization objective.
+
+ :param goal: Defines supported metric goals for hyperparameter tuning. Acceptable values
+ are: "minimize" or "maximize".
+ :type goal: str
+ :param primary_metric: The name of the metric to optimize.
+ :type primary_metric: str
+ """
+ if goal is not None:
+ self.goal = goal.lower()
+ self.primary_metric = primary_metric
+
+ def _to_rest_object(self) -> RestObjective:
+ return RestObjective(
+ goal=self.goal,
+ primary_metric=self.primary_metric,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestObjective) -> Optional["Objective"]:
+ if not obj:
+ return None
+
+ return cls(goal=obj.goal, primary_metric=obj.primary_metric)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/parameterized_sweep.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/parameterized_sweep.py
new file mode 100644
index 00000000..5d69201f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/parameterized_sweep.py
@@ -0,0 +1,341 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from typing import Any, Dict, List, Optional, Type, Union
+
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+
+from ..job_limits import SweepJobLimits
+from ..job_resource_configuration import JobResourceConfiguration
+from ..queue_settings import QueueSettings
+from .early_termination_policy import (
+ BanditPolicy,
+ EarlyTerminationPolicy,
+ EarlyTerminationPolicyType,
+ MedianStoppingPolicy,
+ TruncationSelectionPolicy,
+)
+from .objective import Objective
+from .sampling_algorithm import (
+ BayesianSamplingAlgorithm,
+ GridSamplingAlgorithm,
+ RandomSamplingAlgorithm,
+ RestBayesianSamplingAlgorithm,
+ RestGridSamplingAlgorithm,
+ RestRandomSamplingAlgorithm,
+ RestSamplingAlgorithm,
+ SamplingAlgorithm,
+ SamplingAlgorithmType,
+)
+
+SAMPLING_ALGORITHM_TO_REST_CONSTRUCTOR: Dict[SamplingAlgorithmType, Type[RestSamplingAlgorithm]] = {
+ SamplingAlgorithmType.RANDOM: RestRandomSamplingAlgorithm,
+ SamplingAlgorithmType.GRID: RestGridSamplingAlgorithm,
+ SamplingAlgorithmType.BAYESIAN: RestBayesianSamplingAlgorithm,
+}
+
+SAMPLING_ALGORITHM_CONSTRUCTOR: Dict[SamplingAlgorithmType, Type[SamplingAlgorithm]] = {
+ SamplingAlgorithmType.RANDOM: RandomSamplingAlgorithm,
+ SamplingAlgorithmType.GRID: GridSamplingAlgorithm,
+ SamplingAlgorithmType.BAYESIAN: BayesianSamplingAlgorithm,
+}
+
+
+class ParameterizedSweep: # pylint:disable=too-many-instance-attributes
+ """Shared logic for standalone and pipeline sweep job."""
+
+ def __init__(
+ self,
+ limits: Optional[SweepJobLimits] = None,
+ sampling_algorithm: Optional[Union[str, SamplingAlgorithm]] = None,
+ objective: Optional[Union[Dict, Objective]] = None,
+ early_termination: Optional[Any] = None,
+ search_space: Optional[Dict] = None,
+ queue_settings: Optional[QueueSettings] = None,
+ resources: Optional[Union[dict, JobResourceConfiguration]] = None,
+ ) -> None:
+ """
+ :param limits: Limits for sweep job.
+ :type limits: ~azure.ai.ml.sweep.SweepJobLimits
+ :param sampling_algorithm: Sampling algorithm for sweep job.
+ :type sampling_algorithm: ~azure.ai.ml.sweep.SamplingAlgorithm
+ :param objective: Objective for sweep job.
+ :type objective: ~azure.ai.ml.sweep.Objective
+ :param early_termination: Early termination policy for sweep job.
+ :type early_termination: ~azure.ai.ml.entities._job.sweep.early_termination_policy.EarlyTerminationPolicy
+ :param search_space: Search space for sweep job.
+ :type search_space: Dict[str, Union[
+ ~azure.ai.ml.sweep.Choice,
+ ~azure.ai.ml.sweep.LogNormal,
+ ~azure.ai.ml.sweep.LogUniform,
+ ~azure.ai.ml.sweep.Normal,
+ ~azure.ai.ml.sweep.QLogNormal,
+ ~azure.ai.ml.sweep.QLogUniform,
+ ~azure.ai.ml.sweep.QNormal,
+ ~azure.ai.ml.sweep.QUniform,
+ ~azure.ai.ml.sweep.Randint,
+ ~azure.ai.ml.sweep.Uniform
+
+ ]]
+ :param queue_settings: Queue settings for sweep job.
+ :type queue_settings: ~azure.ai.ml.entities.QueueSettings
+ :param resources: Compute Resource configuration for the job.
+ :type resources: ~azure.ai.ml.entities.ResourceConfiguration
+ """
+ self.sampling_algorithm = sampling_algorithm
+ self.early_termination = early_termination # type: ignore[assignment]
+ self._limits = limits
+ self.search_space = search_space
+ self.queue_settings = queue_settings
+ self.objective: Optional[Objective] = None
+ self.resources = resources
+
+ if isinstance(objective, Dict):
+ self.objective = Objective(**objective)
+ else:
+ self.objective = objective
+
+ @property
+ def resources(self) -> Optional[Union[dict, JobResourceConfiguration]]:
+ """Resources for sweep job.
+
+ :returns: Resources for sweep job.
+ :rtype: ~azure.ai.ml.entities.ResourceConfiguration
+ """
+ return self._resources
+
+ @resources.setter
+ def resources(self, value: Optional[Union[dict, JobResourceConfiguration]]) -> None:
+ """Set Resources for sweep job.
+
+ :param value: Compute Resource configuration for the job.
+ :type value: ~azure.ai.ml.entities.ResourceConfiguration
+ """
+ if isinstance(value, dict):
+ value = JobResourceConfiguration(**value)
+ self._resources = value
+
+ @property
+ def limits(self) -> Optional[SweepJobLimits]:
+ """Limits for sweep job.
+
+ :returns: Limits for sweep job.
+ :rtype: ~azure.ai.ml.sweep.SweepJobLimits
+ """
+ return self._limits
+
+ @limits.setter
+ def limits(self, value: SweepJobLimits) -> None:
+ """Set limits for sweep job.
+
+ :param value: Limits for sweep job.
+ :type value: ~azure.ai.ml.sweep.SweepJobLimits
+ """
+ if not isinstance(value, SweepJobLimits):
+ msg = f"limits must be SweepJobLimits but get {type(value)} instead"
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.SWEEP_JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ self._limits = value
+
+ def set_resources(
+ self,
+ *,
+ instance_type: Optional[Union[str, List[str]]] = None,
+ instance_count: Optional[int] = None,
+ locations: Optional[List[str]] = None,
+ properties: Optional[Dict] = None,
+ docker_args: Optional[str] = None,
+ shm_size: Optional[str] = None,
+ ) -> None:
+ """Set resources for Sweep.
+
+ :keyword instance_type: The instance type to use for the job.
+ :paramtype instance_type: Optional[Union[str, List[str]]]
+ :keyword instance_count: The number of instances to use for the job.
+ :paramtype instance_count: Optional[int]
+ :keyword locations: The locations to use for the job.
+ :paramtype locations: Optional[List[str]]
+ :keyword properties: The properties for the job.
+ :paramtype properties: Optional[Dict]
+ :keyword docker_args: The docker arguments for the job.
+ :paramtype docker_args: Optional[str]
+ :keyword shm_size: The shared memory size for the job.
+ :paramtype shm_size: Optional[str]
+ """
+ if self.resources is None:
+ self.resources = JobResourceConfiguration()
+
+ if not isinstance(self.resources, dict):
+ if locations is not None:
+ self.resources.locations = locations
+ if instance_type is not None:
+ self.resources.instance_type = instance_type
+ if instance_count is not None:
+ self.resources.instance_count = instance_count
+ if properties is not None:
+ self.resources.properties = properties
+ if docker_args is not None:
+ self.resources.docker_args = docker_args
+ if shm_size is not None:
+ self.resources.shm_size = shm_size
+
+ def set_limits(
+ self,
+ *,
+ max_concurrent_trials: Optional[int] = None,
+ max_total_trials: Optional[int] = None,
+ timeout: Optional[int] = None,
+ trial_timeout: Optional[int] = None,
+ ) -> None:
+ """Set limits for Sweep node. Leave parameters as None if you don't want to update corresponding values.
+
+ :keyword max_concurrent_trials: maximum concurrent trial number.
+ :paramtype max_concurrent_trials: int
+ :keyword max_total_trials: maximum total trial number.
+ :paramtype max_total_trials: int
+ :keyword timeout: total timeout in seconds for sweep node
+ :paramtype timeout: int
+ :keyword trial_timeout: timeout in seconds for each trial
+ :paramtype trial_timeout: int
+ """
+ # Looks related to https://github.com/pylint-dev/pylint/issues/3502, still an open issue
+ # pylint:disable=attribute-defined-outside-init
+ if self._limits is None:
+ self._limits = SweepJobLimits(
+ max_concurrent_trials=max_concurrent_trials,
+ max_total_trials=max_total_trials,
+ timeout=timeout,
+ trial_timeout=trial_timeout,
+ )
+ else:
+ if self.limits is not None:
+ if max_concurrent_trials is not None:
+ self.limits.max_concurrent_trials = max_concurrent_trials
+ if max_total_trials is not None:
+ self.limits.max_total_trials = max_total_trials
+ if timeout is not None:
+ self.limits.timeout = timeout
+ if trial_timeout is not None:
+ self.limits.trial_timeout = trial_timeout
+
+ def set_objective(self, *, goal: Optional[str] = None, primary_metric: Optional[str] = None) -> None:
+ """Set the sweep object.. Leave parameters as None if you don't want to update corresponding values.
+
+ :keyword goal: Defines supported metric goals for hyperparameter tuning. Acceptable values are:
+ "minimize" and "maximize".
+ :paramtype goal: str
+ :keyword primary_metric: Name of the metric to optimize.
+ :paramtype primary_metric: str
+ """
+
+ if self.objective is not None:
+ if goal:
+ self.objective.goal = goal
+ if primary_metric:
+ self.objective.primary_metric = primary_metric
+ else:
+ self.objective = Objective(goal=goal, primary_metric=primary_metric)
+
+ @property
+ def sampling_algorithm(self) -> Optional[Union[str, SamplingAlgorithm]]:
+ """Sampling algorithm for sweep job.
+
+ :returns: Sampling algorithm for sweep job.
+ :rtype: ~azure.ai.ml.sweep.SamplingAlgorithm
+ """
+ return self._sampling_algorithm
+
+ @sampling_algorithm.setter
+ def sampling_algorithm(self, value: Optional[Union[SamplingAlgorithm, str]] = None) -> None:
+ """Set sampling algorithm for sweep job.
+
+ :param value: Sampling algorithm for sweep job.
+ :type value: ~azure.ai.ml.sweep.SamplingAlgorithm
+ """
+ if value is None:
+ self._sampling_algorithm = None
+ elif isinstance(value, SamplingAlgorithm) or (
+ isinstance(value, str) and value.lower().capitalize() in SAMPLING_ALGORITHM_CONSTRUCTOR
+ ):
+ self._sampling_algorithm = value
+ else:
+ msg = f"unsupported sampling algorithm: {value}"
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.SWEEP_JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+ def _get_rest_sampling_algorithm(self) -> RestSamplingAlgorithm:
+ # TODO: self.sampling_algorithm will always return SamplingAlgorithm
+ if isinstance(self.sampling_algorithm, SamplingAlgorithm):
+ return self.sampling_algorithm._to_rest_object() # pylint: disable=protected-access
+
+ if isinstance(self.sampling_algorithm, str):
+ return SAMPLING_ALGORITHM_CONSTRUCTOR[ # pylint: disable=protected-access
+ SamplingAlgorithmType(self.sampling_algorithm.lower().capitalize())
+ ]()._to_rest_object()
+
+ msg = f"Received unsupported value {self._sampling_algorithm} as the sampling algorithm"
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.SWEEP_JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+ @property
+ def early_termination(self) -> Optional[Union[str, EarlyTerminationPolicy]]:
+ """Early termination policy for sweep job.
+
+ :returns: Early termination policy for sweep job.
+ :rtype: ~azure.ai.ml.entities._job.sweep.early_termination_policy.EarlyTerminationPolicy
+ """
+ return self._early_termination
+
+ @early_termination.setter
+ def early_termination(self, value: Any) -> None:
+ """Set early termination policy for sweep job.
+
+ :param value: Early termination policy for sweep job.
+ :type value: ~azure.ai.ml.entities._job.sweep.early_termination_policy.EarlyTerminationPolicy
+ """
+ self._early_termination: Optional[Union[str, EarlyTerminationPolicy]]
+ if value is None:
+ self._early_termination = None
+ elif isinstance(value, EarlyTerminationPolicy):
+ self._early_termination = value
+ elif isinstance(value, str):
+ value = value.lower().capitalize()
+ if value == EarlyTerminationPolicyType.BANDIT:
+ self._early_termination = BanditPolicy()
+ elif value == EarlyTerminationPolicyType.MEDIAN_STOPPING:
+ self._early_termination = MedianStoppingPolicy()
+ elif value == EarlyTerminationPolicyType.TRUNCATION_SELECTION:
+ self._early_termination = TruncationSelectionPolicy()
+ else:
+ msg = f"Received unsupported value {value} as the early termination policy"
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.SWEEP_JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ else:
+ msg = f"Received unsupported value of type {type(value)} as the early termination policy"
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.SWEEP_JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/sampling_algorithm.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/sampling_algorithm.py
new file mode 100644
index 00000000..d0bf795d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/sampling_algorithm.py
@@ -0,0 +1,141 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from abc import ABC
+from typing import Any, Optional, Union, cast
+
+from azure.ai.ml._restclient.v2023_08_01_preview.models import (
+ BayesianSamplingAlgorithm as RestBayesianSamplingAlgorithm,
+)
+from azure.ai.ml._restclient.v2023_08_01_preview.models import GridSamplingAlgorithm as RestGridSamplingAlgorithm
+from azure.ai.ml._restclient.v2023_08_01_preview.models import RandomSamplingAlgorithm as RestRandomSamplingAlgorithm
+from azure.ai.ml._restclient.v2023_08_01_preview.models import SamplingAlgorithm as RestSamplingAlgorithm
+from azure.ai.ml._restclient.v2023_08_01_preview.models import SamplingAlgorithmType
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+
+class SamplingAlgorithm(ABC, RestTranslatableMixin):
+ """Base class for sampling algorithms.
+
+ This class should not be instantiated directly. Instead, use one of its subclasses.
+ """
+
+ def __init__(self) -> None:
+ self.type = None
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestSamplingAlgorithm) -> Optional["SamplingAlgorithm"]:
+ if not obj:
+ return None
+
+ sampling_algorithm: Any = None
+ if obj.sampling_algorithm_type == SamplingAlgorithmType.RANDOM:
+ sampling_algorithm = RandomSamplingAlgorithm._from_rest_object(obj) # pylint: disable=protected-access
+
+ if obj.sampling_algorithm_type == SamplingAlgorithmType.GRID:
+ sampling_algorithm = GridSamplingAlgorithm._from_rest_object(obj) # pylint: disable=protected-access
+
+ if obj.sampling_algorithm_type == SamplingAlgorithmType.BAYESIAN:
+ sampling_algorithm = BayesianSamplingAlgorithm._from_rest_object(obj) # pylint: disable=protected-access
+
+ return cast(Optional["SamplingAlgorithm"], sampling_algorithm)
+
+
+class RandomSamplingAlgorithm(SamplingAlgorithm):
+ """Random Sampling Algorithm.
+
+ :keyword rule: The specific type of random algorithm. Accepted values are: "random" and "sobol".
+ :type rule: str
+ :keyword seed: The seed for random number generation.
+ :paramtype seed: int
+ :keyword logbase: A positive number or the number "e" in string format to be used as the base for log
+ based random sampling.
+ :paramtype logbase: Union[float, str]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_sweep_configurations.py
+ :start-after: [START configure_sweep_job_random_sampling_algorithm]
+ :end-before: [END configure_sweep_job_random_sampling_algorithm]
+ :language: python
+ :dedent: 8
+ :caption: Assigning a random sampling algorithm for a SweepJob
+ """
+
+ def __init__(
+ self,
+ *,
+ rule: Optional[str] = None,
+ seed: Optional[int] = None,
+ logbase: Optional[Union[float, str]] = None,
+ ) -> None:
+ super().__init__()
+ self.type = SamplingAlgorithmType.RANDOM.lower()
+ self.rule = rule
+ self.seed = seed
+ self.logbase = logbase
+
+ def _to_rest_object(self) -> RestRandomSamplingAlgorithm:
+ return RestRandomSamplingAlgorithm(
+ rule=self.rule,
+ seed=self.seed,
+ logbase=self.logbase,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestRandomSamplingAlgorithm) -> "RandomSamplingAlgorithm":
+ return cls(
+ rule=obj.rule,
+ seed=obj.seed,
+ logbase=obj.logbase,
+ )
+
+
+class GridSamplingAlgorithm(SamplingAlgorithm):
+ """Grid Sampling Algorithm.
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_sweep_configurations.py
+ :start-after: [START configure_sweep_job_grid_sampling_algorithm]
+ :end-before: [END configure_sweep_job_grid_sampling_algorithm]
+ :language: python
+ :dedent: 8
+ :caption: Assigning a grid sampling algorithm for a SweepJob
+ """
+
+ def __init__(self) -> None:
+ super().__init__()
+ self.type = SamplingAlgorithmType.GRID.lower()
+
+ def _to_rest_object(self) -> RestGridSamplingAlgorithm:
+ return RestGridSamplingAlgorithm()
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestGridSamplingAlgorithm) -> "GridSamplingAlgorithm":
+ return cls()
+
+
+class BayesianSamplingAlgorithm(SamplingAlgorithm):
+ """Bayesian Sampling Algorithm.
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_sweep_configurations.py
+ :start-after: [START configure_sweep_job_bayesian_sampling_algorithm]
+ :end-before: [END configure_sweep_job_bayesian_sampling_algorithm]
+ :language: python
+ :dedent: 8
+ :caption: Assigning a Bayesian sampling algorithm for a SweepJob
+ """
+
+ def __init__(self) -> None:
+ super().__init__()
+ self.type = SamplingAlgorithmType.BAYESIAN.lower()
+
+ def _to_rest_object(self) -> RestBayesianSamplingAlgorithm:
+ return RestBayesianSamplingAlgorithm()
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestBayesianSamplingAlgorithm) -> "BayesianSamplingAlgorithm":
+ return cls()
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/search_space.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/search_space.py
new file mode 100644
index 00000000..bbc08d98
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/search_space.py
@@ -0,0 +1,393 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from abc import ABC
+from typing import Any, List, Optional, Union
+
+from azure.ai.ml.constants._common import TYPE
+from azure.ai.ml.constants._job.sweep import SearchSpace
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, JobException
+
+
+class SweepDistribution(ABC, RestTranslatableMixin):
+ """Base class for sweep distribution configuration.
+
+ This class should not be instantiated directly. Instead, use one of its subclasses.
+
+ :keyword type: Type of distribution.
+ :paramtype type: str
+ """
+
+ def __init__(self, *, type: Optional[str] = None) -> None: # pylint: disable=redefined-builtin
+ self.type = type
+
+ @classmethod
+ def _from_rest_object(cls, obj: List) -> "SweepDistribution":
+ mapping = {
+ SearchSpace.CHOICE: Choice,
+ SearchSpace.NORMAL: Normal,
+ SearchSpace.LOGNORMAL: LogNormal,
+ SearchSpace.QNORMAL: QNormal,
+ SearchSpace.QLOGNORMAL: QLogNormal,
+ SearchSpace.RANDINT: Randint,
+ SearchSpace.UNIFORM: Uniform,
+ SearchSpace.QUNIFORM: QUniform,
+ SearchSpace.LOGUNIFORM: LogUniform,
+ SearchSpace.QLOGUNIFORM: QLogUniform,
+ }
+
+ ss_class: Any = mapping.get(obj[0], None)
+ if ss_class:
+ res: SweepDistribution = ss_class._from_rest_object(obj)
+ return res
+
+ msg = f"Unknown search space type: {obj[0]}"
+ raise JobException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.SWEEP_JOB,
+ error_category=ErrorCategory.SYSTEM_ERROR,
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, SweepDistribution):
+ return NotImplemented
+ res: bool = self._to_rest_object() == other._to_rest_object()
+ return res
+
+
+class Choice(SweepDistribution):
+ """Choice distribution configuration.
+
+ :param values: List of values to choose from.
+ :type values: list[Union[float, str, dict]]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_sweep_configurations.py
+ :start-after: [START configure_sweep_job_choice_loguniform]
+ :end-before: [END configure_sweep_job_choice_loguniform]
+ :language: python
+ :dedent: 8
+ :caption: Using Choice distribution to set values for a hyperparameter sweep
+ """
+
+ def __init__(self, values: Optional[List[Union[float, str, dict]]] = None, **kwargs: Any) -> None:
+ kwargs.setdefault(TYPE, SearchSpace.CHOICE)
+ super().__init__(**kwargs)
+ self.values = values
+
+ def _to_rest_object(self) -> List:
+ items: List = []
+ if self.values is not None:
+ for value in self.values:
+ if isinstance(value, dict):
+ rest_dict = {}
+ for k, v in value.items():
+ if isinstance(v, SweepDistribution):
+ rest_dict[k] = v._to_rest_object()
+ else:
+ rest_dict[k] = v
+ items.append(rest_dict)
+ else:
+ items.append(value)
+ return [self.type, [items]]
+
+ @classmethod
+ def _from_rest_object(cls, obj: List) -> "Choice":
+ rest_values = obj[1][0]
+ from_rest_values = []
+ for rest_value in rest_values:
+ if isinstance(rest_value, dict):
+ from_rest_dict = {}
+ for k, v in rest_value.items():
+ try:
+ # first assume that any dictionary value is a valid distribution (i.e. normal, uniform, etc)
+ # and try to deserialize it into a the correct SDK distribution object
+ from_rest_dict[k] = SweepDistribution._from_rest_object(v)
+ except Exception: # pylint: disable=W0718
+ # if an exception is raised, assume that the value was not a valid distribution and use the
+ # value as it is for deserialization
+ from_rest_dict[k] = v
+ from_rest_values.append(from_rest_dict)
+ else:
+ from_rest_values.append(rest_value)
+ return Choice(values=from_rest_values) # type: ignore[arg-type]
+
+
+class Normal(SweepDistribution):
+ """Normal distribution configuration.
+
+ :param mu: Mean of the distribution.
+ :type mu: float
+ :param sigma: Standard deviation of the distribution.
+ :type sigma: float
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_sweep_configurations.py
+ :start-after: [START configure_sweep_job_randint_normal]
+ :end-before: [END configure_sweep_job_randint_normal]
+ :language: python
+ :dedent: 8
+ :caption: Configuring Normal distributions for a hyperparameter sweep on a Command job.
+ """
+
+ def __init__(self, mu: Optional[float] = None, sigma: Optional[float] = None, **kwargs: Any) -> None:
+ kwargs.setdefault(TYPE, SearchSpace.NORMAL)
+ super().__init__(**kwargs)
+ self.mu = mu
+ self.sigma = sigma
+
+ def _to_rest_object(self) -> List:
+ return [self.type, [self.mu, self.sigma]]
+
+ @classmethod
+ def _from_rest_object(cls, obj: List) -> "Normal":
+ return cls(mu=obj[1][0], sigma=obj[1][1])
+
+
+class LogNormal(Normal):
+ """LogNormal distribution configuration.
+
+ :param mu: Mean of the log of the distribution.
+ :type mu: float
+ :param sigma: Standard deviation of the log of the distribution.
+ :type sigma: float
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_sweep_configurations.py
+ :start-after: [START configure_sweep_job_lognormal_qlognormal]
+ :end-before: [END configure_sweep_job_lognormal_qlognormal]
+ :language: python
+ :dedent: 8
+ :caption: Configuring LogNormal distributions for a hyperparameter sweep on a Command job.
+ """
+
+ def __init__(self, mu: Optional[float] = None, sigma: Optional[float] = None, **kwargs: Any) -> None:
+ kwargs.setdefault(TYPE, SearchSpace.LOGNORMAL)
+ super().__init__(mu=mu, sigma=sigma, **kwargs)
+
+
+class QNormal(Normal):
+ """QNormal distribution configuration.
+
+ :param mu: Mean of the distribution.
+ :type mu: float
+ :param sigma: Standard deviation of the distribution.
+ :type sigma: float
+ :param q: Quantization factor.
+ :type q: int
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_sweep_configurations.py
+ :start-after: [START configure_sweep_job_qloguniform_qnormal]
+ :end-before: [END configure_sweep_job_qloguniform_qnormal]
+ :language: python
+ :dedent: 8
+ :caption: Configuring QNormal distributions for a hyperparameter sweep on a Command job.
+ """
+
+ def __init__(
+ self, mu: Optional[float] = None, sigma: Optional[float] = None, q: Optional[int] = None, **kwargs: Any
+ ) -> None:
+ kwargs.setdefault(TYPE, SearchSpace.QNORMAL)
+ super().__init__(mu=mu, sigma=sigma, **kwargs)
+ self.q = q
+
+ def _to_rest_object(self) -> List:
+ return [self.type, [self.mu, self.sigma, self.q]]
+
+ @classmethod
+ def _from_rest_object(cls, obj: List) -> "QNormal":
+ return cls(mu=obj[1][0], sigma=obj[1][1], q=obj[1][2])
+
+
+class QLogNormal(QNormal):
+ """QLogNormal distribution configuration.
+
+ :param mu: Mean of the log of the distribution.
+ :type mu: Optional[float]
+ :param sigma: Standard deviation of the log of the distribution.
+ :type sigma: Optional[float]
+ :param q: Quantization factor.
+ :type q: Optional[int]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_sweep_configurations.py
+ :start-after: [START configure_sweep_job_lognormal_qlognormal]
+ :end-before: [END configure_sweep_job_lognormal_qlognormal]
+ :language: python
+ :dedent: 8
+ :caption: Configuring QLogNormal distributions for a hyperparameter sweep on a Command job.
+ """
+
+ def __init__(
+ self, mu: Optional[float] = None, sigma: Optional[float] = None, q: Optional[int] = None, **kwargs: Any
+ ) -> None:
+ kwargs.setdefault(TYPE, SearchSpace.QLOGNORMAL)
+ super().__init__(mu=mu, sigma=sigma, q=q, **kwargs)
+
+
+class Randint(SweepDistribution):
+ """Randint distribution configuration.
+
+ :param upper: Upper bound of the distribution.
+ :type upper: int
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_sweep_configurations.py
+ :start-after: [START configure_sweep_job_randint_normal]
+ :end-before: [END configure_sweep_job_randint_normal]
+ :language: python
+ :dedent: 8
+ :caption: Configuring Randint distributions for a hyperparameter sweep on a Command job.
+ """
+
+ def __init__(self, upper: Optional[int] = None, **kwargs: Any) -> None:
+ kwargs.setdefault(TYPE, SearchSpace.RANDINT)
+ super().__init__(**kwargs)
+ self.upper = upper
+
+ def _to_rest_object(self) -> List:
+ return [self.type, [self.upper]]
+
+ @classmethod
+ def _from_rest_object(cls, obj: List) -> "Randint":
+ return cls(upper=obj[1][0])
+
+
+class Uniform(SweepDistribution):
+ """
+
+ Uniform distribution configuration.
+
+ :param min_value: Minimum value of the distribution.
+ :type min_value: float
+ :param max_value: Maximum value of the distribution.
+ :type max_value: float
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_sweep_configurations.py
+ :start-after: [START configure_sweep_job_uniform]
+ :end-before: [END configure_sweep_job_uniform]
+ :language: python
+ :dedent: 8
+ :caption: Configuring Uniform distributions for learning rates and momentum
+ during a hyperparameter sweep on a Command job.
+ """
+
+ def __init__(self, min_value: Optional[float] = None, max_value: Optional[float] = None, **kwargs: Any) -> None:
+ kwargs.setdefault(TYPE, SearchSpace.UNIFORM)
+ super().__init__(**kwargs)
+ self.min_value = min_value
+ self.max_value = max_value
+
+ def _to_rest_object(self) -> List:
+ return [self.type, [self.min_value, self.max_value]]
+
+ @classmethod
+ def _from_rest_object(cls, obj: List) -> "Uniform":
+ return cls(min_value=obj[1][0], max_value=obj[1][1])
+
+
+class LogUniform(Uniform):
+ """LogUniform distribution configuration.
+
+ :param min_value: Minimum value of the log of the distribution.
+ :type min_value: float
+ :param max_value: Maximum value of the log of the distribution.
+ :type max_value: float
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_sweep_configurations.py
+ :start-after: [START configure_sweep_job_choice_loguniform]
+ :end-before: [END configure_sweep_job_choice_loguniform]
+ :language: python
+ :dedent: 8
+ :caption: Configuring a LogUniform distribution for a hyperparameter sweep job learning rate
+ """
+
+ def __init__(self, min_value: Optional[float] = None, max_value: Optional[float] = None, **kwargs: Any) -> None:
+ kwargs.setdefault(TYPE, SearchSpace.LOGUNIFORM)
+ super().__init__(min_value=min_value, max_value=max_value, **kwargs)
+
+
+class QUniform(Uniform):
+ """QUniform distribution configuration.
+
+ :param min_value: Minimum value of the distribution.
+ :type min_value: Optional[Union[int, float]]
+ :param max_value: Maximum value of the distribution.
+ :type max_value: Optional[Union[int, float]]
+ :param q: Quantization factor.
+ :type q: Optional[int]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_sweep_configurations.py
+ :start-after: [START configure_sweep_job_truncation_selection_policy]
+ :end-before: [END configure_sweep_job_truncation_selection_policy]
+ :language: python
+ :dedent: 8
+ :caption: Configuring QUniform distributions for a hyperparameter sweep on a Command job.
+ """
+
+ def __init__(
+ self,
+ min_value: Optional[Union[int, float]] = None,
+ max_value: Optional[Union[int, float]] = None,
+ q: Optional[int] = None,
+ **kwargs: Any,
+ ) -> None:
+ kwargs.setdefault(TYPE, SearchSpace.QUNIFORM)
+ super().__init__(min_value=min_value, max_value=max_value, **kwargs)
+ self.q = q
+
+ def _to_rest_object(self) -> List:
+ return [self.type, [self.min_value, self.max_value, self.q]]
+
+ @classmethod
+ def _from_rest_object(cls, obj: List) -> "QUniform":
+ return cls(min_value=obj[1][0], max_value=obj[1][1], q=obj[1][2])
+
+
+class QLogUniform(QUniform):
+ """QLogUniform distribution configuration.
+
+ :param min_value: Minimum value of the log of the distribution.
+ :type min_value: Optional[float]
+ :param max_value: Maximum value of the log of the distribution.
+ :type max_value: Optional[float]
+ :param q: Quantization factor.
+ :type q: Optional[int]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_sweep_configurations.py
+ :start-after: [START configure_sweep_job_qloguniform_qnormal]
+ :end-before: [END configure_sweep_job_qloguniform_qnormal]
+ :language: python
+ :dedent: 8
+ :caption: Configuring QLogUniform distributions for a hyperparameter sweep on a Command job.
+ """
+
+ def __init__(
+ self,
+ min_value: Optional[float] = None,
+ max_value: Optional[float] = None,
+ q: Optional[int] = None,
+ **kwargs: Any,
+ ) -> None:
+ kwargs.setdefault(TYPE, SearchSpace.QLOGUNIFORM)
+ super().__init__(min_value=min_value, max_value=max_value, q=q, **kwargs)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/sweep_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/sweep_job.py
new file mode 100644
index 00000000..0a99bb39
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/sweep/sweep_job.py
@@ -0,0 +1,361 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import logging
+from typing import Any, Dict, NoReturn, Optional, Union
+
+from azure.ai.ml._restclient.v2023_08_01_preview.models import JobBase
+from azure.ai.ml._restclient.v2023_08_01_preview.models import SweepJob as RestSweepJob
+from azure.ai.ml._restclient.v2023_08_01_preview.models import TrialComponent
+from azure.ai.ml._schema._sweep.sweep_job import SweepJobSchema
+from azure.ai.ml._utils.utils import map_single_brackets_and_warn
+from azure.ai.ml.constants import JobType
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE
+from azure.ai.ml.entities._component.command_component import CommandComponent
+from azure.ai.ml.entities._credentials import (
+ AmlTokenConfiguration,
+ ManagedIdentityConfiguration,
+ UserIdentityConfiguration,
+ _BaseJobIdentityConfiguration,
+)
+from azure.ai.ml.entities._inputs_outputs import Input
+from azure.ai.ml.entities._job._input_output_helpers import (
+ from_rest_data_outputs,
+ from_rest_inputs_to_dataset_literal,
+ to_rest_data_outputs,
+ to_rest_dataset_literal_inputs,
+ validate_inputs_for_command,
+ validate_key_contains_allowed_characters,
+)
+from azure.ai.ml.entities._job.command_job import CommandJob
+from azure.ai.ml.entities._job.job import Job
+from azure.ai.ml.entities._job.job_io_mixin import JobIOMixin
+from azure.ai.ml.entities._job.job_resource_configuration import JobResourceConfiguration
+from azure.ai.ml.entities._job.sweep.sampling_algorithm import SamplingAlgorithm
+from azure.ai.ml.entities._system_data import SystemData
+from azure.ai.ml.entities._util import load_from_dict
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, JobException
+
+# from ..identity import AmlToken, Identity, ManagedIdentity, UserIdentity
+from ..job_limits import SweepJobLimits
+from ..parameterized_command import ParameterizedCommand
+from ..queue_settings import QueueSettings
+from .early_termination_policy import (
+ BanditPolicy,
+ EarlyTerminationPolicy,
+ MedianStoppingPolicy,
+ TruncationSelectionPolicy,
+)
+from .objective import Objective
+from .parameterized_sweep import ParameterizedSweep
+from .search_space import (
+ Choice,
+ LogNormal,
+ LogUniform,
+ Normal,
+ QLogNormal,
+ QLogUniform,
+ QNormal,
+ QUniform,
+ Randint,
+ SweepDistribution,
+ Uniform,
+)
+
+module_logger = logging.getLogger(__name__)
+
+
+class SweepJob(Job, ParameterizedSweep, JobIOMixin):
+ """Sweep job for hyperparameter tuning.
+
+ .. note::
+ For sweep jobs, inputs, outputs, and parameters are accessible as environment variables using the prefix
+ ``AZUREML_SWEEP_``. For example, if you have a parameter named "learning_rate", you can access it as
+ ``AZUREML_SWEEP_learning_rate``.
+
+ :keyword name: Name of the job.
+ :paramtype name: str
+ :keyword display_name: Display name of the job.
+ :paramtype display_name: str
+ :keyword description: Description of the job.
+ :paramtype description: str
+ :keyword tags: Tag dictionary. Tags can be added, removed, and updated.
+ :paramtype tags: dict[str, str]
+ :keyword properties: The asset property dictionary.
+ :paramtype properties: dict[str, str]
+ :keyword experiment_name: Name of the experiment the job will be created under. If None is provided,
+ job will be created under experiment 'Default'.
+ :paramtype experiment_name: str
+ :keyword identity: Identity that the training job will use while running on compute.
+ :paramtype identity: Union[
+ ~azure.ai.ml.ManagedIdentityConfiguration,
+ ~azure.ai.ml.AmlTokenConfiguration,
+ ~azure.ai.ml.UserIdentityConfiguration
+
+ ]
+
+ :keyword inputs: Inputs to the command.
+ :paramtype inputs: dict
+ :keyword outputs: Mapping of output data bindings used in the job.
+ :paramtype outputs: dict[str, ~azure.ai.ml.Output]
+ :keyword sampling_algorithm: The hyperparameter sampling algorithm to use over the `search_space`. Defaults to
+ "random".
+
+ :paramtype sampling_algorithm: str
+ :keyword search_space: Dictionary of the hyperparameter search space. The key is the name of the hyperparameter
+ and the value is the parameter expression.
+
+ :paramtype search_space: Dict
+ :keyword objective: Metric to optimize for.
+ :paramtype objective: Objective
+ :keyword compute: The compute target the job runs on.
+ :paramtype compute: str
+ :keyword trial: The job configuration for each trial. Each trial will be provided with a different combination
+ of hyperparameter values that the system samples from the search_space.
+
+ :paramtype trial: Union[
+ ~azure.ai.ml.entities.CommandJob,
+ ~azure.ai.ml.entities.CommandComponent
+
+ ]
+
+ :keyword early_termination: The early termination policy to use. A trial job is canceled
+ when the criteria of the specified policy are met. If omitted, no early termination policy will be applied.
+
+ :paramtype early_termination: Union[
+ ~azure.mgmt.machinelearningservices.models.BanditPolicy,
+ ~azure.mgmt.machinelearningservices.models.MedianStoppingPolicy,
+ ~azure.mgmt.machinelearningservices.models.TruncationSelectionPolicy
+
+ ]
+
+ :keyword limits: Limits for the sweep job.
+ :paramtype limits: ~azure.ai.ml.entities.SweepJobLimits
+ :keyword queue_settings: Queue settings for the job.
+ :paramtype queue_settings: ~azure.ai.ml.entities.QueueSettings
+ :keyword resources: Compute Resource configuration for the job.
+ :paramtype resources: Optional[Union[~azure.ai.ml.entities.ResourceConfiguration]
+ :keyword kwargs: A dictionary of additional configuration parameters.
+ :paramtype kwargs: dict
+
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_sweep_configurations.py
+ :start-after: [START configure_sweep_job_bayesian_sampling_algorithm]
+ :end-before: [END configure_sweep_job_bayesian_sampling_algorithm]
+ :language: python
+ :dedent: 8
+ :caption: Creating a SweepJob
+ """
+
+ def __init__(
+ self,
+ *,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ display_name: Optional[str] = None,
+ experiment_name: Optional[str] = None,
+ identity: Optional[
+ Union[ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration]
+ ] = None,
+ inputs: Optional[Dict[str, Union[Input, str, bool, int, float]]] = None,
+ outputs: Optional[Dict] = None,
+ compute: Optional[str] = None,
+ limits: Optional[SweepJobLimits] = None,
+ sampling_algorithm: Optional[Union[str, SamplingAlgorithm]] = None,
+ search_space: Optional[
+ Dict[
+ str,
+ Union[
+ Choice, LogNormal, LogUniform, Normal, QLogNormal, QLogUniform, QNormal, QUniform, Randint, Uniform
+ ],
+ ]
+ ] = None,
+ objective: Optional[Objective] = None,
+ trial: Optional[Union[CommandJob, CommandComponent]] = None,
+ early_termination: Optional[
+ Union[EarlyTerminationPolicy, BanditPolicy, MedianStoppingPolicy, TruncationSelectionPolicy]
+ ] = None,
+ queue_settings: Optional[QueueSettings] = None,
+ resources: Optional[Union[dict, JobResourceConfiguration]] = None,
+ **kwargs: Any,
+ ) -> None:
+ kwargs[TYPE] = JobType.SWEEP
+
+ Job.__init__(
+ self,
+ name=name,
+ description=description,
+ tags=tags,
+ display_name=display_name,
+ experiment_name=experiment_name,
+ compute=compute,
+ **kwargs,
+ )
+ self.inputs = inputs # type: ignore[assignment]
+ self.outputs = outputs # type: ignore[assignment]
+ self.trial = trial
+ self.identity = identity
+
+ ParameterizedSweep.__init__(
+ self,
+ limits=limits,
+ sampling_algorithm=sampling_algorithm,
+ objective=objective,
+ early_termination=early_termination,
+ search_space=search_space,
+ queue_settings=queue_settings,
+ resources=resources,
+ )
+
+ def _to_dict(self) -> Dict:
+ res: dict = SweepJobSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
+
+ def _to_rest_object(self) -> JobBase:
+ self._override_missing_properties_from_trial()
+ if self.trial is not None:
+ self.trial.command = map_single_brackets_and_warn(self.trial.command)
+
+ if self.search_space is not None:
+ search_space = {param: space._to_rest_object() for (param, space) in self.search_space.items()}
+
+ if self.trial is not None:
+ validate_inputs_for_command(self.trial.command, self.inputs)
+ for key in search_space.keys(): # pylint: disable=possibly-used-before-assignment
+ validate_key_contains_allowed_characters(key)
+
+ if self.trial is not None:
+ trial_component = TrialComponent(
+ code_id=self.trial.code,
+ distribution=(
+ self.trial.distribution._to_rest_object()
+ if self.trial.distribution and not isinstance(self.trial.distribution, Dict)
+ else None
+ ),
+ environment_id=self.trial.environment,
+ command=self.trial.command,
+ environment_variables=self.trial.environment_variables,
+ resources=(
+ self.trial.resources._to_rest_object()
+ if self.trial.resources and not isinstance(self.trial.resources, Dict)
+ else None
+ ),
+ )
+
+ sweep_job = RestSweepJob(
+ display_name=self.display_name,
+ description=self.description,
+ experiment_name=self.experiment_name,
+ search_space=search_space,
+ sampling_algorithm=self._get_rest_sampling_algorithm() if self.sampling_algorithm else None,
+ limits=self.limits._to_rest_object() if self.limits else None,
+ early_termination=(
+ self.early_termination._to_rest_object()
+ if self.early_termination and not isinstance(self.early_termination, str)
+ else None
+ ),
+ properties=self.properties,
+ compute_id=self.compute,
+ objective=self.objective._to_rest_object() if self.objective else None,
+ trial=trial_component, # pylint: disable=possibly-used-before-assignment
+ tags=self.tags,
+ inputs=to_rest_dataset_literal_inputs(self.inputs, job_type=self.type),
+ outputs=to_rest_data_outputs(self.outputs),
+ identity=self.identity._to_job_rest_object() if self.identity else None,
+ queue_settings=self.queue_settings._to_rest_object() if self.queue_settings else None,
+ resources=(
+ self.resources._to_rest_object() if self.resources and not isinstance(self.resources, dict) else None
+ ),
+ )
+
+ if not sweep_job.resources and sweep_job.trial.resources:
+ sweep_job.resources = sweep_job.trial.resources
+
+ sweep_job_resource = JobBase(properties=sweep_job)
+ sweep_job_resource.name = self.name
+ return sweep_job_resource
+
+ def _to_component(self, context: Optional[Dict] = None, **kwargs: Any) -> NoReturn:
+ msg = "no sweep component entity"
+ raise JobException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.SWEEP_JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ @classmethod
+ def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "SweepJob":
+ loaded_schema = load_from_dict(SweepJobSchema, data, context, additional_message, **kwargs)
+ loaded_schema["trial"] = ParameterizedCommand(**(loaded_schema["trial"]))
+ sweep_job = SweepJob(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_schema)
+ return sweep_job
+
+ @classmethod
+ def _load_from_rest(cls, obj: JobBase) -> "SweepJob":
+ properties: RestSweepJob = obj.properties
+
+ # Unpack termination schema
+ early_termination = EarlyTerminationPolicy._from_rest_object(properties.early_termination)
+
+ # Unpack sampling algorithm
+ sampling_algorithm = SamplingAlgorithm._from_rest_object(properties.sampling_algorithm)
+
+ trial = ParameterizedCommand._load_from_sweep_job(obj.properties)
+ # Compute also appears in both layers of the yaml, but only one of the REST.
+ # This should be a required field in one place, but cannot be if its optional in two
+
+ _search_space = {}
+ for param, dist in properties.search_space.items():
+ _search_space[param] = SweepDistribution._from_rest_object(dist)
+
+ return SweepJob(
+ name=obj.name,
+ id=obj.id,
+ display_name=properties.display_name,
+ description=properties.description,
+ properties=properties.properties,
+ tags=properties.tags,
+ experiment_name=properties.experiment_name,
+ services=properties.services,
+ status=properties.status,
+ creation_context=SystemData._from_rest_object(obj.system_data) if obj.system_data else None,
+ trial=trial, # type: ignore[arg-type]
+ compute=properties.compute_id,
+ sampling_algorithm=sampling_algorithm,
+ search_space=_search_space, # type: ignore[arg-type]
+ limits=SweepJobLimits._from_rest_object(properties.limits),
+ early_termination=early_termination,
+ objective=Objective._from_rest_object(properties.objective) if properties.objective else None,
+ inputs=from_rest_inputs_to_dataset_literal(properties.inputs),
+ outputs=from_rest_data_outputs(properties.outputs),
+ identity=(
+ _BaseJobIdentityConfiguration._from_rest_object(properties.identity) if properties.identity else None
+ ),
+ queue_settings=properties.queue_settings,
+ resources=properties.resources if hasattr(properties, "resources") else None,
+ )
+
+ def _override_missing_properties_from_trial(self) -> None:
+ if not isinstance(self.trial, CommandJob):
+ return
+
+ if not self.compute:
+ self.compute = self.trial.compute
+ if not self.inputs:
+ self.inputs = self.trial.inputs
+ if not self.outputs:
+ self.outputs = self.trial.outputs
+
+ has_trial_limits_timeout = self.trial.limits and self.trial.limits.timeout
+ if has_trial_limits_timeout and not self.limits:
+ time_out = self.trial.limits.timeout if self.trial.limits is not None else None
+ self.limits = SweepJobLimits(trial_timeout=time_out)
+ elif has_trial_limits_timeout and self.limits is not None and not self.limits.trial_timeout:
+ self.limits.trial_timeout = self.trial.limits.timeout if self.trial.limits is not None else None
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/to_rest_functions.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/to_rest_functions.py
new file mode 100644
index 00000000..472cbc91
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/to_rest_functions.py
@@ -0,0 +1,82 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from functools import singledispatch
+from pathlib import Path
+from typing import Any
+
+from azure.ai.ml._restclient.v2023_08_01_preview.models import JobBase as JobBaseData
+from azure.ai.ml._restclient.v2025_01_01_preview.models import JobBase as JobBaseData202501
+from azure.ai.ml.constants._common import DEFAULT_EXPERIMENT_NAME
+from azure.ai.ml.entities._builders.command import Command
+from azure.ai.ml.entities._builders.pipeline import Pipeline
+from azure.ai.ml.entities._builders.spark import Spark
+from azure.ai.ml.entities._builders.sweep import Sweep
+from azure.ai.ml.entities._job.job_name_generator import generate_job_name
+
+from .import_job import ImportJob
+from .job import Job
+
+
+def generate_defaults(job: Job, rest_job: JobBaseData) -> None:
+ # Default name to a generated user friendly name.
+ if not job.name:
+ rest_job.name = generate_job_name()
+
+ if not job.display_name:
+ rest_job.properties.display_name = rest_job.name
+
+ # Default experiment to current folder name or "Default"
+ if not job.experiment_name:
+ rest_job.properties.experiment_name = Path("./").resolve().stem.replace(" ", "") or DEFAULT_EXPERIMENT_NAME
+
+
+@singledispatch
+def to_rest_job_object(something: Any) -> JobBaseData:
+ raise NotImplementedError()
+
+
+@to_rest_job_object.register(Job)
+def _(job: Job) -> JobBaseData:
+ # TODO: Bug Item number: 2883432
+ rest_job = job._to_rest_object() # type: ignore
+ generate_defaults(job, rest_job)
+ return rest_job
+
+
+@to_rest_job_object.register(Command)
+def _(command: Command) -> JobBaseData202501:
+ rest_job = command._to_job()._to_rest_object()
+ generate_defaults(command, rest_job)
+ return rest_job
+
+
+@to_rest_job_object.register(Sweep)
+def _(sweep: Sweep) -> JobBaseData:
+ rest_job = sweep._to_job()._to_rest_object()
+ generate_defaults(sweep, rest_job)
+ return rest_job
+
+
+@to_rest_job_object.register(Pipeline)
+def _(pipeline: Pipeline) -> JobBaseData:
+ rest_job = pipeline._to_job()._to_rest_object()
+ generate_defaults(pipeline, rest_job)
+ return rest_job
+
+
+@to_rest_job_object.register(Spark)
+def _(spark: Spark) -> JobBaseData:
+ rest_job = spark._to_job()._to_rest_object()
+ generate_defaults(spark, rest_job)
+ return rest_job
+
+
+@to_rest_job_object.register(ImportJob)
+def _(importJob: ImportJob) -> JobBaseData:
+ rest_job = importJob._to_rest_object()
+ generate_defaults(importJob, rest_job)
+ return rest_job
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_load_functions.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_load_functions.py
new file mode 100644
index 00000000..81417792
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_load_functions.py
@@ -0,0 +1,1103 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=too-many-lines
+
+import logging
+import warnings
+from os import PathLike
+from pathlib import Path
+from typing import IO, Any, AnyStr, Dict, List, Optional, Union, cast
+
+from marshmallow import ValidationError
+from pydash import objects
+
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml._utils.utils import load_yaml
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+from azure.ai.ml.entities._assets._artifacts._package.model_package import ModelPackage
+from azure.ai.ml.entities._assets._artifacts.code import Code
+from azure.ai.ml.entities._assets._artifacts.data import Data
+from azure.ai.ml.entities._assets._artifacts.feature_set import FeatureSet
+from azure.ai.ml.entities._assets._artifacts.index import Index
+from azure.ai.ml.entities._assets._artifacts.model import Model
+from azure.ai.ml.entities._assets.environment import Environment
+from azure.ai.ml.entities._autogen_entities.models import MarketplaceSubscription, ServerlessEndpoint
+from azure.ai.ml.entities._component.command_component import CommandComponent
+from azure.ai.ml.entities._component.component import Component
+from azure.ai.ml.entities._component.parallel_component import ParallelComponent
+from azure.ai.ml.entities._component.pipeline_component import PipelineComponent
+from azure.ai.ml.entities._compute.compute import Compute
+from azure.ai.ml.entities._datastore.datastore import Datastore
+from azure.ai.ml.entities._deployment.batch_deployment import BatchDeployment
+from azure.ai.ml.entities._deployment.model_batch_deployment import ModelBatchDeployment
+from azure.ai.ml.entities._deployment.online_deployment import OnlineDeployment
+from azure.ai.ml.entities._deployment.pipeline_component_batch_deployment import PipelineComponentBatchDeployment
+from azure.ai.ml.entities._endpoint.batch_endpoint import BatchEndpoint
+from azure.ai.ml.entities._endpoint.online_endpoint import OnlineEndpoint
+from azure.ai.ml.entities._feature_set.feature_set_backfill_request import FeatureSetBackfillRequest
+from azure.ai.ml.entities._feature_store.feature_store import FeatureStore
+from azure.ai.ml.entities._feature_store_entity.feature_store_entity import FeatureStoreEntity
+from azure.ai.ml.entities._job.job import Job
+from azure.ai.ml.entities._registry.registry import Registry
+from azure.ai.ml.entities._resource import Resource
+from azure.ai.ml.entities._schedule.schedule import Schedule
+from azure.ai.ml.entities._validation import PathAwareSchemaValidatableMixin, ValidationResultBuilder
+from azure.ai.ml.entities._workspace.connections.workspace_connection import WorkspaceConnection
+from azure.ai.ml.entities._workspace.workspace import Workspace
+from azure.ai.ml.entities._workspace._ai_workspaces.capability_host import CapabilityHost
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+
+module_logger = logging.getLogger(__name__)
+
+_DEFAULT_RELATIVE_ORIGIN = "./"
+
+
+def load_common(
+ cls: Any,
+ source: Union[str, PathLike, IO[AnyStr]],
+ relative_origin: Optional[str] = None,
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+) -> Resource:
+ """Private function to load a yaml file to an entity object.
+
+ :param cls: The entity class type.
+ :type cls: type[Resource]
+ :param source: A source of yaml.
+ :type source: Union[str, PathLike, IO[AnyStr]]
+ :param relative_origin: The origin of to be used when deducing
+ the relative locations of files referenced in the parsed yaml.
+ Must be provided, and is assumed to be assigned by other internal
+ functions that call this.
+ :type relative_origin: str
+ :param params_override: Fields to overwrite on top of the yaml file.
+ Format is [{"field1": "value1"}, {"field2": "value2"}]
+ :type params_override: List[Dict]
+ :return: The loaded resource
+ :rtype: Resource
+ """
+
+ path = kwargs.pop("path", None)
+ # Check for deprecated path input, either named or as first unnamed input
+ if source is None and path is not None:
+ source = path
+ warnings.warn(
+ "the 'path' input for load functions is deprecated. Please use 'source' instead.", DeprecationWarning
+ )
+
+ if relative_origin is None:
+ if isinstance(source, (str, PathLike)):
+ relative_origin = str(source)
+ else:
+ try:
+ relative_origin = source.name
+ except AttributeError: # input is a stream or something
+ relative_origin = _DEFAULT_RELATIVE_ORIGIN
+
+ params_override = params_override or []
+ yaml_dict = _try_load_yaml_dict(source)
+
+ # pylint: disable=protected-access
+ cls, type_str = cls._resolve_cls_and_type(data=yaml_dict, params_override=params_override)
+
+ try:
+ return _load_common_raising_marshmallow_error(cls, yaml_dict, relative_origin, params_override, **kwargs)
+ except ValidationError as e:
+ if issubclass(cls, PathAwareSchemaValidatableMixin):
+ validation_result = ValidationResultBuilder.from_validation_error(e, source_path=relative_origin)
+ schema = cls._create_schema_for_validation(context={BASE_PATH_CONTEXT_KEY: Path.cwd()})
+ if type_str is None:
+ additional_message = ""
+ else:
+ additional_message = (
+ f"If you are trying to configure an entity that is not "
+ f"of type {type_str}, please specify the correct "
+ f"type in the 'type' property."
+ )
+
+ def build_error(message: str, _: Any) -> ValidationError:
+ from azure.ai.ml.entities._util import decorate_validation_error
+
+ return ValidationError(
+ message=decorate_validation_error(
+ schema=schema.__class__,
+ pretty_error=message,
+ additional_message=additional_message,
+ ),
+ )
+
+ validation_result.try_raise(error_func=build_error)
+ raise e
+
+
+def _try_load_yaml_dict(source: Union[str, PathLike, IO[AnyStr]]) -> dict:
+ yaml_dict = load_yaml(source)
+ if yaml_dict is None: # This happens when a YAML is empty.
+ msg = "Target yaml file is empty"
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.COMPONENT,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.CANNOT_PARSE,
+ )
+ if not isinstance(yaml_dict, dict): # This happens when a YAML file is mal formatted.
+ msg = "Expect dict but get {} after parsing yaml file"
+ raise ValidationException(
+ message=msg.format(type(yaml_dict)),
+ target=ErrorTarget.COMPONENT,
+ no_personal_data_message=msg.format(type(yaml_dict)),
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.CANNOT_PARSE,
+ )
+ return yaml_dict
+
+
+def _load_common_raising_marshmallow_error(
+ cls: Any,
+ yaml_dict: Dict,
+ relative_origin: Optional[Union[PathLike, str, IO[AnyStr]]],
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+) -> Resource:
+ # pylint: disable=protected-access
+ res: Resource = cls._load(data=yaml_dict, yaml_path=relative_origin, params_override=params_override, **kwargs)
+ return res
+
+
+def add_param_overrides(data, param_overrides) -> None:
+ if param_overrides is not None:
+ for override in param_overrides:
+ for param, val in override.items():
+ # Check that none of the intermediary levels are string references (azureml/file)
+ param_tokens = param.split(".")
+ test_layer = data
+ for layer in param_tokens:
+ if test_layer is None:
+ continue
+ if isinstance(test_layer, str):
+ # pylint: disable=broad-exception-raised
+ raise Exception(f"Cannot use '--set' on properties defined by reference strings: --set {param}")
+ test_layer = test_layer.get(layer, None)
+ objects.set_(data, param, val)
+
+
+def load_from_autogen_entity(cls, source: Union[str, PathLike, IO[AnyStr]], **kwargs):
+ loaded_dict = _try_load_yaml_dict(source)
+ add_param_overrides(loaded_dict, param_overrides=kwargs.get("params_override", None))
+ entity = cls(loaded_dict)
+ try:
+ entity._validate() # pylint: disable=protected-access
+ except ValueError as e:
+ validation_result = ValidationResultBuilder.from_single_message(singular_error_message=str(e))
+ validation_result.try_raise()
+ return entity
+
+
+def load_job(
+ source: Union[str, PathLike, IO[AnyStr]],
+ *,
+ relative_origin: Optional[str] = None,
+ params_override: Optional[List[Dict]] = None,
+ **kwargs: Any,
+) -> Job:
+ """Constructs a Job object from a YAML file.
+
+ :param source: A path to a local YAML file or an already-open file object containing a job configuration.
+ If the source is a path, it will be opened and read. If the source is an open file, the file will be read
+ directly.
+ :type source: Union[PathLike, str, io.TextIOWrapper]
+ :keyword relative_origin: The root directory for the YAML. This directory will be used as the origin for deducing
+ the relative locations of files referenced in the parsed YAML. Defaults to the same directory as source if
+ source is a file or file path input. Defaults to "./" if the source is a stream input with no name value.
+ :paramtype relative_origin: Optional[str]
+ :keyword params_override: Fields to overwrite on top of the yaml file.
+ Format is [{"field1": "value1"}, {"field2": "value2"}]
+ :paramtype params_override: List[Dict]
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Job cannot be successfully validated.
+ Details will be provided in the error message.
+ :return: A loaded Job object.
+ :rtype: ~azure.ai.ml.entities.Job
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START load_job]
+ :end-before: [END load_job]
+ :language: python
+ :dedent: 8
+ :caption: Loading a Job from a YAML config file.
+ """
+ return cast(Job, load_common(Job, source, relative_origin, params_override, **kwargs))
+
+
+@experimental
+def load_index(
+ source: Union[str, PathLike, IO[AnyStr]],
+ *,
+ relative_origin: Optional[str] = None,
+ params_override: Optional[List[Dict]] = None,
+ **kwargs: Any,
+) -> Index:
+ """Constructs a Index object from a YAML file.
+
+ :param source: A path to a local YAML file or an already-open file object containing an index configuration.
+ If the source is a path, it will be opened and read. If the source is an open file, the file will be read
+ directly.
+ :type source: Union[PathLike, str, io.TextIOWrapper]
+ :keyword relative_origin: The root directory for the YAML. This directory will be used as the origin for deducing
+ the relative locations of files referenced in the parsed YAML. Defaults to the same directory as source if
+ source is a file or file path input. Defaults to "./" if the source is a stream input with no name value.
+ :paramtype relative_origin: Optional[str]
+ :keyword params_override: Fields to overwrite on top of the yaml file.
+ Format is [{"field1": "value1"}, {"field2": "value2"}]
+ :paramtype params_override: List[Dict]
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Index cannot be successfully validated.
+ Details will be provided in the error message.
+ :return: A loaded Index object.
+ :rtype: ~azure.ai.ml.entities.Index
+ """
+ return cast(Index, load_common(Index, source, relative_origin, params_override, **kwargs))
+
+
+@experimental
+def load_serverless_endpoint(
+ source: Union[str, PathLike, IO[AnyStr]],
+ *,
+ relative_origin: Optional[str] = None, # pylint: disable=unused-argument
+ **kwargs: Any,
+) -> ServerlessEndpoint:
+ return load_from_autogen_entity(ServerlessEndpoint, source, **kwargs)
+
+
+@experimental
+def load_marketplace_subscription(
+ source: Union[str, PathLike, IO[AnyStr]],
+ *,
+ relative_origin: Optional[str] = None, # pylint: disable=unused-argument
+ **kwargs: Any,
+) -> MarketplaceSubscription:
+ return load_from_autogen_entity(MarketplaceSubscription, source, **kwargs)
+
+
+def load_workspace(
+ source: Union[str, PathLike, IO[AnyStr]],
+ *,
+ relative_origin: Optional[str] = None,
+ params_override: Optional[List[Dict]] = None,
+ **kwargs: Any,
+) -> Workspace:
+ """Load a workspace object from a yaml file. This includes workspace sub-classes
+ like hubs and projects.
+
+ :param source: The local yaml source of a workspace. Must be either a
+ path to a local file, or an already-open file.
+ If the source is a path, it will be open and read.
+ An exception is raised if the file does not exist.
+ If the source is an open file, the file will be read directly,
+ and an exception is raised if the file is not readable.
+ :type source: Union[PathLike, str, io.TextIOWrapper]
+ :keyword relative_origin: The origin to be used when deducing
+ the relative locations of files referenced in the parsed yaml.
+ Defaults to the inputted source's directory if it is a file or file path input.
+ Defaults to "./" if the source is a stream input with no name value.
+ :paramtype relative_origin: str
+ :keyword params_override: Fields to overwrite on top of the yaml file.
+ Format is [{"field1": "value1"}, {"field2": "value2"}]
+ :paramtype params_override: List[Dict]
+ :return: Loaded workspace object.
+ :rtype: Workspace
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_workspace.py
+ :start-after: [START load_workspace]
+ :end-before: [END load_workspace]
+ :language: python
+ :dedent: 8
+ :caption: Loading a Workspace from a YAML config file.
+ """
+ return cast(Workspace, load_common(Workspace, source, relative_origin, params_override, **kwargs))
+
+
+def load_registry(
+ source: Union[str, PathLike, IO[AnyStr]],
+ *,
+ relative_origin: Optional[str] = None,
+ params_override: Optional[List[Dict]] = None,
+ **kwargs: Any,
+) -> Registry:
+ """Load a registry object from a yaml file.
+
+ :param source: The local yaml source of a registry. Must be either a
+ path to a local file, or an already-open file.
+ If the source is a path, it will be open and read.
+ An exception is raised if the file does not exist.
+ If the source is an open file, the file will be read directly,
+ and an exception is raised if the file is not readable.
+ :type source: Union[PathLike, str, io.TextIOWrapper]
+ :keyword relative_origin: The origin to be used when deducing
+ the relative locations of files referenced in the parsed yaml.
+ Defaults to the inputted source's directory if it is a file or file path input.
+ Defaults to "./" if the source is a stream input with no name value.
+ :paramtype relative_origin: str
+ :keyword params_override: Fields to overwrite on top of the yaml file.
+ Format is [{"field1": "value1"}, {"field2": "value2"}]
+ :paramtype params_override: List[Dict]
+ :return: Loaded registry object.
+ :rtype: Registry
+ """
+ return cast(Registry, load_common(Registry, source, relative_origin, params_override, **kwargs))
+
+
+def load_datastore(
+ source: Union[str, PathLike, IO[AnyStr]],
+ *,
+ relative_origin: Optional[str] = None,
+ params_override: Optional[List[Dict]] = None,
+ **kwargs: Any,
+) -> Datastore:
+ """Construct a datastore object from a yaml file.
+
+ :param source: The local yaml source of a datastore. Must be either a
+ path to a local file, or an already-open file.
+ If the source is a path, it will be open and read.
+ An exception is raised if the file does not exist.
+ If the source is an open file, the file will be read directly,
+ and an exception is raised if the file is not readable.
+ :type source: Union[PathLike, str, io.TextIOWrapper]
+ :keyword relative_origin: The origin to be used when deducing
+ the relative locations of files referenced in the parsed yaml.
+ Defaults to the inputted source's directory if it is a file or file path input.
+ Defaults to "./" if the source is a stream input with no name value.
+ :paramtype relative_origin: str
+ :keyword params_override: Fields to overwrite on top of the yaml file.
+ Format is [{"field1": "value1"}, {"field2": "value2"}]
+ :paramtype params_override: List[Dict]
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Datastore cannot be successfully validated.
+ Details will be provided in the error message.
+ :return: Loaded datastore object.
+ :rtype: Datastore
+ """
+ return cast(Datastore, load_common(Datastore, source, relative_origin, params_override, **kwargs))
+
+
+def load_code(
+ source: Union[str, PathLike, IO[AnyStr]],
+ *,
+ relative_origin: Optional[str] = None,
+ params_override: Optional[List[Dict]] = None,
+ **kwargs: Any,
+) -> Code:
+ """Construct a code object from a yaml file.
+
+ :param source: The local yaml source of a code object. Must be either a
+ path to a local file, or an already-open file.
+ If the source is a path, it will be open and read.
+ An exception is raised if the file does not exist.
+ If the source is an open file, the file will be read directly,
+ and an exception is raised if the file is not readable.
+ :type source: Union[PathLike, str, io.TextIOWrapper]
+ :keyword relative_origin: The origin to be used when deducing
+ the relative locations of files referenced in the parsed yaml.
+ Defaults to the inputted source's directory if it is a file or file path input.
+ Defaults to "./" if the source is a stream input with no name value.
+ :paramtype relative_origin: Optional[str]
+ :keyword params_override: Fields to overwrite on top of the yaml file.
+ Format is [{"field1": "value1"}, {"field2": "value2"}]
+ :paramtype params_override: List[Dict]
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Code cannot be successfully validated.
+ Details will be provided in the error message.
+ :return: Loaded code object.
+ :rtype: ~azure.ai.ml.entities._assets._artifacts.code.Code
+ """
+ return cast(Code, load_common(Code, source, relative_origin, params_override, **kwargs))
+
+
+def load_compute(
+ source: Union[str, PathLike, IO[AnyStr]],
+ *,
+ relative_origin: Optional[str] = None,
+ params_override: Optional[List[Dict[str, str]]] = None,
+ **kwargs: Any,
+) -> Compute:
+ """Construct a compute object from a yaml file.
+
+ :param source: The local yaml source of a compute. Must be either a
+ path to a local file, or an already-open file.
+ If the source is a path, it will be open and read.
+ An exception is raised if the file does not exist.
+ If the source is an open file, the file will be read directly,
+ and an exception is raised if the file is not readable.
+ :type source: Union[PathLike, str, io.TextIOWrapper]
+ :keyword relative_origin: The origin to be used when deducing
+ the relative locations of files referenced in the parsed yaml.
+ Defaults to the inputted source's directory if it is a file or file path input.
+ Defaults to "./" if the source is a stream input with no name value.
+ :paramtype relative_origin: Optional[str]
+ :keyword params_override: Optional parameters to override in the loaded yaml.
+ :paramtype params_override: Optional[List[Dict[str, str]]
+ :return: Loaded compute object.
+ :rtype: ~azure.ai.ml.entities.Compute
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_compute.py
+ :start-after: [START load_compute]
+ :end-before: [END load_compute]
+ :language: python
+ :dedent: 8
+ :caption: Loading a Compute object from a YAML file and overriding its description.
+ """
+ return cast(Compute, load_common(Compute, source, relative_origin, params_override, **kwargs))
+
+
+def load_component(
+ source: Optional[Union[str, PathLike, IO[AnyStr]]] = None,
+ *,
+ relative_origin: Optional[str] = None,
+ params_override: Optional[List[Dict]] = None,
+ **kwargs: Any,
+) -> Union[CommandComponent, ParallelComponent, PipelineComponent]:
+ """Load component from local or remote to a component function.
+
+ :param source: The local yaml source of a component. Must be either a
+ path to a local file, or an already-open file.
+ If the source is a path, it will be open and read.
+ An exception is raised if the file does not exist.
+ If the source is an open file, the file will be read directly,
+ and an exception is raised if the file is not readable.
+ :type source: Union[PathLike, str, io.TextIOWrapper]
+ :keyword relative_origin: The origin to be used when deducing
+ the relative locations of files referenced in the parsed yaml.
+ Defaults to the inputted source's directory if it is a file or file path input.
+ Defaults to "./" if the source is a stream input with no name value.
+ :paramtype relative_origin: str
+ :keyword params_override: Fields to overwrite on top of the yaml file.
+ Format is [{"field1": "value1"}, {"field2": "value2"}]
+ :paramtype params_override: List[Dict]
+ :return: A Component object
+ :rtype: Union[CommandComponent, ParallelComponent, PipelineComponent]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_component_configurations.py
+ :start-after: [START configure_load_component]
+ :end-before: [END configure_load_component]
+ :language: python
+ :dedent: 8
+ :caption: Loading a Component object from a YAML file, overriding its version to "1.0.2", and
+ registering it remotely.
+ """
+
+ client = kwargs.pop("client", None)
+ name = kwargs.pop("name", None)
+ version = kwargs.pop("version", None)
+
+ if source:
+ component_entity = load_common(Component, source, relative_origin, params_override, **kwargs)
+ elif client and name and version:
+ component_entity = client.components.get(name, version)
+ else:
+ msg = "One of (client, name, version), (source) should be provided."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.COMPONENT,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.MISSING_FIELD,
+ )
+ return cast(Union[CommandComponent, ParallelComponent, PipelineComponent], component_entity)
+
+
+def load_model(
+ source: Union[str, PathLike, IO[AnyStr]],
+ *,
+ relative_origin: Optional[str] = None,
+ params_override: Optional[List[Dict]] = None,
+ **kwargs: Any,
+) -> Model:
+ """Constructs a Model object from a YAML file.
+
+ :param source: A path to a local YAML file or an already-open file object containing a job configuration.
+ If the source is a path, it will be opened and read. If the source is an open file, the file will be read
+ directly.
+ :type source: Union[PathLike, str, io.TextIOWrapper]
+ :keyword relative_origin: The root directory for the YAML. This directory will be used as the origin for deducing
+ the relative locations of files referenced in the parsed YAML. Defaults to the same directory as source if
+ source is a file or file path input. Defaults to "./" if the source is a stream input with no name value.
+ :paramtype relative_origin: Optional[str]
+ :keyword params_override: Fields to overwrite on top of the yaml file.
+ Format is [{"field1": "value1"}, {"field2": "value2"}]
+ :paramtype params_override: List[Dict]
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Job cannot be successfully validated.
+ Details will be provided in the error message.
+ :return: A loaded Model object.
+ :rtype: ~azure.ai.ml.entities.Model
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START load_model]
+ :end-before: [END load_model]
+ :language: python
+ :dedent: 8
+ :caption: Loading a Model from a YAML config file, overriding the name and version parameters.
+ """
+ return cast(Model, load_common(Model, source, relative_origin, params_override, **kwargs))
+
+
+def load_data(
+ source: Union[str, PathLike, IO[AnyStr]],
+ *,
+ relative_origin: Optional[str] = None,
+ params_override: Optional[List[Dict]] = None,
+ **kwargs: Any,
+) -> Data:
+ """Construct a data object from yaml file.
+
+ :param source: The local yaml source of a data object. Must be either a
+ path to a local file, or an already-open file.
+ If the source is a path, it will be open and read.
+ An exception is raised if the file does not exist.
+ If the source is an open file, the file will be read directly,
+ and an exception is raised if the file is not readable.
+ :type source: Union[PathLike, str, io.TextIOWrapper]
+ :keyword relative_origin: The origin to be used when deducing
+ the relative locations of files referenced in the parsed yaml.
+ Defaults to the inputted source's directory if it is a file or file path input.
+ Defaults to "./" if the source is a stream input with no name value.
+ :paramtype relative_origin: str
+ :keyword params_override: Fields to overwrite on top of the yaml file.
+ Format is [{"field1": "value1"}, {"field2": "value2"}]
+ :paramtype params_override: List[Dict]
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Data cannot be successfully validated.
+ Details will be provided in the error message.
+ :return: Constructed Data or DataImport object.
+ :rtype: Data
+ """
+ return cast(Data, load_common(Data, source, relative_origin, params_override, **kwargs))
+
+
+def load_environment(
+ source: Union[str, PathLike, IO[AnyStr]],
+ *,
+ relative_origin: Optional[str] = None,
+ params_override: Optional[List[Dict]] = None,
+ **kwargs: Any,
+) -> Environment:
+ """Construct a environment object from yaml file.
+
+ :param source: The local yaml source of an environment. Must be either a
+ path to a local file, or an already-open file.
+ If the source is a path, it will be open and read.
+ An exception is raised if the file does not exist.
+ If the source is an open file, the file will be read directly,
+ and an exception is raised if the file is not readable.
+ :type source: Union[PathLike, str, io.TextIOWrapper]
+ :keyword relative_origin: The origin to be used when deducing
+ the relative locations of files referenced in the parsed yaml.
+ Defaults to the inputted source's directory if it is a file or file path input.
+ Defaults to "./" if the source is a stream input with no name value.
+ :paramtype relative_origin: str
+ :keyword params_override: Fields to overwrite on top of the yaml file.
+ Format is [{"field1": "value1"}, {"field2": "value2"}]
+ :paramtype params_override: List[Dict]
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Environment cannot be successfully validated.
+ Details will be provided in the error message.
+ :return: Constructed environment object.
+ :rtype: Environment
+ """
+ return cast(Environment, load_common(Environment, source, relative_origin, params_override, **kwargs))
+
+
+def load_online_deployment(
+ source: Union[str, PathLike, IO[AnyStr]],
+ *,
+ relative_origin: Optional[str] = None,
+ params_override: Optional[List[Dict]] = None,
+ **kwargs: Any,
+) -> OnlineDeployment:
+ """Construct a online deployment object from yaml file.
+
+ :param source: The local yaml source of an online deployment object. Must be either a
+ path to a local file, or an already-open file.
+ If the source is a path, it will be open and read.
+ An exception is raised if the file does not exist.
+ If the source is an open file, the file will be read directly,
+ and an exception is raised if the file is not readable.
+ :type source: Union[PathLike, str, io.TextIOWrapper]
+ :keyword relative_origin: The origin to be used when deducing
+ the relative locations of files referenced in the parsed yaml.
+ Defaults to the inputted source's directory if it is a file or file path input.
+ Defaults to "./" if the source is a stream input with no name value.
+ :paramtype relative_origin: str
+ :keyword params_override: Fields to overwrite on top of the yaml file.
+ Format is [{"field1": "value1"}, {"field2": "value2"}]
+ :paramtype params_override: List[Dict]
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Online Deployment cannot be successfully validated.
+ Details will be provided in the error message.
+ :return: Constructed online deployment object.
+ :rtype: OnlineDeployment
+ """
+ return cast(OnlineDeployment, load_common(OnlineDeployment, source, relative_origin, params_override, **kwargs))
+
+
+def load_batch_deployment(
+ source: Union[str, PathLike, IO[AnyStr]],
+ *,
+ relative_origin: Optional[str] = None,
+ params_override: Optional[List[Dict]] = None,
+ **kwargs: Any,
+) -> BatchDeployment:
+ """Construct a batch deployment object from yaml file.
+
+ :param source: The local yaml source of a batch deployment object. Must be either a
+ path to a local file, or an already-open file.
+ If the source is a path, it will be open and read.
+ An exception is raised if the file does not exist.
+ If the source is an open file, the file will be read directly,
+ and an exception is raised if the file is not readable.
+ :type source: Union[PathLike, str, io.TextIOWrapper]
+ :keyword relative_origin: The origin to be used when deducing
+ the relative locations of files referenced in the parsed yaml.
+ Defaults to the inputted source's directory if it is a file or file path input.
+ Defaults to "./" if the source is a stream input with no name value.
+ :paramtype relative_origin: str
+ :keyword params_override: Fields to overwrite on top of the yaml file.
+ Format is [{"field1": "value1"}, {"field2": "value2"}]
+ :paramtype params_override: List[Dict]
+ :return: Constructed batch deployment object.
+ :rtype: BatchDeployment
+ """
+ return cast(BatchDeployment, load_common(BatchDeployment, source, relative_origin, params_override, **kwargs))
+
+
+def load_model_batch_deployment(
+ source: Union[str, PathLike, IO[AnyStr]],
+ *,
+ relative_origin: Optional[str] = None,
+ params_override: Optional[List[Dict]] = None,
+ **kwargs: Any,
+) -> ModelBatchDeployment:
+ """Construct a model batch deployment object from yaml file.
+
+ :param source: The local yaml source of a batch deployment object. Must be either a
+ path to a local file, or an already-open file.
+ If the source is a path, it will be open and read.
+ An exception is raised if the file does not exist.
+ If the source is an open file, the file will be read directly,
+ and an exception is raised if the file is not readable.
+ :type source: Union[PathLike, str, io.TextIOWrapper]
+ :keyword relative_origin: The origin to be used when deducing
+ the relative locations of files referenced in the parsed yaml.
+ Defaults to the inputted source's directory if it is a file or file path input.
+ Defaults to "./" if the source is a stream input with no name value.
+ :paramtype relative_origin: str
+ :keyword params_override: Fields to overwrite on top of the yaml file.
+ Format is [{"field1": "value1"}, {"field2": "value2"}]
+ :paramtype params_override: List[Dict]
+ :return: Constructed model batch deployment object.
+ :rtype: ModelBatchDeployment
+ """
+ return cast(
+ ModelBatchDeployment, load_common(ModelBatchDeployment, source, relative_origin, params_override, **kwargs)
+ )
+
+
+def load_pipeline_component_batch_deployment(
+ source: Union[str, PathLike, IO[AnyStr]],
+ *,
+ relative_origin: Optional[str] = None,
+ params_override: Optional[List[Dict]] = None,
+ **kwargs: Any,
+) -> PipelineComponentBatchDeployment:
+ """Construct a pipeline component batch deployment object from yaml file.
+
+ :param source: The local yaml source of a batch deployment object. Must be either a
+ path to a local file, or an already-open file.
+ If the source is a path, it will be open and read.
+ An exception is raised if the file does not exist.
+ If the source is an open file, the file will be read directly,
+ and an exception is raised if the file is not readable.
+ :type source: Union[PathLike, str, io.TextIOWrapper]
+ :keyword relative_origin: The origin to be used when deducing
+ the relative locations of files referenced in the parsed yaml.
+ Defaults to the inputted source's directory if it is a file or file path input.
+ Defaults to "./" if the source is a stream input with no name value.
+ :paramtype relative_origin: str
+ :keyword params_override: Fields to overwrite on top of the yaml file.
+ Format is [{"field1": "value1"}, {"field2": "value2"}]
+ :paramtype params_override: List[Dict]
+ :return: Constructed pipeline component batch deployment object.
+ :rtype: PipelineComponentBatchDeployment
+ """
+ return cast(
+ PipelineComponentBatchDeployment,
+ load_common(PipelineComponentBatchDeployment, source, relative_origin, params_override, **kwargs),
+ )
+
+
+def load_online_endpoint(
+ source: Union[str, PathLike, IO[AnyStr]],
+ *,
+ relative_origin: Optional[str] = None,
+ params_override: Optional[List[Dict]] = None,
+ **kwargs: Any,
+) -> OnlineEndpoint:
+ """Construct a online endpoint object from yaml file.
+
+ :param source: The local yaml source of an online endpoint object. Must be either a
+ path to a local file, or an already-open file.
+ If the source is a path, it will be open and read.
+ An exception is raised if the file does not exist.
+ If the source is an open file, the file will be read directly,
+ and an exception is raised if the file is not readable.
+ :type source: Union[PathLike, str, io.TextIOWrapper]
+ :keyword relative_origin: The origin to be used when deducing
+ the relative locations of files referenced in the parsed yaml.
+ Defaults to the inputted source's directory if it is a file or file path input.
+ Defaults to "./" if the source is a stream input with no name value.
+ :paramtype relative_origin: str
+ :keyword params_override: Fields to overwrite on top of the yaml file.
+ Format is [{"field1": "value1"}, {"field2": "value2"}]
+ :paramtype params_override: List[Dict]
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Online Endpoint cannot be successfully validated.
+ Details will be provided in the error message.
+ :return: Constructed online endpoint object.
+ :rtype: OnlineEndpoint
+ """
+ return cast(OnlineEndpoint, load_common(OnlineEndpoint, source, relative_origin, params_override, **kwargs))
+
+
+def load_batch_endpoint(
+ source: Union[str, PathLike, IO[AnyStr]],
+ relative_origin: Optional[str] = None,
+ *,
+ params_override: Optional[List[Dict]] = None,
+ **kwargs: Any,
+) -> BatchEndpoint:
+ """Construct a batch endpoint object from yaml file.
+
+ :param source: The local yaml source of a batch endpoint object. Must be either a
+ path to a local file, or an already-open file.
+ If the source is a path, it will be open and read.
+ An exception is raised if the file does not exist.
+ If the source is an open file, the file will be read directly,
+ and an exception is raised if the file is not readable.
+ :type source: Union[PathLike, str, io.TextIOWrapper]
+ :param relative_origin: The origin to be used when deducing
+ the relative locations of files referenced in the parsed yaml.
+ Defaults to the inputted source's directory if it is a file or file path input.
+ Defaults to "./" if the source is a stream input with no name value.
+ :type relative_origin: str
+ :keyword params_override: Fields to overwrite on top of the yaml file.
+ Format is [{"field1": "value1"}, {"field2": "value2"}]
+ :paramtype params_override: List[Dict]
+ :return: Constructed batch endpoint object.
+ :rtype: BatchEndpoint
+ """
+ return cast(BatchEndpoint, load_common(BatchEndpoint, source, relative_origin, params_override, **kwargs))
+
+
+def load_connection(
+ source: Union[str, PathLike, IO[AnyStr]],
+ *,
+ relative_origin: Optional[str] = None,
+ params_override: Optional[List[Dict]] = None,
+ **kwargs: Any,
+) -> WorkspaceConnection:
+ """Construct a connection object from yaml file.
+
+ :param source: The local yaml source of a connection object. Must be either a
+ path to a local file, or an already-open file.
+ If the source is a path, it will be open and read.
+ An exception is raised if the file does not exist.
+ If the source is an open file, the file will be read directly,
+ and an exception is raised if the file is not readable.
+ :type source: Union[PathLike, str, io.TextIOWrapper]
+ :keyword relative_origin: The origin to be used when deducing
+ the relative locations of files referenced in the parsed yaml.
+ Defaults to the inputted source's directory if it is a file or file path input.
+ Defaults to "./" if the source is a stream input with no name value.
+ :paramtype relative_origin: str
+ :keyword params_override: Fields to overwrite on top of the yaml file.
+ Format is [{"field1": "value1"}, {"field2": "value2"}]
+ :paramtype params_override: List[Dict]
+ :return: Constructed connection object.
+ :rtype: Connection
+
+ """
+ return cast(
+ WorkspaceConnection, load_common(WorkspaceConnection, source, relative_origin, params_override, **kwargs)
+ )
+
+
+# Unlike other aspects of connections, this wasn't made experimental, and thus couldn't just be replaced
+# During the renaming from 'workspace connection' to just 'connection'.
+def load_workspace_connection(
+ source: Union[str, PathLike, IO[AnyStr]],
+ *,
+ relative_origin: Optional[str] = None,
+ **kwargs: Any,
+) -> WorkspaceConnection:
+ """Deprecated - use 'load_connection' instead. Construct a connection object from yaml file.
+
+ :param source: The local yaml source of a connection object. Must be either a
+ path to a local file, or an already-open file.
+ If the source is a path, it will be open and read.
+ An exception is raised if the file does not exist.
+ If the source is an open file, the file will be read directly,
+ and an exception is raised if the file is not readable.
+ :type source: Union[PathLike, str, io.TextIOWrapper]
+ :keyword relative_origin: The origin to be used when deducing
+ the relative locations of files referenced in the parsed yaml.
+ Defaults to the inputted source's directory if it is a file or file path input.
+ Defaults to "./" if the source is a stream input with no name value.
+ :paramtype relative_origin: str
+
+ :return: Constructed connection object.
+ :rtype: Connection
+
+ """
+ warnings.warn(
+ "the 'load_workspace_connection' function is deprecated. Use 'load_connection' instead.", DeprecationWarning
+ )
+ return load_connection(source, relative_origin=relative_origin, **kwargs)
+
+
+def load_schedule(
+ source: Union[str, PathLike, IO[AnyStr]],
+ relative_origin: Optional[str] = None,
+ *,
+ params_override: Optional[List[Dict]] = None,
+ **kwargs: Any,
+) -> Schedule:
+ """Construct a schedule object from yaml file.
+
+ :param source: The local yaml source of a schedule object. Must be either a
+ path to a local file, or an already-open file.
+ If the source is a path, it will be open and read.
+ An exception is raised if the file does not exist.
+ If the source is an open file, the file will be read directly,
+ and an exception is raised if the file is not readable.
+ :type source: Union[PathLike, str, io.TextIOWrapper]
+ :param relative_origin: The origin to be used when deducing
+ the relative locations of files referenced in the parsed yaml.
+ Defaults to the inputted source's directory if it is a file or file path input.
+ Defaults to "./" if the source is a stream input with no name value.
+ :type relative_origin: str
+ :keyword params_override: Fields to overwrite on top of the yaml file.
+ Format is [{"field1": "value1"}, {"field2": "value2"}]
+ :paramtype params_override: List[Dict]
+ :return: Constructed schedule object.
+ :rtype: Schedule
+ """
+ return cast(Schedule, load_common(Schedule, source, relative_origin, params_override, **kwargs))
+
+
+def load_feature_store(
+ source: Union[str, PathLike, IO[AnyStr]],
+ *,
+ relative_origin: Optional[str] = None,
+ params_override: Optional[List[Dict]] = None,
+ **kwargs: Any,
+) -> FeatureStore:
+ """Load a feature store object from a yaml file.
+
+ :param source: The local yaml source of a feature store. Must be either a
+ path to a local file, or an already-open file.
+ If the source is a path, it will be open and read.
+ An exception is raised if the file does not exist.
+ If the source is an open file, the file will be read directly,
+ and an exception is raised if the file is not readable.
+ :type source: Union[PathLike, str, io.TextIOWrapper]
+ :keyword relative_origin: The origin to be used when deducing
+ the relative locations of files referenced in the parsed yaml.
+ Defaults to the inputted source's directory if it is a file or file path input.
+ Defaults to "./" if the source is a stream input with no name value.
+ :paramtype relative_origin: str
+ :keyword params_override: Fields to overwrite on top of the yaml file.
+ Format is [{"field1": "value1"}, {"field2": "value2"}]
+ :paramtype params_override: List[Dict]
+ :return: Loaded feature store object.
+ :rtype: FeatureStore
+ """
+ return cast(FeatureStore, load_common(FeatureStore, source, relative_origin, params_override, **kwargs))
+
+
+def load_feature_set(
+ source: Union[str, PathLike, IO[AnyStr]],
+ *,
+ relative_origin: Optional[str] = None,
+ params_override: Optional[List[Dict]] = None,
+ **kwargs: Any,
+) -> FeatureSet:
+ """Construct a FeatureSet object from yaml file.
+
+ :param source: The local yaml source of a FeatureSet object. Must be either a
+ path to a local file, or an already-open file.
+ If the source is a path, it will be open and read.
+ An exception is raised if the file does not exist.
+ If the source is an open file, the file will be read directly,
+ and an exception is raised if the file is not readable.
+ :type source: Union[PathLike, str, io.TextIOWrapper]
+ :keyword relative_origin: The origin to be used when deducing
+ the relative locations of files referenced in the parsed yaml.
+ Defaults to the inputted source's directory if it is a file or file path input.
+ Defaults to "./" if the source is a stream input with no name value.
+ :paramtype relative_origin: str
+ :keyword params_override: Fields to overwrite on top of the yaml file.
+ Format is [{"field1": "value1"}, {"field2": "value2"}]
+ :paramtype params_override: List[Dict]
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if FeatureSet cannot be successfully validated.
+ Details will be provided in the error message.
+ :return: Constructed FeatureSet object.
+ :rtype: FeatureSet
+ """
+ return cast(FeatureSet, load_common(FeatureSet, source, relative_origin, params_override, **kwargs))
+
+
+def load_feature_store_entity(
+ source: Union[str, PathLike, IO[AnyStr]],
+ *,
+ relative_origin: Optional[str] = None,
+ params_override: Optional[List[Dict]] = None,
+ **kwargs: Any,
+) -> FeatureStoreEntity:
+ """Construct a FeatureStoreEntity object from yaml file.
+
+ :param source: The local yaml source of a FeatureStoreEntity object. Must be either a
+ path to a local file, or an already-open file.
+ If the source is a path, it will be open and read.
+ An exception is raised if the file does not exist.
+ If the source is an open file, the file will be read directly,
+ and an exception is raised if the file is not readable.
+ :type source: Union[PathLike, str, io.TextIOWrapper]
+ :keyword relative_origin: The origin to be used when deducing
+ the relative locations of files referenced in the parsed yaml.
+ Defaults to the inputted source's directory if it is a file or file path input.
+ Defaults to "./" if the source is a stream input with no name value.
+ :paramtype relative_origin: str
+ :keyword params_override: Fields to overwrite on top of the yaml file.
+ Format is [{"field1": "value1"}, {"field2": "value2"}]
+ :paramtype params_override: List[Dict]
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if FeatureStoreEntity cannot be successfully validated.
+ Details will be provided in the error message.
+ :return: Constructed FeatureStoreEntity object.
+ :rtype: FeatureStoreEntity
+ """
+ return cast(FeatureStoreEntity, load_common(FeatureStoreEntity, source, relative_origin, params_override, **kwargs))
+
+
+@experimental
+def load_model_package(
+ source: Union[str, PathLike, IO[AnyStr]],
+ *,
+ relative_origin: Optional[str] = None,
+ params_override: Optional[List[Dict]] = None,
+ **kwargs: Any,
+) -> ModelPackage:
+ """Constructs a ModelPackage object from a YAML file.
+
+ :param source: A path to a local YAML file or an already-open file object containing a job configuration.
+ If the source is a path, it will be opened and read. If the source is an open file, the file will be read
+ directly.
+ :type source: Union[PathLike, str, io.TextIOWrapper]
+ :keyword relative_origin: The root directory for the YAML. This directory will be used as the origin for deducing
+ the relative locations of files referenced in the parsed YAML. Defaults to the same directory as source if
+ source is a file or file path input. Defaults to "./" if the source is a stream input with no name value.
+ :paramtype relative_origin: Optional[str]
+ :keyword params_override: Fields to overwrite on top of the yaml file.
+ Format is [{"field1": "value1"}, {"field2": "value2"}]
+ :paramtype params_override: List[Dict]
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Job cannot be successfully validated.
+ Details will be provided in the error message.
+ :return: A loaded ModelPackage object.
+ :rtype: ~azure.ai.ml.entities.ModelPackage
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START load_model_package]
+ :end-before: [END load_model_package]
+ :language: python
+ :dedent: 8
+ :caption: Loading a ModelPackage from a YAML config file.
+ """
+ return cast(ModelPackage, load_common(ModelPackage, source, relative_origin, params_override, **kwargs))
+
+
+def load_feature_set_backfill_request(
+ source: Union[str, PathLike, IO[AnyStr]],
+ *,
+ relative_origin: Optional[str] = None,
+ params_override: Optional[List[Dict]] = None,
+ **kwargs: Any,
+) -> FeatureSetBackfillRequest:
+ """Construct a FeatureSetBackfillRequest object from yaml file.
+
+ :param source: The local yaml source of a FeatureSetBackfillRequest object. Must be either a
+ path to a local file, or an already-open file.
+ If the source is a path, it will be open and read.
+ An exception is raised if the file does not exist.
+ If the source is an open file, the file will be read directly,
+ and an exception is raised if the file is not readable.
+ :type source: Union[PathLike, str, io.TextIOWrapper]
+ :keyword relative_origin: The origin to be used when deducing
+ the relative locations of files referenced in the parsed yaml.
+ Defaults to the inputted source's directory if it is a file or file path input.
+ Defaults to "./" if the source is a stream input with no name value.
+ :type relative_origin: str
+ :keyword params_override: Fields to overwrite on top of the yaml file.
+ Format is [{"field1": "value1"}, {"field2": "value2"}]
+ :paramtype params_override: List[Dict]
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if FeatureSetBackfillRequest
+ cannot be successfully validated. Details will be provided in the error message.
+ :return: Constructed FeatureSetBackfillRequest object.
+ :rtype: FeatureSetBackfillRequest
+ """
+ return cast(
+ FeatureSetBackfillRequest,
+ load_common(FeatureSetBackfillRequest, source, relative_origin, params_override, **kwargs),
+ )
+
+
+def load_capability_host(
+ source: Union[str, PathLike, IO[AnyStr]],
+ *,
+ relative_origin: Optional[str] = None,
+ params_override: Optional[List[Dict]] = None,
+ **kwargs: Any,
+) -> CapabilityHost:
+ """Constructs a CapabilityHost object from a YAML file.
+
+ :param source: A path to a local YAML file or an already-open file object containing a capabilityhost configuration.
+ If the source is a path, it will be opened and read. If the source is an open file, the file will be read
+ directly.
+ :type source: Union[PathLike, str, io.TextIOWrapper]
+ :keyword relative_origin: The root directory for the YAML. This directory will be used as the origin for deducing
+ the relative locations of files referenced in the parsed YAML. Defaults to the same directory as source if
+ source is a file or file path input. Defaults to "./" if the source is a stream input with no name value.
+ :paramtype relative_origin: Optional[str]
+ :keyword params_override: Fields to overwrite on top of the yaml file.
+ Format is [{"field1": "value1"}, {"field2": "value2"}]
+ :paramtype params_override: List[Dict]
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if CapabilityHost cannot be successfully validated.
+ Details will be provided in the error message.
+ :return: Loaded CapabilityHost object.
+ :rtype: ~azure.ai.ml.entities._workspace._ai_workspaces.capability_host.CapabilityHost
+
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_capability_host.py
+ :start-after: [START load_capability_host]
+ :end-before: [END load_capability_host]
+ :language: python
+ :dedent: 8
+ :caption: Loading a capabilityhost from a YAML config file.
+ """
+ return cast(CapabilityHost, load_common(CapabilityHost, source, relative_origin, params_override, **kwargs))
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_mixins.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_mixins.py
new file mode 100644
index 00000000..5b7306f9
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_mixins.py
@@ -0,0 +1,163 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from abc import abstractmethod
+from typing import Any, Dict, Iterator, Optional
+
+from azure.ai.ml._utils.utils import dump_yaml
+
+
+class RestTranslatableMixin:
+ def _to_rest_object(self) -> Any:
+ pass
+
+ @classmethod
+ def _from_rest_object(cls, obj: Any) -> Any:
+ pass
+
+
+class DictMixin(object):
+ def __contains__(self, item: Any) -> bool:
+ return self.__dict__.__contains__(item)
+
+ def __iter__(self) -> Iterator[str]:
+ return self.__dict__.__iter__()
+
+ def __setitem__(self, key: Any, item: Any) -> None:
+ self.__dict__[key] = item
+
+ def __getitem__(self, key: Any) -> Any:
+ return self.__dict__[key]
+
+ def __repr__(self) -> str:
+ return str(self)
+
+ def __len__(self) -> int:
+ return len(self.keys())
+
+ def __delitem__(self, key: Any) -> None:
+ self.__dict__[key] = None
+
+ def __eq__(self, other: Any) -> bool:
+ """Compare objects by comparing all attributes.
+
+ :param other: The other object
+ :type other: Any
+ :return: True if both object are the same class and have matching __dict__, False otherwise
+ :rtype: bool
+ """
+ if isinstance(other, self.__class__):
+ return self.__dict__ == other.__dict__
+ return False
+
+ def __ne__(self, other: Any) -> bool:
+ """Compare objects by comparing all attributes.
+
+ :param other: The other object
+ :type other: Any
+ :return: not self.__eq__(other)
+ :rtype: bool
+ """
+ return not self.__eq__(other)
+
+ def __str__(self) -> str:
+ return str({k: v for k, v in self.__dict__.items() if not k.startswith("_") and v is not None})
+
+ def has_key(self, k: Any) -> bool:
+ return k in self.__dict__
+
+ def update(self, *args: Any, **kwargs: Any) -> None:
+ return self.__dict__.update(*args, **kwargs)
+
+ def keys(self) -> list:
+ return [k for k in self.__dict__ if not k.startswith("_")]
+
+ def values(self) -> list:
+ return [v for k, v in self.__dict__.items() if not k.startswith("_")]
+
+ def items(self) -> list:
+ return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_")]
+
+ def get(self, key: Any, default: Optional[Any] = None) -> Any:
+ if key in self.__dict__:
+ return self.__dict__[key]
+ return default
+
+
+class TelemetryMixin:
+ # pylint: disable-next=docstring-missing-param
+ def _get_telemetry_values(self, *args: Any, **kwargs: Any) -> Dict: # pylint: disable=unused-argument
+ """Return the telemetry values of object.
+
+ :return: The telemetry values
+ :rtype: Dict
+ """
+ return {}
+
+
+class YamlTranslatableMixin:
+ @abstractmethod
+ def _to_dict(self) -> Dict:
+ """Dump the object into a dictionary."""
+
+ def _to_ordered_dict_for_yaml_dump(self) -> Dict:
+ """Dump the object into a dictionary with a specific key order.
+
+ :return: The ordered dict
+ :rtype: Dict
+ """
+ order_keys = [
+ "$schema",
+ "name",
+ "version",
+ "display_name",
+ "description",
+ "tags",
+ "type",
+ "inputs",
+ "outputs",
+ "command",
+ "environment",
+ "code",
+ "resources",
+ "limits",
+ "schedule",
+ "jobs",
+ ]
+ nested_keys = ["component", "trial"]
+
+ def _sort_dict_according_to_list(order_keys: Any, dict_value: Any) -> dict:
+ for nested_key in nested_keys:
+ if nested_key in dict_value and isinstance(dict_value[nested_key], dict):
+ dict_value[nested_key] = _sort_dict_according_to_list(order_keys, dict_value[nested_key])
+ if "jobs" in dict_value:
+ for node_name, node in dict_value["jobs"].items():
+ dict_value["jobs"][node_name] = _sort_dict_according_to_list(order_keys, node)
+ difference = list(set(dict_value.keys()).difference(set(order_keys)))
+ # keys not in order_keys will be put at the end of the list in the order of alphabetic
+ order_keys.extend(sorted(difference))
+ return dict(
+ sorted(
+ dict_value.items(),
+ key=lambda dict_value_: order_keys.index(dict_value_[0]),
+ )
+ )
+
+ return _sort_dict_according_to_list(order_keys, self._to_dict())
+
+ def _to_yaml(self) -> str:
+ """Dump the object content into a sorted yaml string.
+
+ :return: YAML formatted string
+ :rtype: str
+ """
+ return str(dump_yaml(self._to_ordered_dict_for_yaml_dump(), sort_keys=False))
+
+
+class LocalizableMixin:
+ def _localize(self, base_path: str) -> None:
+ """Called on an asset got from service to clean up remote attributes like id, creation_context, etc.
+
+ :param base_path: The base path
+ :type base_path: str
+ """
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/__init__.py
new file mode 100644
index 00000000..fdf8caba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/alert_notification.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/alert_notification.py
new file mode 100644
index 00000000..2df0d055
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/alert_notification.py
@@ -0,0 +1,54 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import List, Optional
+
+from azure.ai.ml._restclient.v2023_06_01_preview.models import (
+ EmailMonitoringAlertNotificationSettings,
+ EmailNotificationEnableType,
+ NotificationSetting,
+)
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+
+class AlertNotification(RestTranslatableMixin):
+ """Alert notification configuration for monitoring jobs
+
+ :keyword emails: A list of email addresses that will receive notifications for monitoring alerts.
+ Defaults to None.
+ :paramtype emails: Optional[List[str]]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_spark_configurations.py
+ :start-after: [START spark_monitor_definition]
+ :end-before: [END spark_monitor_definition]
+ :language: python
+ :dedent: 8
+ :caption: Configuring alert notifications for a monitored job.
+ """
+
+ def __init__(
+ self,
+ *,
+ emails: Optional[List[str]] = None,
+ ) -> None:
+ self.emails = emails
+
+ def _to_rest_object(
+ self,
+ ) -> EmailMonitoringAlertNotificationSettings:
+ return EmailMonitoringAlertNotificationSettings(
+ email_notification_setting=NotificationSetting(
+ emails=self.emails,
+ email_on=[
+ EmailNotificationEnableType.JOB_FAILED,
+ EmailNotificationEnableType.JOB_COMPLETED,
+ ],
+ )
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: EmailMonitoringAlertNotificationSettings) -> "AlertNotification":
+ return cls(emails=obj.email_notification_setting.emails)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/compute.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/compute.py
new file mode 100644
index 00000000..ff91a814
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/compute.py
@@ -0,0 +1,55 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from azure.ai.ml._exception_helper import log_and_raise_error
+from azure.ai.ml._restclient.v2023_06_01_preview.models import AmlTokenComputeIdentity, MonitorServerlessSparkCompute
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+
+
+class ServerlessSparkCompute:
+ """Serverless Spark compute.
+
+ :param runtime_version: The runtime version of the compute.
+ :type runtime_version: str
+ :param instance_type: The instance type of the compute.
+ :type instance_type: str
+ """
+
+ def __init__(
+ self,
+ *,
+ runtime_version: str,
+ instance_type: str,
+ ):
+ self.runtime_version = runtime_version
+ self.instance_type = instance_type
+
+ def _to_rest_object(self) -> MonitorServerlessSparkCompute:
+ self._validate()
+ return MonitorServerlessSparkCompute(
+ runtime_version=self.runtime_version,
+ instance_type=self.instance_type,
+ compute_identity=AmlTokenComputeIdentity(
+ compute_identity_type="AmlToken",
+ ),
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: MonitorServerlessSparkCompute) -> "ServerlessSparkCompute":
+ return cls(
+ runtime_version=obj.runtime_version,
+ instance_type=obj.instance_type,
+ )
+
+ def _validate(self) -> None:
+ if self.runtime_version != "3.4":
+ msg = "Compute runtime version must be 3.4"
+ err = ValidationException(
+ message=msg,
+ target=ErrorTarget.MODEL_MONITORING,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.MISSING_FIELD,
+ )
+ log_and_raise_error(err)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/definition.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/definition.py
new file mode 100644
index 00000000..3b81be1e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/definition.py
@@ -0,0 +1,162 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from typing import Any, Dict, Optional, Union
+
+from typing_extensions import Literal
+
+from azure.ai.ml._restclient.v2023_06_01_preview.models import AzMonMonitoringAlertNotificationSettings
+from azure.ai.ml._restclient.v2023_06_01_preview.models import MonitorDefinition as RestMonitorDefinition
+from azure.ai.ml.constants._monitoring import (
+ AZMONITORING,
+ DEFAULT_DATA_DRIFT_SIGNAL_NAME,
+ DEFAULT_DATA_QUALITY_SIGNAL_NAME,
+ DEFAULT_PREDICTION_DRIFT_SIGNAL_NAME,
+ DEFAULT_TOKEN_USAGE_SIGNAL_NAME,
+ MonitorTargetTasks,
+)
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+from azure.ai.ml.entities._monitoring.alert_notification import AlertNotification
+from azure.ai.ml.entities._monitoring.compute import ServerlessSparkCompute
+from azure.ai.ml.entities._monitoring.signals import (
+ CustomMonitoringSignal,
+ DataDriftSignal,
+ DataQualitySignal,
+ FeatureAttributionDriftSignal,
+ GenerationSafetyQualitySignal,
+ GenerationTokenStatisticsSignal,
+ MonitoringSignal,
+ PredictionDriftSignal,
+)
+from azure.ai.ml.entities._monitoring.target import MonitoringTarget
+
+
+class MonitorDefinition(RestTranslatableMixin):
+ """Monitor definition
+
+ :keyword compute: The Spark resource configuration to be associated with the monitor
+ :paramtype compute: ~azure.ai.ml.entities.SparkResourceConfiguration
+ :keyword monitoring_target: The ARM ID object associated with the model or deployment that is being monitored.
+ :paramtype monitoring_target: Optional[~azure.ai.ml.entities.MonitoringTarget]
+ :keyword monitoring_signals: The dictionary of signals to monitor. The key is the name of the signal and the value
+ is the DataSignal object. Accepted values for the DataSignal objects are DataDriftSignal, DataQualitySignal,
+ PredictionDriftSignal, FeatureAttributionDriftSignal, and CustomMonitoringSignal.
+ :paramtype monitoring_signals: Optional[Dict[str, Union[~azure.ai.ml.entities.DataDriftSignal
+ , ~azure.ai.ml.entities.DataQualitySignal, ~azure.ai.ml.entities.PredictionDriftSignal
+ , ~azure.ai.ml.entities.FeatureAttributionDriftSignal
+ , ~azure.ai.ml.entities.CustomMonitoringSignal
+ , ~azure.ai.ml.entities.GenerationSafetyQualitySignal
+ , ~azure.ai.ml.entities.GenerationTokenStatisticsSignal
+ , ~azure.ai.ml.entities.ModelPerformanceSignal]]]
+ :keyword alert_notification: The alert configuration for the monitor.
+ :paramtype alert_notification: Optional[Union[Literal['azmonitoring'], ~azure.ai.ml.entities.AlertNotification]]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_spark_configurations.py
+ :start-after: [START spark_monitor_definition]
+ :end-before: [END spark_monitor_definition]
+ :language: python
+ :dedent: 8
+ :caption: Creating Monitor definition.
+ """
+
+ def __init__(
+ self,
+ *,
+ compute: ServerlessSparkCompute,
+ monitoring_target: Optional[MonitoringTarget] = None,
+ monitoring_signals: Dict[
+ str,
+ Union[
+ DataDriftSignal,
+ DataQualitySignal,
+ PredictionDriftSignal,
+ FeatureAttributionDriftSignal,
+ CustomMonitoringSignal,
+ GenerationSafetyQualitySignal,
+ GenerationTokenStatisticsSignal,
+ ],
+ ] = None, # type: ignore[assignment]
+ alert_notification: Optional[Union[Literal["azmonitoring"], AlertNotification]] = None,
+ ) -> None:
+ self.compute = compute
+ self.monitoring_target = monitoring_target
+ self.monitoring_signals = monitoring_signals
+ self.alert_notification = alert_notification
+
+ def _to_rest_object(self, **kwargs: Any) -> RestMonitorDefinition:
+ default_data_window_size = kwargs.get("default_data_window_size")
+ ref_data_window_size = kwargs.get("ref_data_window_size")
+ rest_alert_notification = None
+ if self.alert_notification:
+ if isinstance(self.alert_notification, str) and self.alert_notification.lower() == AZMONITORING:
+ rest_alert_notification = AzMonMonitoringAlertNotificationSettings()
+ else:
+ if not isinstance(self.alert_notification, str):
+ rest_alert_notification = self.alert_notification._to_rest_object()
+
+ if self.monitoring_signals is not None:
+ _signals = {
+ signal_name: signal._to_rest_object(
+ default_data_window_size=default_data_window_size,
+ ref_data_window_size=ref_data_window_size,
+ )
+ for signal_name, signal in self.monitoring_signals.items()
+ }
+ return RestMonitorDefinition(
+ compute_configuration=self.compute._to_rest_object(),
+ monitoring_target=self.monitoring_target._to_rest_object() if self.monitoring_target else None,
+ signals=_signals, # pylint: disable=possibly-used-before-assignment
+ alert_notification_setting=rest_alert_notification,
+ )
+
+ @classmethod
+ def _from_rest_object(
+ cls, # pylint: disable=unused-argument
+ obj: RestMonitorDefinition,
+ **kwargs: Any,
+ ) -> "MonitorDefinition":
+ from_rest_alert_notification: Any = None
+ if obj.alert_notification_setting:
+ if isinstance(obj.alert_notification_setting, AzMonMonitoringAlertNotificationSettings):
+ from_rest_alert_notification = AZMONITORING
+ else:
+ from_rest_alert_notification = AlertNotification._from_rest_object(obj.alert_notification_setting)
+
+ _monitoring_signals = {}
+ for signal_name, signal in obj.signals.items():
+ _monitoring_signals[signal_name] = MonitoringSignal._from_rest_object(signal)
+
+ return cls(
+ compute=ServerlessSparkCompute._from_rest_object(obj.compute_configuration),
+ monitoring_target=(
+ MonitoringTarget(
+ endpoint_deployment_id=obj.monitoring_target.deployment_id, ml_task=obj.monitoring_target.task_type
+ )
+ if obj.monitoring_target
+ else None
+ ),
+ monitoring_signals=_monitoring_signals, # type: ignore[arg-type]
+ alert_notification=from_rest_alert_notification,
+ )
+
+ def _populate_default_signal_information(self) -> None:
+ if (
+ isinstance(self.monitoring_target, MonitoringTarget)
+ and self.monitoring_target.ml_task is not None
+ and self.monitoring_target.ml_task.lower()
+ == MonitorTargetTasks.QUESTION_ANSWERING.lower() # type: ignore[union-attr]
+ ):
+ self.monitoring_signals = {
+ DEFAULT_TOKEN_USAGE_SIGNAL_NAME: GenerationTokenStatisticsSignal._get_default_token_statistics_signal(),
+ }
+ else:
+ self.monitoring_signals = {
+ DEFAULT_DATA_DRIFT_SIGNAL_NAME: DataDriftSignal._get_default_data_drift_signal(),
+ DEFAULT_PREDICTION_DRIFT_SIGNAL_NAME: PredictionDriftSignal._get_default_prediction_drift_signal(),
+ DEFAULT_DATA_QUALITY_SIGNAL_NAME: DataQualitySignal._get_default_data_quality_signal(),
+ }
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/input_data.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/input_data.py
new file mode 100644
index 00000000..10d80531
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/input_data.py
@@ -0,0 +1,206 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import datetime
+from typing import Dict, Optional
+
+import isodate
+
+from azure.ai.ml._restclient.v2023_06_01_preview.models import FixedInputData as RestFixedInputData
+from azure.ai.ml._restclient.v2023_06_01_preview.models import MonitoringInputDataBase as RestMonitorInputBase
+from azure.ai.ml._restclient.v2023_06_01_preview.models import StaticInputData as RestStaticInputData
+from azure.ai.ml._restclient.v2023_06_01_preview.models import TrailingInputData as RestTrailingInputData
+from azure.ai.ml._utils.utils import camel_to_snake, snake_to_camel
+from azure.ai.ml.constants._monitoring import MonitorDatasetContext, MonitorInputDataType
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+
+class MonitorInputData(RestTranslatableMixin):
+ """Monitor input data.
+
+ :keyword type: Specifies the type of monitoring input data.
+ :paramtype type: MonitorInputDataType
+ :keyword input_dataset: Input data used by the monitor
+ :paramtype input_dataset: Optional[~azure.ai.ml.Input]
+ :keyword dataset_context: The context of the input dataset. Accepted values are "model_inputs",
+ "model_outputs", "training", "test", "validation", and "ground_truth".
+ :paramtype dataset_context: Optional[Union[str, ~azure.ai.ml.constants.MonitorDatasetContext]]
+ :keyword target_column_name: The target column in the given input dataset.
+ :paramtype target_column_name: Optional[str]
+ :keyword pre_processing_component: The ARM (Azure Resource Manager) resource ID of the component resource used to
+ preprocess the data.
+ :paramtype pre_processing_component: Optional[str]
+ """
+
+ def __init__(
+ self,
+ *,
+ type: Optional[MonitorInputDataType] = None,
+ data_context: Optional[MonitorDatasetContext] = None,
+ target_columns: Optional[Dict] = None,
+ job_type: Optional[str] = None,
+ uri: Optional[str] = None,
+ ):
+ self.type = type
+ self.data_context = data_context
+ self.target_columns = target_columns
+ self.job_type = job_type
+ self.uri = uri
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestMonitorInputBase) -> Optional["MonitorInputData"]:
+ if obj.input_data_type == MonitorInputDataType.FIXED:
+ return FixedInputData._from_rest_object(obj)
+ if obj.input_data_type == MonitorInputDataType.TRAILING:
+ return TrailingInputData._from_rest_object(obj)
+ if obj.input_data_type == MonitorInputDataType.STATIC:
+ return StaticInputData._from_rest_object(obj)
+
+ return None
+
+
+class FixedInputData(MonitorInputData):
+ """
+ :ivar type: Specifies the type of monitoring input data. Set automatically to "Fixed" for this class.
+ :var type: MonitorInputDataType
+ """
+
+ def __init__(
+ self,
+ *,
+ data_context: Optional[MonitorDatasetContext] = None,
+ target_columns: Optional[Dict] = None,
+ job_type: Optional[str] = None,
+ uri: Optional[str] = None,
+ ):
+ super().__init__(
+ type=MonitorInputDataType.FIXED,
+ data_context=data_context,
+ target_columns=target_columns,
+ job_type=job_type,
+ uri=uri,
+ )
+
+ def _to_rest_object(self) -> RestFixedInputData:
+ return RestFixedInputData(
+ data_context=camel_to_snake(self.data_context),
+ columns=self.target_columns,
+ job_input_type=self.job_type,
+ uri=self.uri,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestFixedInputData) -> "FixedInputData":
+ return cls(
+ data_context=camel_to_snake(obj.data_context),
+ target_columns=obj.columns,
+ job_type=obj.job_input_type,
+ uri=obj.uri,
+ )
+
+
+class TrailingInputData(MonitorInputData):
+ """
+ :ivar type: Specifies the type of monitoring input data. Set automatically to "Trailing" for this class.
+ :var type: MonitorInputDataType
+ """
+
+ def __init__(
+ self,
+ *,
+ data_context: Optional[MonitorDatasetContext] = None,
+ target_columns: Optional[Dict] = None,
+ job_type: Optional[str] = None,
+ uri: Optional[str] = None,
+ window_size: Optional[str] = None,
+ window_offset: Optional[str] = None,
+ pre_processing_component_id: Optional[str] = None,
+ ):
+ super().__init__(
+ type=MonitorInputDataType.TRAILING,
+ data_context=data_context,
+ target_columns=target_columns,
+ job_type=job_type,
+ uri=uri,
+ )
+ self.window_size = window_size
+ self.window_offset = window_offset
+ self.pre_processing_component_id = pre_processing_component_id
+
+ def _to_rest_object(self) -> RestTrailingInputData:
+ return RestTrailingInputData(
+ data_context=camel_to_snake(self.data_context),
+ columns=self.target_columns,
+ job_input_type=self.job_type,
+ uri=self.uri,
+ window_size=self.window_size,
+ window_offset=self.window_offset,
+ preprocessing_component_id=self.pre_processing_component_id,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestTrailingInputData) -> "TrailingInputData":
+ return cls(
+ data_context=snake_to_camel(obj.data_context),
+ target_columns=obj.columns,
+ job_type=obj.job_input_type,
+ uri=obj.uri,
+ window_size=str(isodate.duration_isoformat(obj.window_size)),
+ window_offset=str(isodate.duration_isoformat(obj.window_offset)),
+ pre_processing_component_id=obj.preprocessing_component_id,
+ )
+
+
+class StaticInputData(MonitorInputData):
+ """
+ :ivar type: Specifies the type of monitoring input data. Set automatically to "Static" for this class.
+ :var type: MonitorInputDataType
+ """
+
+ def __init__(
+ self,
+ *,
+ data_context: Optional[MonitorDatasetContext] = None,
+ target_columns: Optional[Dict] = None,
+ job_type: Optional[str] = None,
+ uri: Optional[str] = None,
+ pre_processing_component_id: Optional[str] = None,
+ window_start: Optional[str] = None,
+ window_end: Optional[str] = None,
+ ):
+ super().__init__(
+ type=MonitorInputDataType.STATIC,
+ data_context=data_context,
+ target_columns=target_columns,
+ job_type=job_type,
+ uri=uri,
+ )
+ self.pre_processing_component_id = pre_processing_component_id
+ self.window_start = window_start
+ self.window_end = window_end
+
+ def _to_rest_object(self) -> RestStaticInputData:
+ return RestStaticInputData(
+ data_context=camel_to_snake(self.data_context),
+ columns=self.target_columns,
+ job_input_type=self.job_type,
+ uri=self.uri,
+ preprocessing_component_id=self.pre_processing_component_id,
+ window_start=datetime.datetime.strptime(str(self.window_start), "%Y-%m-%d"),
+ window_end=datetime.datetime.strptime(str(self.window_end), "%Y-%m-%d"),
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestStaticInputData) -> "StaticInputData":
+ return cls(
+ data_context=snake_to_camel(obj.data_context),
+ target_columns=obj.columns,
+ job_type=obj.job_input_type,
+ uri=obj.uri,
+ pre_processing_component_id=obj.preprocessing_component_id,
+ window_start=str(datetime.datetime.strftime(obj.window_start, "%Y-%m-%d")),
+ window_end=datetime.datetime.strftime(obj.window_end, "%Y-%m-%d"),
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/schedule.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/schedule.py
new file mode 100644
index 00000000..f23c4e3e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/schedule.py
@@ -0,0 +1,175 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import logging
+from os import PathLike
+from pathlib import Path
+from typing import IO, Any, AnyStr, Dict, Optional, Union, cast
+
+from azure.ai.ml._restclient.v2023_06_01_preview.models import CreateMonitorAction, RecurrenceFrequency
+from azure.ai.ml._restclient.v2023_06_01_preview.models import Schedule as RestSchedule
+from azure.ai.ml._restclient.v2023_06_01_preview.models import ScheduleProperties
+from azure.ai.ml._schema.monitoring.schedule import MonitorScheduleSchema
+from azure.ai.ml._utils.utils import dump_yaml_to_file
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY, ScheduleType
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+from azure.ai.ml.entities._monitoring.definition import MonitorDefinition
+from azure.ai.ml.entities._schedule.schedule import Schedule
+from azure.ai.ml.entities._schedule.trigger import CronTrigger, RecurrenceTrigger, TriggerBase
+from azure.ai.ml.entities._system_data import SystemData
+from azure.ai.ml.entities._util import load_from_dict
+
+module_logger = logging.getLogger(__name__)
+
+
+class MonitorSchedule(Schedule, RestTranslatableMixin):
+ """Monitor schedule.
+
+ :keyword name: The schedule name.
+ :paramtype name: str
+ :keyword trigger: The schedule trigger.
+ :paramtype trigger: Union[~azure.ai.ml.entities.CronTrigger, ~azure.ai.ml.entities.RecurrenceTrigger]
+ :keyword create_monitor: The schedule action monitor definition.
+ :paramtype create_monitor: ~azure.ai.ml.entities.MonitorDefinition
+ :keyword display_name: The display name of the schedule.
+ :paramtype display_name: Optional[str]
+ :keyword description: A description of the schedule.
+ :paramtype description: Optional[str]
+ :keyword tags: Tag dictionary. Tags can be added, removed, and updated.
+ :paramtype tags: Optional[dict[str, str]]
+ :keyword properties: The job property dictionary.
+ :paramtype properties: Optional[dict[str, str]]
+ """
+
+ def __init__(
+ self,
+ *,
+ name: str,
+ trigger: Optional[Union[CronTrigger, RecurrenceTrigger]],
+ create_monitor: MonitorDefinition,
+ display_name: Optional[str] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ properties: Optional[Dict] = None,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ name=name,
+ trigger=trigger,
+ display_name=display_name,
+ description=description,
+ tags=tags,
+ properties=properties,
+ **kwargs,
+ )
+ self.create_monitor = create_monitor
+ self._type = ScheduleType.MONITOR
+
+ @classmethod
+ def _load(
+ cls,
+ data: Optional[Dict] = None,
+ yaml_path: Optional[Union[PathLike, str]] = None,
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "MonitorSchedule":
+ data = data or {}
+ params_override = params_override or []
+ context = {
+ BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"),
+ PARAMS_OVERRIDE_KEY: params_override,
+ }
+ return cls(
+ base_path=cast(Dict, context[BASE_PATH_CONTEXT_KEY]),
+ **load_from_dict(MonitorScheduleSchema, data, context, **kwargs),
+ )
+
+ def _to_rest_object(self) -> RestSchedule:
+ if self.tags is not None:
+ tags = {
+ **self.tags,
+ }
+ # default data window size is calculated based on the trigger frequency
+ # by default 7 days if user provides incorrect recurrence frequency
+ # or a cron expression
+ default_data_window_size = "P7D"
+ ref_data_window_size = "P14D"
+ if isinstance(self.trigger, RecurrenceTrigger):
+ frequency = self.trigger.frequency.lower()
+ interval = self.trigger.interval
+ if frequency == RecurrenceFrequency.MINUTE.lower() or frequency == RecurrenceFrequency.HOUR.lower():
+ default_data_window_size = "P1D"
+ ref_data_window_size = "P2D"
+ elif frequency == RecurrenceFrequency.DAY.lower():
+ default_data_window_size = f"P{interval}D"
+ ref_data_window_size = f"P{interval * 2}D"
+ elif frequency == RecurrenceFrequency.WEEK.lower():
+ default_data_window_size = f"P{interval * 7}D"
+ ref_data_window_size = f"P{(interval * 7) * 2}D"
+ elif frequency == RecurrenceFrequency.MONTH.lower():
+ default_data_window_size = f"P{interval * 30}D"
+ ref_data_window_size = f"P{(interval * 30) * 2}D"
+
+ return RestSchedule(
+ properties=ScheduleProperties(
+ description=self.description,
+ properties=self.properties,
+ tags=tags, # pylint: disable=possibly-used-before-assignment
+ action=CreateMonitorAction(
+ monitor_definition=self.create_monitor._to_rest_object(
+ default_data_window_size=default_data_window_size, ref_data_window_size=ref_data_window_size
+ )
+ ),
+ display_name=self.display_name,
+ is_enabled=self._is_enabled,
+ trigger=self.trigger._to_rest_object() if self.trigger is not None else None,
+ )
+ )
+
+ def dump(self, dest: Union[str, PathLike, IO[AnyStr]], **kwargs: Any) -> None:
+ """Dump the asset content into a file in YAML format.
+
+ :param dest: The local path or file stream to write the YAML content to.
+ If dest is a file path, a new file will be created.
+ If dest is an open file, the file will be written to directly.
+ :type dest: Union[PathLike, str, IO[AnyStr]]
+ :raises FileExistsError: Raised if dest is a file path and the file already exists.
+ :raises IOError: Raised if dest is an open file and the file is not writable.
+ """
+ path = kwargs.pop("path", None)
+ yaml_serialized = self._to_dict()
+ dump_yaml_to_file(dest, yaml_serialized, default_flow_style=False, path=path, **kwargs)
+
+ def _to_dict(self) -> Dict:
+ res: dict = MonitorScheduleSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestSchedule) -> "MonitorSchedule":
+ properties = obj.properties
+ return cls(
+ trigger=TriggerBase._from_rest_object(properties.trigger),
+ create_monitor=MonitorDefinition._from_rest_object(
+ properties.action.monitor_definition, tags=obj.properties.tags
+ ),
+ name=obj.name,
+ id=obj.id,
+ display_name=properties.display_name,
+ description=properties.description,
+ tags=properties.tags,
+ properties=properties.properties,
+ provisioning_state=properties.provisioning_state,
+ is_enabled=properties.is_enabled,
+ creation_context=SystemData._from_rest_object(obj.system_data) if obj.system_data else None,
+ )
+
+ def _create_default_monitor_definition(self) -> None:
+ self.create_monitor._populate_default_signal_information()
+
+ def _set_baseline_data_trailing_tags_for_signal(self, signal_name: str) -> None:
+ if self.tags is not None:
+ self.tags[f"{signal_name}.baselinedata.datarange.type"] = "Trailing"
+ self.tags[f"{signal_name}.baselinedata.datarange.window_size"] = "P7D"
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/signals.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/signals.py
new file mode 100644
index 00000000..5a9e1df7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/signals.py
@@ -0,0 +1,1338 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access, too-many-lines
+
+import datetime
+from typing import Any, Dict, List, Optional, Union
+
+import isodate
+from typing_extensions import Literal
+
+from azure.ai.ml._exception_helper import log_and_raise_error
+from azure.ai.ml._restclient.v2023_06_01_preview.models import AllFeatures as RestAllFeatures
+from azure.ai.ml._restclient.v2023_06_01_preview.models import CustomMonitoringSignal as RestCustomMonitoringSignal
+from azure.ai.ml._restclient.v2023_06_01_preview.models import (
+ DataDriftMonitoringSignal as RestMonitoringDataDriftSignal,
+)
+from azure.ai.ml._restclient.v2023_06_01_preview.models import (
+ DataQualityMonitoringSignal as RestMonitoringDataQualitySignal,
+)
+from azure.ai.ml._restclient.v2023_06_01_preview.models import (
+ FeatureAttributionDriftMonitoringSignal as RestFeatureAttributionDriftMonitoringSignal,
+)
+from azure.ai.ml._restclient.v2023_06_01_preview.models import FeatureSubset as RestFeatureSubset
+from azure.ai.ml._restclient.v2023_06_01_preview.models import (
+ GenerationSafetyQualityMonitoringSignal as RestGenerationSafetyQualityMonitoringSignal,
+)
+from azure.ai.ml._restclient.v2023_06_01_preview.models import (
+ GenerationTokenStatisticsSignal as RestGenerationTokenStatisticsSignal,
+)
+from azure.ai.ml._restclient.v2023_06_01_preview.models import ModelPerformanceSignal as RestModelPerformanceSignal
+from azure.ai.ml._restclient.v2023_06_01_preview.models import MonitoringDataSegment as RestMonitoringDataSegment
+from azure.ai.ml._restclient.v2023_06_01_preview.models import (
+ MonitoringFeatureFilterBase as RestMonitoringFeatureFilterBase,
+)
+from azure.ai.ml._restclient.v2023_06_01_preview.models import MonitoringInputDataBase as RestMonitoringInputData
+from azure.ai.ml._restclient.v2023_06_01_preview.models import MonitoringNotificationMode
+from azure.ai.ml._restclient.v2023_06_01_preview.models import MonitoringSignalBase as RestMonitoringSignalBase
+from azure.ai.ml._restclient.v2023_06_01_preview.models import MonitoringSignalType
+from azure.ai.ml._restclient.v2023_06_01_preview.models import (
+ MonitoringWorkspaceConnection as RestMonitoringWorkspaceConnection,
+)
+from azure.ai.ml._restclient.v2023_06_01_preview.models import (
+ PredictionDriftMonitoringSignal as RestPredictionDriftMonitoringSignal,
+)
+from azure.ai.ml._restclient.v2023_06_01_preview.models import (
+ TopNFeaturesByAttribution as RestTopNFeaturesByAttribution,
+)
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants._monitoring import (
+ ALL_FEATURES,
+ MonitorDatasetContext,
+ MonitorFeatureDataType,
+ MonitorSignalType,
+)
+from azure.ai.ml.entities._inputs_outputs import Input
+from azure.ai.ml.entities._job._input_output_helpers import (
+ from_rest_inputs_to_dataset_literal,
+ to_rest_dataset_literal_inputs,
+)
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+from azure.ai.ml.entities._monitoring.input_data import FixedInputData, StaticInputData, TrailingInputData
+from azure.ai.ml.entities._monitoring.thresholds import (
+ CustomMonitoringMetricThreshold,
+ DataDriftMetricThreshold,
+ DataQualityMetricThreshold,
+ FeatureAttributionDriftMetricThreshold,
+ GenerationSafetyQualityMonitoringMetricThreshold,
+ GenerationTokenStatisticsMonitorMetricThreshold,
+ MetricThreshold,
+ ModelPerformanceMetricThreshold,
+ PredictionDriftMetricThreshold,
+)
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+
+
+class DataSegment(RestTranslatableMixin):
+ """Data segment for monitoring.
+
+ :keyword feature_name: The feature to segment the data on.
+ :paramtype feature_name: str
+ :keyword feature_values: A list of values for the given segmented feature to filter.
+ :paramtype feature_values: List[str]
+ """
+
+ def __init__(
+ self,
+ *,
+ feature_name: Optional[str] = None,
+ feature_values: Optional[List[str]] = None,
+ ) -> None:
+ self.feature_name = feature_name
+ self.feature_values = feature_values
+
+ def _to_rest_object(self) -> RestMonitoringDataSegment:
+ return RestMonitoringDataSegment(feature=self.feature_name, values=self.feature_values)
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestMonitoringDataSegment) -> "DataSegment":
+ return cls(
+ feature_name=obj.feature,
+ feature_values=obj.values,
+ )
+
+
+class MonitorFeatureFilter(RestTranslatableMixin):
+ """Monitor feature filter
+
+ :keyword top_n_feature_importance: The number of top features to include. Defaults to 10.
+ :paramtype top_n_feature_importance: int
+ """
+
+ def __init__(
+ self,
+ *,
+ top_n_feature_importance: int = 10,
+ ) -> None:
+ self.top_n_feature_importance = top_n_feature_importance
+
+ def _to_rest_object(self) -> RestTopNFeaturesByAttribution:
+ return RestTopNFeaturesByAttribution(
+ top=self.top_n_feature_importance,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestTopNFeaturesByAttribution) -> "MonitorFeatureFilter":
+ return cls(top_n_feature_importance=obj.top)
+
+
+class BaselineDataRange:
+ """Baseline data range for monitoring.
+
+ This class is used when initializing a data_window for a ReferenceData object.
+ For trailing input, set lookback_window_size and lookback_window_offset to a desired value.
+ For static input, set window_start and window_end to a desired value.
+ """
+
+ def __init__(
+ self,
+ *,
+ window_start: Optional[str] = None,
+ window_end: Optional[str] = None,
+ lookback_window_size: Optional[str] = None,
+ lookback_window_offset: Optional[str] = None,
+ ):
+ self.window_start = window_start
+ self.window_end = window_end
+ self.lookback_window_size = lookback_window_size
+ self.lookback_window_offset = lookback_window_offset
+
+
+class ProductionData(RestTranslatableMixin):
+ """Production Data
+
+ :param input_data: The data for which drift will be calculated
+ :type Input: ~azure.ai.ml.entities._input_outputs
+ :param data_context: The context of the input dataset. Possible values
+ include: model_inputs, model_outputs, training, test, validation, ground_truth
+ :type MonitorDatasetContext: ~azure.ai.ml.constants.MonitorDatasetContext
+ :param pre_processing_component: ARM resource ID of the component resource used to
+ preprocess the data.
+ :type pre_processing_component: string
+ :param data_window: The number of days or a time frame that a singal monitor looks back over the target.
+ :type data_window_size: BaselineDataRange
+ """
+
+ def __init__(
+ self,
+ *,
+ input_data: Input,
+ data_context: Optional[MonitorDatasetContext] = None,
+ pre_processing_component: Optional[str] = None,
+ data_window: Optional[BaselineDataRange] = None,
+ data_column_names: Optional[Dict[str, str]] = None,
+ ):
+ self.input_data = input_data
+ self.data_context = data_context
+ self.pre_processing_component = pre_processing_component
+ self.data_window = data_window
+ self.data_column_names = data_column_names
+
+ def _to_rest_object(self, **kwargs: Any) -> RestMonitoringInputData:
+ self._validate()
+ default_data_window_size = kwargs.get("default_data_window_size")
+ if self.data_window is None:
+ self.data_window = BaselineDataRange(
+ lookback_window_size=default_data_window_size, lookback_window_offset="P0D"
+ )
+ if self.data_window.lookback_window_size in ["default", None]:
+ self.data_window.lookback_window_size = default_data_window_size
+ uri = self.input_data.path
+ job_type = self.input_data.type
+ monitoring_input_data = TrailingInputData(
+ data_context=self.data_context,
+ target_columns=self.data_column_names,
+ job_type=job_type,
+ uri=uri,
+ pre_processing_component_id=self.pre_processing_component,
+ window_size=self.data_window.lookback_window_size,
+ window_offset=(
+ self.data_window.lookback_window_offset
+ if self.data_window.lookback_window_offset is not None
+ else "P0D"
+ ),
+ )
+ return monitoring_input_data._to_rest_object()
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestMonitoringInputData) -> "ProductionData":
+ data_window = BaselineDataRange(
+ lookback_window_size=isodate.duration_isoformat(obj.window_size),
+ lookback_window_offset=isodate.duration_isoformat(obj.window_offset),
+ )
+ return cls(
+ input_data=Input(
+ path=obj.uri,
+ type=obj.job_input_type,
+ ),
+ data_context=obj.data_context,
+ pre_processing_component=obj.preprocessing_component_id,
+ data_window=data_window,
+ data_column_names=obj.columns,
+ )
+
+ def _validate(self) -> None:
+ if self.data_window:
+ if self.data_window.window_start or self.data_window.window_end:
+ msg = "ProductionData only accepts lookback_window_size and lookback_window_offset."
+ err = ValidationException(
+ message=msg,
+ target=ErrorTarget.MODEL_MONITORING,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.MISSING_FIELD,
+ )
+ log_and_raise_error(err)
+
+
+class ReferenceData(RestTranslatableMixin):
+ """Reference Data
+
+ :param input_data: The data for which drift will be calculated
+ :type Input: ~azure.ai.ml.entities._input_outputs
+ :param data_context: The context of the input dataset. Possible values
+ include: model_inputs, model_outputs, training, test, validation, ground_truth
+ :type MonitorDatasetContext: ~azure.ai.ml.constants.MonitorDatasetContext
+ :param pre_processing_component: ARM resource ID of the component resource used to
+ preprocess the data.
+ :type pre_processing_component: string
+ :param target_column_name: The name of the target column in the dataset.
+ :type target_column_name: string
+ :param data_window: The number of days or a time frame that a single monitor looks back over the target.
+ :type data_window_size: BaselineDataRange
+ """
+
+ def __init__(
+ self,
+ *,
+ input_data: Input,
+ data_context: Optional[MonitorDatasetContext] = None,
+ pre_processing_component: Optional[str] = None,
+ data_window: Optional[BaselineDataRange] = None,
+ data_column_names: Optional[Dict[str, str]] = None,
+ ):
+ self.input_data = input_data
+ self.data_context = data_context
+ self.pre_processing_component = pre_processing_component
+ self.data_window = data_window
+ self.data_column_names = data_column_names
+
+ def _to_rest_object(self, **kwargs: Any) -> RestMonitoringInputData:
+ default_data_window = kwargs.get("default_data_window")
+ ref_data_window_size = kwargs.get("ref_data_window_size")
+ if self.data_window is not None:
+ if self.data_window.lookback_window_size is not None:
+ if self.data_window.lookback_window_size == "default":
+ self.data_window.lookback_window_size = ref_data_window_size
+ if self.data_window.lookback_window_offset == "default":
+ self.data_window.lookback_window_offset = default_data_window
+ return TrailingInputData(
+ data_context=self.data_context,
+ target_columns=self.data_column_names,
+ job_type=self.input_data.type,
+ uri=self.input_data.path,
+ pre_processing_component_id=self.pre_processing_component,
+ window_size=self.data_window.lookback_window_size,
+ window_offset=(
+ self.data_window.lookback_window_offset
+ if self.data_window.lookback_window_offset is not None
+ else "P0D"
+ ),
+ )._to_rest_object()
+ if self.data_window.window_start is not None and self.data_window.window_end is not None:
+ return StaticInputData(
+ data_context=self.data_context,
+ target_columns=self.data_column_names,
+ job_type=self.input_data.type,
+ uri=self.input_data.path,
+ pre_processing_component_id=self.pre_processing_component,
+ window_start=self.data_window.window_start,
+ window_end=self.data_window.window_end,
+ )._to_rest_object()
+
+ return FixedInputData(
+ data_context=self.data_context,
+ target_columns=self.data_column_names,
+ job_type=self.input_data.type,
+ uri=self.input_data.path,
+ )._to_rest_object()
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestMonitoringInputData) -> "ReferenceData":
+ data_window = None
+ if obj.input_data_type == "Static":
+ data_window = BaselineDataRange(
+ window_start=datetime.datetime.strftime(obj.window_start, "%Y-%m-%d"),
+ window_end=datetime.datetime.strftime(obj.window_end, "%Y-%m-%d"),
+ )
+ if obj.input_data_type == "Trailing":
+ data_window = BaselineDataRange(
+ lookback_window_size=isodate.duration_isoformat(obj.window_size),
+ lookback_window_offset=isodate.duration_isoformat(obj.window_offset),
+ )
+
+ return cls(
+ input_data=Input(
+ path=obj.uri,
+ type=obj.job_input_type,
+ ),
+ data_context=obj.data_context,
+ pre_processing_component=obj.preprocessing_component_id if obj.input_data_type != "Fixed" else None,
+ data_window=data_window,
+ data_column_names=obj.columns,
+ )
+
+
+class MonitoringSignal(RestTranslatableMixin):
+ """
+ Base class for monitoring signals.
+
+ This class should not be instantiated directly. Instead, use one of its subclasses.
+
+ :keyword baseline_dataset: The baseline dataset definition for monitor input.
+ :paramtype baseline_dataset: ~azure.ai.ml.entities.MonitorInputData
+ :keyword metric_thresholds: The metric thresholds for the signal.
+ :paramtype metric_thresholds: Union[
+ ~azure.ai.ml.entities.DataDriftMetricThreshold,
+ ~azure.ai.ml.entities.DataQualityMetricThreshold,
+ ~azure.ai.ml.entities.PredictionDriftMetricThreshold,
+ ~azure.ai.ml.entities.FeatureAttributionDriftMetricThreshold,
+ ~azure.ai.ml.entities.CustomMonitoringMetricThreshold,
+ ~azure.ai.ml.entities.GenerationSafetyQualityMonitoringMetricThreshold,
+ List[Union[
+ ~azure.ai.ml.entities.DataDriftMetricThreshold,
+ ~azure.ai.ml.entities.DataQualityMetricThreshold,
+ ~azure.ai.ml.entities.PredictionDriftMetricThreshold,
+ ~azure.ai.ml.entities.FeatureAttributionDriftMetricThreshold,
+ ~azure.ai.ml.entities.CustomMonitoringMetricThreshold,
+ ~azure.ai.ml.entities.GenerationSafetyQualityMonitoringMetricThreshold,
+
+ ]]]
+ :keyword alert_enabled: Whether or not to enable alerts for the signal. Defaults to False.
+ :paramtype alert_enabled: bool
+ """
+
+ def __init__(
+ self,
+ *,
+ production_data: Optional[ProductionData] = None,
+ reference_data: Optional[ReferenceData] = None,
+ metric_thresholds: Optional[Union[MetricThreshold, List[MetricThreshold]]],
+ properties: Optional[Dict[str, str]] = None,
+ alert_enabled: bool = False,
+ ):
+ self.production_data = production_data
+ self.reference_data = reference_data
+ self.metric_thresholds = metric_thresholds
+ self.alert_enabled = alert_enabled
+ self.properties = properties
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestMonitoringSignalBase) -> Optional[ # pylint: disable=too-many-return-statements
+ Union[
+ "DataDriftSignal",
+ "DataQualitySignal",
+ "PredictionDriftSignal",
+ "ModelPerformanceSignal",
+ "FeatureAttributionDriftSignal",
+ "CustomMonitoringSignal",
+ "GenerationSafetyQualitySignal",
+ "GenerationTokenStatisticsSignal",
+ ]
+ ]:
+ if obj.signal_type == MonitoringSignalType.DATA_DRIFT:
+ return DataDriftSignal._from_rest_object(obj)
+ if obj.signal_type == MonitoringSignalType.DATA_QUALITY:
+ return DataQualitySignal._from_rest_object(obj)
+ if obj.signal_type == MonitoringSignalType.PREDICTION_DRIFT:
+ return PredictionDriftSignal._from_rest_object(obj)
+ if obj.signal_type == "ModelPerformanceSignalBase":
+ return ModelPerformanceSignal._from_rest_object(obj)
+ if obj.signal_type == MonitoringSignalType.FEATURE_ATTRIBUTION_DRIFT:
+ return FeatureAttributionDriftSignal._from_rest_object(obj)
+ if obj.signal_type == MonitoringSignalType.CUSTOM:
+ return CustomMonitoringSignal._from_rest_object(obj)
+ if obj.signal_type == MonitoringSignalType.GENERATION_SAFETY_QUALITY:
+ return GenerationSafetyQualitySignal._from_rest_object(obj)
+ if obj.signal_type == MonitoringSignalType.MODEL_PERFORMANCE:
+ return ModelPerformanceSignal._from_rest_object(obj)
+ if obj.signal_type == MonitoringSignalType.GENERATION_TOKEN_STATISTICS:
+ return GenerationTokenStatisticsSignal._from_rest_object(obj)
+
+ return None
+
+
+class DataSignal(MonitoringSignal):
+ """Base class for data signals.
+
+ This class should not be instantiated directly. Instead, use one of its subclasses.
+
+ :keyword baseline_dataset: The baseline dataset definition for monitor input.
+ :paramtype baseline_dataset: ~azure.ai.ml.entities.MonitorInputData
+ :keyword features: The features to include in the signal.
+ :paramtype features: Union[List[str], ~azure.ai.ml.entities.MonitorFeatureFilter, Literal[ALL_FEATURES]]
+ :keyword metric_thresholds: The metric thresholds for the signal.
+ :paramtype metric_thresholds: List[Union[
+ ~azure.ai.ml.entities.DataDriftMetricThreshold,
+ ~azure.ai.ml.entities.DataQualityMetricThreshold,
+ ~azure.ai.ml.entities.PredictionDriftMetricThreshold,
+ ~azure.ai.ml.entities.FeatureAttributionDriftMetricThreshold,
+ ~azure.ai.ml.entities.CustomMonitoringMetricThreshold,
+ ~azure.ai.ml.entities.GenerationSafetyQualityMonitoringMetricThreshold,
+
+ ]]
+ :keyword alert_enabled: Whether or not to enable alerts for the signal. Defaults to False.
+ :paramtype alert_enabled: bool
+ """
+
+ def __init__(
+ self,
+ *,
+ production_data: Optional[ProductionData] = None,
+ reference_data: Optional[ReferenceData] = None,
+ features: Optional[Union[List[str], MonitorFeatureFilter, Literal["all_features"]]] = None,
+ feature_type_override: Optional[Dict[str, Union[str, MonitorFeatureDataType]]] = None,
+ metric_thresholds: Optional[Union[MetricThreshold, List[MetricThreshold]]],
+ alert_enabled: bool = False,
+ properties: Optional[Dict[str, str]] = None,
+ ):
+ super().__init__(
+ production_data=production_data,
+ reference_data=reference_data,
+ metric_thresholds=metric_thresholds,
+ alert_enabled=alert_enabled,
+ properties=properties,
+ )
+ self.features = features
+ self.feature_type_override = feature_type_override
+
+
+class DataDriftSignal(DataSignal):
+ """Data drift signal.
+
+ :ivar type: The type of the signal, set to "data_drift" for this class.
+ :vartype type: str
+ :param production_data: The data for which drift will be calculated
+ :paramtype production_data: ~azure.ai.ml.entities.ProductionData
+ :param reference_data: The data to calculate drift against
+ :paramtype reference_data: ~azure.ai.ml.entities.ReferenceData
+ :param metric_thresholds: Metrics to calculate and their associated thresholds
+ :paramtype metric_thresholds: ~azure.ai.ml.entities.DataDriftMetricThreshold
+ :param alert_enabled: Whether or not to enable alerts for the signal. Defaults to False.
+ :paramtype alert_enabled: bool
+ :param data_segment: The data segment used for scoping on a subset of the data population.
+ :paramtype data_segment: ~azure.ai.ml.entities.DataSegment
+ :keyword features: The feature filter identifying which feature(s) to calculate drift over.
+ :paramtype features: Union[List[str], ~azure.ai.ml.entities.MonitorFeatureFilter, Literal['all_features']]
+ :param feature_type_override: Dictionary of features and what they should be overridden to.
+ :paramtype feature_type_override: dict[str, str]
+ :param properties: Dictionary of additional properties.
+ :paramtype properties: dict[str, str]
+ """
+
+ def __init__(
+ self,
+ *,
+ production_data: Optional[ProductionData] = None,
+ reference_data: Optional[ReferenceData] = None,
+ features: Optional[Union[List[str], MonitorFeatureFilter, Literal["all_features"]]] = None,
+ feature_type_override: Optional[Dict[str, Union[str, MonitorFeatureDataType]]] = None,
+ metric_thresholds: Optional[Union[DataDriftMetricThreshold, List[MetricThreshold]]] = None,
+ alert_enabled: bool = False,
+ data_segment: Optional[DataSegment] = None,
+ properties: Optional[Dict[str, str]] = None,
+ ):
+ super().__init__(
+ production_data=production_data,
+ reference_data=reference_data,
+ metric_thresholds=metric_thresholds,
+ features=features,
+ feature_type_override=feature_type_override,
+ alert_enabled=alert_enabled,
+ properties=properties,
+ )
+ self.type = MonitorSignalType.DATA_DRIFT
+ self.data_segment = data_segment
+
+ def _to_rest_object(self, **kwargs: Any) -> RestMonitoringDataDriftSignal:
+ default_data_window_size = kwargs.get("default_data_window_size")
+ ref_data_window_size = kwargs.get("ref_data_window_size")
+ if self.production_data is not None and self.production_data.data_window is None:
+ self.production_data.data_window = BaselineDataRange(lookback_window_size=default_data_window_size)
+ rest_features = _to_rest_features(self.features) if self.features else None
+ return RestMonitoringDataDriftSignal(
+ production_data=(
+ self.production_data._to_rest_object(default_data_window_size=default_data_window_size)
+ if self.production_data is not None
+ else None
+ ),
+ reference_data=(
+ self.reference_data._to_rest_object(
+ default_data_window=default_data_window_size, ref_data_window_size=ref_data_window_size
+ )
+ if self.reference_data is not None
+ else None
+ ),
+ features=rest_features,
+ feature_data_type_override=self.feature_type_override,
+ metric_thresholds=(
+ self.metric_thresholds._to_rest_object()
+ if isinstance(self.metric_thresholds, MetricThreshold)
+ else None
+ ),
+ mode=MonitoringNotificationMode.ENABLED if self.alert_enabled else MonitoringNotificationMode.DISABLED,
+ data_segment=self.data_segment._to_rest_object() if self.data_segment else None,
+ properties=self.properties,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestMonitoringDataDriftSignal) -> "DataDriftSignal":
+ return cls(
+ production_data=ProductionData._from_rest_object(obj.production_data),
+ reference_data=ReferenceData._from_rest_object(obj.reference_data),
+ features=_from_rest_features(obj.features),
+ feature_type_override=obj.feature_data_type_override,
+ metric_thresholds=DataDriftMetricThreshold._from_rest_object(obj.metric_thresholds),
+ alert_enabled=(
+ False
+ if not obj.mode or (obj.mode and obj.mode == MonitoringNotificationMode.DISABLED)
+ else MonitoringNotificationMode.ENABLED
+ ),
+ data_segment=DataSegment._from_rest_object(obj.data_segment) if obj.data_segment else None,
+ properties=obj.properties,
+ )
+
+ @classmethod
+ def _get_default_data_drift_signal(cls) -> "DataDriftSignal":
+ return cls(
+ features=ALL_FEATURES, # type: ignore[arg-type]
+ metric_thresholds=DataDriftMetricThreshold._get_default_thresholds(),
+ )
+
+
+class PredictionDriftSignal(MonitoringSignal):
+ """Prediction drift signal.
+
+ :ivar type: The type of the signal, set to "prediction_drift" for this class.
+ :vartype type: str
+ :param production_data: The data for which drift will be calculated
+ :paramtype production_data: ~azure.ai.ml.entities.ProductionData
+ :param reference_data: The data to calculate drift against
+ :paramtype reference_data: ~azure.ai.ml.entities.ReferenceData
+ :param metric_thresholds: Metrics to calculate and their associated thresholds
+ :paramtype metric_thresholds: ~azure.ai.ml.entities.DataDriftMetricThreshold
+ :param alert_enabled: Whether or not to enable alerts for the signal. Defaults to False.
+ :paramtype alert_enabled: bool
+ :param properties: Dictionary of additional properties.
+ :paramtype properties: dict[str, str]
+ """
+
+ def __init__(
+ self,
+ *,
+ production_data: Optional[ProductionData] = None,
+ reference_data: Optional[ReferenceData] = None,
+ metric_thresholds: PredictionDriftMetricThreshold,
+ alert_enabled: bool = False,
+ properties: Optional[Dict[str, str]] = None,
+ ):
+ super().__init__(
+ production_data=production_data,
+ reference_data=reference_data,
+ metric_thresholds=metric_thresholds,
+ alert_enabled=alert_enabled,
+ properties=properties,
+ )
+ self.type = MonitorSignalType.PREDICTION_DRIFT
+
+ def _to_rest_object(self, **kwargs: Any) -> RestPredictionDriftMonitoringSignal:
+ default_data_window_size = kwargs.get("default_data_window_size")
+ ref_data_window_size = kwargs.get("ref_data_window_size")
+ if self.production_data is not None and self.production_data.data_window is None:
+ self.production_data.data_window = BaselineDataRange(lookback_window_size=default_data_window_size)
+ return RestPredictionDriftMonitoringSignal(
+ production_data=(
+ self.production_data._to_rest_object(default_data_window_size=default_data_window_size)
+ if self.production_data is not None
+ else None
+ ),
+ reference_data=(
+ self.reference_data._to_rest_object(
+ default_data_window=default_data_window_size, ref_data_window_size=ref_data_window_size
+ )
+ if self.reference_data is not None
+ else None
+ ),
+ metric_thresholds=(
+ self.metric_thresholds._to_rest_object()
+ if isinstance(self.metric_thresholds, MetricThreshold)
+ else None
+ ),
+ properties=self.properties,
+ mode=MonitoringNotificationMode.ENABLED if self.alert_enabled else MonitoringNotificationMode.DISABLED,
+ model_type="classification",
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestPredictionDriftMonitoringSignal) -> "PredictionDriftSignal":
+ return cls(
+ production_data=ProductionData._from_rest_object(obj.production_data),
+ reference_data=ReferenceData._from_rest_object(obj.reference_data),
+ metric_thresholds=PredictionDriftMetricThreshold._from_rest_object(obj.metric_thresholds),
+ alert_enabled=(
+ False
+ if not obj.mode or (obj.mode and obj.mode == MonitoringNotificationMode.DISABLED)
+ else MonitoringNotificationMode.ENABLED
+ ),
+ properties=obj.properties,
+ )
+
+ @classmethod
+ def _get_default_prediction_drift_signal(cls) -> "PredictionDriftSignal":
+ return cls(
+ metric_thresholds=PredictionDriftMetricThreshold._get_default_thresholds(),
+ )
+
+
+class DataQualitySignal(DataSignal):
+ """Data quality signal
+
+ :ivar type: The type of the signal. Set to "data_quality" for this class.
+ :vartype type: str
+ :param production_data: The data for which drift will be calculated
+ :paramtype production_data: ~azure.ai.ml.entities.ProductionData
+ :param reference_data: The data to calculate drift against
+ :paramtype reference_data: ~azure.ai.ml.entities.ReferenceData
+ :param metric_thresholds: Metrics to calculate and their associated thresholds
+ :paramtype metric_thresholds: ~azure.ai.ml.entities.DataDriftMetricThreshold
+ :param alert_enabled: Whether or not to enable alerts for the signal. Defaults to False.
+ :paramtype alert_enabled: bool
+ :keyword features: The feature filter identifying which feature(s) to calculate drift over.
+ :paramtype features: Union[List[str], ~azure.ai.ml.entities.MonitorFeatureFilter, Literal['all_features']]
+ :param feature_type_override: Dictionary of features and what they should be overridden to.
+ :paramtype feature_type_override: dict[str, str]
+ :param properties: Dictionary of additional properties.
+ :paramtype properties: dict[str, str]
+ """
+
+ def __init__(
+ self,
+ *,
+ production_data: Optional[ProductionData] = None,
+ reference_data: Optional[ReferenceData] = None,
+ features: Optional[Union[List[str], MonitorFeatureFilter, Literal["all_features"]]] = None,
+ feature_type_override: Optional[Dict[str, Union[str, MonitorFeatureDataType]]] = None,
+ metric_thresholds: Optional[Union[MetricThreshold, List[MetricThreshold]]] = None,
+ alert_enabled: bool = False,
+ properties: Optional[Dict[str, str]] = None,
+ ):
+ super().__init__(
+ production_data=production_data,
+ reference_data=reference_data,
+ metric_thresholds=metric_thresholds,
+ features=features,
+ feature_type_override=feature_type_override,
+ alert_enabled=alert_enabled,
+ properties=properties,
+ )
+ self.type = MonitorSignalType.DATA_QUALITY
+
+ def _to_rest_object(self, **kwargs: Any) -> RestMonitoringDataQualitySignal:
+ default_data_window_size = kwargs.get("default_data_window_size")
+ ref_data_window_size = kwargs.get("ref_data_window_size")
+ if self.production_data is not None and self.production_data.data_window is None:
+ self.production_data.data_window = BaselineDataRange(
+ lookback_window_size=default_data_window_size,
+ )
+ rest_features = _to_rest_features(self.features) if self.features else None
+ rest_metrics = (
+ # TODO: Bug Item number: 2883365
+ _to_rest_data_quality_metrics(
+ self.metric_thresholds.numerical, self.metric_thresholds.categorical # type: ignore
+ )
+ if isinstance(self.metric_thresholds, MetricThreshold)
+ else None
+ )
+ return RestMonitoringDataQualitySignal(
+ production_data=(
+ self.production_data._to_rest_object(default_data_window_size=default_data_window_size)
+ if self.production_data is not None
+ else None
+ ),
+ reference_data=(
+ self.reference_data._to_rest_object(
+ default_data_window=default_data_window_size, ref_data_window_size=ref_data_window_size
+ )
+ if self.reference_data is not None
+ else None
+ ),
+ features=rest_features,
+ feature_data_type_override=self.feature_type_override,
+ metric_thresholds=rest_metrics,
+ mode=MonitoringNotificationMode.ENABLED if self.alert_enabled else MonitoringNotificationMode.DISABLED,
+ properties=self.properties,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestMonitoringDataQualitySignal) -> "DataQualitySignal":
+ return cls(
+ production_data=ProductionData._from_rest_object(obj.production_data),
+ reference_data=ReferenceData._from_rest_object(obj.reference_data),
+ features=_from_rest_features(obj.features),
+ feature_type_override=obj.feature_data_type_override,
+ metric_thresholds=DataQualityMetricThreshold._from_rest_object(obj.metric_thresholds),
+ alert_enabled=(
+ False
+ if not obj.mode or (obj.mode and obj.mode == MonitoringNotificationMode.DISABLED)
+ else MonitoringNotificationMode.ENABLED
+ ),
+ properties=obj.properties,
+ )
+
+ @classmethod
+ def _get_default_data_quality_signal(
+ cls,
+ ) -> "DataQualitySignal":
+ return cls(
+ features=ALL_FEATURES, # type: ignore[arg-type]
+ metric_thresholds=DataQualityMetricThreshold._get_default_thresholds(),
+ )
+
+
+@experimental
+class FADProductionData(RestTranslatableMixin):
+ """Feature Attribution Production Data
+
+ :keyword input_data: Input data used by the monitor.
+ :paramtype input_data: ~azure.ai.ml.Input
+ :keyword data_context: The context of the input dataset. Accepted values are "model_inputs",
+ "model_outputs", "training", "test", "validation", and "ground_truth".
+ :paramtype data_context: ~azure.ai.ml.constants._monitoring
+ :keyword data_column_names: The names of the columns in the input data.
+ :paramtype data_column_names: Dict[str, str]
+ :keyword pre_processing_component: The ARM (Azure Resource Manager) resource ID of the component resource used to
+ preprocess the data.
+ :paramtype pre_processing_component: string
+ :param data_window: The number of days or a time frame that a singal monitor looks back over the target.
+ :type data_window: BaselineDataRange
+ """
+
+ def __init__(
+ self,
+ *,
+ input_data: Input,
+ data_context: Optional[MonitorDatasetContext] = None,
+ data_column_names: Optional[Dict[str, str]] = None,
+ pre_processing_component: Optional[str] = None,
+ data_window: Optional[BaselineDataRange] = None,
+ ):
+ self.input_data = input_data
+ self.data_context = data_context
+ self.data_column_names = data_column_names
+ self.pre_processing_component = pre_processing_component
+ self.data_window = data_window
+
+ def _to_rest_object(self, **kwargs: Any) -> RestMonitoringInputData:
+ default_data_window_size = kwargs.get("default")
+ if self.data_window is None:
+ self.data_window = BaselineDataRange(
+ lookback_window_size=default_data_window_size, lookback_window_offset="P0D"
+ )
+ if self.data_window.lookback_window_size == "default":
+ self.data_window.lookback_window_size = default_data_window_size
+ uri = self.input_data.path
+ job_type = self.input_data.type
+ monitoring_input_data = TrailingInputData(
+ data_context=self.data_context,
+ target_columns=self.data_column_names,
+ job_type=job_type,
+ uri=uri,
+ pre_processing_component_id=self.pre_processing_component,
+ window_size=self.data_window.lookback_window_size,
+ window_offset=(
+ self.data_window.lookback_window_offset
+ if self.data_window.lookback_window_offset is not None
+ else "P0D"
+ ),
+ )
+ return monitoring_input_data._to_rest_object()
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestMonitoringInputData) -> "FADProductionData":
+ data_window = BaselineDataRange(
+ lookback_window_size=isodate.duration_isoformat(obj.window_size),
+ lookback_window_offset=isodate.duration_isoformat(obj.window_offset),
+ )
+ return cls(
+ input_data=Input(
+ path=obj.uri,
+ type=obj.job_input_type,
+ ),
+ data_context=obj.data_context,
+ data_column_names=obj.columns,
+ pre_processing_component=obj.preprocessing_component_id,
+ data_window=data_window,
+ )
+
+
+@experimental
+class FeatureAttributionDriftSignal(RestTranslatableMixin):
+ """Feature attribution drift signal
+
+ :ivar type: The type of the signal. Set to "feature_attribution_drift" for this class.
+ :vartype type: str
+ :keyword production_data: The data for which drift will be calculated.
+ :paratype production_data: ~azure.ai.ml.entities.FADProductionData
+ :keyword reference_data: The data to calculate drift against.
+ :paramtype reference_data: ~azure.ai.ml.entities.ReferenceData
+ :keyword metric_thresholds: Metrics to calculate and their
+ associated thresholds.
+ :paramtype metric_thresholds: ~azure.ai.ml.entities.FeatureAttributionDriftMetricThreshold
+ :keyword alert_enabled: Whether or not to enable alerts for the signal. Defaults to False.
+ :paramtype alert_enabled: bool
+ """
+
+ def __init__(
+ self,
+ *,
+ production_data: Optional[List[FADProductionData]] = None,
+ reference_data: ReferenceData,
+ metric_thresholds: FeatureAttributionDriftMetricThreshold,
+ alert_enabled: bool = False,
+ properties: Optional[Dict[str, str]] = None,
+ ):
+ self.production_data = production_data
+ self.reference_data = reference_data
+ self.metric_thresholds = metric_thresholds
+ self.alert_enabled = alert_enabled
+ self.properties = properties
+ self.type = MonitorSignalType.FEATURE_ATTRIBUTION_DRIFT
+
+ def _to_rest_object(self, **kwargs: Any) -> RestFeatureAttributionDriftMonitoringSignal:
+ default_window_size = kwargs.get("default_data_window_size")
+ ref_data_window_size = kwargs.get("ref_data_window_size")
+ return RestFeatureAttributionDriftMonitoringSignal(
+ production_data=(
+ [data._to_rest_object(default=default_window_size) for data in self.production_data]
+ if self.production_data is not None
+ else None
+ ),
+ reference_data=self.reference_data._to_rest_object(
+ default_data_window=default_window_size, ref_data_window_size=ref_data_window_size
+ ),
+ metric_threshold=self.metric_thresholds._to_rest_object(),
+ mode=MonitoringNotificationMode.ENABLED if self.alert_enabled else MonitoringNotificationMode.DISABLED,
+ properties=self.properties,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestFeatureAttributionDriftMonitoringSignal) -> "FeatureAttributionDriftSignal":
+ return cls(
+ production_data=[FADProductionData._from_rest_object(data) for data in obj.production_data],
+ reference_data=ReferenceData._from_rest_object(obj.reference_data),
+ metric_thresholds=FeatureAttributionDriftMetricThreshold._from_rest_object(obj.metric_threshold),
+ alert_enabled=(
+ False
+ if not obj.mode or (obj.mode and obj.mode == MonitoringNotificationMode.DISABLED)
+ else MonitoringNotificationMode.ENABLED
+ ),
+ properties=obj.properties,
+ )
+
+
+@experimental
+class ModelPerformanceSignal(RestTranslatableMixin):
+ """Model performance signal.
+
+ :keyword baseline_dataset: The data to calculate performance against.
+ :paramtype baseline_dataset: ~azure.ai.ml.entities.MonitorInputData
+ :keyword metric_thresholds: A list of metrics to calculate and their
+ associated thresholds.
+ :paramtype metric_thresholds: ~azure.ai.ml.entities.ModelPerformanceMetricThreshold
+ :keyword model_type: The model type.
+ :paramtype model_type: ~azure.ai.ml.constants.MonitorModelType
+ :keyword data_segment: The data segment to calculate performance against.
+ :paramtype data_segment: ~azure.ai.ml.entities.DataSegment
+ :keyword alert_enabled: Whether or not to enable alerts for the signal. Defaults to False.
+ :paramtype alert_enabled: bool
+ """
+
+ def __init__(
+ self,
+ *,
+ production_data: ProductionData,
+ reference_data: ReferenceData,
+ metric_thresholds: ModelPerformanceMetricThreshold,
+ data_segment: Optional[DataSegment] = None,
+ alert_enabled: bool = False,
+ properties: Optional[Dict[str, str]] = None,
+ ) -> None:
+ self.production_data = production_data
+ self.reference_data = reference_data
+ self.metric_thresholds = metric_thresholds
+ self.alert_enabled = alert_enabled
+ self.type = MonitorSignalType.MODEL_PERFORMANCE
+ self.data_segment = data_segment
+ self.properties = properties
+
+ def _to_rest_object(self, **kwargs: Any) -> RestModelPerformanceSignal:
+ default_data_window_size = kwargs.get("default_data_window_size")
+ ref_data_window_size = kwargs.get("ref_data_window_size")
+ if self.properties is None:
+ self.properties = {}
+ self.properties["azureml.modelmonitor.model_performance_thresholds"] = self.metric_thresholds._to_str_object()
+ if self.production_data.data_window is None:
+ self.production_data.data_window = BaselineDataRange(
+ lookback_window_size=default_data_window_size,
+ )
+ return RestModelPerformanceSignal(
+ production_data=[self.production_data._to_rest_object(default_data_window_size=default_data_window_size)],
+ reference_data=self.reference_data._to_rest_object(
+ default_data_window_size=default_data_window_size, ref_data_window_size=ref_data_window_size
+ ),
+ metric_threshold=self.metric_thresholds._to_rest_object(),
+ data_segment=self.data_segment._to_rest_object() if self.data_segment else None,
+ mode=MonitoringNotificationMode.ENABLED if self.alert_enabled else MonitoringNotificationMode.DISABLED,
+ properties=self.properties,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestModelPerformanceSignal) -> "ModelPerformanceSignal":
+ return cls(
+ production_data=ProductionData._from_rest_object(obj.production_data[0]),
+ reference_data=ReferenceData._from_rest_object(obj.reference_data),
+ metric_thresholds=ModelPerformanceMetricThreshold._from_rest_object(obj.metric_threshold),
+ data_segment=DataSegment._from_rest_object(obj.data_segment) if obj.data_segment else None,
+ alert_enabled=(
+ False
+ if not obj.mode or (obj.mode and obj.mode == MonitoringNotificationMode.DISABLED)
+ else MonitoringNotificationMode.ENABLED
+ ),
+ )
+
+
+@experimental
+class Connection(RestTranslatableMixin):
+ """Monitoring Connection
+
+ :param environment_variables: A dictionary of environment variables to set for the workspace.
+ :paramtype environment_variables: Optional[dict[str, str]]
+ :param secret_config: A dictionary of secrets to set for the workspace.
+ :paramtype secret_config: Optional[dict[str, str]]
+ """
+
+ def __init__(
+ self,
+ *,
+ environment_variables: Optional[Dict[str, str]] = None,
+ secret_config: Optional[Dict[str, str]] = None,
+ ):
+ self.environment_variables = environment_variables
+ self.secret_config = secret_config
+
+ def _to_rest_object(self) -> RestMonitoringWorkspaceConnection:
+ return RestMonitoringWorkspaceConnection(
+ environment_variables=self.environment_variables,
+ secrets=self.secret_config,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestMonitoringWorkspaceConnection) -> "Connection":
+ return cls(
+ environment_variables=obj.environment_variables,
+ secret_config=obj.secrets,
+ )
+
+
+@experimental
+class CustomMonitoringSignal(RestTranslatableMixin):
+ """Custom monitoring signal.
+
+ :ivar type: The type of the signal. Set to "custom" for this class.
+ :vartype type: str
+ :keyword input_data: A dictionary of input datasets for monitoring.
+ Each key is the component input port name, and its value is the data asset.
+ :paramtype input_data: Optional[dict[str, ~azure.ai.ml.entities.ReferenceData]]
+ :keyword metric_thresholds: A list of metrics to calculate and their
+ associated thresholds.
+ :paramtype metric_thresholds: List[~azure.ai.ml.entities.CustomMonitoringMetricThreshold]
+ :keyword inputs:
+ :paramtype inputs: Optional[dict[str, ~azure.ai.ml.entities.Input]]
+ :keyword component_id: The ARM (Azure Resource Manager) ID of the component resource used to
+ calculate the custom metrics.
+ :paramtype component_id: str
+ :keyword connection: Specify connection with environment variables and secret configs.
+ :paramtype connection: Optional[~azure.ai.ml.entities.WorkspaceConnection]
+ :keyword alert_enabled: Whether or not to enable alerts for the signal. Defaults to False.
+ :paramtype alert_enabled: bool
+ :keyword properties: A dictionary of custom properties for the signal.
+ :paramtype properties: Optional[dict[str, str]]
+ """
+
+ def __init__(
+ self,
+ *,
+ inputs: Optional[Dict[str, Input]] = None,
+ metric_thresholds: List[CustomMonitoringMetricThreshold],
+ component_id: str,
+ connection: Optional[Connection] = None,
+ input_data: Optional[Dict[str, ReferenceData]] = None,
+ alert_enabled: bool = False,
+ properties: Optional[Dict[str, str]] = None,
+ ):
+ self.type = MonitorSignalType.CUSTOM
+ self.inputs = inputs
+ self.metric_thresholds = metric_thresholds
+ self.component_id = component_id
+ self.alert_enabled = alert_enabled
+ self.input_data = input_data
+ self.properties = properties
+ self.connection = connection
+
+ def _to_rest_object(self, **kwargs: Any) -> RestCustomMonitoringSignal: # pylint:disable=unused-argument
+ if self.connection is None:
+ self.connection = Connection()
+ return RestCustomMonitoringSignal(
+ component_id=self.component_id,
+ metric_thresholds=[threshold._to_rest_object() for threshold in self.metric_thresholds],
+ inputs=to_rest_dataset_literal_inputs(self.inputs, job_type=None) if self.inputs else None,
+ input_assets=(
+ {asset_name: asset_value._to_rest_object() for asset_name, asset_value in self.input_data.items()}
+ if self.input_data
+ else None
+ ),
+ workspace_connection=self.connection._to_rest_object(),
+ mode=MonitoringNotificationMode.ENABLED if self.alert_enabled else MonitoringNotificationMode.DISABLED,
+ properties=self.properties,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestCustomMonitoringSignal) -> "CustomMonitoringSignal":
+ return cls(
+ inputs=from_rest_inputs_to_dataset_literal(obj.inputs) if obj.inputs else None,
+ input_data={key: ReferenceData._from_rest_object(data) for key, data in obj.input_assets.items()},
+ metric_thresholds=[
+ CustomMonitoringMetricThreshold._from_rest_object(metric) for metric in obj.metric_thresholds
+ ],
+ component_id=obj.component_id,
+ alert_enabled=(
+ False
+ if not obj.mode or (obj.mode and obj.mode == MonitoringNotificationMode.DISABLED)
+ else MonitoringNotificationMode.ENABLED
+ ),
+ properties=obj.properties,
+ connection=Connection._from_rest_object(obj.workspace_connection),
+ )
+
+
+@experimental
+class LlmData(RestTranslatableMixin):
+ """LLM Request Response Data
+
+ :param input_data: Input data used by the monitor.
+ :paramtype input_data: ~azure.ai.ml.entities.Input
+ :param data_column_names: The names of columns in the input data.
+ :paramtype data_column_names: Dict[str, str]
+ :param data_window: The number of days or a time frame that a singal monitor looks back over the target.
+ :type data_window_size: BaselineDataRange
+ """
+
+ def __init__(
+ self,
+ *,
+ input_data: Input,
+ data_column_names: Optional[Dict[str, str]] = None,
+ data_window: Optional[BaselineDataRange] = None,
+ ):
+ self.input_data = input_data
+ self.data_column_names = data_column_names
+ self.data_window = data_window
+
+ def _to_rest_object(self, **kwargs: Any) -> RestMonitoringInputData:
+ if self.data_window is None:
+ self.data_window = BaselineDataRange(
+ lookback_window_size=kwargs.get("default"),
+ )
+ return TrailingInputData(
+ target_columns=self.data_column_names,
+ job_type=self.input_data.type,
+ uri=self.input_data.path,
+ window_size=self.data_window.lookback_window_size,
+ window_offset=(
+ self.data_window.lookback_window_offset
+ if self.data_window.lookback_window_offset is not None
+ else "P0D"
+ ),
+ )._to_rest_object()
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestMonitoringInputData) -> "LlmData":
+ data_window = BaselineDataRange(
+ lookback_window_size=isodate.duration_isoformat(obj.window_size),
+ lookback_window_offset=isodate.duration_isoformat(obj.window_offset),
+ )
+ return cls(
+ input_data=Input(
+ path=obj.uri,
+ type=obj.job_input_type,
+ ),
+ data_column_names=obj.columns,
+ data_window=data_window,
+ )
+
+
+@experimental
+class GenerationSafetyQualitySignal(RestTranslatableMixin):
+ """Generation Safety Quality monitoring signal.
+
+ :ivar type: The type of the signal. Set to "generationsafetyquality" for this class.
+ :vartype type: str
+ :keyword production_data: A list of input datasets for monitoring.
+ :paramtype input_datasets: Optional[dict[str, ~azure.ai.ml.entities.LlmData]]
+ :keyword metric_thresholds: Metrics to calculate and their associated thresholds.
+ :paramtype metric_thresholds: ~azure.ai.ml.entities.GenerationSafetyQualityMonitoringMetricThreshold
+ :keyword alert_enabled: Whether or not to enable alerts for the signal. Defaults to False.
+ :paramtype alert_enabled: bool
+ :keyword connection_id: Gets or sets the connection ID used to connect to the
+ content generation endpoint.
+ :paramtype connection_id: str
+ :keyword properties: The properties of the signal
+ :paramtype properties: Dict[str, str]
+ :keyword sampling_rate: The sample rate of the target data, should be greater
+ than 0 and at most 1.
+ :paramtype sampling_rate: float
+ """
+
+ def __init__(
+ self,
+ *,
+ production_data: Optional[List[LlmData]] = None,
+ connection_id: Optional[str] = None,
+ metric_thresholds: GenerationSafetyQualityMonitoringMetricThreshold,
+ alert_enabled: bool = False,
+ properties: Optional[Dict[str, str]] = None,
+ sampling_rate: Optional[float] = None,
+ ):
+ self.type = MonitorSignalType.GENERATION_SAFETY_QUALITY
+ self.production_data = production_data
+ self.connection_id = connection_id
+ self.metric_thresholds = metric_thresholds
+ self.alert_enabled = alert_enabled
+ self.properties = properties
+ self.sampling_rate = sampling_rate
+
+ def _to_rest_object(self, **kwargs: Any) -> RestGenerationSafetyQualityMonitoringSignal:
+ data_window_size = kwargs.get("default_data_window_size")
+ return RestGenerationSafetyQualityMonitoringSignal(
+ production_data=(
+ [data._to_rest_object(default=data_window_size) for data in self.production_data]
+ if self.production_data is not None
+ else None
+ ),
+ workspace_connection_id=self.connection_id,
+ metric_thresholds=self.metric_thresholds._to_rest_object(),
+ mode=MonitoringNotificationMode.ENABLED if self.alert_enabled else MonitoringNotificationMode.DISABLED,
+ properties=self.properties,
+ sampling_rate=self.sampling_rate,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestGenerationSafetyQualityMonitoringSignal) -> "GenerationSafetyQualitySignal":
+ return cls(
+ production_data=[LlmData._from_rest_object(data) for data in obj.production_data],
+ connection_id=obj.workspace_connection_id,
+ metric_thresholds=GenerationSafetyQualityMonitoringMetricThreshold._from_rest_object(obj.metric_thresholds),
+ alert_enabled=(
+ False
+ if not obj.mode or (obj.mode and obj.mode == MonitoringNotificationMode.DISABLED)
+ else MonitoringNotificationMode.ENABLED
+ ),
+ properties=obj.properties,
+ sampling_rate=obj.sampling_rate,
+ )
+
+
+@experimental
+class GenerationTokenStatisticsSignal(RestTranslatableMixin):
+ """Generation token statistics signal definition.
+
+ :ivar type: The type of the signal. Set to "generationtokenstatisticssignal" for this class.
+ :vartype type: str
+ :keyword production_data: input dataset for monitoring.
+ :paramtype input_dataset: Optional[~azure.ai.ml.entities.LlmData]
+ :keyword metric_thresholds: Metrics to calculate and their associated thresholds. Defaults to App Traces
+ :paramtype metric_thresholds: Optional[~azure.ai.ml.entities.GenerationTokenStatisticsMonitorMetricThreshold]
+ :keyword alert_enabled: Whether or not to enable alerts for the signal. Defaults to False.
+ :paramtype alert_enabled: bool
+ :keyword properties: The properties of the signal
+ :paramtype properties: Optional[Dict[str, str]]
+ :keyword sampling_rate: The sample rate of the target data, should be greater
+ than 0 and at most 1.
+ :paramtype sampling_rate: float
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_genAI_monitors_configuration.py
+ :start-after: [START default_monitoring]
+ :end-before: [END default_monitoring]
+ :language: python
+ :dedent: 8
+ :caption: Set Token Statistics Monitor.
+ """
+
+ def __init__(
+ self,
+ *,
+ production_data: Optional[LlmData] = None,
+ metric_thresholds: Optional[GenerationTokenStatisticsMonitorMetricThreshold] = None,
+ alert_enabled: bool = False,
+ properties: Optional[Dict[str, str]] = None,
+ sampling_rate: Optional[float] = None,
+ ):
+ self.type = MonitorSignalType.GENERATION_TOKEN_STATISTICS
+ self.production_data = production_data
+ self.metric_thresholds = metric_thresholds
+ self.alert_enabled = alert_enabled
+ self.properties = properties
+ self.sampling_rate = sampling_rate
+
+ def _to_rest_object(self, **kwargs: Any) -> RestGenerationTokenStatisticsSignal:
+ data_window_size = kwargs.get("default_data_window_size")
+ return RestGenerationTokenStatisticsSignal(
+ production_data=(
+ self.production_data._to_rest_object(default=data_window_size)
+ if self.production_data is not None
+ else None
+ ),
+ metric_thresholds=(
+ self.metric_thresholds._to_rest_object()
+ if self.metric_thresholds
+ else GenerationTokenStatisticsMonitorMetricThreshold._get_default_thresholds()._to_rest_object()
+ ),
+ mode=MonitoringNotificationMode.ENABLED if self.alert_enabled else MonitoringNotificationMode.DISABLED,
+ properties=self.properties,
+ sampling_rate=self.sampling_rate if self.sampling_rate else 0.1,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestGenerationTokenStatisticsSignal) -> "GenerationTokenStatisticsSignal":
+ return cls(
+ production_data=LlmData._from_rest_object(obj.production_data),
+ metric_thresholds=GenerationTokenStatisticsMonitorMetricThreshold._from_rest_object(obj.metric_thresholds),
+ alert_enabled=(
+ False
+ if not obj.mode or (obj.mode and obj.mode == MonitoringNotificationMode.DISABLED)
+ else MonitoringNotificationMode.ENABLED
+ ),
+ properties=obj.properties,
+ sampling_rate=obj.sampling_rate,
+ )
+
+ @classmethod
+ def _get_default_token_statistics_signal(cls) -> "GenerationTokenStatisticsSignal":
+ return cls(
+ metric_thresholds=GenerationTokenStatisticsMonitorMetricThreshold._get_default_thresholds(),
+ sampling_rate=0.1,
+ )
+
+
+def _from_rest_features(
+ obj: RestMonitoringFeatureFilterBase,
+) -> Optional[Union[List[str], MonitorFeatureFilter, Literal["all_features"]]]:
+ if isinstance(obj, RestTopNFeaturesByAttribution):
+ return MonitorFeatureFilter(top_n_feature_importance=obj.top)
+ if isinstance(obj, RestFeatureSubset):
+ _restFeatureSubset: List[str] = obj.features
+ return _restFeatureSubset
+ if isinstance(obj, RestAllFeatures):
+ _restAllFeatures: Literal["all_features"] = ALL_FEATURES # type: ignore[assignment]
+ return _restAllFeatures
+
+ return None
+
+
+def _to_rest_features(
+ features: Union[List[str], MonitorFeatureFilter, Literal["all_features"]]
+) -> RestMonitoringFeatureFilterBase:
+ rest_features = None
+ if isinstance(features, list):
+ rest_features = RestFeatureSubset(features=features)
+ elif isinstance(features, MonitorFeatureFilter):
+ rest_features = features._to_rest_object()
+ elif isinstance(features, str) and features == ALL_FEATURES:
+ rest_features = RestAllFeatures()
+ return rest_features
+
+
+def _to_rest_num_cat_metrics(numerical_metrics: Any, categorical_metrics: Any) -> List:
+ metrics = []
+ if numerical_metrics is not None:
+ metrics.append(numerical_metrics._to_rest_object())
+
+ if categorical_metrics is not None:
+ metrics.append(categorical_metrics._to_rest_object())
+
+ return metrics
+
+
+def _to_rest_data_quality_metrics(numerical_metrics: Any, categorical_metrics: Any) -> List:
+ metric_thresholds: List = []
+ if numerical_metrics is not None:
+ metric_thresholds = metric_thresholds + numerical_metrics._to_rest_object()
+
+ if categorical_metrics is not None:
+ metric_thresholds = metric_thresholds + categorical_metrics._to_rest_object()
+
+ return metric_thresholds
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/target.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/target.py
new file mode 100644
index 00000000..73a11895
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/target.py
@@ -0,0 +1,55 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Optional, Union
+
+from azure.ai.ml._restclient.v2023_06_01_preview.models import MonitoringTarget as RestMonitoringTarget
+from azure.ai.ml.constants._monitoring import MonitorTargetTasks
+
+
+class MonitoringTarget:
+ """Monitoring target.
+
+ :keyword ml_task: Type of task. Allowed values: Classification, Regression, and QuestionAnswering
+ :paramtype ml_task: Optional[Union[str, MonitorTargetTasks]]
+ :keyword endpoint_deployment_id: The ARM ID of the target deployment. Mutually exclusive with model_id.
+ :paramtype endpoint_deployment_id: Optional[str]
+ :keyword model_id: ARM ID of the target model ID. Mutually exclusive with endpoint_deployment_id.
+ :paramtype model_id: Optional[str]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_spark_configurations.py
+ :start-after: [START spark_monitor_definition]
+ :end-before: [END spark_monitor_definition]
+ :language: python
+ :dedent: 8
+ :caption: Setting a monitoring target using endpoint_deployment_id.
+ """
+
+ def __init__(
+ self,
+ *,
+ ml_task: Optional[Union[str, MonitorTargetTasks]] = None,
+ endpoint_deployment_id: Optional[str] = None,
+ model_id: Optional[str] = None,
+ ):
+ self.endpoint_deployment_id = endpoint_deployment_id
+ self.model_id = model_id
+ self.ml_task = ml_task
+
+ def _to_rest_object(self) -> RestMonitoringTarget:
+ return RestMonitoringTarget(
+ task_type=self.ml_task if self.ml_task else "classification",
+ deployment_id=self.endpoint_deployment_id,
+ model_id=self.model_id,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestMonitoringTarget) -> "MonitoringTarget":
+ return cls(
+ ml_task=obj.task_type,
+ endpoint_deployment_id=obj.endpoint_deployment_id,
+ model_id=obj.model_id,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/thresholds.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/thresholds.py
new file mode 100644
index 00000000..3e1c33b5
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/thresholds.py
@@ -0,0 +1,954 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument, protected-access
+
+from typing import Any, Dict, List, Optional, Tuple
+
+from azure.ai.ml._restclient.v2023_06_01_preview.models import (
+ CategoricalDataDriftMetricThreshold,
+ CategoricalDataQualityMetricThreshold,
+ CategoricalPredictionDriftMetricThreshold,
+ ClassificationModelPerformanceMetricThreshold,
+ CustomMetricThreshold,
+ DataDriftMetricThresholdBase,
+ DataQualityMetricThresholdBase,
+ FeatureAttributionMetricThreshold,
+ GenerationSafetyQualityMetricThreshold,
+ GenerationTokenStatisticsMetricThreshold,
+ ModelPerformanceMetricThresholdBase,
+ MonitoringThreshold,
+ NumericalDataDriftMetricThreshold,
+ NumericalDataQualityMetricThreshold,
+ NumericalPredictionDriftMetricThreshold,
+ PredictionDriftMetricThresholdBase,
+)
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml._utils.utils import camel_to_snake, snake_to_camel
+from azure.ai.ml.constants._monitoring import MonitorFeatureType, MonitorMetricName
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+
+class MetricThreshold(RestTranslatableMixin):
+ def __init__(self, *, threshold: Optional[float] = None):
+ self.data_type: Any = None
+ self.metric_name: Optional[str] = None
+ self.threshold = threshold
+
+
+class NumericalDriftMetrics(RestTranslatableMixin):
+ """Numerical Drift Metrics
+
+ :param jensen_shannon_distance: The Jensen-Shannon distance between the two distributions
+ :paramtype jensen_shannon_distance: float
+ :param normalized_wasserstein_distance: The normalized Wasserstein distance between the two distributions
+ :paramtype normalized_wasserstein_distance: float
+ :param population_stability_index: The population stability index between the two distributions
+ :paramtype population_stability_index: float
+ :param two_sample_kolmogorov_smirnov_test: The two sample Kolmogorov-Smirnov test between the two distributions
+ :paramtype two_sample_kolmogorov_smirnov_test: float
+ """
+
+ def __init__(
+ self,
+ *,
+ jensen_shannon_distance: Optional[float] = None,
+ normalized_wasserstein_distance: Optional[float] = None,
+ population_stability_index: Optional[float] = None,
+ two_sample_kolmogorov_smirnov_test: Optional[float] = None,
+ metric: Optional[str] = None,
+ metric_threshold: Optional[float] = None,
+ ):
+ self.jensen_shannon_distance = jensen_shannon_distance
+ self.normalized_wasserstein_distance = normalized_wasserstein_distance
+ self.population_stability_index = population_stability_index
+ self.two_sample_kolmogorov_smirnov_test = two_sample_kolmogorov_smirnov_test
+ self.metric = metric
+ self.metric_threshold = metric_threshold
+
+ def _find_name_and_threshold(self) -> Tuple:
+ metric_name = None
+ threshold = None
+ if self.jensen_shannon_distance:
+ metric_name = MonitorMetricName.JENSEN_SHANNON_DISTANCE
+ threshold = MonitoringThreshold(value=self.jensen_shannon_distance)
+ elif self.normalized_wasserstein_distance:
+ metric_name = MonitorMetricName.NORMALIZED_WASSERSTEIN_DISTANCE
+ threshold = MonitoringThreshold(value=self.normalized_wasserstein_distance)
+ elif self.population_stability_index:
+ metric_name = MonitorMetricName.POPULATION_STABILITY_INDEX
+ threshold = MonitoringThreshold(value=self.population_stability_index)
+ elif self.two_sample_kolmogorov_smirnov_test:
+ metric_name = MonitorMetricName.TWO_SAMPLE_KOLMOGOROV_SMIRNOV_TEST
+ threshold = MonitoringThreshold(value=self.two_sample_kolmogorov_smirnov_test)
+
+ return metric_name, threshold
+
+ @classmethod
+ # pylint: disable=arguments-differ
+ def _from_rest_object(cls, metric_name: str, threshold: Optional[float]) -> "NumericalDriftMetrics": # type: ignore
+ metric_name = camel_to_snake(metric_name)
+ if metric_name == MonitorMetricName.JENSEN_SHANNON_DISTANCE:
+ return cls(jensen_shannon_distance=threshold)
+ if metric_name == MonitorMetricName.NORMALIZED_WASSERSTEIN_DISTANCE:
+ return cls(normalized_wasserstein_distance=threshold)
+ if metric_name == MonitorMetricName.POPULATION_STABILITY_INDEX:
+ return cls(population_stability_index=threshold)
+ if metric_name == MonitorMetricName.TWO_SAMPLE_KOLMOGOROV_SMIRNOV_TEST:
+ return cls(two_sample_kolmogorov_smirnov_test=threshold)
+ return cls()
+
+ @classmethod
+ def _get_default_thresholds(cls) -> "NumericalDriftMetrics":
+ return cls(
+ normalized_wasserstein_distance=0.1,
+ )
+
+
+class CategoricalDriftMetrics(RestTranslatableMixin):
+ """Categorical Drift Metrics
+
+ :param jensen_shannon_distance: The Jensen-Shannon distance between the two distributions
+ :paramtype jensen_shannon_distance: float
+ :param population_stability_index: The population stability index between the two distributions
+ :paramtype population_stability_index: float
+ :param pearsons_chi_squared_test: The Pearson's Chi-Squared test between the two distributions
+ :paramtype pearsons_chi_squared_test: float
+ """
+
+ def __init__(
+ self,
+ *,
+ jensen_shannon_distance: Optional[float] = None,
+ population_stability_index: Optional[float] = None,
+ pearsons_chi_squared_test: Optional[float] = None,
+ ):
+ self.jensen_shannon_distance = jensen_shannon_distance
+ self.population_stability_index = population_stability_index
+ self.pearsons_chi_squared_test = pearsons_chi_squared_test
+
+ def _find_name_and_threshold(self) -> Tuple:
+ metric_name = None
+ threshold = None
+ if self.jensen_shannon_distance:
+ metric_name = MonitorMetricName.JENSEN_SHANNON_DISTANCE
+ threshold = MonitoringThreshold(value=self.jensen_shannon_distance)
+ if self.population_stability_index and threshold is None:
+ metric_name = MonitorMetricName.POPULATION_STABILITY_INDEX
+ threshold = MonitoringThreshold(value=self.population_stability_index)
+ if self.pearsons_chi_squared_test and threshold is None:
+ metric_name = MonitorMetricName.PEARSONS_CHI_SQUARED_TEST
+ threshold = MonitoringThreshold(value=self.pearsons_chi_squared_test)
+
+ return metric_name, threshold
+
+ @classmethod
+ # pylint: disable=arguments-differ
+ def _from_rest_object( # type: ignore
+ cls, metric_name: str, threshold: Optional[float]
+ ) -> "CategoricalDriftMetrics":
+ metric_name = camel_to_snake(metric_name)
+ if metric_name == MonitorMetricName.JENSEN_SHANNON_DISTANCE:
+ return cls(jensen_shannon_distance=threshold)
+ if metric_name == MonitorMetricName.POPULATION_STABILITY_INDEX:
+ return cls(population_stability_index=threshold)
+ if metric_name == MonitorMetricName.PEARSONS_CHI_SQUARED_TEST:
+ return cls(pearsons_chi_squared_test=threshold)
+ return cls()
+
+ @classmethod
+ def _get_default_thresholds(cls) -> "CategoricalDriftMetrics":
+ return cls(
+ jensen_shannon_distance=0.1,
+ )
+
+
+class DataDriftMetricThreshold(MetricThreshold):
+ """Data drift metric threshold
+
+ :param numerical: Numerical drift metrics
+ :paramtype numerical: ~azure.ai.ml.entities.NumericalDriftMetrics
+ :param categorical: Categorical drift metrics
+ :paramtype categorical: ~azure.ai.ml.entities.CategoricalDriftMetrics
+ """
+
+ def __init__(
+ self,
+ *,
+ data_type: Optional[MonitorFeatureType] = None,
+ threshold: Optional[float] = None,
+ metric: Optional[str] = None,
+ numerical: Optional[NumericalDriftMetrics] = None,
+ categorical: Optional[CategoricalDriftMetrics] = None,
+ ):
+ super().__init__(threshold=threshold)
+ self.data_type = data_type
+ self.metric = metric
+ self.numerical = numerical
+ self.categorical = categorical
+
+ def _to_rest_object(self) -> DataDriftMetricThresholdBase:
+ thresholds = []
+ if self.numerical:
+ num_metric_name, num_threshold = self.numerical._find_name_and_threshold()
+ thresholds.append(
+ NumericalDataDriftMetricThreshold(
+ metric=snake_to_camel(num_metric_name),
+ threshold=num_threshold,
+ )
+ )
+ if self.categorical:
+ cat_metric_name, cat_threshold = self.categorical._find_name_and_threshold()
+ thresholds.append(
+ CategoricalDataDriftMetricThreshold(
+ metric=snake_to_camel(cat_metric_name),
+ threshold=cat_threshold,
+ )
+ )
+
+ return thresholds
+
+ @classmethod
+ def _from_rest_object(cls, obj: DataDriftMetricThresholdBase) -> "DataDriftMetricThreshold":
+ num = None
+ cat = None
+ for threshold in obj:
+ if threshold.data_type == "Numerical":
+ num = NumericalDriftMetrics()._from_rest_object( # pylint: disable=protected-access
+ threshold.metric, threshold.threshold.value if threshold.threshold else None
+ )
+ elif threshold.data_type == "Categorical":
+ cat = CategoricalDriftMetrics()._from_rest_object( # pylint: disable=protected-access
+ threshold.metric, threshold.threshold.value if threshold.threshold else None
+ )
+
+ return cls(
+ numerical=num,
+ categorical=cat,
+ )
+
+ @classmethod
+ def _get_default_thresholds(cls) -> "DataDriftMetricThreshold":
+ return cls(
+ numerical=NumericalDriftMetrics._get_default_thresholds(),
+ categorical=CategoricalDriftMetrics._get_default_thresholds(),
+ )
+
+ def __eq__(self, other: Any) -> bool:
+ if not isinstance(other, DataDriftMetricThreshold):
+ return NotImplemented
+ return self.numerical == other.numerical and self.categorical == other.categorical
+
+
+class PredictionDriftMetricThreshold(MetricThreshold):
+ """Prediction drift metric threshold
+
+ :param numerical: Numerical drift metrics
+ :paramtype numerical: ~azure.ai.ml.entities.NumericalDriftMetrics
+ :param categorical: Categorical drift metrics
+ :paramtype categorical: ~azure.ai.ml.entities.CategoricalDriftMetrics
+ """
+
+ def __init__(
+ self,
+ *,
+ data_type: Optional[MonitorFeatureType] = None,
+ threshold: Optional[float] = None,
+ numerical: Optional[NumericalDriftMetrics] = None,
+ categorical: Optional[CategoricalDriftMetrics] = None,
+ ):
+ super().__init__(threshold=threshold)
+ self.data_type = data_type
+ self.numerical = numerical
+ self.categorical = categorical
+
+ def _to_rest_object(self) -> PredictionDriftMetricThresholdBase:
+ thresholds = []
+ if self.numerical:
+ num_metric_name, num_threshold = self.numerical._find_name_and_threshold()
+ thresholds.append(
+ NumericalPredictionDriftMetricThreshold(
+ metric=snake_to_camel(num_metric_name),
+ threshold=num_threshold,
+ )
+ )
+ if self.categorical:
+ cat_metric_name, cat_threshold = self.categorical._find_name_and_threshold()
+ thresholds.append(
+ CategoricalPredictionDriftMetricThreshold(
+ metric=snake_to_camel(cat_metric_name),
+ threshold=cat_threshold,
+ )
+ )
+
+ return thresholds
+
+ @classmethod
+ def _from_rest_object(cls, obj: PredictionDriftMetricThresholdBase) -> "PredictionDriftMetricThreshold":
+ num = None
+ cat = None
+ for threshold in obj:
+ if threshold.data_type == "Numerical":
+ num = NumericalDriftMetrics()._from_rest_object( # pylint: disable=protected-access
+ threshold.metric, threshold.threshold.value if threshold.threshold else None
+ )
+ elif threshold.data_type == "Categorical":
+ cat = CategoricalDriftMetrics()._from_rest_object( # pylint: disable=protected-access
+ threshold.metric, threshold.threshold.value if threshold.threshold else None
+ )
+
+ return cls(
+ numerical=num,
+ categorical=cat,
+ )
+
+ @classmethod
+ def _get_default_thresholds(cls) -> "PredictionDriftMetricThreshold":
+ return cls(
+ numerical=NumericalDriftMetrics._get_default_thresholds(),
+ categorical=CategoricalDriftMetrics._get_default_thresholds(),
+ )
+
+ def __eq__(self, other: Any) -> bool:
+ if not isinstance(other, PredictionDriftMetricThreshold):
+ return NotImplemented
+ return (
+ self.data_type == other.data_type
+ and self.metric_name == other.metric_name
+ and self.threshold == other.threshold
+ )
+
+
+class DataQualityMetricsNumerical(RestTranslatableMixin):
+ """Data Quality Numerical Metrics
+
+ :param null_value_rate: The null value rate
+ :paramtype null_value_rate: float
+ :param data_type_error_rate: The data type error rate
+ :paramtype data_type_error_rate: float
+ :param out_of_bounds_rate: The out of bounds rate
+ :paramtype out_of_bounds_rate: float
+ """
+
+ def __init__(
+ self,
+ *,
+ null_value_rate: Optional[float] = None,
+ data_type_error_rate: Optional[float] = None,
+ out_of_bounds_rate: Optional[float] = None,
+ ):
+ self.null_value_rate = null_value_rate
+ self.data_type_error_rate = data_type_error_rate
+ self.out_of_bounds_rate = out_of_bounds_rate
+
+ def _to_rest_object(self) -> List[NumericalDataQualityMetricThreshold]:
+ metric_thresholds = []
+ if self.null_value_rate is not None:
+ metric_name = MonitorMetricName.NULL_VALUE_RATE
+ threshold = MonitoringThreshold(value=self.null_value_rate)
+ metric_thresholds.append(
+ NumericalDataQualityMetricThreshold(metric=snake_to_camel(metric_name), threshold=threshold)
+ )
+ if self.data_type_error_rate is not None:
+ metric_name = MonitorMetricName.DATA_TYPE_ERROR_RATE
+ threshold = MonitoringThreshold(value=self.data_type_error_rate)
+ metric_thresholds.append(
+ NumericalDataQualityMetricThreshold(metric=snake_to_camel(metric_name), threshold=threshold)
+ )
+ if self.out_of_bounds_rate is not None:
+ metric_name = MonitorMetricName.OUT_OF_BOUND_RATE
+ threshold = MonitoringThreshold(value=self.out_of_bounds_rate)
+ metric_thresholds.append(
+ NumericalDataQualityMetricThreshold(metric=snake_to_camel(metric_name), threshold=threshold)
+ )
+
+ return metric_thresholds
+
+ @classmethod
+ def _from_rest_object(cls, obj: List) -> "DataQualityMetricsNumerical":
+ null_value_rate_val = None
+ data_type_error_rate_val = None
+ out_of_bounds_rate_val = None
+ for thresholds in obj:
+ if thresholds.metric in ("NullValueRate" "nullValueRate"):
+ null_value_rate_val = thresholds.threshold.value
+ if thresholds.metric in ("DataTypeErrorRate", "dataTypeErrorRate"):
+ data_type_error_rate_val = thresholds.threshold.value
+ if thresholds.metric in ("OutOfBoundsRate", "outOfBoundsRate"):
+ out_of_bounds_rate_val = thresholds.threshold.value
+ return cls(
+ null_value_rate=null_value_rate_val,
+ data_type_error_rate=data_type_error_rate_val,
+ out_of_bounds_rate=out_of_bounds_rate_val,
+ )
+
+ @classmethod
+ def _get_default_thresholds(cls) -> "DataQualityMetricsNumerical":
+ return cls(
+ null_value_rate=0.0,
+ data_type_error_rate=0.0,
+ out_of_bounds_rate=0.0,
+ )
+
+
+class DataQualityMetricsCategorical(RestTranslatableMixin):
+ """Data Quality Categorical Metrics
+
+ :param null_value_rate: The null value rate
+ :paramtype null_value_rate: float
+ :param data_type_error_rate: The data type error rate
+ :paramtype data_type_error_rate: float
+ :param out_of_bounds_rate: The out of bounds rate
+ :paramtype out_of_bounds_rate: float
+ """
+
+ def __init__(
+ self,
+ *,
+ null_value_rate: Optional[float] = None,
+ data_type_error_rate: Optional[float] = None,
+ out_of_bounds_rate: Optional[float] = None,
+ ):
+ self.null_value_rate = null_value_rate
+ self.data_type_error_rate = data_type_error_rate
+ self.out_of_bounds_rate = out_of_bounds_rate
+
+ def _to_rest_object(self) -> List[CategoricalDataQualityMetricThreshold]:
+ metric_thresholds = []
+ if self.null_value_rate is not None:
+ metric_name = MonitorMetricName.NULL_VALUE_RATE
+ threshold = MonitoringThreshold(value=self.null_value_rate)
+ metric_thresholds.append(
+ CategoricalDataQualityMetricThreshold(metric=snake_to_camel(metric_name), threshold=threshold)
+ )
+ if self.data_type_error_rate is not None:
+ metric_name = MonitorMetricName.DATA_TYPE_ERROR_RATE
+ threshold = MonitoringThreshold(value=self.data_type_error_rate)
+ metric_thresholds.append(
+ CategoricalDataQualityMetricThreshold(metric=snake_to_camel(metric_name), threshold=threshold)
+ )
+ if self.out_of_bounds_rate is not None:
+ metric_name = MonitorMetricName.OUT_OF_BOUND_RATE
+ threshold = MonitoringThreshold(value=self.out_of_bounds_rate)
+ metric_thresholds.append(
+ CategoricalDataQualityMetricThreshold(metric=snake_to_camel(metric_name), threshold=threshold)
+ )
+
+ return metric_thresholds
+
+ @classmethod
+ def _from_rest_object(cls, obj: List) -> "DataQualityMetricsCategorical":
+ null_value_rate_val = None
+ data_type_error_rate_val = None
+ out_of_bounds_rate_val = None
+ for thresholds in obj:
+ if thresholds.metric in ("NullValueRate" "nullValueRate"):
+ null_value_rate_val = thresholds.threshold.value
+ if thresholds.metric in ("DataTypeErrorRate", "dataTypeErrorRate"):
+ data_type_error_rate_val = thresholds.threshold.value
+ if thresholds.metric in ("OutOfBoundsRate", "outOfBoundsRate"):
+ out_of_bounds_rate_val = thresholds.threshold.value
+ return cls(
+ null_value_rate=null_value_rate_val,
+ data_type_error_rate=data_type_error_rate_val,
+ out_of_bounds_rate=out_of_bounds_rate_val,
+ )
+
+ @classmethod
+ def _get_default_thresholds(cls) -> "DataQualityMetricsCategorical":
+ return cls(
+ null_value_rate=0.0,
+ data_type_error_rate=0.0,
+ out_of_bounds_rate=0.0,
+ )
+
+
+class DataQualityMetricThreshold(MetricThreshold):
+ """Data quality metric threshold
+
+ :param numerical: Numerical data quality metrics
+ :paramtype numerical: ~azure.ai.ml.entities.DataQualityMetricsNumerical
+ :param categorical: Categorical data quality metrics
+ :paramtype categorical: ~azure.ai.ml.entities.DataQualityMetricsCategorical
+ """
+
+ def __init__(
+ self,
+ *,
+ data_type: Optional[MonitorFeatureType] = None,
+ threshold: Optional[float] = None,
+ metric_name: Optional[str] = None,
+ numerical: Optional[DataQualityMetricsNumerical] = None,
+ categorical: Optional[DataQualityMetricsCategorical] = None,
+ ):
+ super().__init__(threshold=threshold)
+ self.data_type = data_type
+ self.metric_name = metric_name
+ self.numerical = numerical
+ self.categorical = categorical
+
+ def _to_rest_object(self) -> DataQualityMetricThresholdBase:
+ thresholds: list = []
+ if self.numerical:
+ thresholds = thresholds + (
+ DataQualityMetricsNumerical( # pylint: disable=protected-access
+ null_value_rate=self.numerical.null_value_rate,
+ data_type_error_rate=self.numerical.data_type_error_rate,
+ out_of_bounds_rate=self.numerical.out_of_bounds_rate,
+ )._to_rest_object()
+ )
+ if self.categorical:
+ thresholds = (
+ thresholds
+ + (
+ DataQualityMetricsCategorical( # pylint: disable=protected-access
+ null_value_rate=self.numerical.null_value_rate,
+ data_type_error_rate=self.numerical.data_type_error_rate,
+ out_of_bounds_rate=self.numerical.out_of_bounds_rate,
+ )._to_rest_object()
+ )
+ if self.numerical is not None
+ else thresholds
+ )
+ return thresholds
+
+ @classmethod
+ def _from_rest_object(cls, obj: DataQualityMetricThresholdBase) -> "DataQualityMetricThreshold":
+ num = []
+ cat = []
+ for threshold in obj:
+ if threshold.data_type == "Numerical":
+ num.append(threshold)
+ elif threshold.data_type == "Categorical":
+ cat.append(threshold)
+
+ num_from_rest = DataQualityMetricsNumerical()._from_rest_object(num) # pylint: disable=protected-access
+ cat_from_rest = DataQualityMetricsCategorical()._from_rest_object(cat) # pylint: disable=protected-access
+ return cls(
+ numerical=num_from_rest,
+ categorical=cat_from_rest,
+ )
+
+ @classmethod
+ def _get_default_thresholds(cls) -> "DataQualityMetricThreshold":
+ return cls(
+ numerical=DataQualityMetricsNumerical()._get_default_thresholds(), # pylint: disable=protected-access
+ categorical=DataQualityMetricsCategorical()._get_default_thresholds(), # pylint: disable=protected-access
+ )
+
+ def __eq__(self, other: Any) -> bool:
+ if not isinstance(other, DataQualityMetricThreshold):
+ return NotImplemented
+ return (
+ self.data_type == other.data_type
+ and self.metric_name == other.metric_name
+ and self.threshold == other.threshold
+ )
+
+
+@experimental
+class FeatureAttributionDriftMetricThreshold(MetricThreshold):
+ """Feature attribution drift metric threshold
+
+ :param normalized_discounted_cumulative_gain: The threshold value for metric.
+ :paramtype normalized_discounted_cumulative_gain: float
+ """
+
+ def __init__(
+ self, *, normalized_discounted_cumulative_gain: Optional[float] = None, threshold: Optional[float] = None
+ ):
+ super().__init__(threshold=threshold)
+ self.data_type = MonitorFeatureType.ALL_FEATURE_TYPES
+ self.metric_name = MonitorMetricName.NORMALIZED_DISCOUNTED_CUMULATIVE_GAIN
+ self.normalized_discounted_cumulative_gain = normalized_discounted_cumulative_gain
+
+ def _to_rest_object(self) -> FeatureAttributionMetricThreshold:
+ return FeatureAttributionMetricThreshold(
+ metric=snake_to_camel(self.metric_name),
+ threshold=(
+ MonitoringThreshold(value=self.normalized_discounted_cumulative_gain)
+ if self.normalized_discounted_cumulative_gain
+ else None
+ ),
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: FeatureAttributionMetricThreshold) -> "FeatureAttributionDriftMetricThreshold":
+ return cls(normalized_discounted_cumulative_gain=obj.threshold.value if obj.threshold else None)
+
+
+@experimental
+class ModelPerformanceClassificationThresholds(RestTranslatableMixin):
+ def __init__(
+ self,
+ *,
+ accuracy: Optional[float] = None,
+ precision: Optional[float] = None,
+ recall: Optional[float] = None,
+ ):
+ self.accuracy = accuracy
+ self.precision = precision
+ self.recall = recall
+
+ def _to_str_object(self, **kwargs):
+ thresholds = []
+ if self.accuracy:
+ thresholds.append(
+ '{"modelType":"classification","metric":"Accuracy","threshold":{"value":' + f"{self.accuracy}" + "}}"
+ )
+ if self.precision:
+ thresholds.append(
+ '{"modelType":"classification","metric":"Precision","threshold":{"value":' + f"{self.precision}" + "}}"
+ )
+ if self.recall:
+ thresholds.append(
+ '{"modelType":"classification","metric":"Recall","threshold":{"value":' + f"{self.recall}" + "}}"
+ )
+
+ if not thresholds:
+ return None
+
+ return ", ".join(thresholds)
+
+ @classmethod
+ def _from_rest_object(cls, obj) -> "ModelPerformanceClassificationThresholds":
+ return cls(
+ accuracy=obj.threshold.value if obj.threshold else None,
+ )
+
+
+@experimental
+class ModelPerformanceRegressionThresholds(RestTranslatableMixin):
+ def __init__(
+ self,
+ *,
+ mean_absolute_error: Optional[float] = None,
+ mean_squared_error: Optional[float] = None,
+ root_mean_squared_error: Optional[float] = None,
+ ):
+ self.mean_absolute_error = mean_absolute_error
+ self.mean_squared_error = mean_squared_error
+ self.root_mean_squared_error = root_mean_squared_error
+
+ def _to_str_object(self, **kwargs):
+ thresholds = []
+ if self.mean_absolute_error:
+ thresholds.append(
+ '{"modelType":"regression","metric":"MeanAbsoluteError","threshold":{"value":'
+ + f"{self.mean_absolute_error}"
+ + "}}"
+ )
+ if self.mean_squared_error:
+ thresholds.append(
+ '{"modelType":"regression","metric":"MeanSquaredError","threshold":{"value":'
+ + f"{self.mean_squared_error}"
+ + "}}"
+ )
+ if self.root_mean_squared_error:
+ thresholds.append(
+ '{"modelType":"regression","metric":"RootMeanSquaredError","threshold":{"value":'
+ + f"{self.root_mean_squared_error}"
+ + "}}"
+ )
+
+ if not thresholds:
+ return None
+
+ return ", ".join(thresholds)
+
+
+@experimental
+class ModelPerformanceMetricThreshold(RestTranslatableMixin):
+ def __init__(
+ self,
+ *,
+ classification: Optional[ModelPerformanceClassificationThresholds] = None,
+ regression: Optional[ModelPerformanceRegressionThresholds] = None,
+ ):
+ self.classification = classification
+ self.regression = regression
+
+ def _to_str_object(self, **kwargs):
+ thresholds = []
+ if self.classification:
+ thresholds.append(self.classification._to_str_object(**kwargs))
+ if self.regression:
+ thresholds.append(self.regression._to_str_object(**kwargs))
+
+ if not thresholds:
+ return None
+ if len(thresholds) == 2:
+ result = "[" + ", ".join(thresholds) + "]"
+ else:
+ result = "[" + thresholds[0] + "]"
+ return result
+
+ def _to_rest_object(self, **kwargs) -> ModelPerformanceMetricThresholdBase:
+ threshold = MonitoringThreshold(value=0.9)
+ return ClassificationModelPerformanceMetricThreshold(
+ metric="Accuracy",
+ threshold=threshold,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: ModelPerformanceMetricThresholdBase) -> "ModelPerformanceMetricThreshold":
+ return cls(
+ classification=ModelPerformanceClassificationThresholds._from_rest_object(obj),
+ regression=None,
+ )
+
+
+@experimental
+class CustomMonitoringMetricThreshold(MetricThreshold):
+ """Feature attribution drift metric threshold
+
+ :param metric_name: The metric to calculate
+ :type metric_name: str
+ :param threshold: The threshold value. If None, a default value will be set
+ depending on the selected metric.
+ :type threshold: float
+ """
+
+ def __init__(
+ self,
+ *,
+ metric_name: Optional[str],
+ threshold: Optional[float] = None,
+ ):
+ super().__init__(threshold=threshold)
+ self.metric_name = metric_name
+
+ def _to_rest_object(self) -> CustomMetricThreshold:
+ return CustomMetricThreshold(
+ metric=self.metric_name,
+ threshold=MonitoringThreshold(value=self.threshold) if self.threshold is not None else None,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: CustomMetricThreshold) -> "CustomMonitoringMetricThreshold":
+ return cls(metric_name=obj.metric, threshold=obj.threshold.value if obj.threshold else None)
+
+
+@experimental
+class GenerationSafetyQualityMonitoringMetricThreshold(RestTranslatableMixin): # pylint: disable=name-too-long
+ """Generation safety quality metric threshold
+
+ :param groundedness: The groundedness metric threshold
+ :paramtype groundedness: Dict[str, float]
+ :param relevance: The relevance metric threshold
+ :paramtype relevance: Dict[str, float]
+ :param coherence: The coherence metric threshold
+ :paramtype coherence: Dict[str, float]
+ :param fluency: The fluency metric threshold
+ :paramtype fluency: Dict[str, float]
+ :param similarity: The similarity metric threshold
+ :paramtype similarity: Dict[str, float]
+ """
+
+ def __init__(
+ self,
+ *,
+ groundedness: Optional[Dict[str, float]] = None,
+ relevance: Optional[Dict[str, float]] = None,
+ coherence: Optional[Dict[str, float]] = None,
+ fluency: Optional[Dict[str, float]] = None,
+ similarity: Optional[Dict[str, float]] = None,
+ ):
+ self.groundedness = groundedness
+ self.relevance = relevance
+ self.coherence = coherence
+ self.fluency = fluency
+ self.similarity = similarity
+
+ def _to_rest_object(self) -> GenerationSafetyQualityMetricThreshold:
+ metric_thresholds = []
+ if self.groundedness:
+ if "acceptable_groundedness_score_per_instance" in self.groundedness:
+ acceptable_threshold = MonitoringThreshold(
+ value=self.groundedness["acceptable_groundedness_score_per_instance"]
+ )
+ else:
+ acceptable_threshold = MonitoringThreshold(value=3)
+ metric_thresholds.append(
+ GenerationSafetyQualityMetricThreshold(
+ metric="AcceptableGroundednessScorePerInstance", threshold=acceptable_threshold
+ )
+ )
+ aggregated_threshold = MonitoringThreshold(value=self.groundedness["aggregated_groundedness_pass_rate"])
+ metric_thresholds.append(
+ GenerationSafetyQualityMetricThreshold(
+ metric="AggregatedGroundednessPassRate", threshold=aggregated_threshold
+ )
+ )
+ if self.relevance:
+ if "acceptable_relevance_score_per_instance" in self.relevance:
+ acceptable_threshold = MonitoringThreshold(
+ value=self.relevance["acceptable_relevance_score_per_instance"]
+ )
+ else:
+ acceptable_threshold = MonitoringThreshold(value=3)
+ metric_thresholds.append(
+ GenerationSafetyQualityMetricThreshold(
+ metric="AcceptableRelevanceScorePerInstance", threshold=acceptable_threshold
+ )
+ )
+ aggregated_threshold = MonitoringThreshold(value=self.relevance["aggregated_relevance_pass_rate"])
+ metric_thresholds.append(
+ GenerationSafetyQualityMetricThreshold(
+ metric="AggregatedRelevancePassRate", threshold=aggregated_threshold
+ )
+ )
+ if self.coherence:
+ if "acceptable_coherence_score_per_instance" in self.coherence:
+ acceptable_threshold = MonitoringThreshold(
+ value=self.coherence["acceptable_coherence_score_per_instance"]
+ )
+ else:
+ acceptable_threshold = MonitoringThreshold(value=3)
+ metric_thresholds.append(
+ GenerationSafetyQualityMetricThreshold(
+ metric="AcceptableCoherenceScorePerInstance", threshold=acceptable_threshold
+ )
+ )
+ aggregated_threshold = MonitoringThreshold(value=self.coherence["aggregated_coherence_pass_rate"])
+ metric_thresholds.append(
+ GenerationSafetyQualityMetricThreshold(
+ metric="AggregatedCoherencePassRate", threshold=aggregated_threshold
+ )
+ )
+ if self.fluency:
+ if "acceptable_fluency_score_per_instance" in self.fluency:
+ acceptable_threshold = MonitoringThreshold(value=self.fluency["acceptable_fluency_score_per_instance"])
+ else:
+ acceptable_threshold = MonitoringThreshold(value=3)
+ metric_thresholds.append(
+ GenerationSafetyQualityMetricThreshold(
+ metric="AcceptableFluencyScorePerInstance", threshold=acceptable_threshold
+ )
+ )
+ aggregated_threshold = MonitoringThreshold(value=self.fluency["aggregated_fluency_pass_rate"])
+ metric_thresholds.append(
+ GenerationSafetyQualityMetricThreshold(
+ metric="AggregatedFluencyPassRate", threshold=aggregated_threshold
+ )
+ )
+ if self.similarity:
+ if "acceptable_similarity_score_per_instance" in self.similarity:
+ acceptable_threshold = MonitoringThreshold(
+ value=self.similarity["acceptable_similarity_score_per_instance"]
+ )
+ else:
+ acceptable_threshold = MonitoringThreshold(value=3)
+ metric_thresholds.append(
+ GenerationSafetyQualityMetricThreshold(
+ metric="AcceptableSimilarityScorePerInstance", threshold=acceptable_threshold
+ )
+ )
+ aggregated_threshold = MonitoringThreshold(value=self.similarity["aggregated_similarity_pass_rate"])
+ metric_thresholds.append(
+ GenerationSafetyQualityMetricThreshold(
+ metric="AggregatedSimilarityPassRate", threshold=aggregated_threshold
+ )
+ )
+ return metric_thresholds
+
+ @classmethod
+ def _from_rest_object(
+ cls, obj: GenerationSafetyQualityMetricThreshold
+ ) -> "GenerationSafetyQualityMonitoringMetricThreshold":
+ groundedness = {}
+ relevance = {}
+ coherence = {}
+ fluency = {}
+ similarity = {}
+
+ for threshold in obj:
+ if threshold.metric == "AcceptableGroundednessScorePerInstance":
+ groundedness["acceptable_groundedness_score_per_instance"] = threshold.threshold.value
+ if threshold.metric == "AcceptableRelevanceScorePerInstance":
+ relevance["acceptable_relevance_score_per_instance"] = threshold.threshold.value
+ if threshold.metric == "AcceptableCoherenceScorePerInstance":
+ coherence["acceptable_coherence_score_per_instance"] = threshold.threshold.value
+ if threshold.metric == "AcceptableFluencyScorePerInstance":
+ fluency["acceptable_fluency_score_per_instance"] = threshold.threshold.value
+ if threshold.metric == "AcceptableSimilarityScorePerInstance":
+ similarity["acceptable_similarity_score_per_instance"] = threshold.threshold.value
+ if threshold.metric == "AggregatedGroundednessPassRate":
+ groundedness["aggregated_groundedness_pass_rate"] = threshold.threshold.value
+ if threshold.metric == "AggregatedRelevancePassRate":
+ relevance["aggregated_relevance_pass_rate"] = threshold.threshold.value
+ if threshold.metric == "AggregatedCoherencePassRate":
+ coherence["aggregated_coherence_pass_rate"] = threshold.threshold.value
+ if threshold.metric == "AggregatedFluencyPassRate":
+ fluency["aggregated_fluency_pass_rate"] = threshold.threshold.value
+ if threshold.metric == "AggregatedSimilarityPassRate":
+ similarity["aggregated_similarity_pass_rate"] = threshold.threshold.value
+
+ return cls(
+ groundedness=groundedness if groundedness else None,
+ relevance=relevance if relevance else None,
+ coherence=coherence if coherence else None,
+ fluency=fluency if fluency else None,
+ similarity=similarity if similarity else None,
+ )
+
+
+@experimental
+class GenerationTokenStatisticsMonitorMetricThreshold(RestTranslatableMixin): # pylint: disable=name-too-long
+ """Generation token statistics metric threshold definition.
+
+ All required parameters must be populated in order to send to Azure.
+
+ :ivar metric: Required. [Required] Gets or sets the feature attribution metric to calculate.
+ Possible values include: "TotalTokenCount", "TotalTokenCountPerGroup".
+ :vartype metric: str or
+ ~azure.mgmt.machinelearningservices.models.GenerationTokenStatisticsMetric
+ :ivar threshold: Gets or sets the threshold value.
+ If null, a default value will be set depending on the selected metric.
+ :vartype threshold: ~azure.mgmt.machinelearningservices.models.MonitoringThreshold
+ """
+
+ def __init__(
+ self,
+ *,
+ totaltoken: Optional[Dict[str, float]] = None,
+ ):
+ self.totaltoken = totaltoken
+
+ def _to_rest_object(self) -> GenerationSafetyQualityMetricThreshold:
+ metric_thresholds = []
+ if self.totaltoken:
+ if "total_token_count" in self.totaltoken:
+ acceptable_threshold = MonitoringThreshold(value=self.totaltoken["total_token_count"])
+ else:
+ acceptable_threshold = MonitoringThreshold(value=3)
+ metric_thresholds.append(
+ GenerationTokenStatisticsMetricThreshold(metric="TotalTokenCount", threshold=acceptable_threshold)
+ )
+ acceptable_threshold_per_group = MonitoringThreshold(value=self.totaltoken["total_token_count_per_group"])
+ metric_thresholds.append(
+ GenerationSafetyQualityMetricThreshold(
+ metric="TotalTokenCountPerGroup", threshold=acceptable_threshold_per_group
+ )
+ )
+ return metric_thresholds
+
+ @classmethod
+ def _from_rest_object(
+ cls, obj: GenerationTokenStatisticsMetricThreshold
+ ) -> "GenerationTokenStatisticsMonitorMetricThreshold":
+ totaltoken = {}
+ for threshold in obj:
+ if threshold.metric == "TotalTokenCount":
+ totaltoken["total_token_count"] = threshold.threshold.value
+ if threshold.metric == "TotalTokenCountPerGroup":
+ totaltoken["total_token_count_per_group"] = threshold.threshold.value
+
+ return cls(
+ totaltoken=totaltoken if totaltoken else None,
+ )
+
+ @classmethod
+ def _get_default_thresholds(cls) -> "GenerationTokenStatisticsMonitorMetricThreshold":
+ return cls(totaltoken={"total_token_count": 0, "total_token_count_per_group": 0})
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_notification/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_notification/__init__.py
new file mode 100644
index 00000000..fdf8caba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_notification/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_notification/notification.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_notification/notification.py
new file mode 100644
index 00000000..91380870
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_notification/notification.py
@@ -0,0 +1,33 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import List, Optional
+
+from azure.ai.ml._restclient.v2023_02_01_preview.models import NotificationSetting as RestNotificationSetting
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+
+class Notification(RestTranslatableMixin):
+ """Configuration for notification.
+
+ :param email_on: Send email notification to user on specified notification type. Accepted values are
+ "JobCompleted", "JobFailed", and "JobCancelled".
+ :type email_on: Optional[list[str]]
+ :param: The email recipient list which. Note that this parameter has a character limit of 499 which
+ includes all of the recipient strings and each comma seperator.
+ :paramtype emails: Optional[list[str]]
+ """
+
+ def __init__(self, *, email_on: Optional[List[str]] = None, emails: Optional[List[str]] = None) -> None:
+ self.email_on = email_on
+ self.emails = emails
+
+ def _to_rest_object(self) -> RestNotificationSetting:
+ return RestNotificationSetting(email_on=self.email_on, emails=self.emails)
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestNotificationSetting) -> Optional["Notification"]:
+ if not obj:
+ return None
+ return Notification(email_on=obj.email_on, emails=obj.emails)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_registry/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_registry/__init__.py
new file mode 100644
index 00000000..fdf8caba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_registry/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_registry/registry.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_registry/registry.py
new file mode 100644
index 00000000..a01e70d3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_registry/registry.py
@@ -0,0 +1,231 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from os import PathLike
+from pathlib import Path
+from typing import IO, Any, AnyStr, Dict, List, Optional, Union
+
+from azure.ai.ml._restclient.v2022_10_01_preview.models import ManagedServiceIdentity as RestManagedServiceIdentity
+from azure.ai.ml._restclient.v2022_10_01_preview.models import (
+ ManagedServiceIdentityType as RestManagedServiceIdentityType,
+)
+from azure.ai.ml._restclient.v2022_10_01_preview.models import Registry as RestRegistry
+from azure.ai.ml._restclient.v2022_10_01_preview.models import RegistryProperties
+from azure.ai.ml._utils.utils import dump_yaml_to_file
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY
+from azure.ai.ml.entities._assets.intellectual_property import IntellectualProperty
+from azure.ai.ml.entities._credentials import IdentityConfiguration
+from azure.ai.ml.entities._resource import Resource
+from azure.ai.ml.entities._util import load_from_dict
+
+from .registry_support_classes import RegistryRegionDetails
+
+CONTAINER_REGISTRY = "container_registry"
+REPLICATION_LOCATIONS = "replication_locations"
+INTELLECTUAL_PROPERTY = "intellectual_property"
+
+
+class Registry(Resource):
+ def __init__(
+ self,
+ *,
+ name: str,
+ location: str,
+ identity: Optional[IdentityConfiguration] = None,
+ tags: Optional[Dict[str, str]] = None,
+ public_network_access: Optional[str] = None,
+ discovery_url: Optional[str] = None,
+ intellectual_property: Optional[IntellectualProperty] = None,
+ managed_resource_group: Optional[str] = None,
+ mlflow_registry_uri: Optional[str] = None,
+ replication_locations: Optional[List[RegistryRegionDetails]],
+ **kwargs: Any,
+ ):
+ """Azure ML registry.
+
+ :param name: Name of the registry. Must be globally unique and is immutable.
+ :type name: str
+ :param location: The location this registry resource is located in.
+ :type location: str
+ :param identity: registry's System Managed Identity
+ :type identity: ManagedServiceIdentity
+ :param tags: Tags of the registry.
+ :type tags: dict
+ :param public_network_access: Whether to allow public endpoint connectivity.
+ :type public_network_access: str
+ :param discovery_url: Backend service base url for the registry.
+ :type discovery_url: str
+ :param intellectual_property: **Experimental** Intellectual property publisher.
+ :type intellectual_property: ~azure.ai.ml.entities.IntellectualProperty
+ :param managed_resource_group: Managed resource group created for the registry.
+ :type managed_resource_group: str
+ :param mlflow_registry_uri: Ml flow tracking uri for the registry.
+ :type mlflow_registry_uri: str
+ :param region_details: Details of each region the registry is in.
+ :type region_details: List[RegistryRegionDetails]
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: dict
+ """
+
+ super().__init__(name=name, tags=tags, **kwargs)
+
+ # self.display_name = name # Do we need a top-level visible name value?
+ self.location = location
+ self.identity = identity
+ self.replication_locations = replication_locations
+ self.public_network_access = public_network_access
+ self.intellectual_property = intellectual_property
+ self.managed_resource_group = managed_resource_group
+ self.discovery_url = discovery_url
+ self.mlflow_registry_uri = mlflow_registry_uri
+ self.container_registry = None
+
+ def dump(
+ self,
+ dest: Union[str, PathLike, IO[AnyStr]],
+ **kwargs: Any,
+ ) -> None:
+ """Dump the registry spec into a file in yaml format.
+
+ :param dest: Path to a local file as the target, new file will be created, raises exception if the file exists.
+ :type dest: str
+ """
+ yaml_serialized = self._to_dict()
+ dump_yaml_to_file(dest, yaml_serialized, default_flow_style=False)
+
+ # The internal structure of the registry object is closer to how it's
+ # represented by the registry API, which differs from how registries
+ # are represented in YAML. This function converts those differences.
+ def _to_dict(self) -> Dict:
+ # JIT import to avoid experimental warnings on unrelated calls
+ from azure.ai.ml._schema.registry.registry import RegistrySchema
+
+ schema = RegistrySchema(context={BASE_PATH_CONTEXT_KEY: "./"})
+
+ # Grab the first acr account of the first region and set that
+ # as the system-wide container registry.
+ # Although support for multiple ACRs per region, as well as
+ # different ACRs per region technically exist according to the
+ # API schema, we do not want to surface that as an option,
+ # since the use cases for variable/multiple ACRs are extremely
+ # limited, and would probably just confuse most users.
+ if self.replication_locations and len(self.replication_locations) > 0:
+ if self.replication_locations[0].acr_config and len(self.replication_locations[0].acr_config) > 0:
+ self.container_registry = self.replication_locations[0].acr_config[0] # type: ignore[assignment]
+
+ res: dict = schema.dump(self)
+ return res
+
+ @classmethod
+ def _load(
+ cls,
+ data: Optional[Dict] = None,
+ yaml_path: Optional[Union[PathLike, str]] = None,
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "Registry":
+ data = data or {}
+ params_override = params_override or []
+ context = {
+ BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"),
+ PARAMS_OVERRIDE_KEY: params_override,
+ }
+ # JIT import to avoid experimental warnings on unrelated calls
+ from azure.ai.ml._schema.registry.registry import RegistrySchema
+
+ loaded_schema = load_from_dict(RegistrySchema, data, context, **kwargs)
+ cls._convert_yaml_dict_to_entity_input(loaded_schema)
+ return Registry(**loaded_schema)
+
+ @classmethod
+ def _from_rest_object(cls, rest_obj: RestRegistry) -> Optional["Registry"]:
+ if not rest_obj:
+ return None
+ real_registry = rest_obj.properties
+
+ # Convert from api name region_details to user-shown name "replication locations"
+ replication_locations = []
+ if real_registry and real_registry.region_details:
+ replication_locations = [
+ RegistryRegionDetails._from_rest_object(details) for details in real_registry.region_details
+ ]
+ identity = None
+ if rest_obj.identity and isinstance(rest_obj.identity, RestManagedServiceIdentity):
+ identity = IdentityConfiguration._from_rest_object(rest_obj.identity)
+ return Registry(
+ name=rest_obj.name,
+ identity=identity,
+ id=rest_obj.id,
+ tags=rest_obj.tags,
+ location=rest_obj.location,
+ public_network_access=real_registry.public_network_access,
+ discovery_url=real_registry.discovery_url,
+ intellectual_property=(
+ IntellectualProperty(publisher=real_registry.intellectual_property_publisher)
+ if real_registry.intellectual_property_publisher
+ else None
+ ),
+ managed_resource_group=real_registry.managed_resource_group,
+ mlflow_registry_uri=real_registry.ml_flow_registry_uri,
+ replication_locations=replication_locations, # type: ignore[arg-type]
+ )
+
+ # There are differences between what our registry validation schema
+ # accepts, and how we actually represent things internally.
+ # This is mostly due to the compromise required to balance
+ # the actual shape of registries as they're defined by
+ # autorest with how the spec wanted users to be able to
+ # configure them. This function should eventually be
+ @classmethod
+ def _convert_yaml_dict_to_entity_input(
+ cls,
+ input: Dict, # pylint: disable=redefined-builtin
+ ) -> None:
+ # pop container_registry value.
+ global_acr_exists = False
+ if CONTAINER_REGISTRY in input:
+ acr_input = input.pop(CONTAINER_REGISTRY)
+ global_acr_exists = True
+ for region_detail in input[REPLICATION_LOCATIONS]:
+ # Apply container_registry as acr_config of each region detail
+ if global_acr_exists:
+ if not hasattr(region_detail, "acr_details") or len(region_detail.acr_details) == 0:
+ region_detail.acr_config = [acr_input] # pylint: disable=(possibly-used-before-assignment
+
+ def _to_rest_object(self) -> RestRegistry:
+ """Build current parameterized schedule instance to a registry object before submission.
+
+ :return: Rest registry.
+ :rtype: RestRegistry
+ """
+ identity = RestManagedServiceIdentity(type=RestManagedServiceIdentityType.SYSTEM_ASSIGNED)
+ replication_locations = []
+ if self.replication_locations:
+ replication_locations = [details._to_rest_object() for details in self.replication_locations]
+ # Notes about this construction.
+ # RestRegistry.properties.tags: this property exists due to swagger inheritance
+ # issues, don't actually use it, use top level RestRegistry.tags instead
+ # RestRegistry.properties.managed_resource_group_tags: Registries create a
+ # managed resource group to manage their internal sub-resources.
+ # We always want the tags on this MRG to match those of the registry itself
+ # to keep janitor policies aligned.
+ return RestRegistry(
+ name=self.name,
+ location=self.location,
+ identity=identity,
+ tags=self.tags,
+ properties=RegistryProperties(
+ public_network_access=self.public_network_access,
+ discovery_url=self.discovery_url,
+ intellectual_property_publisher=(
+ (self.intellectual_property.publisher) if self.intellectual_property else None
+ ),
+ managed_resource_group=self.managed_resource_group,
+ ml_flow_registry_uri=self.mlflow_registry_uri,
+ region_details=replication_locations,
+ managed_resource_group_tags=self.tags,
+ ),
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_registry/registry_support_classes.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_registry/registry_support_classes.py
new file mode 100644
index 00000000..810c5df5
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_registry/registry_support_classes.py
@@ -0,0 +1,273 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+# pylint:disable=protected-access,no-else-return
+
+from copy import deepcopy
+from functools import reduce
+from typing import List, Optional, Union
+
+from azure.ai.ml._exception_helper import log_and_raise_error
+from azure.ai.ml._restclient.v2022_10_01_preview.models import AcrDetails as RestAcrDetails
+from azure.ai.ml._restclient.v2022_10_01_preview.models import ArmResourceId as RestArmResourceId
+from azure.ai.ml._restclient.v2022_10_01_preview.models import RegistryRegionArmDetails as RestRegistryRegionArmDetails
+from azure.ai.ml._restclient.v2022_10_01_preview.models import StorageAccountDetails as RestStorageAccountDetails
+from azure.ai.ml._restclient.v2022_10_01_preview.models import SystemCreatedAcrAccount as RestSystemCreatedAcrAccount
+from azure.ai.ml._restclient.v2022_10_01_preview.models import (
+ SystemCreatedStorageAccount as RestSystemCreatedStorageAccount,
+)
+from azure.ai.ml._restclient.v2022_10_01_preview.models import UserCreatedAcrAccount as RestUserCreatedAcrAccount
+from azure.ai.ml.constants._registry import StorageAccountType
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+
+from .util import _make_rest_user_storage_from_id
+
+
+# This exists despite not being used by the schema validator because this entire
+# class is an output only value from the API.
+class SystemCreatedAcrAccount:
+ def __init__(
+ self,
+ *,
+ acr_account_sku: str,
+ arm_resource_id: Optional[str] = None,
+ ):
+ """Azure ML ACR account.
+
+ :param acr_account_sku: The storage account service tier. Currently
+ only Premium is a valid option for registries.
+ :type acr_account_sku: str
+ :param arm_resource_id: Resource ID of the ACR account.
+ :type arm_resource_id: str. Default value is None.
+ """
+ self.acr_account_sku = acr_account_sku
+ self.arm_resource_id = arm_resource_id
+
+ # acr should technically be a union between str and SystemCreatedAcrAccount,
+ # but python doesn't accept self class references apparently.
+ # Class method instead of normal function to accept possible
+ # string input.
+ @classmethod
+ def _to_rest_object(cls, acr: Union[str, "SystemCreatedAcrAccount"]) -> RestAcrDetails:
+ if hasattr(acr, "acr_account_sku") and acr.acr_account_sku is not None:
+ # SKU enum requires input to be a capitalized word,
+ # so we format the input to be acceptable as long as spelling is
+ # correct.
+ acr_account_sku = acr.acr_account_sku.capitalize()
+ # We DO NOT want to set the arm_resource_id. The backend provides very
+ # unhelpful errors if you provide an empty/null/invalid resource ID,
+ # and ignores the value otherwise. It's better to avoid setting it in
+ # the conversion in this direction at all.
+ return RestAcrDetails(
+ system_created_acr_account=RestSystemCreatedAcrAccount(
+ acr_account_sku=acr_account_sku,
+ )
+ )
+ else:
+ return RestAcrDetails(
+ user_created_acr_account=RestUserCreatedAcrAccount(arm_resource_id=RestArmResourceId(resource_id=acr))
+ )
+
+ @classmethod
+ def _from_rest_object(cls, rest_obj: RestAcrDetails) -> Optional["Union[str, SystemCreatedAcrAccount]"]:
+ if not rest_obj:
+ return None
+ if hasattr(rest_obj, "system_created_acr_account") and rest_obj.system_created_acr_account is not None:
+ resource_id = None
+ if rest_obj.system_created_acr_account.arm_resource_id:
+ resource_id = rest_obj.system_created_acr_account.arm_resource_id.resource_id
+ return SystemCreatedAcrAccount(
+ acr_account_sku=rest_obj.system_created_acr_account.acr_account_sku,
+ arm_resource_id=resource_id,
+ )
+ elif hasattr(rest_obj, "user_created_acr_account") and rest_obj.user_created_acr_account is not None:
+ res: Optional[str] = rest_obj.user_created_acr_account.arm_resource_id.resource_id
+ return res
+ else:
+ return None
+
+
+class SystemCreatedStorageAccount:
+ def __init__(
+ self,
+ *,
+ storage_account_hns: bool,
+ storage_account_type: Optional[StorageAccountType],
+ arm_resource_id: Optional[str] = None,
+ replicated_ids: Optional[List[str]] = None,
+ replication_count: int = 1,
+ ):
+ """
+ :param arm_resource_id: Resource ID of the storage account.
+ :type arm_resource_id: str
+ :param storage_account_hns: Whether or not this storage account
+ has hierarchical namespaces enabled.
+ :type storage_account_hns: bool
+ :param storage_account_type: Allowed values: "Standard_LRS",
+ "Standard_GRS, "Standard_RAGRS", "Standard_ZRS", "Standard_GZRS",
+ "Standard_RAGZRS", "Premium_LRS", "Premium_ZRS"
+ :type storage_account_type: StorageAccountType
+ :param replication_count: The number of replicas of this storage account
+ that should be created. Defaults to 1. Values less than 1 are invalid.
+ :type replication_count: int
+ :param replicated_ids: If this storage was replicated, then this is a
+ list of all storage IDs with these settings for this registry.
+ Defaults to none for un-replicated storage accounts.
+ :type replicated_ids: List[str]
+ """
+ self.arm_resource_id = arm_resource_id
+ self.storage_account_hns = storage_account_hns
+ self.storage_account_type = storage_account_type
+ self.replication_count = replication_count
+ self.replicated_ids = replicated_ids
+
+
+# Per-region information for registries.
+class RegistryRegionDetails:
+ def __init__(
+ self,
+ *,
+ acr_config: Optional[List[Union[str, SystemCreatedAcrAccount]]] = None,
+ location: Optional[str] = None,
+ storage_config: Optional[Union[List[str], SystemCreatedStorageAccount]] = None,
+ ):
+ """Details for each region a registry is in.
+
+ :param acr_details: List of ACR account details. Each value can either be a
+ single string representing the arm_resource_id of a user-created
+ acr_details object, or a entire SystemCreatedAcrAccount object.
+ :type acr_details: List[Union[str, SystemCreatedAcrAccount]]
+ :param location: The location where the registry exists.
+ :type location: str
+ :param storage_account_details: List of storage accounts. Each value
+ can either be a single string representing the arm_resource_id of
+ a user-created storage account, or an entire
+ SystemCreatedStorageAccount object.
+ :type storage_account_details: Union[List[str], SystemCreatedStorageAccount]
+ """
+ self.acr_config = acr_config
+ self.location = location
+ self.storage_config = storage_config
+
+ @classmethod
+ def _from_rest_object(cls, rest_obj: RestRegistryRegionArmDetails) -> Optional["RegistryRegionDetails"]:
+ if not rest_obj:
+ return None
+ converted_acr_details = []
+ if rest_obj.acr_details:
+ converted_acr_details = [SystemCreatedAcrAccount._from_rest_object(acr) for acr in rest_obj.acr_details]
+ storages: Optional[Union[List[str], SystemCreatedStorageAccount]] = []
+ if rest_obj.storage_account_details:
+ storages = cls._storage_config_from_rest_object(rest_obj.storage_account_details)
+
+ return RegistryRegionDetails(
+ acr_config=converted_acr_details, # type: ignore[arg-type]
+ location=rest_obj.location,
+ storage_config=storages,
+ )
+
+ def _to_rest_object(self) -> RestRegistryRegionArmDetails:
+ converted_acr_details = []
+ if self.acr_config:
+ converted_acr_details = [SystemCreatedAcrAccount._to_rest_object(acr) for acr in self.acr_config]
+ storages = []
+ if self.storage_config:
+ storages = self._storage_config_to_rest_object()
+ return RestRegistryRegionArmDetails(
+ acr_details=converted_acr_details,
+ location=self.location,
+ storage_account_details=storages,
+ )
+
+ def _storage_config_to_rest_object(self) -> List[RestStorageAccountDetails]:
+ storage = self.storage_config
+ # storage_config can either be a single system-created storage account,
+ # or list of user-inputted id's.
+ if (
+ storage is not None
+ and not isinstance(storage, list)
+ and hasattr(storage, "storage_account_type")
+ and storage.storage_account_type is not None
+ ):
+ # We DO NOT want to set the arm_resource_id. The backend provides very
+ # unhelpful errors if you provide an empty/null/invalid resource ID,
+ # and ignores the value otherwise. It's better to avoid setting it in
+ # the conversion in this direction at all.
+ # We don't bother processing storage_account_type because the
+ # rest version is case insensitive.
+ account = RestStorageAccountDetails(
+ system_created_storage_account=RestSystemCreatedStorageAccount(
+ storage_account_hns_enabled=storage.storage_account_hns,
+ storage_account_type=storage.storage_account_type,
+ )
+ )
+ # duplicate this value based on the replication_count
+ count = storage.replication_count
+ if count < 1:
+ raise ValueError(f"Replication count cannot be less than 1. Value was: {count}.")
+ return [deepcopy(account) for _ in range(0, count)]
+ elif storage is not None and not isinstance(storage, SystemCreatedStorageAccount) and len(storage) > 0:
+ return [_make_rest_user_storage_from_id(user_id=user_id) for user_id in storage]
+ else:
+ return []
+
+ @classmethod
+ def _storage_config_from_rest_object(
+ cls, rest_configs: Optional[List]
+ ) -> Optional[Union[List[str], SystemCreatedStorageAccount]]:
+ if not rest_configs:
+ return None
+ num_configs = len(rest_configs)
+ if num_configs == 0:
+ return None
+ system_created_count = reduce(
+ # TODO: Bug Item number: 2883323
+ lambda x, y: int(x) + int(y), # type: ignore
+ [
+ hasattr(config, "system_created_storage_account") and config.system_created_storage_account is not None
+ for config in rest_configs
+ ],
+ )
+ # configs should be mono-typed. Either they're all system created
+ # or all user created.
+ if system_created_count == num_configs:
+ # System created case - assume all elements are duplicates
+ # of a single storage configuration.
+ # Convert back into a single local representation by
+ # combining id's into a list, and using the first element's
+ # account type and hns.
+ first_config = rest_configs[0].system_created_storage_account
+ resource_id = None
+ if first_config.arm_resource_id:
+ resource_id = first_config.arm_resource_id.resource_id
+ # account for ids of duplicated if they exist
+ replicated_ids = None
+ if num_configs > 1:
+ replicated_ids = [
+ config.system_created_storage_account.arm_resource_id.resource_id for config in rest_configs
+ ]
+ return SystemCreatedStorageAccount(
+ storage_account_hns=first_config.storage_account_hns_enabled,
+ storage_account_type=(
+ (StorageAccountType(first_config.storage_account_type.lower()))
+ if first_config.storage_account_type
+ else None
+ ),
+ arm_resource_id=resource_id,
+ replication_count=num_configs,
+ replicated_ids=replicated_ids,
+ )
+ elif system_created_count == 0:
+ return [config.user_created_storage_account.arm_resource_id.resource_id for config in rest_configs]
+ else:
+ msg = f"""tried reading in a registry whose storage accounts were not
+ mono-managed or user-created. {system_created_count} out of {num_configs} were managed."""
+ err = ValidationException(
+ message=msg,
+ target=ErrorTarget.REGISTRY,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ log_and_raise_error(err)
+ return None
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_registry/util.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_registry/util.py
new file mode 100644
index 00000000..18f56169
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_registry/util.py
@@ -0,0 +1,17 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from azure.ai.ml._restclient.v2022_10_01_preview.models import ArmResourceId as RestArmResourceId
+from azure.ai.ml._restclient.v2022_10_01_preview.models import StorageAccountDetails as RestStorageAccountDetails
+from azure.ai.ml._restclient.v2022_10_01_preview.models import (
+ UserCreatedStorageAccount as RestUserCreatedStorageAccount,
+)
+
+
+def _make_rest_user_storage_from_id(*, user_id: str) -> RestStorageAccountDetails:
+ return RestStorageAccountDetails(
+ user_created_storage_account=RestUserCreatedStorageAccount(
+ arm_resource_id=RestArmResourceId(resource_id=user_id)
+ )
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_resource.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_resource.py
new file mode 100644
index 00000000..d20eaeff
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_resource.py
@@ -0,0 +1,194 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+
+import abc
+import os
+from os import PathLike
+from pathlib import Path
+from typing import IO, Any, AnyStr, Dict, List, Optional, Tuple, Union, cast
+
+from msrest import Serializer
+
+from azure.ai.ml._restclient.v2022_10_01 import models
+from azure.ai.ml._telemetry.logging_handler import in_jupyter_notebook
+from azure.ai.ml._utils.utils import dump_yaml
+
+from ..constants._common import BASE_PATH_CONTEXT_KEY
+from ._system_data import SystemData
+
+
+class Resource(abc.ABC):
+ """Base class for entity classes.
+
+ Resource is an abstract object that serves as a base for creating resources. It contains common properties and
+ methods for all resources.
+
+ This class should not be instantiated directly. Instead, use one of its subclasses.
+
+ :param name: The name of the resource.
+ :type name: str
+ :param description: The description of the resource.
+ :type description: Optional[str]
+ :param tags: Tags can be added, removed, and updated.
+ :type tags: Optional[dict]
+ :param properties: The resource's property dictionary.
+ :type properties: Optional[dict]
+ :keyword print_as_yaml: Specifies if the the resource should print out as a YAML-formatted object. If False,
+ the resource will print out in a more-compact style. By default, the YAML output is only used in Jupyter
+ notebooks. Be aware that some bookkeeping values are shown only in the non-YAML output.
+ :paramtype print_as_yaml: bool
+ """
+
+ def __init__(
+ self,
+ name: Optional[str],
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ properties: Optional[Dict] = None,
+ **kwargs: Any,
+ ) -> None:
+ self.name = name
+ self.description = description
+ self.tags: Optional[Dict] = dict(tags) if tags else {}
+ self.properties = dict(properties) if properties else {}
+ # Conditional assignment to prevent entity bloat when unused.
+ self._print_as_yaml = kwargs.pop("print_as_yaml", False)
+
+ # Hide read only properties in kwargs
+ self._id = kwargs.pop("id", None)
+ self.__source_path: Union[str, PathLike] = kwargs.pop("source_path", "")
+ self._base_path = kwargs.pop(BASE_PATH_CONTEXT_KEY, None) or os.getcwd() # base path should never be None
+ self._creation_context: Optional[SystemData] = kwargs.pop("creation_context", None)
+ client_models = {k: v for k, v in models.__dict__.items() if isinstance(v, type)}
+ self._serialize = Serializer(client_models)
+ self._serialize.client_side_validation = False
+ super().__init__(**kwargs)
+
+ @property
+ def _source_path(self) -> Union[str, PathLike]:
+ # source path is added to display file location for validation error messages
+ # usually, base_path = Path(source_path).parent if source_path else os.getcwd()
+ return self.__source_path
+
+ @_source_path.setter
+ def _source_path(self, value: Union[str, PathLike]) -> None:
+ self.__source_path = Path(value).as_posix()
+
+ @property
+ def id(self) -> Optional[str]:
+ """The resource ID.
+
+ :return: The global ID of the resource, an Azure Resource Manager (ARM) ID.
+ :rtype: Optional[str]
+ """
+ if self._id is None:
+ return None
+ return str(self._id)
+
+ @property
+ def creation_context(self) -> Optional[SystemData]:
+ """The creation context of the resource.
+
+ :return: The creation metadata for the resource.
+ :rtype: Optional[~azure.ai.ml.entities.SystemData]
+ """
+ return cast(Optional[SystemData], self._creation_context)
+
+ @property
+ def base_path(self) -> str:
+ """The base path of the resource.
+
+ :return: The base path of the resource.
+ :rtype: str
+ """
+ return self._base_path
+
+ @abc.abstractmethod
+ def dump(self, dest: Union[str, PathLike, IO[AnyStr]], **kwargs: Any) -> Any:
+ """Dump the object content into a file.
+
+ :param dest: The local path or file stream to write the YAML content to.
+ If dest is a file path, a new file will be created.
+ If dest is an open file, the file will be written to directly.
+ :type dest: Union[PathLike, str, IO[AnyStr]]
+ """
+
+ @classmethod
+ # pylint: disable=unused-argument
+ def _resolve_cls_and_type(cls, data: Dict, params_override: Optional[List[Dict]] = None) -> Tuple:
+ """Resolve the class to use for deserializing the data. Return current class if no override is provided.
+
+ :param data: Data to deserialize.
+ :type data: dict
+ :param params_override: Parameters to override, defaults to None
+ :type params_override: typing.Optional[list]
+ :return: Class to use for deserializing the data & its "type". Type will be None if no override is provided.
+ :rtype: tuple[class, typing.Optional[str]]
+ """
+ return cls, None
+
+ @classmethod
+ @abc.abstractmethod
+ def _load(
+ cls,
+ data: Optional[Dict] = None,
+ yaml_path: Optional[Union[PathLike, str]] = None,
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "Resource":
+ """Construct a resource object from a file. @classmethod.
+
+ :param cls: Indicates that this is a class method.
+ :type cls: class
+ :param data: Path to a local file as the source, defaults to None
+ :type data: typing.Optional[typing.Dict]
+ :param yaml_path: Path to a yaml file as the source, defaults to None
+ :type yaml_path: typing.Optional[typing.Union[typing.PathLike, str]]
+ :param params_override: Parameters to override, defaults to None
+ :type params_override: typing.Optional[list]
+ :return: Resource
+ :rtype: Resource
+ """
+
+ # pylint: disable:unused-argument
+ def _get_arm_resource(
+ self,
+ # pylint: disable=unused-argument
+ **kwargs: Any,
+ ) -> Dict:
+ """Get arm resource.
+
+ :return: Resource
+ :rtype: dict
+ """
+ from azure.ai.ml._arm_deployments.arm_helper import get_template
+
+ # pylint: disable=no-member
+ template = get_template(resource_type=self._arm_type) # type: ignore
+ # pylint: disable=no-member
+ template["copy"]["name"] = f"{self._arm_type}Deployment" # type: ignore
+ return dict(template)
+
+ def _get_arm_resource_and_params(self, **kwargs: Any) -> List:
+ """Get arm resource and parameters.
+
+ :return: Resource and parameters
+ :rtype: dict
+ """
+ resource = self._get_arm_resource(**kwargs)
+ # pylint: disable=no-member
+ param = self._to_arm_resource_param(**kwargs) # type: ignore
+ return [(resource, param)]
+
+ def __repr__(self) -> str:
+ var_dict = {k.strip("_"): v for (k, v) in vars(self).items()}
+ return f"{self.__class__.__name__}({var_dict})"
+
+ def __str__(self) -> str:
+ if self._print_as_yaml or in_jupyter_notebook():
+ # pylint: disable=no-member
+ yaml_serialized = self._to_dict() # type: ignore
+ return str(dump_yaml(yaml_serialized, default_flow_style=False))
+ return self.__repr__()
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_schedule/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_schedule/__init__.py
new file mode 100644
index 00000000..fdf8caba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_schedule/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_schedule/schedule.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_schedule/schedule.py
new file mode 100644
index 00000000..93867a9e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_schedule/schedule.py
@@ -0,0 +1,513 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+# pylint: disable=protected-access
+import logging
+import typing
+from os import PathLike
+from pathlib import Path
+from typing import IO, Any, AnyStr, Dict, List, Optional, Tuple, Union
+
+from typing_extensions import Literal
+
+from azure.ai.ml._restclient.v2023_06_01_preview.models import JobBase as RestJobBase
+from azure.ai.ml._restclient.v2023_06_01_preview.models import JobScheduleAction
+from azure.ai.ml._restclient.v2023_06_01_preview.models import PipelineJob as RestPipelineJob
+from azure.ai.ml._restclient.v2023_06_01_preview.models import Schedule as RestSchedule
+from azure.ai.ml._restclient.v2023_06_01_preview.models import ScheduleActionType as RestScheduleActionType
+from azure.ai.ml._restclient.v2023_06_01_preview.models import ScheduleProperties
+from azure.ai.ml._restclient.v2024_01_01_preview.models import TriggerRunSubmissionDto as RestTriggerRunSubmissionDto
+from azure.ai.ml._schema.schedule.schedule import JobScheduleSchema
+from azure.ai.ml._utils.utils import camel_to_snake, dump_yaml_to_file, is_private_preview_enabled
+from azure.ai.ml.constants import JobType
+from azure.ai.ml.constants._common import ARM_ID_PREFIX, BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY, ScheduleType
+from azure.ai.ml.entities._job.command_job import CommandJob
+from azure.ai.ml.entities._job.job import Job
+from azure.ai.ml.entities._job.pipeline.pipeline_job import PipelineJob
+from azure.ai.ml.entities._job.spark_job import SparkJob
+from azure.ai.ml.entities._mixins import RestTranslatableMixin, TelemetryMixin, YamlTranslatableMixin
+from azure.ai.ml.entities._resource import Resource
+from azure.ai.ml.entities._system_data import SystemData
+from azure.ai.ml.entities._util import load_from_dict
+from azure.ai.ml.entities._validation import MutableValidationResult, PathAwareSchemaValidatableMixin
+
+from ...exceptions import ErrorCategory, ErrorTarget, ScheduleException, ValidationException
+from .._builders import BaseNode
+from .trigger import CronTrigger, RecurrenceTrigger, TriggerBase
+
+module_logger = logging.getLogger(__name__)
+
+
+class Schedule(YamlTranslatableMixin, PathAwareSchemaValidatableMixin, Resource):
+ """Schedule object used to create and manage schedules.
+
+ This class should not be instantiated directly. Instead, please use the subclasses.
+
+ :keyword name: The name of the schedule.
+ :paramtype name: str
+ :keyword trigger: The schedule trigger configuration.
+ :paramtype trigger: Union[~azure.ai.ml.entities.CronTrigger, ~azure.ai.ml.entities.RecurrenceTrigger]
+ :keyword display_name: The display name of the schedule.
+ :paramtype display_name: Optional[str]
+ :keyword description: The description of the schedule.
+ :paramtype description: Optional[str]
+ :keyword tags: Tag dictionary. Tags can be added, removed, and updated.
+ :paramtype tags: Optional[dict]]
+ :keyword properties: A dictionary of properties to associate with the schedule.
+ :paramtype properties: Optional[dict[str, str]]
+ :keyword kwargs: Additional keyword arguments passed to the Resource constructor.
+ :paramtype kwargs: dict
+ """
+
+ def __init__(
+ self,
+ *,
+ name: str,
+ trigger: Optional[Union[CronTrigger, RecurrenceTrigger]],
+ display_name: Optional[str] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ properties: Optional[Dict] = None,
+ **kwargs: Any,
+ ) -> None:
+ is_enabled = kwargs.pop("is_enabled", None)
+ provisioning_state = kwargs.pop("provisioning_state", None)
+ super().__init__(name=name, description=description, tags=tags, properties=properties, **kwargs)
+ self.trigger = trigger
+ self.display_name = display_name
+ self._is_enabled: bool = is_enabled
+ self._provisioning_state: str = provisioning_state
+ self._type: Any = None
+
+ def dump(self, dest: Union[str, PathLike, IO[AnyStr]], **kwargs: Any) -> None:
+ """Dump the schedule content into a file in YAML format.
+
+ :param dest: The local path or file stream to write the YAML content to.
+ If dest is a file path, a new file will be created.
+ If dest is an open file, the file will be written to directly.
+ :type dest: Union[PathLike, str, IO[AnyStr]]
+ :raises FileExistsError: Raised if dest is a file path and the file already exists.
+ :raises IOError: Raised if dest is an open file and the file is not writable.
+ """
+ path = kwargs.pop("path", None)
+ yaml_serialized = self._to_dict()
+ dump_yaml_to_file(dest, yaml_serialized, default_flow_style=False, path=path, **kwargs)
+
+ @classmethod
+ def _create_validation_error(cls, message: str, no_personal_data_message: str) -> ValidationException:
+ return ValidationException(
+ message=message,
+ no_personal_data_message=no_personal_data_message,
+ target=ErrorTarget.SCHEDULE,
+ )
+
+ @classmethod
+ def _resolve_cls_and_type(cls, data: Dict, params_override: Optional[List[Dict]] = None) -> Tuple:
+ from azure.ai.ml.entities._data_import.schedule import ImportDataSchedule
+ from azure.ai.ml.entities._monitoring.schedule import MonitorSchedule
+
+ if "create_monitor" in data:
+ return MonitorSchedule, None
+ if "import_data" in data:
+ return ImportDataSchedule, None
+ return JobSchedule, None
+
+ @property
+ def create_job(self) -> Any: # pylint: disable=useless-return
+ """The create_job entity associated with the schedule if exists."""
+ module_logger.warning("create_job is not a valid property of %s", str(type(self)))
+ # return None here just to be explicit
+ return None
+
+ @create_job.setter
+ def create_job(self, value: Any) -> None: # pylint: disable=unused-argument
+ """Set the create_job entity associated with the schedule if exists.
+
+ :param value: The create_job entity associated with the schedule if exists.
+ :type value: Any
+ """
+ module_logger.warning("create_job is not a valid property of %s", str(type(self)))
+
+ @property
+ def is_enabled(self) -> bool:
+ """Specifies if the schedule is enabled or not.
+
+ :return: True if the schedule is enabled, False otherwise.
+ :rtype: bool
+ """
+ return self._is_enabled
+
+ @property
+ def provisioning_state(self) -> str:
+ """Returns the schedule's provisioning state. The possible values include
+ "Creating", "Updating", "Deleting", "Succeeded", "Failed", "Canceled".
+
+ :return: The schedule's provisioning state.
+ :rtype: str
+ """
+ return self._provisioning_state
+
+ @property
+ def type(self) -> Optional[str]:
+ """The schedule type. Accepted values are 'job' and 'monitor'.
+
+ :return: The schedule type.
+ :rtype: str
+ """
+ return self._type
+
+ def _to_dict(self) -> Dict:
+ res: dict = self._dump_for_validation()
+ return res
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestSchedule) -> "Schedule":
+ from azure.ai.ml.entities._data_import.schedule import ImportDataSchedule
+ from azure.ai.ml.entities._monitoring.schedule import MonitorSchedule
+
+ if obj.properties.action.action_type == RestScheduleActionType.CREATE_JOB:
+ return JobSchedule._from_rest_object(obj)
+ if obj.properties.action.action_type == RestScheduleActionType.CREATE_MONITOR:
+ res_monitor_schedule: Schedule = MonitorSchedule._from_rest_object(obj)
+ return res_monitor_schedule
+ if obj.properties.action.action_type == RestScheduleActionType.IMPORT_DATA:
+ res_data_schedule: Schedule = ImportDataSchedule._from_rest_object(obj)
+ return res_data_schedule
+ msg = f"Unsupported schedule type {obj.properties.action.action_type}"
+ raise ScheduleException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.SCHEDULE,
+ error_category=ErrorCategory.SYSTEM_ERROR,
+ )
+
+
+class JobSchedule(RestTranslatableMixin, Schedule, TelemetryMixin):
+ """Class for managing job schedules.
+
+ :keyword name: The name of the schedule.
+ :paramtype name: str
+ :keyword trigger: The trigger configuration for the schedule.
+ :paramtype trigger: Union[~azure.ai.ml.entities.CronTrigger, ~azure.ai.ml.entities.RecurrenceTrigger]
+ :keyword create_job: The job definition or an existing job name.
+ :paramtype create_job: Union[~azure.ai.ml.entities.Job, str]
+ :keyword display_name: The display name of the schedule.
+ :paramtype display_name: Optional[str]
+ :keyword description: The description of the schedule.
+ :paramtype description: Optional[str]
+ :keyword tags: Tag dictionary. Tags can be added, removed, and updated.
+ :paramtype tags: Optional[dict[str, str]]
+ :keyword properties: A dictionary of properties to associate with the schedule.
+ :paramtype properties: Optional[dict[str, str]]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START job_schedule_configuration]
+ :end-before: [END job_schedule_configuration]
+ :language: python
+ :dedent: 8
+ :caption: Configuring a JobSchedule.
+ """
+
+ def __init__(
+ self,
+ *,
+ name: str,
+ trigger: Optional[Union[CronTrigger, RecurrenceTrigger]],
+ create_job: Union[Job, str],
+ display_name: Optional[str] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict] = None,
+ properties: Optional[Dict] = None,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ name=name,
+ trigger=trigger,
+ display_name=display_name,
+ description=description,
+ tags=tags,
+ properties=properties,
+ **kwargs,
+ )
+ self._create_job = create_job
+ self._type = ScheduleType.JOB
+
+ @property
+ def create_job(self) -> Union[Job, str]:
+ """Return the job associated with the schedule.
+
+ :return: The job definition or an existing job name.
+ :rtype: Union[~azure.ai.ml.entities.Job, str]
+ """
+ return self._create_job
+
+ @create_job.setter
+ def create_job(self, value: Union[Job, str]) -> None:
+ """Sets the job that will be run when the schedule is triggered.
+
+ :param value: The job definition or an existing job name.
+ :type value: Union[~azure.ai.ml.entities.Job, str]
+ """
+ self._create_job = value
+
+ @classmethod
+ def _load(
+ cls,
+ data: Optional[Dict] = None,
+ yaml_path: Optional[Union[PathLike, str]] = None,
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "JobSchedule":
+ data = data or {}
+ params_override = params_override or []
+ context = {
+ BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"),
+ PARAMS_OVERRIDE_KEY: params_override,
+ }
+ return JobSchedule(
+ base_path=context[BASE_PATH_CONTEXT_KEY],
+ **load_from_dict(JobScheduleSchema, data, context, **kwargs),
+ )
+
+ @classmethod
+ def _load_from_rest_dict(
+ cls,
+ data: Optional[Dict] = None,
+ yaml_path: Optional[Union[PathLike, str]] = None,
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "JobSchedule":
+ """
+ Load job schedule from rest object dict.
+
+ This function is added because the user-faced schema is different from the rest one.
+
+ For example:
+
+ user yaml create_job is a file reference with updates(not a job definition):
+
+ .. code-block:: yaml
+
+ create_job:
+ job: ./job.yaml
+ inputs:
+ input: 10
+
+ while what we get from rest will be a complete job definition:
+
+ .. code-block:: yaml
+
+ create_job:
+ name: xx
+ jobs:
+ node1: ...
+ inputs:
+ input: ..
+
+ :param data: The REST object to convert
+ :type data: Optional[Dict]
+ :param yaml_path: The yaml path
+ :type yaml_path: Optional[Union[PathLike str]]
+ :param params_override: A list of parameter overrides
+ :type params_override: Optional[list]
+ :return: The job schedule
+ :rtype: JobSchedule
+ """
+ data = data or {}
+ params_override = params_override or []
+ context = {
+ BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"),
+ PARAMS_OVERRIDE_KEY: params_override,
+ }
+ create_job_key = "create_job"
+ if create_job_key not in data:
+ msg = "Job definition for schedule '{}' can not be None."
+ raise ScheduleException(
+ message=msg.format(data["name"]),
+ no_personal_data_message=msg.format("[name]"),
+ target=ErrorTarget.JOB,
+ error_category=ErrorCategory.SYSTEM_ERROR,
+ )
+ # Load the job definition separately
+ create_job_data = data.pop(create_job_key)
+ # Save the id for remote job reference before load job, as data dict will be changed
+ job_id = create_job_data.get("id")
+ if isinstance(job_id, str) and job_id.startswith(ARM_ID_PREFIX):
+ job_id = job_id[len(ARM_ID_PREFIX) :]
+ create_job = Job._load(
+ data=create_job_data,
+ **kwargs,
+ )
+ # Set id manually as it is a dump only field in schema
+ create_job._id = job_id
+ schedule = JobSchedule(
+ base_path=context[BASE_PATH_CONTEXT_KEY],
+ **load_from_dict(JobScheduleSchema, data, context, **kwargs),
+ )
+ schedule.create_job = create_job
+ return schedule
+
+ @classmethod
+ def _create_schema_for_validation(cls, context: Any) -> JobScheduleSchema:
+ return JobScheduleSchema(context=context)
+
+ def _customized_validate(self) -> MutableValidationResult:
+ """Validate the resource with customized logic.
+
+ :return: The validation result
+ :rtype: MutableValidationResult
+ """
+ if isinstance(self.create_job, PipelineJob):
+ return self.create_job._validate()
+ return self._create_empty_validation_result()
+
+ @classmethod
+ def _get_skip_fields_in_schema_validation(cls) -> typing.List[str]:
+ """Get the fields that should be skipped in schema validation.
+
+ Override this method to add customized validation logic.
+
+ :return: The list of fields to skip in schema validation
+ :rtype: typing.List[str]
+ """
+ return ["create_job"]
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestSchedule) -> "JobSchedule":
+ properties = obj.properties
+ action: JobScheduleAction = properties.action
+ if action.job_definition is None:
+ msg = "Job definition for schedule '{}' can not be None."
+ raise ScheduleException(
+ message=msg.format(obj.name),
+ no_personal_data_message=msg.format("[name]"),
+ target=ErrorTarget.JOB,
+ error_category=ErrorCategory.SYSTEM_ERROR,
+ )
+ if camel_to_snake(action.job_definition.job_type) not in [JobType.PIPELINE, JobType.COMMAND, JobType.SPARK]:
+ msg = f"Unsupported job type {action.job_definition.job_type} for schedule '{{}}'."
+ raise ScheduleException(
+ message=msg.format(obj.name),
+ no_personal_data_message=msg.format("[name]"),
+ target=ErrorTarget.JOB,
+ # Classified as user_error as we may support other type afterwards.
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ # Wrap job definition with JobBase for Job._from_rest_object call.
+ create_job = RestJobBase(properties=action.job_definition)
+ # id is a readonly field so set it after init.
+ # TODO: Add this support after source job id move to JobBaseProperties
+ if hasattr(action.job_definition, "source_job_id"):
+ create_job.id = action.job_definition.source_job_id
+ create_job = Job._from_rest_object(create_job)
+ return cls(
+ trigger=TriggerBase._from_rest_object(properties.trigger),
+ create_job=create_job,
+ name=obj.name,
+ display_name=properties.display_name,
+ description=properties.description,
+ tags=properties.tags,
+ properties=properties.properties,
+ provisioning_state=properties.provisioning_state,
+ is_enabled=properties.is_enabled,
+ creation_context=SystemData._from_rest_object(obj.system_data),
+ )
+
+ def _to_rest_object(self) -> RestSchedule:
+ """Build current parameterized schedule instance to a schedule object before submission.
+
+ :return: Rest schedule.
+ :rtype: RestSchedule
+ """
+ if isinstance(self.create_job, BaseNode):
+ self.create_job = self.create_job._to_job()
+ private_enabled = is_private_preview_enabled()
+ if isinstance(self.create_job, PipelineJob):
+ job_definition = self.create_job._to_rest_object().properties
+ # Set the source job id, as it is used only for schedule scenario.
+ job_definition.source_job_id = self.create_job.id
+ elif private_enabled and isinstance(self.create_job, (CommandJob, SparkJob)):
+ job_definition = self.create_job._to_rest_object().properties
+ # TODO: Merge this branch with PipelineJob after source job id move to JobBaseProperties
+ # job_definition.source_job_id = self.create_job.id
+ elif isinstance(self.create_job, str): # arm id reference
+ # TODO: Update this after source job id move to JobBaseProperties
+ # Rest pipeline job will hold a 'Default' as experiment_name,
+ # MFE will add default if None, so pass an empty string here.
+ job_definition = RestPipelineJob(source_job_id=self.create_job, experiment_name="")
+ else:
+ msg = "Unsupported job type '{}' in schedule {}."
+ raise ValidationException(
+ message=msg.format(type(self.create_job).__name__, self.name),
+ no_personal_data_message=msg.format("[type]", "[name]"),
+ target=ErrorTarget.SCHEDULE,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ return RestSchedule(
+ properties=ScheduleProperties(
+ description=self.description,
+ properties=self.properties,
+ tags=self.tags,
+ action=JobScheduleAction(job_definition=job_definition),
+ display_name=self.display_name,
+ is_enabled=self._is_enabled,
+ trigger=self.trigger._to_rest_object() if self.trigger is not None else None,
+ )
+ )
+
+ def __str__(self) -> str:
+ try:
+ res_yaml: str = self._to_yaml()
+ return res_yaml
+ except BaseException: # pylint: disable=W0718
+ res_jobSchedule: str = super(JobSchedule, self).__str__()
+ return res_jobSchedule
+
+ # pylint: disable-next=docstring-missing-param
+ def _get_telemetry_values(self, *args: Any, **kwargs: Any) -> Dict[Literal["trigger_type"], str]:
+ """Return the telemetry values of schedule.
+
+ :return: A dictionary with telemetry values
+ :rtype: Dict[Literal["trigger_type"], str]
+ """
+ return {"trigger_type": type(self.trigger).__name__}
+
+
+class ScheduleTriggerResult:
+ """Schedule trigger result returned by trigger an enabled schedule once.
+
+ This class shouldn't be instantiated directly. Instead, it is used as the return type of schedule trigger.
+
+ :ivar str job_name:
+ :ivar str schedule_action_type:
+ """
+
+ def __init__(self, **kwargs):
+ self.job_name = kwargs.get("job_name", None)
+ self.schedule_action_type = kwargs.get("schedule_action_type", None)
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestTriggerRunSubmissionDto) -> "ScheduleTriggerResult":
+ """Construct a ScheduleJob from a rest object.
+
+ :param obj: The rest object to construct from.
+ :type obj: ~azure.ai.ml._restclient.v2024_01_01_preview.models.TriggerRunSubmissionDto
+ :return: The constructed ScheduleJob.
+ :rtype: ScheduleTriggerResult
+ """
+ return cls(
+ schedule_action_type=obj.schedule_action_type,
+ job_name=obj.submission_id,
+ )
+
+ def _to_dict(self) -> dict:
+ """Convert the object to a dictionary.
+ :return: The dictionary representation of the object.
+ :rtype: dict
+ """
+ return {
+ "job_name": self.job_name,
+ "schedule_action_type": self.schedule_action_type,
+ }
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_schedule/trigger.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_schedule/trigger.py
new file mode 100644
index 00000000..855aac9e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_schedule/trigger.py
@@ -0,0 +1,290 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+# pylint: disable=protected-access
+import logging
+from abc import ABC
+from datetime import datetime
+from typing import List, Optional, Union
+
+from azure.ai.ml._restclient.v2023_04_01_preview.models import CronTrigger as RestCronTrigger
+from azure.ai.ml._restclient.v2023_04_01_preview.models import RecurrenceSchedule as RestRecurrencePattern
+from azure.ai.ml._restclient.v2023_04_01_preview.models import RecurrenceTrigger as RestRecurrenceTrigger
+from azure.ai.ml._restclient.v2023_04_01_preview.models import TriggerBase as RestTriggerBase
+from azure.ai.ml._restclient.v2023_04_01_preview.models import TriggerType as RestTriggerType
+from azure.ai.ml._utils.utils import camel_to_snake, snake_to_camel
+from azure.ai.ml.constants import TimeZone
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+module_logger = logging.getLogger(__name__)
+
+
+class TriggerBase(RestTranslatableMixin, ABC):
+ """Base class of Trigger.
+
+ This class should not be instantiated directly. Instead, use one of its subclasses.
+
+ :keyword type: The type of trigger.
+ :paramtype type: str
+ :keyword start_time: Specifies the start time of the schedule in ISO 8601 format.
+ :paramtype start_time: Optional[Union[str, datetime]]
+ :keyword end_time: Specifies the end time of the schedule in ISO 8601 format.
+ Note that end_time is not supported for compute schedules.
+ :paramtype end_time: Optional[Union[str, datetime]]
+ :keyword time_zone: The time zone where the schedule will run. Defaults to UTC(+00:00).
+ Note that this applies to the start_time and end_time.
+ :paramtype time_zone: ~azure.ai.ml.constants.TimeZone
+ """
+
+ def __init__(
+ self,
+ *,
+ type: str, # pylint: disable=redefined-builtin
+ start_time: Optional[Union[str, datetime]] = None,
+ end_time: Optional[Union[str, datetime]] = None,
+ time_zone: Union[str, TimeZone] = TimeZone.UTC,
+ ) -> None:
+ super().__init__()
+ self.type = type
+ self.start_time = start_time
+ self.end_time = end_time
+ self.time_zone = time_zone
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestTriggerBase) -> Optional[Union["CronTrigger", "RecurrenceTrigger"]]:
+ if obj.trigger_type == RestTriggerType.RECURRENCE:
+ return RecurrenceTrigger._from_rest_object(obj)
+ if obj.trigger_type == RestTriggerType.CRON:
+ return CronTrigger._from_rest_object(obj)
+
+ return None
+
+
+class RecurrencePattern(RestTranslatableMixin):
+ """Recurrence pattern for a job schedule.
+
+ :keyword hours: The number of hours for the recurrence schedule pattern.
+ :paramtype hours: Union[int, List[int]]
+ :keyword minutes: The number of minutes for the recurrence schedule pattern.
+ :paramtype minutes: Union[int, List[int]]
+ :keyword week_days: A list of days of the week for the recurrence schedule pattern.
+ Acceptable values include: "monday", "tuesday", "wednesday", "thursday", "friday", "saturday", "sunday"
+ :type week_days: Optional[Union[str, List[str]]]
+ :keyword month_days: A list of days of the month for the recurrence schedule pattern.
+ :paramtype month_days: Optional[Union[int, List[int]]]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START job_schedule_configuration]
+ :end-before: [END job_schedule_configuration]
+ :language: python
+ :dedent: 8
+ :caption: Configuring a JobSchedule to use a RecurrencePattern.
+ """
+
+ def __init__(
+ self,
+ *,
+ hours: Union[int, List[int]],
+ minutes: Union[int, List[int]],
+ week_days: Optional[Union[str, List[str]]] = None,
+ month_days: Optional[Union[int, List[int]]] = None,
+ ) -> None:
+ self.hours = hours
+ self.minutes = minutes
+ self.week_days = week_days
+ self.month_days = month_days
+
+ def _to_rest_object(self) -> RestRecurrencePattern:
+ return RestRecurrencePattern(
+ hours=[self.hours] if not isinstance(self.hours, list) else self.hours,
+ minutes=[self.minutes] if not isinstance(self.minutes, list) else self.minutes,
+ week_days=[self.week_days] if self.week_days and not isinstance(self.week_days, list) else self.week_days,
+ month_days=(
+ [self.month_days] if self.month_days and not isinstance(self.month_days, list) else self.month_days
+ ),
+ )
+
+ def _to_rest_compute_pattern_object(self) -> RestRecurrencePattern:
+ # This function is added because we can't make compute trigger to use same class
+ # with schedule from service side.
+ if self.month_days:
+ module_logger.warning("'month_days' is ignored for not supported on compute recurrence schedule.")
+ return RestRecurrencePattern(
+ hours=[self.hours] if not isinstance(self.hours, list) else self.hours,
+ minutes=[self.minutes] if not isinstance(self.minutes, list) else self.minutes,
+ week_days=[self.week_days] if self.week_days and not isinstance(self.week_days, list) else self.week_days,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestRecurrencePattern) -> "RecurrencePattern":
+ return cls(
+ hours=obj.hours,
+ minutes=obj.minutes,
+ week_days=obj.week_days,
+ month_days=obj.month_days if hasattr(obj, "month_days") else None,
+ )
+
+
+class CronTrigger(TriggerBase):
+ """Cron Trigger for a job schedule.
+
+ :keyword expression: The cron expression of schedule, following NCronTab format.
+ :paramtype expression: str
+ :keyword start_time: The start time for the trigger. If using a datetime object, leave the tzinfo as None and use
+ the ``time_zone`` parameter to specify a time zone if needed. If using a string, use the format
+ YYYY-MM-DDThh:mm:ss. Defaults to running the first workload instantly and continuing future workloads
+ based on the schedule. If the start time is in the past, the first workload is run at the next calculated run
+ time.
+ :paramtype start_time: Optional[Union[str, datetime]]
+ :keyword end_time: The start time for the trigger. If using a datetime object, leave the tzinfo as None and use
+ the ``time_zone`` parameter to specify a time zone if needed. If using a string, use the format
+ YYYY-MM-DDThh:mm:ss. Note that end_time is not supported for compute schedules.
+ :paramtype end_time: Optional[Union[str, datetime]]
+ :keyword time_zone: The time zone where the schedule will run. Defaults to UTC(+00:00).
+ Note that this applies to the start_time and end_time.
+ :paramtype time_zone: Union[str, ~azure.ai.ml.constants.TimeZone]
+ :raises Exception: Raised if end_time is in the past.
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START cron_trigger_configuration]
+ :end-before: [END cron_trigger_configuration]
+ :language: python
+ :dedent: 8
+ :caption: Configuring a CronTrigger.
+ """
+
+ def __init__(
+ self,
+ *,
+ expression: str,
+ start_time: Optional[Union[str, datetime]] = None,
+ end_time: Optional[Union[str, datetime]] = None,
+ time_zone: Union[str, TimeZone] = TimeZone.UTC,
+ ) -> None:
+ super().__init__(
+ type=RestTriggerType.CRON,
+ start_time=start_time,
+ end_time=end_time,
+ time_zone=time_zone,
+ )
+ self.expression = expression
+
+ def _to_rest_object(self) -> RestCronTrigger: # v2022_12_01.models.CronTrigger
+ return RestCronTrigger(
+ trigger_type=self.type,
+ expression=self.expression,
+ start_time=self.start_time,
+ end_time=self.end_time,
+ time_zone=self.time_zone,
+ )
+
+ def _to_rest_compute_cron_object(self) -> RestCronTrigger: # v2022_12_01_preview.models.CronTrigger
+ # This function is added because we can't make compute trigger to use same class
+ # with schedule from service side.
+ if self.end_time:
+ module_logger.warning("'end_time' is ignored for not supported on compute schedule.")
+ return RestCronTrigger(
+ expression=self.expression,
+ start_time=self.start_time,
+ time_zone=self.time_zone,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestCronTrigger) -> "CronTrigger":
+ return cls(
+ expression=obj.expression,
+ start_time=obj.start_time,
+ end_time=obj.end_time,
+ time_zone=obj.time_zone,
+ )
+
+
+class RecurrenceTrigger(TriggerBase):
+ """Recurrence trigger for a job schedule.
+
+ :keyword start_time: Specifies the start time of the schedule in ISO 8601 format.
+ :paramtype start_time: Optional[Union[str, datetime]]
+ :keyword end_time: Specifies the end time of the schedule in ISO 8601 format.
+ Note that end_time is not supported for compute schedules.
+ :paramtype end_time: Optional[Union[str, datetime]]
+ :keyword time_zone: The time zone where the schedule will run. Defaults to UTC(+00:00).
+ Note that this applies to the start_time and end_time.
+ :paramtype time_zone: Union[str, ~azure.ai.ml.constants.TimeZone]
+ :keyword frequency: Specifies the frequency that the schedule should be triggered with.
+ Possible values include: "minute", "hour", "day", "week", "month".
+ :type frequency: str
+ :keyword interval: Specifies the interval in conjunction with the frequency that the schedule should be triggered
+ with.
+ :paramtype interval: int
+ :keyword schedule: Specifies the recurrence pattern.
+ :paramtype schedule: Optional[~azure.ai.ml.entities.RecurrencePattern]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START job_schedule_configuration]
+ :end-before: [END job_schedule_configuration]
+ :language: python
+ :dedent: 8
+ :caption: Configuring a JobSchedule to trigger recurrence every 4 weeks.
+ """
+
+ def __init__(
+ self,
+ *,
+ frequency: str,
+ interval: int,
+ schedule: Optional[RecurrencePattern] = None,
+ start_time: Optional[Union[str, datetime]] = None,
+ end_time: Optional[Union[str, datetime]] = None,
+ time_zone: Union[str, TimeZone] = TimeZone.UTC,
+ ) -> None:
+ super().__init__(
+ type=RestTriggerType.RECURRENCE,
+ start_time=start_time,
+ end_time=end_time,
+ time_zone=time_zone,
+ )
+ # Create empty pattern as schedule is required in rest model
+ self.schedule = schedule if schedule else RecurrencePattern(hours=[], minutes=[])
+ self.frequency = frequency
+ self.interval = interval
+
+ def _to_rest_object(self) -> RestRecurrenceTrigger: # v2022_12_01.models.RecurrenceTrigger
+ return RestRecurrenceTrigger(
+ frequency=snake_to_camel(self.frequency),
+ interval=self.interval,
+ schedule=self.schedule._to_rest_object(),
+ start_time=self.start_time,
+ end_time=self.end_time,
+ time_zone=self.time_zone,
+ )
+
+ def _to_rest_compute_recurrence_object(self) -> RestRecurrenceTrigger:
+ # v2022_12_01_preview.models.RecurrenceTrigger
+ # This function is added because we can't make compute trigger to use same class
+ # with schedule from service side.
+ if self.end_time:
+ module_logger.warning("'end_time' is ignored for not supported on compute schedule.")
+ return RestRecurrenceTrigger(
+ frequency=snake_to_camel(self.frequency),
+ interval=self.interval,
+ schedule=self.schedule._to_rest_compute_pattern_object(),
+ start_time=self.start_time,
+ time_zone=self.time_zone,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestRecurrenceTrigger) -> "RecurrenceTrigger":
+ return cls(
+ frequency=camel_to_snake(obj.frequency),
+ interval=obj.interval,
+ schedule=RecurrencePattern._from_rest_object(obj.schedule) if obj.schedule else None,
+ start_time=obj.start_time,
+ end_time=obj.end_time,
+ time_zone=obj.time_zone,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_system_data.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_system_data.py
new file mode 100644
index 00000000..05020da2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_system_data.py
@@ -0,0 +1,77 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+
+from typing import Any
+
+from azure.ai.ml._restclient.v2022_10_01.models import SystemData as RestSystemData
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+
+class SystemData(RestTranslatableMixin):
+ """Metadata related to the creation and most recent modification of a resource.
+
+ :ivar created_by: The identity that created the resource.
+ :vartype created_by: str
+ :ivar created_by_type: The type of identity that created the resource. Possible values include:
+ "User", "Application", "ManagedIdentity", "Key".
+ :vartype created_by_type: str or ~azure.ai.ml.entities.CreatedByType
+ :ivar created_at: The timestamp of resource creation (UTC).
+ :vartype created_at: ~datetime.datetime
+ :ivar last_modified_by: The identity that last modified the resource.
+ :vartype last_modified_by: str
+ :ivar last_modified_by_type: The type of identity that last modified the resource. Possible
+ values include: "User", "Application", "ManagedIdentity", "Key".
+ :vartype last_modified_by_type: str or ~azure.ai.ml.entities.CreatedByType
+ :ivar last_modified_at: The timestamp of resource last modification (UTC).
+ :vartype last_modified_at: ~datetime.datetime
+ :keyword created_by: The identity that created the resource.
+ :paramtype created_by: str
+ :keyword created_by_type: The type of identity that created the resource. Accepted values are
+ "User", "Application", "ManagedIdentity", "Key".
+ :paramtype created_by_type: Union[str, ~azure.ai.ml.entities.CreatedByType]
+ :keyword created_at: The timestamp of resource creation (UTC).
+ :paramtype created_at: datetime
+ :keyword last_modified_by: The identity that last modified the resource.
+ :paramtype last_modified_by: str
+ :keyword last_modified_by_type: The type of identity that last modified the resource. Accepted values are
+ "User", "Application", "ManagedIdentity", "Key".
+ :paramtype last_modified_by_type: Union[str, ~azure.ai.ml.entities.CreatedByType]
+ :keyword last_modified_at: The timestamp of resource last modification in UTC.
+ :paramtype last_modified_at: datetime
+ """
+
+ def __init__(self, **kwargs: Any) -> None:
+ self.created_by = kwargs.get("created_by", None)
+ self.created_by_type = kwargs.get("created_by_type", None)
+ self.created_at = kwargs.get("created_at", None)
+ self.last_modified_by = kwargs.get("last_modified_by", None)
+ self.last_modified_by_type = kwargs.get("last_modified_by_type", None)
+ self.last_modified_at = kwargs.get("last_modified_at", None)
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestSystemData) -> "SystemData":
+ return cls(
+ created_by=obj.created_by,
+ created_at=obj.created_at,
+ created_by_type=obj.created_by_type,
+ last_modified_by=obj.last_modified_by,
+ last_modified_by_type=obj.last_modified_by_type,
+ last_modified_at=obj.last_modified_at,
+ )
+
+ def _to_rest_object(self) -> RestSystemData:
+ return RestSystemData(
+ created_by=self.created_by,
+ created_at=self.created_at,
+ created_by_type=self.created_by_type,
+ last_modified_by=self.last_modified_by,
+ last_modified_by_type=self.last_modified_by_type,
+ last_modified_at=self.last_modified_at,
+ )
+
+ def _to_dict(self) -> dict:
+ from azure.ai.ml._schema.job.creation_context import CreationContextSchema
+
+ return CreationContextSchema().dump(self) # pylint: disable=no-member
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_util.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_util.py
new file mode 100644
index 00000000..c487be6e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_util.py
@@ -0,0 +1,645 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+import copy
+import hashlib
+import json
+import os
+import shutil
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union, cast, overload
+from unittest import mock
+
+import msrest
+from marshmallow.exceptions import ValidationError
+
+from .._restclient.v2022_02_01_preview.models import JobInputType as JobInputType02
+from .._restclient.v2023_04_01_preview.models import JobInput as RestJobInput
+from .._restclient.v2023_04_01_preview.models import JobInputType as JobInputType10
+from .._restclient.v2023_04_01_preview.models import JobOutput as RestJobOutput
+from .._schema._datastore import AzureBlobSchema, AzureDataLakeGen1Schema, AzureDataLakeGen2Schema, AzureFileSchema
+from .._schema._deployment.batch.batch_deployment import BatchDeploymentSchema
+from .._schema._deployment.online.online_deployment import (
+ KubernetesOnlineDeploymentSchema,
+ ManagedOnlineDeploymentSchema,
+)
+from .._schema._endpoint.batch.batch_endpoint import BatchEndpointSchema
+from .._schema._endpoint.online.online_endpoint import KubernetesOnlineEndpointSchema, ManagedOnlineEndpointSchema
+from .._schema._sweep import SweepJobSchema
+from .._schema.assets.data import DataSchema
+from .._schema.assets.environment import EnvironmentSchema
+from .._schema.assets.model import ModelSchema
+from .._schema.component.command_component import CommandComponentSchema
+from .._schema.component.parallel_component import ParallelComponentSchema
+from .._schema.compute.aml_compute import AmlComputeSchema
+from .._schema.compute.compute_instance import ComputeInstanceSchema
+from .._schema.compute.virtual_machine_compute import VirtualMachineComputeSchema
+from .._schema.job import CommandJobSchema, ParallelJobSchema
+from .._schema.pipeline.pipeline_job import PipelineJobSchema
+from .._schema.schedule.schedule import JobScheduleSchema
+from .._schema.workspace import WorkspaceSchema
+from .._utils.utils import is_internal_component_data, try_enable_internal_components
+from ..constants._common import (
+ REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT,
+ CommonYamlFields,
+ YAMLRefDocLinks,
+ YAMLRefDocSchemaNames,
+)
+from ..constants._component import NodeType
+from ..constants._endpoint import EndpointYamlFields
+from ..entities._mixins import RestTranslatableMixin
+from ..exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+
+# avoid circular import error
+if TYPE_CHECKING:
+ from azure.ai.ml.entities._inputs_outputs import Output
+ from azure.ai.ml.entities._job.pipeline._io import NodeOutput
+
+# Maps schema class name to formatted error message pointing to Microsoft docs reference page for a schema's YAML
+REF_DOC_ERROR_MESSAGE_MAP = {
+ DataSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format(YAMLRefDocSchemaNames.DATA, YAMLRefDocLinks.DATA),
+ EnvironmentSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format(
+ YAMLRefDocSchemaNames.ENVIRONMENT, YAMLRefDocLinks.ENVIRONMENT
+ ),
+ ModelSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format(YAMLRefDocSchemaNames.MODEL, YAMLRefDocLinks.MODEL),
+ CommandComponentSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format(
+ YAMLRefDocSchemaNames.COMMAND_COMPONENT, YAMLRefDocLinks.COMMAND_COMPONENT
+ ),
+ ParallelComponentSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format(
+ YAMLRefDocSchemaNames.PARALLEL_COMPONENT, YAMLRefDocLinks.PARALLEL_COMPONENT
+ ),
+ AmlComputeSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format(
+ YAMLRefDocSchemaNames.AML_COMPUTE, YAMLRefDocLinks.AML_COMPUTE
+ ),
+ ComputeInstanceSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format(
+ YAMLRefDocSchemaNames.COMPUTE_INSTANCE, YAMLRefDocLinks.COMPUTE_INSTANCE
+ ),
+ VirtualMachineComputeSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format(
+ YAMLRefDocSchemaNames.VIRTUAL_MACHINE_COMPUTE,
+ YAMLRefDocLinks.VIRTUAL_MACHINE_COMPUTE,
+ ),
+ AzureDataLakeGen1Schema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format(
+ YAMLRefDocSchemaNames.DATASTORE_DATA_LAKE_GEN_1,
+ YAMLRefDocLinks.DATASTORE_DATA_LAKE_GEN_1,
+ ),
+ AzureBlobSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format(
+ YAMLRefDocSchemaNames.DATASTORE_BLOB, YAMLRefDocLinks.DATASTORE_BLOB
+ ),
+ AzureFileSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format(
+ YAMLRefDocSchemaNames.DATASTORE_FILE, YAMLRefDocLinks.DATASTORE_FILE
+ ),
+ AzureDataLakeGen2Schema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format(
+ YAMLRefDocSchemaNames.DATASTORE_DATA_LAKE_GEN_2,
+ YAMLRefDocLinks.DATASTORE_DATA_LAKE_GEN_2,
+ ),
+ BatchEndpointSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format(
+ YAMLRefDocSchemaNames.BATCH_ENDPOINT, YAMLRefDocLinks.BATCH_ENDPOINT
+ ),
+ KubernetesOnlineEndpointSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format(
+ YAMLRefDocSchemaNames.ONLINE_ENDPOINT, YAMLRefDocLinks.ONLINE_ENDPOINT
+ ),
+ ManagedOnlineEndpointSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format(
+ YAMLRefDocSchemaNames.ONLINE_ENDPOINT, YAMLRefDocLinks.ONLINE_ENDPOINT
+ ),
+ BatchDeploymentSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format(
+ YAMLRefDocSchemaNames.BATCH_DEPLOYMENT, YAMLRefDocLinks.BATCH_DEPLOYMENT
+ ),
+ ManagedOnlineDeploymentSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format(
+ YAMLRefDocSchemaNames.MANAGED_ONLINE_DEPLOYMENT,
+ YAMLRefDocLinks.MANAGED_ONLINE_DEPLOYMENT,
+ ),
+ KubernetesOnlineDeploymentSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format(
+ YAMLRefDocSchemaNames.KUBERNETES_ONLINE_DEPLOYMENT,
+ YAMLRefDocLinks.KUBERNETES_ONLINE_DEPLOYMENT,
+ ),
+ PipelineJobSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format(
+ YAMLRefDocSchemaNames.PIPELINE_JOB, YAMLRefDocLinks.PIPELINE_JOB
+ ),
+ JobScheduleSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format(
+ YAMLRefDocSchemaNames.JOB_SCHEDULE, YAMLRefDocLinks.JOB_SCHEDULE
+ ),
+ SweepJobSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format(
+ YAMLRefDocSchemaNames.SWEEP_JOB, YAMLRefDocLinks.SWEEP_JOB
+ ),
+ CommandJobSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format(
+ YAMLRefDocSchemaNames.COMMAND_JOB, YAMLRefDocLinks.COMMAND_JOB
+ ),
+ ParallelJobSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format(
+ YAMLRefDocSchemaNames.PARALLEL_JOB, YAMLRefDocLinks.PARALLEL_JOB
+ ),
+ WorkspaceSchema: REF_DOC_YAML_SCHEMA_ERROR_MSG_FORMAT.format(
+ YAMLRefDocSchemaNames.WORKSPACE, YAMLRefDocLinks.WORKSPACE
+ ),
+}
+
+
+def find_field_in_override(field: str, params_override: Optional[list] = None) -> Optional[str]:
+ """Find specific field in params override.
+
+ :param field: The name of the field to find
+ :type field: str
+ :param params_override: The params override
+ :type params_override: Optional[list]
+ :return: The type
+ :rtype: Optional[str]
+ """
+ params_override = params_override or []
+ for override in params_override:
+ if field in override:
+ res: Optional[str] = override[field]
+ return res
+ return None
+
+
+def find_type_in_override(params_override: Optional[list] = None) -> Optional[str]:
+ """Find type in params override.
+
+ :param params_override: The params override
+ :type params_override: Optional[list]
+ :return: The type
+ :rtype: Optional[str]
+ """
+ return find_field_in_override(CommonYamlFields.TYPE, params_override)
+
+
+def is_compute_in_override(params_override: Optional[list] = None) -> bool:
+ """Check if compute is in params override.
+
+ :param params_override: The params override
+ :type params_override: Optional[list]
+ :return: True if compute is in params override
+ :rtype: bool
+ """
+ if params_override is not None:
+ return any(EndpointYamlFields.COMPUTE in param for param in params_override)
+ return False
+
+
+def load_from_dict(schema: Any, data: Dict, context: Dict, additional_message: str = "", **kwargs: Any) -> Any:
+ """Load data from dict.
+
+ :param schema: The schema to load data with.
+ :type schema: Any
+ :param data: The data to load.
+ :type data: Dict
+ :param context: The context of the data.
+ :type context: Dict
+ :param additional_message: The additional message to add to the error message.
+ :type additional_message: str
+ :return: The loaded data.
+ :rtype: Any
+ """
+ try:
+ return schema(context=context).load(data, **kwargs)
+ except ValidationError as e:
+ pretty_error = json.dumps(e.normalized_messages(), indent=2)
+ raise ValidationError(decorate_validation_error(schema, pretty_error, additional_message)) from e
+
+
+def decorate_validation_error(schema: Any, pretty_error: str, additional_message: str = "") -> str:
+ """Decorate validation error with additional message.
+
+ :param schema: The schema that failed validation.
+ :type schema: Any
+ :param pretty_error: The pretty error message.
+ :type pretty_error: str
+ :param additional_message: The additional message to add.
+ :type additional_message: str
+ :return: The decorated error message.
+ :rtype: str
+ """
+ ref_doc_link_error_msg = REF_DOC_ERROR_MESSAGE_MAP.get(schema, "")
+ if ref_doc_link_error_msg:
+ additional_message += f"\n{ref_doc_link_error_msg}"
+ additional_message += (
+ "\nThe easiest way to author a specification file is using IntelliSense and auto-completion Azure ML VS "
+ "code extension provides: https://code.visualstudio.com/docs/datascience/azure-machine-learning. "
+ "To set up: https://learn.microsoft.com/azure/machine-learning/how-to-setup-vs-code"
+ )
+ return f"Validation for {schema.__name__} failed:\n\n {pretty_error} \n\n {additional_message}"
+
+
+def get_md5_string(text: Optional[str]) -> str:
+ """Get md5 string for a given text.
+
+ :param text: The text to get md5 string for.
+ :type text: str
+ :return: The md5 string.
+ :rtype: str
+ """
+ try:
+ if text is not None:
+ return hashlib.md5(text.encode("utf8")).hexdigest() # nosec
+ return ""
+ except Exception as ex:
+ raise ex
+
+
+def validate_attribute_type(attrs_to_check: Dict[str, Any], attr_type_map: Dict[str, Type]) -> None:
+ """Validate if attributes of object are set with valid types, raise error
+ if don't.
+
+ :param attrs_to_check: Mapping from attributes name to actual value.
+ :type attrs_to_check: Dict[str, Any]
+ :param attr_type_map: Mapping from attributes name to tuple of expecting type
+ :type attr_type_map: Dict[str, Type]
+ """
+ #
+ kwargs = attrs_to_check.get("kwargs", {})
+ attrs_to_check.update(kwargs)
+ for attr, expecting_type in attr_type_map.items():
+ attr_val = attrs_to_check.get(attr, None)
+ if attr_val is not None and not isinstance(attr_val, expecting_type):
+ msg = "Expecting {} for {}, got {} instead."
+ raise ValidationException(
+ message=msg.format(expecting_type, attr, type(attr_val)),
+ no_personal_data_message=msg.format(expecting_type, "[attr]", type(attr_val)),
+ target=ErrorTarget.GENERAL,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+
+def is_empty_target(obj: Optional[Dict]) -> bool:
+ """Determines if it's empty target
+
+ :param obj: The object to check
+ :type obj: Optional[Dict]
+ :return: True if obj is None or an empty Dict
+ :rtype: bool
+ """
+ return (
+ obj is None
+ # some objs have overloaded "==" and will cause error. e.g CommandComponent obj
+ or (isinstance(obj, dict) and len(obj) == 0)
+ )
+
+
+def convert_ordered_dict_to_dict(target_object: Union[Dict, List], remove_empty: bool = True) -> Union[Dict, List]:
+ """Convert ordered dict to dict. Remove keys with None value.
+ This is a workaround for rest request must be in dict instead of
+ ordered dict.
+
+ :param target_object: The object to convert
+ :type target_object: Union[Dict, List]
+ :param remove_empty: Whether to omit values that are None or empty dictionaries. Defaults to True.
+ :type remove_empty: bool
+ :return: Converted ordered dict with removed None values
+ :rtype: Union[Dict, List]
+ """
+ # OrderedDict can appear nested in a list
+ if isinstance(target_object, list):
+ new_list = []
+ for item in target_object:
+ item = convert_ordered_dict_to_dict(item)
+ if not is_empty_target(item) or not remove_empty:
+ new_list.append(item)
+ return new_list
+ if isinstance(target_object, dict):
+ new_dict = {}
+ for key, value in target_object.items():
+ value = convert_ordered_dict_to_dict(value)
+ if not is_empty_target(value) or not remove_empty:
+ new_dict[key] = value
+ return new_dict
+ return target_object
+
+
+def _general_copy(src: Union[str, os.PathLike], dst: Union[str, os.PathLike], make_dirs: bool = True) -> None:
+ """Wrapped `shutil.copy2` function for possible "Function not implemented" exception raised by it.
+
+ Background: `shutil.copy2` will throw OSError when dealing with Azure File.
+ See https://stackoverflow.com/questions/51616058 for more information.
+
+ :param src: The source path to copy from
+ :type src: Union[str, os.PathLike]
+ :param dst: The destination path to copy to
+ :type dst: Union[str, os.PathLike]
+ :param make_dirs: Whether to ensure the destination path exists. Defaults to True.
+ :type make_dirs: bool
+ """
+ if make_dirs:
+ os.makedirs(os.path.dirname(dst), exist_ok=True)
+ if hasattr(os, "listxattr"):
+ with mock.patch("shutil._copyxattr", return_value=[]):
+ shutil.copy2(src, dst)
+ else:
+ shutil.copy2(src, dst)
+
+
+def _dump_data_binding_expression_in_fields(obj: Any) -> Any:
+ for key, value in obj.__dict__.items():
+ # PipelineInput is subclass of NodeInput
+ from ._job.pipeline._io import NodeInput
+
+ if isinstance(value, NodeInput):
+ obj.__dict__[key] = str(value)
+ elif isinstance(value, RestTranslatableMixin):
+ _dump_data_binding_expression_in_fields(value)
+ return obj
+
+
+T = TypeVar("T")
+
+
+def get_rest_dict_for_node_attrs(
+ target_obj: Union[T, str], clear_empty_value: bool = False
+) -> Union[T, Dict, List, str, int, float, bool]:
+ """Convert object to dict and convert OrderedDict to dict.
+ Allow data binding expression as value, disregarding of the type defined in rest object.
+
+ :param target_obj: The object to convert
+ :type target_obj: T
+ :param clear_empty_value: Whether to clear empty values. Defaults to False.
+ :type clear_empty_value: bool
+ :return: The translated dict, or the the original object
+ :rtype: Union[T, Dict]
+ """
+ # pylint: disable=too-many-return-statements
+ from azure.ai.ml.entities._job.pipeline._io import PipelineInput
+
+ if target_obj is None:
+ return None
+ if isinstance(target_obj, dict):
+ result_dict: dict = {}
+ for key, value in target_obj.items():
+ if value is None:
+ continue
+ if key in ["additional_properties"]:
+ continue
+ result_dict[key] = get_rest_dict_for_node_attrs(value, clear_empty_value)
+ return result_dict
+ if isinstance(target_obj, list):
+ result_list: list = []
+ for item in target_obj:
+ result_list.append(get_rest_dict_for_node_attrs(item, clear_empty_value))
+ return result_list
+ if isinstance(target_obj, RestTranslatableMixin):
+ # note that the rest object may be invalid as data binding expression may not fit
+ # rest object structure
+ # pylint: disable=protected-access
+ _target_obj = _dump_data_binding_expression_in_fields(copy.deepcopy(target_obj))
+
+ from azure.ai.ml.entities._credentials import _BaseIdentityConfiguration
+
+ if isinstance(_target_obj, _BaseIdentityConfiguration):
+ # TODO: Bug Item number: 2883348
+ return get_rest_dict_for_node_attrs(
+ _target_obj._to_job_rest_object(), clear_empty_value=clear_empty_value # type: ignore
+ )
+ return get_rest_dict_for_node_attrs(_target_obj._to_rest_object(), clear_empty_value=clear_empty_value)
+
+ if isinstance(target_obj, msrest.serialization.Model):
+ # can't use result.as_dict() as data binding expression may not fit rest object structure
+ return get_rest_dict_for_node_attrs(target_obj.__dict__, clear_empty_value=clear_empty_value)
+
+ if isinstance(target_obj, PipelineInput):
+ return get_rest_dict_for_node_attrs(str(target_obj), clear_empty_value=clear_empty_value)
+
+ if not isinstance(target_obj, (str, int, float, bool)):
+ raise ValueError("Unexpected type {}".format(type(target_obj)))
+
+ return target_obj
+
+
+class _DummyRestModelFromDict(msrest.serialization.Model):
+ """A dummy rest model that can be initialized from dict, return base_dict[attr_name]
+ for getattr(self, attr_name) when attr_name is a public attrs; return None when trying to get
+ a non-existent public attribute.
+ """
+
+ def __init__(self, rest_dict: Optional[dict]):
+ self._rest_dict = rest_dict or {}
+ super().__init__()
+
+ def __getattribute__(self, item: str) -> Any:
+ if not item.startswith("_"):
+ return self._rest_dict.get(item, None)
+ return super().__getattribute__(item)
+
+
+def from_rest_dict_to_dummy_rest_object(rest_dict: Optional[Dict]) -> _DummyRestModelFromDict:
+ """Create a dummy rest object based on a rest dict, which is a primitive dict containing
+ attributes in a rest object.
+ For example, for a rest object class like:
+ class A(msrest.serialization.Model):
+ def __init__(self, a, b):
+ self.a = a
+ self.b = b
+ rest_object = A(1, None)
+ rest_dict = {"a": 1}
+ regenerated_rest_object = from_rest_dict_to_fake_rest_object(rest_dict)
+ assert regenerated_rest_object.a == 1
+ assert regenerated_rest_object.b is None
+
+ :param rest_dict: The rest dict
+ :type rest_dict: Optional[Dict]
+ :return: A dummy rest object
+ :rtype: _DummyRestModelFromDict
+ """
+ if rest_dict is None or isinstance(rest_dict, dict):
+ return _DummyRestModelFromDict(rest_dict)
+ raise ValueError("Unexpected type {}".format(type(rest_dict)))
+
+
+def extract_label(input_str: str) -> Union[Tuple, List]:
+ """Extract label from input string.
+
+ :param input_str: The input string
+ :type input_str: str
+ :return: The rest of the string and the label
+ :rtype: Tuple[str, Optional[str]]
+ """
+ if not isinstance(input_str, str):
+ return None, None
+ if "@" in input_str:
+ return input_str.rsplit("@", 1)
+ return input_str, None
+
+
+@overload
+def resolve_pipeline_parameters(pipeline_parameters: None, remove_empty: bool = False) -> None: ...
+
+
+@overload
+def resolve_pipeline_parameters(
+ pipeline_parameters: Dict[str, T], remove_empty: bool = False
+) -> Dict[str, Union[T, str, "NodeOutput"]]: ...
+
+
+def resolve_pipeline_parameters(pipeline_parameters: Optional[Dict], remove_empty: bool = False) -> Optional[Dict]:
+ """Resolve pipeline parameters.
+
+ 1. Resolve BaseNode and OutputsAttrDict type to NodeOutput.
+ 2. Remove empty value (optional).
+
+ :param pipeline_parameters: The pipeline parameters
+ :type pipeline_parameters: Optional[Dict[str, T]]
+ :param remove_empty: Whether to remove None values. Defaults to False.
+ :type remove_empty: bool
+ :return:
+ * None if pipeline_parameters is None
+ * The resolved dict of pipeline parameters
+ :rtype: Optional[Dict[str, Union[T, str, "NodeOutput"]]]
+ """
+
+ if pipeline_parameters is None:
+ return None
+ if not isinstance(pipeline_parameters, dict):
+ raise ValidationException(
+ message="pipeline_parameters must in dict {parameter: value} format.",
+ no_personal_data_message="pipeline_parameters must in dict {parameter: value} format.",
+ target=ErrorTarget.PIPELINE,
+ )
+
+ updated_parameters = {}
+ for k, v in pipeline_parameters.items():
+ v = resolve_pipeline_parameter(v)
+ if v is None and remove_empty:
+ continue
+ updated_parameters[k] = v
+ pipeline_parameters = updated_parameters
+ return pipeline_parameters
+
+
+def resolve_pipeline_parameter(data: Any) -> Union[T, str, "NodeOutput"]:
+ """Resolve pipeline parameter.
+ 1. Resolve BaseNode and OutputsAttrDict type to NodeOutput.
+ 2. Remove empty value (optional).
+ :param data: The pipeline parameter
+ :type data: T
+ :return:
+ * None if data is None
+ * The resolved pipeline parameter
+ :rtype: Union[T, str, "NodeOutput"]
+ """
+ from azure.ai.ml.entities._builders.base_node import BaseNode
+ from azure.ai.ml.entities._builders.pipeline import Pipeline
+ from azure.ai.ml.entities._job.pipeline._io import NodeOutput, OutputsAttrDict
+ from azure.ai.ml.entities._job.pipeline._pipeline_expression import PipelineExpression
+
+ if isinstance(data, PipelineExpression):
+ data = cast(Union[str, BaseNode], data.resolve())
+ if isinstance(data, (BaseNode, Pipeline)):
+ # For the case use a node/pipeline node as the input, we use its only one output as the real input.
+ # Here we set node = node.outputs, then the following logic will get the output object.
+ data = cast(OutputsAttrDict, data.outputs)
+ if isinstance(data, OutputsAttrDict):
+ # For the case that use the outputs of another component as the input,
+ # we use the only one output as the real input,
+ # if multiple outputs are provided, an exception is raised.
+ output_len = len(data)
+ if output_len != 1:
+ raise ValidationException(
+ message="Setting input failed: Exactly 1 output is required, got %d. (%s)" % (output_len, data),
+ no_personal_data_message="multiple output(s) found of specified outputs, exactly 1 output required.",
+ target=ErrorTarget.PIPELINE,
+ )
+ data = cast(NodeOutput, list(data.values())[0])
+ return cast(Union[T, str, "NodeOutput"], data)
+
+
+def normalize_job_input_output_type(input_output_value: Union[RestJobOutput, RestJobInput, Dict]) -> None:
+ """Normalizes the `job_input_type`, `job_output_type`, and `type` keys for REST job output and input objects.
+
+ :param input_output_value: Either a REST input or REST output of a job
+ :type input_output_value: Union[RestJobOutput, RestJobInput, Dict]
+
+ .. note::
+
+ We have changed the api starting v2022_06_01_preview version and there are some api interface changes,
+ which will result in pipeline submitted by v2022_02_01_preview can't be parsed correctly. And this will block
+ az ml job list/show. So we convert the input/output type of camel to snake to be compatible with the Jun/Oct
+ api.
+
+ """
+
+ FEB_JUN_JOB_INPUT_OUTPUT_TYPE_MAPPING = {
+ JobInputType02.CUSTOM_MODEL: JobInputType10.CUSTOM_MODEL,
+ JobInputType02.LITERAL: JobInputType10.LITERAL,
+ JobInputType02.ML_FLOW_MODEL: JobInputType10.MLFLOW_MODEL,
+ JobInputType02.ML_TABLE: JobInputType10.MLTABLE,
+ JobInputType02.TRITON_MODEL: JobInputType10.TRITON_MODEL,
+ JobInputType02.URI_FILE: JobInputType10.URI_FILE,
+ JobInputType02.URI_FOLDER: JobInputType10.URI_FOLDER,
+ }
+ if (
+ hasattr(input_output_value, "job_input_type")
+ and input_output_value.job_input_type in FEB_JUN_JOB_INPUT_OUTPUT_TYPE_MAPPING
+ ):
+ input_output_value.job_input_type = FEB_JUN_JOB_INPUT_OUTPUT_TYPE_MAPPING[input_output_value.job_input_type]
+ elif (
+ hasattr(input_output_value, "job_output_type")
+ and input_output_value.job_output_type in FEB_JUN_JOB_INPUT_OUTPUT_TYPE_MAPPING
+ ):
+ input_output_value.job_output_type = FEB_JUN_JOB_INPUT_OUTPUT_TYPE_MAPPING[input_output_value.job_output_type]
+ elif isinstance(input_output_value, dict):
+ job_output_type = input_output_value.get("job_output_type", None)
+ job_input_type = input_output_value.get("job_input_type", None)
+ job_type = input_output_value.get("type", None)
+
+ if job_output_type and job_output_type in FEB_JUN_JOB_INPUT_OUTPUT_TYPE_MAPPING:
+ input_output_value["job_output_type"] = FEB_JUN_JOB_INPUT_OUTPUT_TYPE_MAPPING[job_output_type]
+ if job_input_type and job_input_type in FEB_JUN_JOB_INPUT_OUTPUT_TYPE_MAPPING:
+ input_output_value["job_input_type"] = FEB_JUN_JOB_INPUT_OUTPUT_TYPE_MAPPING[job_input_type]
+ if job_type and job_type in FEB_JUN_JOB_INPUT_OUTPUT_TYPE_MAPPING:
+ input_output_value["type"] = FEB_JUN_JOB_INPUT_OUTPUT_TYPE_MAPPING[job_type]
+
+
+def get_type_from_spec(data: dict, *, valid_keys: Iterable[str]) -> str:
+ """Get the type of the node or component from the yaml spec.
+
+ Yaml spec must have a key named "type" and exception will be raised if it's not once of valid_keys.
+
+ If internal components are enabled, related factory and schema will be updated.
+
+ :param data: The data
+ :type data: dict
+ :keyword valid_keys: An iterable of valid types
+ :paramtype valid_keys: Iterable[str]
+ :return: The type of the node or component
+ :rtype: str
+ """
+ _type, _ = extract_label(data.get(CommonYamlFields.TYPE, None))
+
+ # we should keep at least 1 place outside _internal to enable internal components
+ # and this is the only place
+ try_enable_internal_components()
+ # todo: refine Hard code for now to support different task type for DataTransfer component
+ if _type == NodeType.DATA_TRANSFER:
+ _type = "_".join([NodeType.DATA_TRANSFER, data.get("task", " ")])
+ if _type not in valid_keys:
+ is_internal_component_data(data, raise_if_not_enabled=True)
+
+ raise ValidationException(
+ message="Unsupported component type: %s." % _type,
+ target=ErrorTarget.COMPONENT,
+ no_personal_data_message="Unsupported component type",
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ res: str = _type
+ return res
+
+
+def copy_output_setting(source: Union["Output", "NodeOutput"], target: "NodeOutput") -> None:
+ """Copy node output setting from source to target.
+
+ Currently only path, name, version will be copied.
+
+ :param source: The Output to copy from
+ :type source: Union[Output, NodeOutput]
+ :param target: The Output to copy to
+ :type target: NodeOutput
+ """
+ # pylint: disable=protected-access
+ from azure.ai.ml.entities._job.pipeline._io import NodeOutput, PipelineOutput
+
+ if not isinstance(source, NodeOutput):
+ # Only copy when source is an output builder
+ return
+ source_data = source._data
+ if isinstance(source_data, PipelineOutput):
+ source_data = source_data._data
+ if source_data:
+ target._data = copy.deepcopy(source_data)
+ # copy pipeline component output's node output to subgraph builder
+ if source._binding_output is not None:
+ target._binding_output = source._binding_output
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validate_funcs.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validate_funcs.py
new file mode 100644
index 00000000..8a082cb5
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validate_funcs.py
@@ -0,0 +1,94 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+from os import PathLike
+from typing import IO, Any, AnyStr, Callable, Dict, List, Optional, Union, cast
+
+from marshmallow import ValidationError
+
+from azure.ai.ml import MLClient
+
+from ..exceptions import ValidationException
+from . import Component, Job
+from ._load_functions import _load_common_raising_marshmallow_error, _try_load_yaml_dict
+from ._validation import PathAwareSchemaValidatableMixin, ValidationResult, ValidationResultBuilder
+
+
+def validate_common(
+ cls: Any,
+ path: Union[str, PathLike, IO[AnyStr]],
+ validate_func: Optional[Callable],
+ params_override: Optional[List[Dict]] = None,
+) -> ValidationResult:
+ params_override = params_override or []
+ yaml_dict = _try_load_yaml_dict(path)
+
+ try:
+ cls, _ = cls._resolve_cls_and_type(data=yaml_dict, params_override=params_override)
+
+ entity = _load_common_raising_marshmallow_error(
+ cls=cls, yaml_dict=yaml_dict, relative_origin=path, params_override=params_override
+ )
+
+ if validate_func is not None:
+ res = cast(ValidationResult, validate_func(entity))
+ return res
+ if isinstance(entity, PathAwareSchemaValidatableMixin):
+ return entity._validate()
+ return ValidationResultBuilder.success()
+ except ValidationException as err:
+ return ValidationResultBuilder.from_single_message(err.message)
+ except ValidationError as err:
+ return ValidationResultBuilder.from_validation_error(err, source_path=path)
+
+
+def validate_component(
+ path: Union[str, PathLike, IO[AnyStr]],
+ ml_client: Optional[MLClient] = None,
+ params_override: Optional[List[Dict]] = None,
+) -> ValidationResult:
+ """Validate a component defined in a local file.
+
+ :param path: The path to the component definition file.
+ :type path: Union[str, PathLike, IO[AnyStr]]
+ :param ml_client: The client to use for validation. Will skip remote validation if None.
+ :type ml_client: azure.ai.ml.core.AzureMLComputeClient
+ :param params_override: Fields to overwrite on top of the yaml file.
+ Format is [{"field1": "value1"}, {"field2": "value2"}]
+ :type params_override: List[Dict]
+ :return: The validation result.
+ :rtype: ValidationResult
+ """
+ return validate_common(
+ cls=Component,
+ path=path,
+ validate_func=ml_client.components.validate if ml_client is not None else None,
+ params_override=params_override,
+ )
+
+
+def validate_job(
+ path: Union[str, PathLike, IO[AnyStr]],
+ ml_client: Optional[MLClient] = None,
+ params_override: Optional[List[Dict]] = None,
+) -> ValidationResult:
+ """Validate a job defined in a local file.
+
+ :param path: The path to the job definition file.
+ :type path: str
+ :param ml_client: The client to use for validation. Will skip remote validation if None.
+ :type ml_client: azure.ai.ml.core.AzureMLComputeClient
+ :param params_override: Fields to overwrite on top of the yaml file.
+ Format is [{"field1": "value1"}, {"field2": "value2"}]
+ :type params_override: List[Dict]
+ :return: The validation result.
+ :rtype: ValidationResult
+ """
+ return validate_common(
+ cls=Job,
+ path=path,
+ validate_func=ml_client.jobs.validate if ml_client is not None else None,
+ params_override=params_override,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/__init__.py
new file mode 100644
index 00000000..29ba05c5
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/__init__.py
@@ -0,0 +1,18 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+
+from .core import MutableValidationResult, ValidationResult, ValidationResultBuilder
+from .path_aware_schema import PathAwareSchemaValidatableMixin
+from .remote import RemoteValidatableMixin
+from .schema import SchemaValidatableMixin
+
+__all__ = [
+ "SchemaValidatableMixin",
+ "PathAwareSchemaValidatableMixin",
+ "RemoteValidatableMixin",
+ "MutableValidationResult",
+ "ValidationResult",
+ "ValidationResultBuilder",
+]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/core.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/core.py
new file mode 100644
index 00000000..a7516c1d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/core.py
@@ -0,0 +1,531 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import copy
+import json
+import logging
+import os.path
+import typing
+from os import PathLike
+from pathlib import Path
+from typing import IO, Any, AnyStr, Dict, List, Optional, Tuple, Union, cast
+
+import pydash
+import strictyaml
+from marshmallow import ValidationError
+
+module_logger = logging.getLogger(__name__)
+
+
+class _ValidationStatus:
+ """Validation status class.
+
+ Validation status is used to indicate the status of an validation result. It can be one of the following values:
+ Succeeded, Failed.
+ """
+
+ SUCCEEDED = "Succeeded"
+ """Succeeded."""
+ FAILED = "Failed"
+ """Failed."""
+
+
+class Diagnostic(object):
+ """Represents a diagnostic of an asset validation error with the location info."""
+
+ def __init__(self, yaml_path: str, message: Optional[str], error_code: Optional[str]) -> None:
+ """Init Diagnostic.
+
+ :keyword yaml_path: A dash path from root to the target element of the diagnostic. jobs.job_a.inputs.input_str
+ :paramtype yaml_path: str
+ :keyword message: Error message of diagnostic.
+ :paramtype message: str
+ :keyword error_code: Error code of diagnostic.
+ :paramtype error_code: str
+ """
+ self.yaml_path = yaml_path
+ self.message = message
+ self.error_code = error_code
+ self.local_path, self.value = None, None
+
+ def __repr__(self) -> str:
+ """The asset friendly name and error message.
+
+ :return: The formatted diagnostic
+ :rtype: str
+ """
+ return "{}: {}".format(self.yaml_path, self.message)
+
+ @classmethod
+ def create_instance(
+ cls,
+ yaml_path: str,
+ message: Optional[str] = None,
+ error_code: Optional[str] = None,
+ ) -> "Diagnostic":
+ """Create a diagnostic instance.
+
+ :param yaml_path: A dash path from root to the target element of the diagnostic. jobs.job_a.inputs.input_str
+ :type yaml_path: str
+ :param message: Error message of diagnostic.
+ :type message: str
+ :param error_code: Error code of diagnostic.
+ :type error_code: str
+ :return: The created instance
+ :rtype: Diagnostic
+ """
+ return cls(
+ yaml_path=yaml_path,
+ message=message,
+ error_code=error_code,
+ )
+
+
+class ValidationResult(object):
+ """Represents the result of job/asset validation.
+
+ This class is used to organize and parse diagnostics from both client & server side before expose them. The result
+ is immutable.
+ """
+
+ def __init__(self) -> None:
+ self._target_obj: Optional[Dict] = None
+ self._errors: List = []
+ self._warnings: List = []
+
+ @property
+ def error_messages(self) -> Dict:
+ """
+ Return all messages of errors in the validation result.
+
+ :return: A dictionary of error messages. The key is the yaml path of the error, and the value is the error
+ message.
+ :rtype: dict
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START validation_result]
+ :end-before: [END validation_result]
+ :language: markdown
+ :dedent: 8
+ """
+ messages = {}
+ for diagnostic in self._errors:
+ if diagnostic.yaml_path not in messages:
+ messages[diagnostic.yaml_path] = diagnostic.message
+ else:
+ messages[diagnostic.yaml_path] += "; " + diagnostic.message
+ return messages
+
+ @property
+ def passed(self) -> bool:
+ """Returns boolean indicating whether any errors were found.
+
+ :return: True if the validation passed, False otherwise.
+ :rtype: bool
+ """
+ return not self._errors
+
+ def _to_dict(self) -> typing.Dict[str, typing.Any]:
+ result: Dict = {
+ "result": _ValidationStatus.SUCCEEDED if self.passed else _ValidationStatus.FAILED,
+ }
+ for diagnostic_type, diagnostics in [
+ ("errors", self._errors),
+ ("warnings", self._warnings),
+ ]:
+ messages = []
+ for diagnostic in diagnostics:
+ message = {
+ "message": diagnostic.message,
+ "path": diagnostic.yaml_path,
+ "value": pydash.get(self._target_obj, diagnostic.yaml_path, diagnostic.value),
+ }
+ if diagnostic.local_path:
+ message["location"] = str(diagnostic.local_path)
+ messages.append(message)
+ if messages:
+ result[diagnostic_type] = messages
+ return result
+
+ def __repr__(self) -> str:
+ """Get the string representation of the validation result.
+
+ :return: The string representation
+ :rtype: str
+ """
+ return json.dumps(self._to_dict(), indent=2)
+
+
+class MutableValidationResult(ValidationResult):
+ """Used by the client side to construct a validation result.
+
+ The result is mutable and should not be exposed to the user.
+ """
+
+ def __init__(self, target_obj: Optional[Dict] = None):
+ super().__init__()
+ self._target_obj = target_obj
+
+ def merge_with(
+ self,
+ target: ValidationResult,
+ field_name: Optional[str] = None,
+ condition_skip: Optional[typing.Callable] = None,
+ overwrite: bool = False,
+ ) -> "MutableValidationResult":
+ """Merge errors & warnings in another validation results into current one.
+
+ Will update current validation result.
+ If field_name is not None, then yaml_path in the other validation result will be updated accordingly.
+ * => field_name, jobs.job_a => field_name.jobs.job_a e.g.. If None, then no update.
+
+ :param target: Validation result to merge.
+ :type target: ValidationResult
+ :param field_name: The base field name for the target to merge.
+ :type field_name: str
+ :param condition_skip: A function to determine whether to skip the merge of a diagnostic in the target.
+ :type condition_skip: typing.Callable
+ :param overwrite: Whether to overwrite the current validation result. If False, all diagnostics will be kept;
+ if True, current diagnostics with the same yaml_path will be dropped.
+ :type overwrite: bool
+ :return: The current validation result.
+ :rtype: MutableValidationResult
+ """
+ for source_diagnostics, target_diagnostics in [
+ (target._errors, self._errors),
+ (target._warnings, self._warnings),
+ ]:
+ if overwrite:
+ keys_to_remove = set(map(lambda x: x.yaml_path, source_diagnostics))
+ target_diagnostics[:] = [
+ diagnostic for diagnostic in target_diagnostics if diagnostic.yaml_path not in keys_to_remove
+ ]
+ for diagnostic in source_diagnostics:
+ if condition_skip and condition_skip(diagnostic):
+ continue
+ new_diagnostic = copy.deepcopy(diagnostic)
+ if field_name:
+ if new_diagnostic.yaml_path == "*":
+ new_diagnostic.yaml_path = field_name
+ else:
+ new_diagnostic.yaml_path = field_name + "." + new_diagnostic.yaml_path
+ target_diagnostics.append(new_diagnostic)
+ return self
+
+ def try_raise(
+ self,
+ raise_error: Optional[bool] = True,
+ *,
+ error_func: Optional[typing.Callable[[str, str], Exception]] = None,
+ ) -> "MutableValidationResult":
+ """Try to raise an error from the validation result.
+
+ If the validation is passed or raise_error is False, this method
+ will return the validation result.
+
+ :param raise_error: Whether to raise the error.
+ :type raise_error: bool
+ :keyword error_func: A function to create the error. If None, a marshmallow.ValidationError will be created.
+ The first parameter of the function is the string representation of the validation result,
+ and the second parameter is the error message without personal data.
+ :type error_func: typing.Callable[[str, str], Exception]
+ :return: The current validation result.
+ :rtype: MutableValidationResult
+ """
+ # pylint: disable=logging-not-lazy
+ if raise_error is False:
+ return self
+
+ if self._warnings:
+ module_logger.warning("Warnings: %s" % str(self._warnings))
+
+ if not self.passed:
+ if error_func is None:
+
+ def error_func(msg: Union[str, list, dict], _: Any) -> ValidationError:
+ return ValidationError(message=msg)
+
+ raise error_func(
+ self.__repr__(),
+ "validation failed on the following fields: " + ", ".join(self.error_messages),
+ )
+ return self
+
+ def append_error(
+ self,
+ yaml_path: str = "*",
+ message: Optional[str] = None,
+ error_code: Optional[str] = None,
+ ) -> "MutableValidationResult":
+ """Append an error to the validation result.
+
+ :param yaml_path: The yaml path of the error.
+ :type yaml_path: str
+ :param message: The message of the error.
+ :type message: str
+ :param error_code: The error code of the error.
+ :type error_code: str
+ :return: The current validation result.
+ :rtype: MutableValidationResult
+ """
+ self._errors.append(
+ Diagnostic.create_instance(
+ yaml_path=yaml_path,
+ message=message,
+ error_code=error_code,
+ )
+ )
+ return self
+
+ def resolve_location_for_diagnostics(self, source_path: str, resolve_value: bool = False) -> None:
+ """Resolve location/value for diagnostics based on the source path where the validatable object is loaded.
+
+ Location includes local path of the exact file (can be different from the source path) & line number of the
+ invalid field. Value of a diagnostic is resolved from the validatable object in transfering to a dict by
+ default; however, when the validatable object is not available for the validation result, validation result is
+ created from marshmallow.ValidationError.messages e.g., it can be resolved from the source path.
+
+ :param source_path: The path of the source file.
+ :type source_path: str
+ :param resolve_value: Whether to resolve the value of the invalid field from source file.
+ :type resolve_value: bool
+ """
+ resolver = _YamlLocationResolver(source_path)
+ for diagnostic in self._errors + self._warnings:
+ res = resolver.resolve(diagnostic.yaml_path)
+ if res is not None:
+ diagnostic.local_path, value = res
+ if value is not None and resolve_value:
+ diagnostic.value = value
+
+ def append_warning(
+ self,
+ yaml_path: str = "*",
+ message: Optional[str] = None,
+ error_code: Optional[str] = None,
+ ) -> "MutableValidationResult":
+ """Append a warning to the validation result.
+
+ :param yaml_path: The yaml path of the warning.
+ :type yaml_path: str
+ :param message: The message of the warning.
+ :type message: str
+ :param error_code: The error code of the warning.
+ :type error_code: str
+ :return: The current validation result.
+ :rtype: MutableValidationResult
+ """
+ self._warnings.append(
+ Diagnostic.create_instance(
+ yaml_path=yaml_path,
+ message=message,
+ error_code=error_code,
+ )
+ )
+ return self
+
+
+class ValidationResultBuilder:
+ """A helper class to create a validation result."""
+
+ UNKNOWN_MESSAGE = "Unknown field."
+
+ def __init__(self) -> None:
+ pass
+
+ @classmethod
+ def success(cls) -> MutableValidationResult:
+ """Create a validation result with success status.
+
+ :return: A validation result
+ :rtype: MutableValidationResult
+ """
+ return MutableValidationResult()
+
+ @classmethod
+ def from_single_message(
+ cls, singular_error_message: Optional[str] = None, yaml_path: str = "*", data: Optional[dict] = None
+ ) -> MutableValidationResult:
+ """Create a validation result with only 1 diagnostic.
+
+ :param singular_error_message: diagnostic.message.
+ :type singular_error_message: Optional[str]
+ :param yaml_path: diagnostic.yaml_path.
+ :type yaml_path: str
+ :param data: serializedvalidation target.
+ :type data: Optional[Dict]
+ :return: The validation result
+ :rtype: MutableValidationResult
+ """
+ obj = MutableValidationResult(target_obj=data)
+ if singular_error_message:
+ obj.append_error(message=singular_error_message, yaml_path=yaml_path)
+ return obj
+
+ @classmethod
+ def from_validation_error(
+ cls,
+ error: ValidationError,
+ *,
+ source_path: Optional[Union[str, PathLike, IO[AnyStr]]] = None,
+ error_on_unknown_field: bool = False,
+ ) -> MutableValidationResult:
+ """Create a validation result from a ValidationError, which will be raised in marshmallow.Schema.load. Please
+ use this function only for exception in loading file.
+
+ :param error: ValidationError raised by marshmallow.Schema.load.
+ :type error: ValidationError
+ :keyword source_path: The path to the source file.
+ :paramtype source_path: Optional[Union[str, PathLike, IO[AnyStr]]]
+ :keyword error_on_unknown_field: whether to raise error if there are unknown field diagnostics.
+ :paramtype error_on_unknown_field: bool
+ :return: The validation result
+ :rtype: MutableValidationResult
+ """
+ obj = cls.from_validation_messages(
+ error.messages, data=error.data, error_on_unknown_field=error_on_unknown_field
+ )
+ if source_path:
+ obj.resolve_location_for_diagnostics(cast(str, source_path), resolve_value=True)
+ return obj
+
+ @classmethod
+ def from_validation_messages(
+ cls, errors: typing.Dict, data: typing.Dict, *, error_on_unknown_field: bool = False
+ ) -> MutableValidationResult:
+ """Create a validation result from error messages, which will be returned by marshmallow.Schema.validate.
+
+ :param errors: error message returned by marshmallow.Schema.validate.
+ :type errors: dict
+ :param data: serialized data to validate
+ :type data: dict
+ :keyword error_on_unknown_field: whether to raise error if there are unknown field diagnostics.
+ :paramtype error_on_unknown_field: bool
+ :return: The validation result
+ :rtype: MutableValidationResult
+ """
+ instance = MutableValidationResult(target_obj=data)
+ errors = copy.deepcopy(errors)
+ cls._from_validation_messages_recursively(errors, [], instance, error_on_unknown_field=error_on_unknown_field)
+ return instance
+
+ @classmethod
+ def _from_validation_messages_recursively(
+ cls,
+ errors: typing.Union[typing.Dict, typing.List, str],
+ path_stack: typing.List[str],
+ instance: MutableValidationResult,
+ error_on_unknown_field: bool,
+ ) -> None:
+ cur_path = ".".join(path_stack) if path_stack else "*"
+ # single error message
+ if isinstance(errors, dict) and "_schema" in errors:
+ instance.append_error(
+ message=";".join(errors["_schema"]),
+ yaml_path=cur_path,
+ )
+ # errors on attributes
+ elif isinstance(errors, dict):
+ for field, msgs in errors.items():
+ # fields.Dict
+ if field in ["key", "value"]:
+ cls._from_validation_messages_recursively(msgs, path_stack, instance, error_on_unknown_field)
+ else:
+ # Todo: Add hack logic here to deal with error message in nested TypeSensitiveUnionField in
+ # DataTransfer: will be a nested dict with None field as dictionary key.
+ # open a item to track: https://msdata.visualstudio.com/Vienna/_workitems/edit/2244262/
+ if field is None:
+ cls._from_validation_messages_recursively(msgs, path_stack, instance, error_on_unknown_field)
+ else:
+ path_stack.append(field)
+ cls._from_validation_messages_recursively(msgs, path_stack, instance, error_on_unknown_field)
+ path_stack.pop()
+
+ # detailed error message
+ elif isinstance(errors, list) and all(isinstance(msg, str) for msg in errors):
+ if cls.UNKNOWN_MESSAGE in errors and not error_on_unknown_field:
+ # Unknown field is not a real error, so we should remove it and append a warning.
+ errors.remove(cls.UNKNOWN_MESSAGE)
+ instance.append_warning(message=cls.UNKNOWN_MESSAGE, yaml_path=cur_path)
+ if errors:
+ instance.append_error(message=";".join(errors), yaml_path=cur_path)
+ # union field
+ elif isinstance(errors, list):
+
+ def msg2str(msg: Any) -> Any:
+ if isinstance(msg, str):
+ return msg
+ if isinstance(msg, dict) and len(msg) == 1 and "_schema" in msg and len(msg["_schema"]) == 1:
+ return str(msg["_schema"][0])
+
+ return str(msg)
+
+ instance.append_error(message="; ".join([msg2str(x) for x in errors]), yaml_path=cur_path)
+ # unknown error
+ else:
+ instance.append_error(message=str(errors), yaml_path=cur_path)
+
+
+class _YamlLocationResolver:
+ def __init__(self, source_path: str):
+ self._source_path = source_path
+
+ def resolve(self, yaml_path: str, source_path: Optional[str] = None) -> Optional[Tuple]:
+ """Resolve the location & value of a yaml path starting from source_path.
+
+ :param yaml_path: yaml path.
+ :type yaml_path: str
+ :param source_path: source path.
+ :type source_path: str
+ :return: the location & value of the yaml path based on source_path.
+ :rtype: Tuple[str, str]
+ """
+ source_path = source_path or self._source_path
+ if source_path is None or not os.path.isfile(source_path):
+ return None, None
+ if yaml_path is None or yaml_path == "*":
+ return source_path, None
+
+ attrs = yaml_path.split(".")
+ attrs.reverse()
+
+ res: Optional[Tuple] = self._resolve_recursively(attrs, Path(source_path))
+ return res
+
+ def _resolve_recursively(self, attrs: List[str], source_path: Path) -> Optional[Tuple]:
+ with open(source_path, encoding="utf-8") as f:
+ try:
+ loaded_yaml = strictyaml.load(f.read())
+ except Exception as e: # pylint: disable=W0718
+ msg = "Can't load source file %s as a strict yaml:\n%s" % (source_path, str(e))
+ module_logger.debug(msg)
+ return None, None
+
+ while attrs:
+ attr = attrs[-1]
+ if loaded_yaml.is_mapping() and attr in loaded_yaml:
+ loaded_yaml = loaded_yaml.get(attr)
+ attrs.pop()
+ else:
+ try:
+ # if current object is a path of a valid yaml file, try to resolve location in new source file
+ next_path = Path(loaded_yaml.value)
+ if not next_path.is_absolute():
+ next_path = source_path.parent / next_path
+ if next_path.is_file():
+ return self._resolve_recursively(attrs, source_path=next_path)
+ except OSError:
+ pass
+ except TypeError:
+ pass
+ # if not, return current section
+ break
+ return (
+ f"{source_path.resolve().absolute()}#line {loaded_yaml.start_line}",
+ None if attrs else loaded_yaml.value,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/path_aware_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/path_aware_schema.py
new file mode 100644
index 00000000..959de310
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/path_aware_schema.py
@@ -0,0 +1,53 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+import typing
+from os import PathLike
+from pathlib import Path
+
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+
+from ..._schema import PathAwareSchema
+from .._job.pipeline._attr_dict import try_get_non_arbitrary_attr
+from .._util import convert_ordered_dict_to_dict
+from .schema import SchemaValidatableMixin
+
+
+class PathAwareSchemaValidatableMixin(SchemaValidatableMixin):
+ """The mixin class for schema validation. Entity classes inheriting from this class should have a base path
+ and a schema of PathAwareSchema.
+ """
+
+ @property
+ def __base_path_for_validation(self) -> typing.Union[str, PathLike]:
+ """Get the base path of the resource.
+
+ It will try to return self.base_path, then self._base_path, then Path.cwd() if above attrs are non-existent or
+ `None.
+
+ :return: The base path of the resource
+ :rtype: typing.Union[str, os.PathLike]
+ """
+ return (
+ try_get_non_arbitrary_attr(self, BASE_PATH_CONTEXT_KEY)
+ or try_get_non_arbitrary_attr(self, f"_{BASE_PATH_CONTEXT_KEY}")
+ or Path.cwd()
+ )
+
+ def _default_context(self) -> dict:
+ # Note that, although context can be passed, nested.schema will be initialized only once
+ # base_path works well because it's fixed after loaded
+ return {BASE_PATH_CONTEXT_KEY: self.__base_path_for_validation}
+
+ @classmethod
+ def _create_schema_for_validation(cls, context: typing.Any) -> PathAwareSchema:
+ raise NotImplementedError()
+
+ @classmethod
+ def _create_validation_error(cls, message: str, no_personal_data_message: str) -> Exception:
+ raise NotImplementedError()
+
+ def _dump_for_validation(self) -> typing.Dict:
+ # this is not a necessary step but to keep the same behavior as before
+ # empty items will be removed when converting to dict
+ return typing.cast(dict, convert_ordered_dict_to_dict(super()._dump_for_validation()))
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/remote.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/remote.py
new file mode 100644
index 00000000..06f022a0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/remote.py
@@ -0,0 +1,162 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import logging
+import typing
+
+import msrest
+
+from azure.ai.ml._vendor.azure_resources.models import (
+ Deployment,
+ DeploymentProperties,
+ DeploymentValidateResult,
+ ErrorResponse,
+)
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+from .core import MutableValidationResult, ValidationResultBuilder
+
+module_logger = logging.getLogger(__name__)
+
+
+class PreflightResource(msrest.serialization.Model):
+ """Specified resource.
+
+ Variables are only populated by the server, and will be ignored when sending a request.
+
+ :ivar id: Resource ID.
+ :vartype id: str
+ :ivar name: Resource name.
+ :vartype name: str
+ :ivar type: Resource type.
+ :vartype type: str
+ :param location: Resource location.
+ :type location: str
+ :param tags: A set of tags. Resource tags.
+ :type tags: dict[str, str]
+ """
+
+ _attribute_map = {
+ "type": {"key": "type", "type": "str"},
+ "name": {"key": "name", "type": "str"},
+ "location": {"key": "location", "type": "str"},
+ "api_version": {"key": "apiversion", "type": "str"},
+ "properties": {"key": "properties", "type": "object"},
+ }
+
+ def __init__(self, **kwargs: typing.Any):
+ super(PreflightResource, self).__init__(**kwargs)
+ self.name = kwargs.get("name", None)
+ self.type = kwargs.get("type", None)
+ self.location = kwargs.get("location", None)
+ self.properties = kwargs.get("properties", None)
+ self.api_version = kwargs.get("api_version", None)
+
+
+class ValidationTemplateRequest(msrest.serialization.Model):
+ """Export resource group template request parameters.
+
+ :param resources: The rest objects to be validated.
+ :type resources: list[_models.Resource]
+ :param options: The export template options. A CSV-formatted list containing zero or more of
+ the following: 'IncludeParameterDefaultValue', 'IncludeComments',
+ 'SkipResourceNameParameterization', 'SkipAllParameterization'.
+ :type options: str
+ """
+
+ _attribute_map = {
+ "resources": {"key": "resources", "type": "[PreflightResource]"},
+ "content_version": {"key": "contentVersion", "type": "str"},
+ "parameters": {"key": "parameters", "type": "object"},
+ "_schema": {
+ "key": "$schema",
+ "type": "str",
+ "default": "https://schema.management.azure.com/schemas/2019-04-01/deploymentTemplate.json#",
+ },
+ }
+
+ def __init__(self, **kwargs: typing.Any):
+ super(ValidationTemplateRequest, self).__init__(**kwargs)
+ self._schema = kwargs.get("_schema", None)
+ self.content_version = kwargs.get("content_version", None)
+ self.parameters = kwargs.get("parameters", None)
+ self.resources = kwargs.get("resources", None)
+
+
+class RemoteValidatableMixin(RestTranslatableMixin):
+ @classmethod
+ def _get_resource_type(cls) -> str:
+ """Return resource type to be used in remote validation.
+
+ Should be overridden by subclass.
+
+ :return: The resource type
+ :rtype: str
+ """
+ raise NotImplementedError()
+
+ def _get_resource_name_version(self) -> typing.Tuple:
+ """Return resource name and version to be used in remote validation.
+
+ Should be overridden by subclass.
+
+ :return: The name and version
+ :rtype: typing.Tuple[str, str]
+ """
+ raise NotImplementedError()
+
+ def _to_preflight_resource(self, location: str, workspace_name: str) -> PreflightResource:
+ """Return the preflight resource to be used in remote validation.
+
+ :param location: The location of the resource.
+ :type location: str
+ :param workspace_name: The workspace name
+ :type workspace_name: str
+ :return: The preflight resource
+ :rtype: PreflightResource
+ """
+ name, version = self._get_resource_name_version()
+ return PreflightResource(
+ type=self._get_resource_type(),
+ name=f"{workspace_name}/{name}/{version}",
+ location=location,
+ properties=self._to_rest_object().properties,
+ api_version="2023-03-01-preview",
+ )
+
+ def _build_rest_object_for_remote_validation(self, location: str, workspace_name: str) -> Deployment:
+ return Deployment(
+ properties=DeploymentProperties(
+ mode="Incremental",
+ template=ValidationTemplateRequest(
+ _schema="https://schema.management.azure.com/schemas/2019-04-01/deploymentTemplate.json#",
+ content_version="1.0.0.0",
+ parameters={},
+ resources=[self._to_preflight_resource(location=location, workspace_name=workspace_name)],
+ ),
+ )
+ )
+
+ @classmethod
+ def _build_validation_result_from_rest_object(cls, rest_obj: DeploymentValidateResult) -> MutableValidationResult:
+ """Create a validation result from a rest object. Note that the created validation result does not have
+ target_obj so should only be used for merging.
+
+ :param rest_obj: The Deployment Validate REST obj
+ :type rest_obj: DeploymentValidateResult
+ :return: The validation result created from rest_obj
+ :rtype: MutableValidationResult
+ """
+ if not rest_obj.error or not rest_obj.error.details:
+ return ValidationResultBuilder.success()
+ result = MutableValidationResult(target_obj=None)
+ details: typing.List[ErrorResponse] = rest_obj.error.details
+ for detail in details:
+ result.append_error(
+ message=detail.message,
+ yaml_path=detail.target.replace("/", "."),
+ error_code=detail.code,
+ # will always be UserError for now, not sure if innerError can be passed back
+ )
+ return result
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/schema.py
new file mode 100644
index 00000000..9e34173d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/schema.py
@@ -0,0 +1,156 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import json
+import logging
+import typing
+
+from marshmallow import Schema, ValidationError
+
+from .core import MutableValidationResult, ValidationResultBuilder
+
+module_logger = logging.getLogger(__name__)
+
+
+class SchemaValidatableMixin:
+ """The mixin class for schema validation."""
+
+ @classmethod
+ def _create_empty_validation_result(cls) -> MutableValidationResult:
+ """Simply create an empty validation result
+
+ To reduce _ValidationResultBuilder importing, which is a private class.
+
+ :return: An empty validation result
+ :rtype: MutableValidationResult
+ """
+ return ValidationResultBuilder.success()
+
+ @classmethod
+ def _load_with_schema(
+ cls, data: typing.Any, *, context: typing.Any, raise_original_exception: bool = False, **kwargs: typing.Any
+ ) -> typing.Any:
+ schema = cls._create_schema_for_validation(context=context)
+
+ try:
+ return schema.load(data, **kwargs)
+ except ValidationError as e:
+ if raise_original_exception:
+ raise e
+ msg = "Trying to load data with schema failed. Data:\n%s\nError: %s" % (
+ json.dumps(data, indent=4) if isinstance(data, dict) else data,
+ json.dumps(e.messages, indent=4),
+ )
+ raise cls._create_validation_error(
+ message=msg,
+ no_personal_data_message=str(e),
+ ) from e
+
+ @classmethod
+ # pylint: disable-next=docstring-missing-param
+ def _create_schema_for_validation(cls, context: typing.Any) -> Schema:
+ """Create a schema of the resource with specific context. Should be overridden by subclass.
+
+ :return: The schema of the resource.
+ :rtype: Schema.
+ """
+ raise NotImplementedError()
+
+ def _default_context(self) -> dict:
+ """Get the default context for schema validation. Should be overridden by subclass.
+
+ :return: The default context for schema validation
+ :rtype: dict
+ """
+ raise NotImplementedError()
+
+ @property
+ def _schema_for_validation(self) -> Schema:
+ """Return the schema of this Resource with default context. Do not override this method.
+ Override _create_schema_for_validation instead.
+
+ :return: The schema of the resource.
+ :rtype: Schema.
+ """
+ return self._create_schema_for_validation(context=self._default_context())
+
+ def _dump_for_validation(self) -> typing.Dict:
+ """Convert the resource to a dictionary.
+
+ :return: Converted dictionary
+ :rtype: typing.Dict
+ """
+ res: dict = self._schema_for_validation.dump(self)
+ return res
+
+ @classmethod
+ def _create_validation_error(cls, message: str, no_personal_data_message: str) -> Exception:
+ """The function to create the validation exception to raise in _try_raise and _validate when
+ raise_error is True.
+
+ Should be overridden by subclass.
+
+ :param message: The error message containing detailed information
+ :type message: str
+ :param no_personal_data_message: The error message without personal data
+ :type no_personal_data_message: str
+ :return: The validation exception to raise
+ :rtype: Exception
+ """
+ raise NotImplementedError()
+
+ @classmethod
+ def _try_raise(
+ cls, validation_result: MutableValidationResult, *, raise_error: typing.Optional[bool] = True
+ ) -> MutableValidationResult:
+ return validation_result.try_raise(raise_error=raise_error, error_func=cls._create_validation_error)
+
+ def _validate(self, raise_error: typing.Optional[bool] = False) -> MutableValidationResult:
+ """Validate the resource. If raise_error is True, raise ValidationError if validation fails and log warnings if
+ applicable; Else, return the validation result.
+
+ :param raise_error: Whether to raise ValidationError if validation fails.
+ :type raise_error: bool
+ :return: The validation result
+ :rtype: MutableValidationResult
+ """
+ result = self.__schema_validate()
+ result.merge_with(self._customized_validate())
+ return self._try_raise(result, raise_error=raise_error)
+
+ def _customized_validate(self) -> MutableValidationResult:
+ """Validate the resource with customized logic.
+
+ Override this method to add customized validation logic.
+
+ :return: The customized validation result
+ :rtype: MutableValidationResult
+ """
+ return self._create_empty_validation_result()
+
+ @classmethod
+ def _get_skip_fields_in_schema_validation(
+ cls,
+ ) -> typing.List[str]:
+ """Get the fields that should be skipped in schema validation.
+
+ Override this method to add customized validation logic.
+
+ :return: The fields to skip in schema validation
+ :rtype: typing.List[str]
+ """
+ return []
+
+ def __schema_validate(self) -> MutableValidationResult:
+ """Validate the resource with the schema.
+
+ :return: The validation result
+ :rtype: MutableValidationResult
+ """
+ data = self._dump_for_validation()
+ messages = self._schema_for_validation.validate(data)
+ for skip_field in self._get_skip_fields_in_schema_validation():
+ if skip_field in messages:
+ del messages[skip_field]
+ return ValidationResultBuilder.from_validation_messages(messages, data=data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/__init__.py
new file mode 100644
index 00000000..fdf8caba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/_ai_workspaces/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/_ai_workspaces/__init__.py
new file mode 100644
index 00000000..fdf8caba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/_ai_workspaces/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/_ai_workspaces/_constants.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/_ai_workspaces/_constants.py
new file mode 100644
index 00000000..1e75a1c2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/_ai_workspaces/_constants.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+ENDPOINT_AI_SERVICE_KIND = "AIServices"
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/_ai_workspaces/capability_host.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/_ai_workspaces/capability_host.py
new file mode 100644
index 00000000..f86ea8ed
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/_ai_workspaces/capability_host.py
@@ -0,0 +1,187 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import os
+from os import PathLike
+from typing import (
+ List,
+ Optional,
+ Union,
+ IO,
+ Any,
+ AnyStr,
+ Dict,
+)
+from pathlib import Path
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.entities._resource import Resource
+from azure.ai.ml.constants._workspace import CapabilityHostKind
+from azure.ai.ml.constants._common import (
+ BASE_PATH_CONTEXT_KEY,
+ PARAMS_OVERRIDE_KEY,
+)
+
+from azure.ai.ml._schema.workspace.ai_workspaces.capability_host import (
+ CapabilityHostSchema,
+)
+from azure.ai.ml._utils.utils import dump_yaml_to_file
+from azure.ai.ml.entities._util import load_from_dict
+from azure.ai.ml._restclient.v2024_10_01_preview.models._models_py3 import (
+ CapabilityHost as RestCapabilityHost,
+)
+from azure.ai.ml._restclient.v2024_10_01_preview.models._models_py3 import (
+ CapabilityHostProperties as RestCapabilityHostProperties,
+)
+
+
+@experimental
+class CapabilityHost(Resource):
+ """Initialize a CapabilityHost instance.
+ Capabilityhost management is controlled by MLClient's capabilityhosts operations.
+
+ :param name: The name of the capability host.
+ :type name: str
+ :param description: The description of the capability host.
+ :type description: Optional[str]
+ :param vector_store_connections: A list of vector store (AI Search) connections.
+ :type vector_store_connections: Optional[List[str]]
+ :param ai_services_connections: A list of OpenAI service connection.
+ :type ai_services_connections: Optional[List[str]]
+ :param storage_connections: A list of storage connections. Default storage connection value is
+ projectname/workspaceblobstore for project workspace.
+ :type storage_connections: Optional[List[str]]
+ :param capability_host_kind: The kind of capability host, either as a string or CapabilityHostKind enum.
+ Default is AGENTS.
+ :type capability_host_kind: Union[str, CapabilityHostKind]
+ :param kwargs: Additional keyword arguments.
+ :type kwargs: Any
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_capability_host.py
+ :start-after: [START capability_host_object_create]
+ :end-before: [END capability_host_object_create]
+ :language: python
+ :dedent: 8
+ :caption: Create a CapabilityHost object.
+ """
+
+ def __init__(
+ self,
+ *,
+ name: str,
+ description: Optional[str] = None,
+ vector_store_connections: Optional[List[str]] = None,
+ ai_services_connections: Optional[List[str]] = None,
+ storage_connections: Optional[List[str]] = None,
+ capability_host_kind: Union[str, CapabilityHostKind] = CapabilityHostKind.AGENTS,
+ **kwargs: Any,
+ ):
+ super().__init__(name=name, description=description, **kwargs)
+ self.capability_host_kind = capability_host_kind
+ self.ai_services_connections = ai_services_connections
+ self.storage_connections = storage_connections
+ self.vector_store_connections = vector_store_connections
+
+ def dump(
+ self,
+ dest: Optional[Union[str, PathLike, IO[AnyStr]]],
+ **kwargs: Any,
+ ) -> None:
+ """Dump the CapabilityHost content into a file in yaml format.
+
+ :param dest: The destination to receive this CapabilityHost's content.
+ Must be either a path to a local file, or an already-open file stream.
+ If dest is a file path, a new file will be created,
+ and an exception is raised if the file exists.
+ If dest is an open file, the file will be written to directly,
+ and an exception will be raised if the file is not writable.
+ :type dest: Union[PathLike, str, IO[AnyStr]]
+ """
+ path = kwargs.pop("path", None)
+ yaml_serialized = self._to_dict()
+ dump_yaml_to_file(dest, yaml_serialized, default_flow_style=False, path=path, **kwargs)
+
+ def _to_dict(self) -> Dict:
+ """Dump the object into a dictionary.
+
+ :return: Dictionary representation of the object.
+ :rtype: Dict
+ """
+
+ return CapabilityHostSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+
+ @classmethod
+ def _load(
+ cls,
+ data: Optional[dict] = None,
+ yaml_path: Optional[Union[os.PathLike, str]] = None,
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "CapabilityHost":
+ """Load a capabilityhost object from a yaml file.
+
+ :param cls: Indicates that this is a class method.
+ :type cls: class
+ :param data: Data Dictionary, defaults to None
+ :type data: Dict
+ :param yaml_path: YAML Path, defaults to None
+ :type yaml_path: Union[PathLike, str]
+ :param params_override: Fields to overwrite on top of the yaml file.
+ Format is [{"field1": "value1"}, {"field2": "value2"}], defaults to None
+ :type params_override: List[Dict]
+ :raises Exception: An exception
+ :return: Loaded CapabilityHost object.
+ :rtype: ~azure.ai.ml.entities._workspace._ai_workspaces.capability_host.CapabilityHost
+ """
+ params_override = params_override or []
+ data = data or {}
+ context = {
+ BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"),
+ PARAMS_OVERRIDE_KEY: params_override,
+ }
+ return cls(**load_from_dict(CapabilityHostSchema, data, context, **kwargs))
+
+ @classmethod
+ def _from_rest_object(cls, rest_obj: RestCapabilityHost) -> "CapabilityHost":
+ """Convert a REST object into a CapabilityHost object.
+
+ :param cls: Indicates that this is a class method.
+ :type cls: class
+ :param rest_obj: The REST object to convert.
+ :type rest_obj: ~azure.ai.ml._restclient.v2024_10_01_preview.models._models_py3.CapabilityHost
+ :return: CapabilityHost object.
+ :rtype: ~azure.ai.ml.entities._workspace._ai_workspaces.capability_host.CapabilityHost
+ """
+ capability_host = cls(
+ name=str(rest_obj.name),
+ description=(rest_obj.properties.description if rest_obj.properties else None),
+ ai_services_connections=(rest_obj.properties.ai_services_connections if rest_obj.properties else None),
+ storage_connections=(rest_obj.properties.storage_connections if rest_obj.properties else None),
+ vector_store_connections=(rest_obj.properties.vector_store_connections if rest_obj.properties else None),
+ capability_host_kind=(
+ rest_obj.properties.capability_host_kind if rest_obj.properties else CapabilityHostKind.AGENTS
+ ),
+ )
+ return capability_host
+
+ def _to_rest_object(self) -> RestCapabilityHost:
+ """
+ Convert the CapabilityHost instance to a RestCapabilityHost object.
+
+ :return: A RestCapabilityHost object representing the capability host for a Hub or Project workspace.
+ :rtype: azure.ai.ml._restclient.v2024_10_01_preview.models._models_py3.CapabilityHost
+ """
+
+ properties = RestCapabilityHostProperties(
+ ai_services_connections=self.ai_services_connections,
+ storage_connections=self.storage_connections,
+ vector_store_connections=self.vector_store_connections,
+ description=self.description,
+ capability_host_kind=self.capability_host_kind,
+ )
+ resource = RestCapabilityHost(
+ properties=properties,
+ )
+ return resource
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/_ai_workspaces/hub.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/_ai_workspaces/hub.py
new file mode 100644
index 00000000..4caac057
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/_ai_workspaces/hub.py
@@ -0,0 +1,220 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Any, Dict, List, Optional
+
+from azure.ai.ml._restclient.v2024_10_01_preview.models import Workspace as RestWorkspace
+from azure.ai.ml._restclient.v2024_10_01_preview.models import WorkspaceHubConfig as RestWorkspaceHubConfig
+from azure.ai.ml._schema.workspace import HubSchema
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants._common import WorkspaceKind
+from azure.ai.ml.entities._credentials import IdentityConfiguration
+from azure.ai.ml.entities._workspace.customer_managed_key import CustomerManagedKey
+from azure.ai.ml.entities._workspace.network_acls import NetworkAcls
+from azure.ai.ml.entities._workspace.networking import ManagedNetwork
+from azure.ai.ml.entities._workspace.workspace import Workspace
+
+
+@experimental
+class Hub(Workspace):
+ """A Hub is a special type of workspace that acts as a parent and resource container for lightweight child
+ workspaces called projects. Resources like the hub's storage account, key vault,
+ and container registry are shared by all child projects.
+
+ As a type of workspace, hub management is controlled by an MLClient's workspace operations.
+
+ :param name: Name of the hub.
+ :type name: str
+ :param description: Description of the hub.
+ :type description: str
+ :param tags: Tags of the hub.
+ :type tags: dict
+ :param display_name: Display name for the hub. This is non-unique within the resource group.
+ :type display_name: str
+ :param location: The location to create the hub in.
+ If not specified, the same location as the resource group will be used.
+ :type location: str
+ :param resource_group: Name of resource group to create the hub in.
+ :type resource_group: str
+ :param managed_network: Hub's Managed Network configuration
+ :type managed_network: ~azure.ai.ml.entities.ManagedNetwork
+ :param storage_account: The resource ID of an existing storage account to use instead of creating a new one.
+ :type storage_account: str
+ :param key_vault: The resource ID of an existing key vault to use instead of creating a new one.
+ :type key_vault: str
+ :param container_registry: The resource ID of an existing container registry
+ to use instead of creating a new one.
+ :type container_registry: str
+ :param customer_managed_key: Key vault details for encrypting data with customer-managed keys.
+ If not specified, Microsoft-managed keys will be used by default.
+ :type customer_managed_key: ~azure.ai.ml.entities.CustomerManagedKey
+ :param image_build_compute: The name of the compute target to use for building environment.
+ Docker images with the container registry is behind a VNet.
+ :type image_build_compute: str
+ :param public_network_access: Whether to allow public endpoint connectivity.
+ when a workspace is private link enabled.
+ :type public_network_access: str
+ :param network_acls: The network access control list (ACL) settings of the workspace.
+ :type network_acls: ~azure.ai.ml.entities.NetworkAcls
+ :param identity: The hub's Managed Identity (user assigned, or system assigned).
+ :type identity: ~azure.ai.ml.entities.IdentityConfiguration
+ :param primary_user_assigned_identity: The hub's primary user assigned identity.
+ :type primary_user_assigned_identity: str
+ :param enable_data_isolation: A flag to determine if workspace has data isolation enabled.
+ The flag can only be set at the creation phase, it can't be updated.
+ :type enable_data_isolation: bool
+ :param default_resource_group: The resource group that will be used by projects
+ created under this hub if no resource group is specified.
+ :type default_resource_group: str
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: dict
+
+ .. literalinclude:: ../samples/ml_samples_workspace.py
+ :start-after: [START workspace_hub]
+ :end-before: [END workspace_hub]
+ :language: python
+ :dedent: 8
+ :caption: Creating a Hub object.
+ """
+
+ # The field 'additional_workspace_storage_accounts' exists in the API but is currently unused.
+
+ def __init__(
+ self,
+ *,
+ name: str,
+ description: Optional[str] = None,
+ tags: Optional[Dict[str, str]] = None,
+ display_name: Optional[str] = None,
+ location: Optional[str] = None,
+ resource_group: Optional[str] = None,
+ managed_network: Optional[ManagedNetwork] = None,
+ storage_account: Optional[str] = None,
+ key_vault: Optional[str] = None,
+ container_registry: Optional[str] = None,
+ customer_managed_key: Optional[CustomerManagedKey] = None,
+ public_network_access: Optional[str] = None,
+ network_acls: Optional[NetworkAcls] = None,
+ identity: Optional[IdentityConfiguration] = None,
+ primary_user_assigned_identity: Optional[str] = None,
+ enable_data_isolation: bool = False,
+ default_resource_group: Optional[str] = None,
+ associated_workspaces: Optional[List[str]] = None, # hidden input for rest->client conversions.
+ **kwargs: Any,
+ ):
+ self._workspace_id = kwargs.pop("workspace_id", "")
+ # Ensure user can't overwrite/double input kind.
+ kwargs.pop("kind", None)
+ super().__init__(
+ name=name,
+ description=description,
+ tags=tags,
+ kind=WorkspaceKind.HUB,
+ display_name=display_name,
+ location=location,
+ storage_account=storage_account,
+ key_vault=key_vault,
+ container_registry=container_registry,
+ resource_group=resource_group,
+ customer_managed_key=customer_managed_key,
+ public_network_access=public_network_access,
+ network_acls=network_acls,
+ identity=identity,
+ primary_user_assigned_identity=primary_user_assigned_identity,
+ managed_network=managed_network,
+ enable_data_isolation=enable_data_isolation,
+ **kwargs,
+ )
+ self._default_resource_group = default_resource_group
+ self._associated_workspaces = associated_workspaces
+
+ @classmethod
+ def _get_schema_class(cls):
+ return HubSchema
+
+ @classmethod
+ def _from_rest_object(cls, rest_obj: RestWorkspace, v2_service_context: Optional[object] = None) -> Optional["Hub"]:
+ if not rest_obj:
+ return None
+
+ workspace_object = Workspace._from_rest_object(rest_obj, v2_service_context)
+
+ default_resource_group = None
+
+ if hasattr(rest_obj, "workspace_hub_config"):
+ if rest_obj.workspace_hub_config and isinstance(rest_obj.workspace_hub_config, RestWorkspaceHubConfig):
+ default_resource_group = rest_obj.workspace_hub_config.default_workspace_resource_group
+
+ if workspace_object is not None:
+ return Hub(
+ name=workspace_object.name if workspace_object.name is not None else "",
+ description=workspace_object.description,
+ tags=workspace_object.tags,
+ display_name=workspace_object.display_name,
+ location=workspace_object.location,
+ resource_group=workspace_object.resource_group,
+ managed_network=workspace_object.managed_network,
+ customer_managed_key=workspace_object.customer_managed_key,
+ public_network_access=workspace_object.public_network_access,
+ network_acls=workspace_object.network_acls,
+ identity=workspace_object.identity,
+ primary_user_assigned_identity=workspace_object.primary_user_assigned_identity,
+ storage_account=rest_obj.storage_account,
+ key_vault=rest_obj.key_vault,
+ container_registry=rest_obj.container_registry,
+ workspace_id=rest_obj.workspace_id,
+ enable_data_isolation=rest_obj.enable_data_isolation,
+ default_resource_group=default_resource_group,
+ associated_workspaces=rest_obj.associated_workspaces if rest_obj.associated_workspaces else [],
+ id=rest_obj.id,
+ )
+ return None
+
+ # Helper function to deal with sub-rest object conversion.
+ def _hub_values_to_rest_object(self) -> RestWorkspaceHubConfig:
+ additional_workspace_storage_accounts = None
+ default_resource_group = None
+ if hasattr(self, "additional_workspace_storage_accounts"):
+ additional_workspace_storage_accounts = None
+ if hasattr(self, "default_resource_group"):
+ default_resource_group = None
+ return RestWorkspaceHubConfig(
+ additional_workspace_storage_accounts=additional_workspace_storage_accounts,
+ default_workspace_resource_group=default_resource_group,
+ )
+
+ def _to_rest_object(self) -> RestWorkspace:
+ restWorkspace = super()._to_rest_object()
+ restWorkspace.workspace_hub_config = self._hub_values_to_rest_object()
+ return restWorkspace
+
+ @property
+ def default_resource_group(self) -> Optional[str]:
+ """The default resource group for this hub and its children.
+
+ :return: The resource group.
+ :rtype: Optional[str]
+ """
+ return self._default_resource_group
+
+ @default_resource_group.setter
+ def default_resource_group(self, value: str):
+ """Set the default resource group for child projects of this hub.
+
+ :param value: The new resource group.
+ :type value: str
+ """
+ if not value:
+ return
+ self._default_resource_group = value
+
+ # No setter, read-only
+ @property
+ def associated_workspaces(self) -> Optional[List[str]]:
+ """The workspaces associated with the hub.
+
+ :return: The resource group.
+ :rtype: Optional[List[str]]
+ """
+ return self._associated_workspaces
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/_ai_workspaces/project.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/_ai_workspaces/project.py
new file mode 100644
index 00000000..ffad4922
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/_ai_workspaces/project.py
@@ -0,0 +1,89 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+
+from typing import Any, Dict, Optional
+
+from azure.ai.ml._schema.workspace import ProjectSchema
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants._common import WorkspaceKind
+from azure.ai.ml.entities._workspace.workspace import Workspace
+
+
+# Effectively a lightweight wrapper around a v2 SDK workspace
+@experimental
+class Project(Workspace):
+ """A Project is a lightweight object for orchestrating AI applications, and is parented by a hub.
+ Unlike a standard workspace, a project does not have a variety of sub-resources directly associated with it.
+ Instead, its parent hub managed these resources, which are then used by the project and its siblings.
+
+ As a type of workspace, project management is controlled by an MLClient's workspace operations.
+
+ :param name: The name of the project.
+ :type name: str
+ :param hub_id: The hub parent of the project, as a resource ID.
+ :type hub_id: str
+ :param description: The description of the project.
+ :type description: Optional[str]
+ :param tags: Tags associated with the project.
+ :type tags: Optional[Dict[str, str]]
+ :param display_name: The display name of the project.
+ :type display_name: Optional[str]
+ :param location: The location of the project. Must match that of the parent hub
+ and is automatically assigned to match the parent hub's location during creation.
+ :type location: Optional[str]
+ :param resource_group: The project's resource group name.
+ :type resource_group: Optional[str]
+ """
+
+ def __init__(
+ self,
+ *,
+ name: str,
+ hub_id: str,
+ description: Optional[str] = None,
+ tags: Optional[Dict[str, str]] = None,
+ display_name: Optional[str] = None,
+ location: Optional[str] = None,
+ resource_group: Optional[str] = None,
+ **kwargs,
+ ) -> None:
+ # Ensure user can't overwrite/double input kind.
+ kwargs.pop("kind", None)
+ super().__init__(
+ name=name,
+ description=description,
+ tags=tags,
+ kind=WorkspaceKind.PROJECT,
+ display_name=display_name,
+ location=location,
+ resource_group=resource_group,
+ hub_id=hub_id,
+ **kwargs,
+ )
+
+ @classmethod
+ def _get_schema_class(cls) -> Any:
+ return ProjectSchema
+
+ @property
+ def hub_id(self) -> str:
+ """The UID of the hub parent of the project.
+
+ :return: Resource ID of the parent hub.
+ :rtype: str
+ """
+ return self._hub_id if self._hub_id else ""
+
+ @hub_id.setter
+ def hub_id(self, value: str):
+ """Set the parent hub id of the project.
+
+ :param value: The hub id to assign to the project.
+ Note: cannot be reassigned after creation.
+ :type value: str
+ """
+ if not value:
+ return
+ self._hub_id = value
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/compute_runtime.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/compute_runtime.py
new file mode 100644
index 00000000..bc7ee127
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/compute_runtime.py
@@ -0,0 +1,41 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Optional
+
+from azure.ai.ml._restclient.v2023_06_01_preview.models import ComputeRuntimeDto as RestComputeRuntimeDto
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+
+class ComputeRuntime(RestTranslatableMixin):
+ """Spark compute runtime configuration.
+
+ :keyword spark_runtime_version: Spark runtime version.
+ :paramtype spark_runtime_version: Optional[str]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_compute.py
+ :start-after: [START compute_runtime]
+ :end-before: [END compute_runtime]
+ :language: python
+ :dedent: 8
+ :caption: Creating a ComputeRuntime object.
+ """
+
+ def __init__(
+ self,
+ *,
+ spark_runtime_version: Optional[str] = None,
+ ) -> None:
+ self.spark_runtime_version = spark_runtime_version
+
+ def _to_rest_object(self) -> RestComputeRuntimeDto:
+ return RestComputeRuntimeDto(spark_runtime_version=self.spark_runtime_version)
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestComputeRuntimeDto) -> Optional["ComputeRuntime"]:
+ if not obj:
+ return None
+ return ComputeRuntime(spark_runtime_version=obj.spark_runtime_version)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/connections/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/connections/__init__.py
new file mode 100644
index 00000000..fdf8caba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/connections/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/connections/connection_subtypes.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/connections/connection_subtypes.py
new file mode 100644
index 00000000..d97e513e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/connections/connection_subtypes.py
@@ -0,0 +1,748 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import re
+from typing import Any, Dict, List, Optional, Type, Union
+
+from azure.ai.ml._restclient.v2024_04_01_preview.models import ConnectionCategory
+from azure.ai.ml._schema.workspace.connections.connection_subtypes import (
+ APIKeyConnectionSchema,
+ AzureAISearchConnectionSchema,
+ AzureAIServicesConnectionSchema,
+ AzureBlobStoreConnectionSchema,
+ AzureContentSafetyConnectionSchema,
+ AzureOpenAIConnectionSchema,
+ AzureSpeechServicesConnectionSchema,
+ MicrosoftOneLakeConnectionSchema,
+ OpenAIConnectionSchema,
+ SerpConnectionSchema,
+ ServerlessConnectionSchema,
+)
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.constants._common import (
+ CONNECTION_ACCOUNT_NAME_KEY,
+ CONNECTION_API_TYPE_KEY,
+ CONNECTION_API_VERSION_KEY,
+ CONNECTION_CONTAINER_NAME_KEY,
+ CONNECTION_KIND_KEY,
+ CONNECTION_RESOURCE_ID_KEY,
+ CognitiveServiceKinds,
+ ConnectionTypes,
+)
+from azure.ai.ml.entities._credentials import AadCredentialConfiguration, ApiKeyConfiguration
+
+from .one_lake_artifacts import OneLakeConnectionArtifact
+from .workspace_connection import WorkspaceConnection
+
+# Dev notes: Any new classes require modifying the elif chains in the following functions in the
+# WorkspaceConnection parent class: _from_rest_object, _get_entity_class_from_type, _get_schema_class_from_type
+
+
+class AzureBlobStoreConnection(WorkspaceConnection):
+ """A connection to an Azure Blob Store.
+
+ :param name: Name of the connection.
+ :type name: str
+ :param url: The URL or ARM resource ID of the external resource.
+ :type url: str
+ :param container_name: The name of the container.
+ :type container_name: str
+ :param account_name: The name of the account.
+ :type account_name: str
+ :param credentials: The credentials for authenticating to the blob store. This type of
+ connection accepts 3 types of credentials: account key and SAS token credentials,
+ or NoneCredentialConfiguration for credential-less connections.
+ :type credentials: Union[
+ ~azure.ai.ml.entities.AccountKeyConfiguration,
+ ~azure.ai.ml.entities.SasTokenConfiguration,
+ ~azure.ai.ml.entities.AadCredentialConfiguration,
+ ]
+ :param metadata: Metadata dictionary.
+ :type metadata: Optional[dict[str,str]]
+ """
+
+ def __init__(
+ self,
+ *,
+ url: str,
+ container_name: str,
+ account_name: str,
+ metadata: Optional[Dict[Any, Any]] = None,
+ **kwargs,
+ ):
+ kwargs.pop("type", None) # make sure we never somehow use wrong type
+ # Blob store connections returned from the API generally have no credentials, but we still don't want
+ # to silently run over user inputted connections if they want to play with them locally, so double-check
+ # kwargs for them.
+ if metadata is None:
+ metadata = {}
+ metadata[CONNECTION_CONTAINER_NAME_KEY] = container_name
+ metadata[CONNECTION_ACCOUNT_NAME_KEY] = account_name
+
+ super().__init__(
+ url=url,
+ type=camel_to_snake(ConnectionCategory.AZURE_BLOB),
+ from_child=True,
+ metadata=metadata,
+ **kwargs,
+ )
+
+ @classmethod
+ def _get_required_metadata_fields(cls) -> List[str]:
+ return [CONNECTION_CONTAINER_NAME_KEY, CONNECTION_ACCOUNT_NAME_KEY]
+
+ @classmethod
+ def _get_schema_class(cls) -> Type:
+ return AzureBlobStoreConnectionSchema
+
+ @property
+ def container_name(self) -> Optional[str]:
+ """The name of the connection's container.
+
+ :return: The name of the container.
+ :rtype: Optional[str]
+ """
+ if self.metadata is not None:
+ return self.metadata.get(CONNECTION_CONTAINER_NAME_KEY, None)
+ return None
+
+ @container_name.setter
+ def container_name(self, value: str) -> None:
+ """Set the container name of the connection.
+
+ :param value: The new container name to set.
+ :type value: str
+ """
+ if self.metadata is None:
+ self.metadata = {}
+ self.metadata[CONNECTION_CONTAINER_NAME_KEY] = value
+
+ @property
+ def account_name(self) -> Optional[str]:
+ """The name of the connection's account
+
+ :return: The name of the account.
+ :rtype: Optional[str]
+ """
+ if self.metadata is not None:
+ return self.metadata.get(CONNECTION_ACCOUNT_NAME_KEY, None)
+ return None
+
+ @account_name.setter
+ def account_name(self, value: str) -> None:
+ """Set the account name of the connection.
+
+ :param value: The new account name to set.
+ :type value: str
+ """
+ if self.metadata is None:
+ self.metadata = {}
+ self.metadata[CONNECTION_ACCOUNT_NAME_KEY] = value
+
+
+# Dev note: One lake connections are unfortunately unique in that it's extremely
+# difficult for customers to find out what the target for their system ought to be.
+# Due to this, we construct the target internally by composing more inputs
+# that are more user-accessible.
+class MicrosoftOneLakeConnection(WorkspaceConnection):
+ """A connection to a Microsoft One Lake. Connections of this type
+ are further specified by their artifact class type, although
+ the number of artifact classes is currently limited.
+
+ :param name: Name of the connection.
+ :type name: str
+ :param endpoint: The endpoint of the connection.
+ :type endpoint: str
+ :param artifact: The artifact class used to further specify the connection.
+ :type artifact: Optional[~azure.ai.ml.entities.OneLakeArtifact]
+ :param one_lake_workspace_name: The name, not ID, of the workspace where the One Lake
+ resource lives.
+ :type one_lake_workspace_name: Optional[str]
+ :param credentials: The credentials for authenticating to the blob store. This type of
+ connection accepts 3 types of credentials: account key and SAS token credentials,
+ or NoneCredentialConfiguration for credential-less connections.
+ :type credentials: Union[
+ ~azure.ai.ml.entities.AccessKeyConfiguration,
+ ~azure.ai.ml.entities.SasTokenConfiguration,
+ ~azure.ai.ml.entities.AadCredentialConfiguration,
+ ]
+ :param metadata: Metadata dictionary.
+ :type metadata: Optional[dict[str,str]]
+ """
+
+ def __init__(
+ self,
+ *,
+ endpoint: str,
+ artifact: Optional[OneLakeConnectionArtifact] = None,
+ one_lake_workspace_name: Optional[str] = None,
+ metadata: Optional[Dict[Any, Any]] = None,
+ **kwargs,
+ ):
+ kwargs.pop("type", None) # make sure we never somehow use wrong type
+
+ # Allow target to be inputted for from-rest conversions where we don't
+ # need to worry about data-availability nonsense.
+ target = kwargs.pop("target", None)
+ if target is None:
+ if artifact is None:
+ raise ValueError("If target is unset, then artifact must be set")
+ if endpoint is None:
+ raise ValueError("If target is unset, then endpoint must be set")
+ if one_lake_workspace_name is None:
+ raise ValueError("If target is unset, then one_lake_workspace_name must be set")
+ target = MicrosoftOneLakeConnection._construct_target(endpoint, one_lake_workspace_name, artifact)
+ super().__init__(
+ target=target,
+ type=camel_to_snake(ConnectionCategory.AZURE_ONE_LAKE),
+ from_child=True,
+ metadata=metadata,
+ **kwargs,
+ )
+
+ @classmethod
+ def _get_schema_class(cls) -> Type:
+ return MicrosoftOneLakeConnectionSchema
+
+ # Target is constructed from user inputs, because it's apparently very difficult for users to
+ # directly access a One Lake's target URL.
+ @classmethod
+ def _construct_target(cls, endpoint: str, workspace: str, artifact: OneLakeConnectionArtifact) -> str:
+ artifact_name = artifact.name
+ # If an id is supplied, the format is different
+ if re.match(".{7}-.{4}-.{4}-.{4}.{12}", artifact_name):
+ return f"https://{endpoint}/{workspace}/{artifact_name}"
+ return f"https://{endpoint}/{workspace}/{artifact_name}.Lakehouse"
+
+
+# There are enough types of connections that their only accept an api key credential,
+# or just an api key credential or no credentials, that it merits a parent class for
+# all of them. One that's slightly more specific than the base Connection.
+# This file contains that parent class, as well as all of its children.
+# Not experimental since users should never see this,
+# No need to add an extra warning.
+class ApiOrAadConnection(WorkspaceConnection):
+ """Internal parent class for all connections that accept either an api key or
+ entra ID as credentials. Entra ID credentials are implicitly assumed if no api key is provided.
+
+ :param name: Name of the connection.
+ :type name: str
+ :param target: The URL or ARM resource ID of the external resource.
+ :type target: str
+ :param api_key: The api key to connect to the azure endpoint.
+ If unset, tries to use the user's Entra ID as credentials instead.
+ :type api_key: Optional[str]
+ :param api_version: The api version that this connection was created for.
+ :type api_version: Optional[str]
+ :param type: The type of the connection.
+ :type type: str
+ :param allow_entra: Whether or not this connection allows initialization without
+ an API key via Aad. Defaults to True.
+ :type allow_entra: bool
+ """
+
+ def __init__(
+ self,
+ *,
+ api_key: Optional[str] = None,
+ allow_entra: bool = True,
+ type: str,
+ metadata: Optional[Dict[Any, Any]] = None,
+ **kwargs: Any,
+ ):
+ # See if credentials directly inputted via kwargs
+ credentials: Union[AadCredentialConfiguration, ApiKeyConfiguration] = kwargs.pop(
+ "credentials", AadCredentialConfiguration()
+ )
+ # Replace anything that isn't an API credential with an AAD credential.
+ # Importantly, this replaced the None credential default from the parent YAML schema.
+ if not isinstance(credentials, ApiKeyConfiguration):
+ credentials = AadCredentialConfiguration()
+ # Further replace that if a key is provided
+ if api_key:
+ credentials = ApiKeyConfiguration(key=api_key)
+ elif not allow_entra and isinstance(credentials, AadCredentialConfiguration):
+ # If no creds are provided in any capacity when needed. complain.
+ raise ValueError("This connection type must set the api_key value.")
+
+ super().__init__(
+ type=type,
+ credentials=credentials,
+ metadata=metadata,
+ **kwargs,
+ )
+
+ @property
+ def api_key(self) -> Optional[str]:
+ """The API key of the connection.
+
+ :return: The API key of the connection.
+ :rtype: Optional[str]
+ """
+ if isinstance(self._credentials, ApiKeyConfiguration):
+ return self._credentials.key
+ return None
+
+ @api_key.setter
+ def api_key(self, value: str) -> None:
+ """Set the API key of the connection. Setting this to None will
+ cause the connection to use the user's Entra ID as credentials.
+
+ :param value: The new API key to set.
+ :type value: str
+ """
+ if value is None:
+ self._credentials = AadCredentialConfiguration()
+ else:
+ self._credentials = ApiKeyConfiguration(key=value)
+
+
+@experimental
+class AzureOpenAIConnection(ApiOrAadConnection):
+ """A Connection that is specifically designed for handling connections
+ to Azure Open AI.
+
+ :param name: Name of the connection.
+ :type name: str
+ :param azure_endpoint: The URL or ARM resource ID of the Azure Open AI Resource.
+ :type azure_endpoint: str
+ :param api_key: The api key to connect to the azure endpoint.
+ If unset, tries to use the user's Entra ID as credentials instead.
+ :type api_key: Optional[str]
+ :param open_ai_resource_id: The fully qualified ID of the Azure Open AI resource to connect to.
+ :type open_ai_resource_id: Optional[str]
+ :param api_version: The api version that this connection was created for.
+ :type api_version: Optional[str]
+ :param metadata: Metadata dictionary.
+ :type metadata: Optional[dict[str,str]]
+ """
+
+ def __init__(
+ self,
+ *,
+ azure_endpoint: str,
+ api_key: Optional[str] = None,
+ api_version: Optional[str] = None,
+ api_type: str = "Azure", # Required API input, hidden to allow for rare overrides
+ open_ai_resource_id: Optional[str] = None,
+ metadata: Optional[Dict[Any, Any]] = None,
+ **kwargs: Any,
+ ):
+ kwargs.pop("type", None) # make sure we never somehow use wrong type
+ # Sneak in resource ID as it's inputted from rest conversions as a kwarg.
+ from_rest_resource_id = kwargs.pop("resource_id", None)
+ if open_ai_resource_id is None and from_rest_resource_id is not None:
+ open_ai_resource_id = from_rest_resource_id
+
+ if metadata is None:
+ metadata = {}
+ metadata[CONNECTION_API_VERSION_KEY] = api_version
+ metadata[CONNECTION_API_TYPE_KEY] = api_type
+ metadata[CONNECTION_RESOURCE_ID_KEY] = open_ai_resource_id
+
+ super().__init__(
+ azure_endpoint=azure_endpoint,
+ api_key=api_key,
+ type=camel_to_snake(ConnectionCategory.AZURE_OPEN_AI),
+ from_child=True,
+ metadata=metadata,
+ **kwargs,
+ )
+
+ @classmethod
+ def _get_required_metadata_fields(cls) -> List[str]:
+ return [CONNECTION_API_VERSION_KEY, CONNECTION_API_TYPE_KEY, CONNECTION_RESOURCE_ID_KEY]
+
+ @classmethod
+ def _get_schema_class(cls) -> Type:
+ return AzureOpenAIConnectionSchema
+
+ @property
+ def api_version(self) -> Optional[str]:
+ """The API version of the connection.
+
+ :return: The API version of the connection.
+ :rtype: Optional[str]
+ """
+ if self.metadata is not None and CONNECTION_API_VERSION_KEY in self.metadata:
+ res: str = self.metadata[CONNECTION_API_VERSION_KEY]
+ return res
+ return None
+
+ @api_version.setter
+ def api_version(self, value: str) -> None:
+ """Set the API version of the connection.
+
+ :param value: The new api version to set.
+ :type value: str
+ """
+ if not hasattr(self, "metadata") or self.metadata is None:
+ self.metadata = {}
+ self.metadata[CONNECTION_API_VERSION_KEY] = value
+
+ @property
+ def open_ai_resource_id(self) -> Optional[str]:
+ """The fully qualified ID of the Azure Open AI resource this connects to.
+
+ :return: The fully qualified ID of the Azure Open AI resource this connects to.
+ :rtype: Optional[str]
+ """
+ if self.metadata is not None and CONNECTION_RESOURCE_ID_KEY in self.metadata:
+ res: str = self.metadata[CONNECTION_RESOURCE_ID_KEY]
+ return res
+ return None
+
+ @open_ai_resource_id.setter
+ def open_ai_resource_id(self, value: Optional[str]) -> None:
+ """Set the fully qualified ID of the Azure Open AI resource to connect to.
+
+ :param value: The new resource id to set.
+ :type value: Optional[str]
+ """
+ if not hasattr(self, "metadata") or self.metadata is None:
+ self.metadata = {}
+ if value is None:
+ self.metadata.pop(CONNECTION_RESOURCE_ID_KEY, None)
+ return
+ self.metadata[CONNECTION_RESOURCE_ID_KEY] = value
+
+
+@experimental
+class AzureAIServicesConnection(ApiOrAadConnection):
+ """A Connection geared towards Azure AI services.
+
+ :param name: Name of the connection.
+ :type name: str
+ :param endpoint: The URL or ARM resource ID of the external resource.
+ :type endpoint: str
+ :param api_key: The api key to connect to the azure endpoint.
+ If unset, tries to use the user's Entra ID as credentials instead.
+ :type api_key: Optional[str]
+ :param ai_services_resource_id: The fully qualified ID of the Azure AI service resource to connect to.
+ :type ai_services_resource_id: str
+ :param metadata: Metadata dictionary.
+ :type metadata: Optional[dict[str,str]]
+ """
+
+ def __init__(
+ self,
+ *,
+ endpoint: str,
+ api_key: Optional[str] = None,
+ ai_services_resource_id: str,
+ metadata: Optional[Dict[Any, Any]] = None,
+ **kwargs: Any,
+ ):
+ kwargs.pop("type", None) # make sure we never somehow use wrong type
+ if metadata is None:
+ metadata = {}
+ metadata[CONNECTION_RESOURCE_ID_KEY] = ai_services_resource_id
+ super().__init__(
+ endpoint=endpoint,
+ api_key=api_key,
+ type=ConnectionTypes.AZURE_AI_SERVICES,
+ from_child=True,
+ metadata=metadata,
+ **kwargs,
+ )
+
+ @classmethod
+ def _get_schema_class(cls) -> Type:
+ return AzureAIServicesConnectionSchema
+
+ @classmethod
+ def _get_required_metadata_fields(cls) -> List[str]:
+ return [CONNECTION_RESOURCE_ID_KEY]
+
+ @property
+ def ai_services_resource_id(self) -> Optional[str]:
+ """The resource id of the ai service being connected to.
+
+ :return: The resource id of the ai service being connected to.
+ :rtype: Optional[str]
+ """
+ if self.metadata is not None and CONNECTION_RESOURCE_ID_KEY in self.metadata:
+ res: str = self.metadata[CONNECTION_RESOURCE_ID_KEY]
+ return res
+ return None
+
+ @ai_services_resource_id.setter
+ def ai_services_resource_id(self, value: str) -> None:
+ """Set the ai service resource id of the connection.
+
+ :param value: The new ai service resource id to set.
+ :type value: str
+ """
+ if not hasattr(self, "metadata") or self.metadata is None:
+ self.metadata = {}
+ self.metadata[CONNECTION_RESOURCE_ID_KEY] = value
+
+
+class AzureAISearchConnection(ApiOrAadConnection):
+ """A Connection that is specifically designed for handling connections to
+ Azure AI Search.
+
+ :param name: Name of the connection.
+ :type name: str
+ :param endpoint: The URL or ARM resource ID of the Azure AI Search Service
+ :type endpoint: str
+ :param api_key: The API key needed to connect to the Azure AI Search Service.
+ :type api_key: Optional[str]
+ :param metadata: Metadata dictionary.
+ :type metadata: Optional[dict[str,str]]
+ """
+
+ def __init__(
+ self,
+ *,
+ endpoint: str,
+ api_key: Optional[str] = None,
+ metadata: Optional[Dict[Any, Any]] = None,
+ **kwargs: Any,
+ ):
+ kwargs.pop("type", None) # make sure we never somehow use wrong type
+
+ super().__init__(
+ endpoint=endpoint,
+ api_key=api_key,
+ type=ConnectionTypes.AZURE_SEARCH,
+ from_child=True,
+ metadata=metadata,
+ **kwargs,
+ )
+
+ @classmethod
+ def _get_schema_class(cls) -> Type:
+ return AzureAISearchConnectionSchema
+
+
+class AzureContentSafetyConnection(ApiOrAadConnection):
+ """A Connection geared towards a Azure Content Safety service.
+
+ :param name: Name of the connection.
+ :type name: str
+ :param endpoint: The URL or ARM resource ID of the external resource.
+ :type endpoint: str
+ :param api_key: The api key to connect to the azure endpoint.
+ If unset, tries to use the user's Entra ID as credentials instead.
+ :type api_key: Optional[str]
+ :param metadata: Metadata dictionary.
+ :type metadata: Optional[dict[str,str]]
+ """
+
+ def __init__(
+ self,
+ *,
+ endpoint: str,
+ api_key: Optional[str] = None,
+ metadata: Optional[Dict[Any, Any]] = None,
+ **kwargs: Any,
+ ):
+ kwargs.pop("type", None) # make sure we never somehow use wrong type
+
+ if metadata is None:
+ metadata = {}
+ metadata[CONNECTION_KIND_KEY] = CognitiveServiceKinds.CONTENT_SAFETY
+
+ super().__init__(
+ endpoint=endpoint,
+ api_key=api_key,
+ type=ConnectionTypes.AZURE_CONTENT_SAFETY,
+ from_child=True,
+ metadata=metadata,
+ **kwargs,
+ )
+
+ @classmethod
+ def _get_schema_class(cls) -> Type:
+ return AzureContentSafetyConnectionSchema
+
+
+class AzureSpeechServicesConnection(ApiOrAadConnection):
+ """A Connection geared towards an Azure Speech service.
+
+ :param name: Name of the connection.
+ :type name: str
+ :param endpoint: The URL or ARM resource ID of the external resource.
+ :type endpoint: str
+ :param api_key: The api key to connect to the azure endpoint.
+ If unset, tries to use the user's Entra ID as credentials instead.
+ :type api_key: Optional[str]
+ :param metadata: Metadata dictionary.
+ :type metadata: Optional[dict[str,str]]
+ """
+
+ # kinds AzureOpenAI", "ContentSafety", and "Speech"
+
+ def __init__(
+ self,
+ *,
+ endpoint: str,
+ api_key: Optional[str] = None,
+ metadata: Optional[Dict[Any, Any]] = None,
+ **kwargs: Any,
+ ):
+ kwargs.pop("type", None) # make sure we never somehow use wrong type
+
+ if metadata is None:
+ metadata = {}
+ metadata[CONNECTION_KIND_KEY] = CognitiveServiceKinds.SPEECH
+ super().__init__(
+ endpoint=endpoint,
+ api_key=api_key,
+ type=ConnectionTypes.AZURE_SPEECH_SERVICES,
+ from_child=True,
+ metadata=metadata,
+ **kwargs,
+ )
+
+ @classmethod
+ def _get_schema_class(cls) -> Type:
+ return AzureSpeechServicesConnectionSchema
+
+
+@experimental
+class APIKeyConnection(ApiOrAadConnection):
+ """A generic connection for any API key-based service.
+
+ :param name: Name of the connection.
+ :type name: str
+ :param api_base: The URL to target with this connection.
+ :type api_base: str
+ :param api_key: The API key needed to connect to the api_base.
+ :type api_key: Optional[str]
+ :param metadata: Metadata dictionary.
+ :type metadata: Optional[dict[str,str]]
+ """
+
+ def __init__(
+ self,
+ *,
+ api_base: str,
+ api_key: Optional[str] = None,
+ metadata: Optional[Dict[Any, Any]] = None,
+ **kwargs,
+ ):
+ kwargs.pop("type", None) # make sure we never somehow use wrong type
+ super().__init__(
+ api_base=api_base,
+ api_key=api_key,
+ type=camel_to_snake(ConnectionCategory.API_KEY),
+ allow_entra=False,
+ from_child=True,
+ metadata=metadata,
+ **kwargs,
+ )
+
+ @classmethod
+ def _get_schema_class(cls) -> Type:
+ return APIKeyConnectionSchema
+
+
+@experimental
+class OpenAIConnection(ApiOrAadConnection):
+ """A connection geared towards direct connections to Open AI.
+ Not to be confused with the AzureOpenAIWorkspaceConnection, which is for Azure's Open AI services.
+
+ :param name: Name of the connection.
+ :type name: str
+ :param api_key: The API key needed to connect to the Open AI.
+ :type api_key: Optional[str]
+ :param metadata: Metadata dictionary.
+ :type metadata: Optional[dict[str,str]]
+ """
+
+ def __init__(
+ self,
+ *,
+ api_key: Optional[str] = None,
+ metadata: Optional[Dict[Any, Any]] = None,
+ **kwargs,
+ ):
+ kwargs.pop("type", None) # make sure we never somehow use wrong type
+ super().__init__(
+ type=ConnectionCategory.Open_AI,
+ api_key=api_key,
+ allow_entra=False,
+ from_child=True,
+ metadata=metadata,
+ **kwargs,
+ )
+
+ @classmethod
+ def _get_schema_class(cls) -> Type:
+ return OpenAIConnectionSchema
+
+
+@experimental
+class SerpConnection(ApiOrAadConnection):
+ """A connection geared towards a Serp service (Open source search API Service)
+
+ :param name: Name of the connection.
+ :type name: str
+ :param api_key: The API key needed to connect to the Open AI.
+ :type api_key: Optional[str]
+ :param metadata: Metadata dictionary.
+ :type metadata: Optional[dict[str,str]]
+ """
+
+ def __init__(
+ self,
+ *,
+ api_key: Optional[str] = None,
+ metadata: Optional[Dict[Any, Any]] = None,
+ **kwargs,
+ ):
+ kwargs.pop("type", None) # make sure we never somehow use wrong type
+ super().__init__(
+ type=ConnectionCategory.SERP,
+ api_key=api_key,
+ allow_entra=False,
+ from_child=True,
+ metadata=metadata,
+ **kwargs,
+ )
+
+ @classmethod
+ def _get_schema_class(cls) -> Type:
+ return SerpConnectionSchema
+
+
+@experimental
+class ServerlessConnection(ApiOrAadConnection):
+ """A connection geared towards a MaaS endpoint (Serverless).
+
+ :param name: Name of the connection.
+ :type name: str
+ :param endpoint: The serverless endpoint.
+ :type endpoint: str
+ :param api_key: The API key needed to connect to the endpoint.
+ :type api_key: Optional[str]
+ :param metadata: Metadata dictionary.
+ :type metadata: Optional[dict[str,str]]
+ """
+
+ def __init__(
+ self,
+ *,
+ endpoint: str,
+ api_key: Optional[str] = None,
+ metadata: Optional[Dict[Any, Any]] = None,
+ **kwargs,
+ ):
+ kwargs.pop("type", None) # make sure we never somehow use wrong type
+ super().__init__(
+ type=ConnectionCategory.SERVERLESS,
+ endpoint=endpoint,
+ api_key=api_key,
+ allow_entra=False,
+ from_child=True,
+ metadata=metadata,
+ **kwargs,
+ )
+
+ @classmethod
+ def _get_schema_class(cls) -> Type:
+ return ServerlessConnectionSchema
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/connections/one_lake_artifacts.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/connections/one_lake_artifacts.py
new file mode 100644
index 00000000..ea81602f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/connections/one_lake_artifacts.py
@@ -0,0 +1,25 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Any
+from azure.ai.ml._utils._experimental import experimental
+
+# Dev note: Supposedly there's going to be more artifact subclasses at some point.
+# If/when that comes to pass, we can worry about adding polymorphism to these classes.
+# For now, this is a one-off that's needed to help match the object structure that PF uses.
+
+
+# Why is this not called a "LakeHouseArtifact"? Because despite the under-the-hood type,
+# users expect this variety to be called "OneLake".
+@experimental
+class OneLakeConnectionArtifact:
+ """Artifact class used by the Connection subclass known
+ as a MicrosoftOneLakeConnection. Supplying this class further
+ specifies the connection as a Lake House connection.
+ """
+
+ # Note: Kwargs exist just to silently absorb type from schema.
+ def __init__(self, *, name: str, **kwargs: Any): # pylint: disable=unused-argument
+ self.name = name
+ self.type = "lake_house"
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/connections/workspace_connection.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/connections/workspace_connection.py
new file mode 100644
index 00000000..ab1ee9f8
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/connections/workspace_connection.py
@@ -0,0 +1,677 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import warnings
+from os import PathLike
+from pathlib import Path
+from typing import IO, Any, AnyStr, Dict, List, Optional, Type, Union, cast
+
+
+from azure.ai.ml._restclient.v2024_04_01_preview.models import (
+ WorkspaceConnectionPropertiesV2BasicResource as RestWorkspaceConnection,
+)
+from azure.ai.ml._restclient.v2024_04_01_preview.models import (
+ ConnectionCategory,
+ NoneAuthTypeWorkspaceConnectionProperties,
+ AADAuthTypeWorkspaceConnectionProperties,
+)
+
+from azure.ai.ml._schema.workspace.connections.workspace_connection import WorkspaceConnectionSchema
+from azure.ai.ml._utils.utils import _snake_to_camel, camel_to_snake, dump_yaml_to_file
+from azure.ai.ml.constants._common import (
+ BASE_PATH_CONTEXT_KEY,
+ PARAMS_OVERRIDE_KEY,
+ ConnectionTypes,
+ CognitiveServiceKinds,
+ CONNECTION_KIND_KEY,
+ CONNECTION_RESOURCE_ID_KEY,
+)
+from azure.ai.ml.entities._credentials import (
+ AccessKeyConfiguration,
+ ApiKeyConfiguration,
+ ManagedIdentityConfiguration,
+ NoneCredentialConfiguration,
+ PatTokenConfiguration,
+ SasTokenConfiguration,
+ ServicePrincipalConfiguration,
+ UsernamePasswordConfiguration,
+ _BaseIdentityConfiguration,
+ AccountKeyConfiguration,
+ AadCredentialConfiguration,
+)
+from azure.ai.ml.entities._resource import Resource
+from azure.ai.ml.entities._system_data import SystemData
+from azure.ai.ml.entities._util import load_from_dict
+
+
+CONNECTION_CATEGORY_TO_CREDENTIAL_MAP = {
+ ConnectionCategory.AZURE_BLOB: [AccessKeyConfiguration, SasTokenConfiguration, AadCredentialConfiguration],
+ ConnectionTypes.AZURE_DATA_LAKE_GEN_2: [
+ ServicePrincipalConfiguration,
+ AadCredentialConfiguration,
+ ManagedIdentityConfiguration,
+ ],
+ ConnectionCategory.GIT: [PatTokenConfiguration, NoneCredentialConfiguration, UsernamePasswordConfiguration],
+ ConnectionCategory.PYTHON_FEED: [UsernamePasswordConfiguration, PatTokenConfiguration, NoneCredentialConfiguration],
+ ConnectionCategory.CONTAINER_REGISTRY: [ManagedIdentityConfiguration, UsernamePasswordConfiguration],
+}
+
+DATASTORE_CONNECTIONS = {
+ ConnectionCategory.AZURE_BLOB,
+ ConnectionTypes.AZURE_DATA_LAKE_GEN_2,
+ ConnectionCategory.AZURE_ONE_LAKE,
+}
+
+CONNECTION_ALTERNATE_TARGET_NAMES = ["target", "api_base", "url", "azure_endpoint", "endpoint"]
+
+
+# Dev note: The acceptable strings for the type field are all snake_cased versions of the string constants defined
+# In the rest client enum defined at _azure_machine_learning_services_enums.ConnectionCategory.
+# We avoid directly referencing it in the docs to avoid restclient references.
+class WorkspaceConnection(Resource):
+ """Azure ML connection provides a secure way to store authentication and configuration information needed
+ to connect and interact with the external resources.
+
+ Note: For connections to OpenAI, Cognitive Search, and Cognitive Services, use the respective subclasses
+ (ex: ~azure.ai.ml.entities.OpenAIConnection) instead of instantiating this class directly.
+
+ :param name: Name of the connection.
+ :type name: str
+ :param target: The URL or ARM resource ID of the external resource.
+ :type target: str
+ :param metadata: Metadata dictionary.
+ :type metadata: Optional[Dict[str, Any]]
+ :param type: The category of external resource for this connection.
+ :type type: The type of connection, possible values are: "git", "python_feed", "container_registry",
+ "feature_store", "s3", "snowflake", "azure_sql_db", "azure_synapse_analytics", "azure_my_sql_db",
+ "azure_postgres_db", "adls_gen_2", "azure_one_lake", "custom".
+ :param credentials: The credentials for authenticating to the external resource. Note that certain connection
+ types (as defined by the type input) only accept certain types of credentials.
+ :type credentials: Union[
+ ~azure.ai.ml.entities.PatTokenConfiguration,
+ ~azure.ai.ml.entities.SasTokenConfiguration,
+ ~azure.ai.ml.entities.UsernamePasswordConfiguration,
+ ~azure.ai.ml.entities.ManagedIdentityConfiguration
+ ~azure.ai.ml.entities.ServicePrincipalConfiguration,
+ ~azure.ai.ml.entities.AccessKeyConfiguration,
+ ~azure.ai.ml.entities.ApiKeyConfiguration,
+ ~azure.ai.ml.entities.NoneCredentialConfiguration
+ ~azure.ai.ml.entities.AccountKeyConfiguration,
+ ~azure.ai.ml.entities.AadCredentialConfiguration,
+ None
+ ]
+ :param is_shared: For connections in project, this controls whether or not this connection
+ is shared amongst other projects that are shared by the parent hub. Defaults to true.
+ :type is_shared: bool
+ """
+
+ def __init__(
+ self,
+ *,
+ # TODO : Check if this is okay since it shadows builtin-type type
+ type: str, # pylint: disable=redefined-builtin
+ credentials: Union[
+ PatTokenConfiguration,
+ SasTokenConfiguration,
+ UsernamePasswordConfiguration,
+ ManagedIdentityConfiguration,
+ ServicePrincipalConfiguration,
+ AccessKeyConfiguration,
+ ApiKeyConfiguration,
+ NoneCredentialConfiguration,
+ AccountKeyConfiguration,
+ AadCredentialConfiguration,
+ ],
+ is_shared: bool = True,
+ metadata: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+ ):
+
+ # Dev note: This initializer has an undocumented kwarg "from_child" to determine if this initialization
+ # is from a child class.
+ # This kwarg is required to allow instantiation of types that are associated with subtypes without a
+ # warning printout.
+ # The additional undocumented kwarg "strict_typing" turns the warning into a value error.
+ from_child = kwargs.pop("from_child", False)
+ strict_typing = kwargs.pop("strict_typing", False)
+ correct_class = WorkspaceConnection._get_entity_class_from_type(type)
+ if not from_child and correct_class != WorkspaceConnection:
+ if strict_typing:
+ raise ValueError(
+ f"Cannot instantiate a base Connection with a type of {type}. "
+ f"Please use the appropriate subclass {correct_class.__name__} instead."
+ )
+ warnings.warn(
+ f"The connection of {type} has additional fields and should not be instantiated directly "
+ f"from the Connection class. Please use its subclass {correct_class.__name__} instead.",
+ )
+ # This disgusting code allows for a variety of inputs names to technically all
+ # act like the target field, while still maintaining the aggregate field as required.
+ target = None
+ for target_name in CONNECTION_ALTERNATE_TARGET_NAMES:
+ target = kwargs.pop(target_name, target)
+ if target is None and type not in {ConnectionCategory.SERP, ConnectionCategory.Open_AI}:
+ raise ValueError("target is a required field for Connection.")
+
+ tags = kwargs.pop("tags", None)
+ if tags is not None:
+ if metadata is not None:
+ # Update tags updated with metadata to make sure metadata values are preserved in case of conflicts.
+ tags.update(metadata)
+ metadata = tags
+ warnings.warn(
+ "Tags are a deprecated field for connections, use metadata instead. Since both "
+ + "metadata and tags are assigned, metadata values will take precedence in the event of conflicts."
+ )
+ else:
+ metadata = tags
+ warnings.warn("Tags are a deprecated field for connections, use metadata instead.")
+
+ super().__init__(**kwargs)
+
+ self.type = type
+ self._target = target
+ self._credentials = credentials
+ self._is_shared = is_shared
+ self._metadata = metadata
+ self._validate_cred_for_conn_cat()
+
+ def _validate_cred_for_conn_cat(self) -> None:
+ """Given a connection type, ensure that the given credentials are valid for that connection type.
+ Does not validate the actual data of the inputted credential, just that they are of the right class
+ type.
+
+ """
+ # Convert none credentials to AAD credentials for datastore connection types.
+ # The backend stores datastore aad creds as none, unlike other connection types with aad,
+ # which actually list them as aad. This IS distinct from regular none credentials, or so I've been told,
+ # so I will endeavor to smooth over that inconsistency here.
+ converted_type = _snake_to_camel(self.type).lower()
+ if self._credentials == NoneCredentialConfiguration() and any(
+ converted_type == _snake_to_camel(item).lower() for item in DATASTORE_CONNECTIONS
+ ):
+ self._credentials = AadCredentialConfiguration()
+
+ if self.type in CONNECTION_CATEGORY_TO_CREDENTIAL_MAP:
+ allowed_credentials = CONNECTION_CATEGORY_TO_CREDENTIAL_MAP[self.type]
+ if self.credentials is None and NoneCredentialConfiguration not in allowed_credentials:
+ raise ValueError(
+ f"Cannot instantiate a Connection with a type of {self.type} and no credentials."
+ f"Please supply credentials from one of the following types: {allowed_credentials}."
+ )
+ cred_type = type(self.credentials)
+ if cred_type not in allowed_credentials:
+ raise ValueError(
+ f"Cannot instantiate a Connection with a type of {self.type} and credentials of type"
+ f" {cred_type}. Please supply credentials from one of the following types: {allowed_credentials}."
+ )
+ # For unknown types, just let the user do whatever they want.
+
+ @property
+ def type(self) -> Optional[str]:
+ """Type of the connection, supported are 'git', 'python_feed' and 'container_registry'.
+
+ :return: Type of the job.
+ :rtype: str
+ """
+ return self._type
+
+ @type.setter
+ def type(self, value: str) -> None:
+ """Set the type of the connection, supported are 'git', 'python_feed' and 'container_registry'.
+
+ :param value: value for the type of connection.
+ :type: str
+ """
+ if not value:
+ return
+ self._type: Optional[str] = camel_to_snake(value)
+
+ @property
+ def target(self) -> Optional[str]:
+ """Target url for the connection.
+
+ :return: Target of the connection.
+ :rtype: Optional[str]
+ """
+ return self._target
+
+ @property
+ def endpoint(self) -> Optional[str]:
+ """Alternate name for the target of the connection,
+ which is used by some connection subclasses.
+
+ :return: The target of the connection.
+ :rtype: str
+ """
+ return self.target
+
+ @property
+ def azure_endpoint(self) -> Optional[str]:
+ """Alternate name for the target of the connection,
+ which is used by some connection subclasses.
+
+ :return: The target of the connection.
+ :rtype: str
+ """
+ return self.target
+
+ @property
+ def url(self) -> Optional[str]:
+ """Alternate name for the target of the connection,
+ which is used by some connection subclasses.
+
+ :return: The target of the connection.
+ :rtype: str
+ """
+ return self.target
+
+ @property
+ def api_base(self) -> Optional[str]:
+ """Alternate name for the target of the connection,
+ which is used by some connection subclasses.
+
+ :return: The target of the connection.
+ :rtype: str
+ """
+ return self.target
+
+ @property
+ def credentials(
+ self,
+ ) -> Union[
+ PatTokenConfiguration,
+ SasTokenConfiguration,
+ UsernamePasswordConfiguration,
+ ManagedIdentityConfiguration,
+ ServicePrincipalConfiguration,
+ AccessKeyConfiguration,
+ ApiKeyConfiguration,
+ NoneCredentialConfiguration,
+ AccountKeyConfiguration,
+ AadCredentialConfiguration,
+ ]:
+ """Credentials for connection.
+
+ :return: Credentials for connection.
+ :rtype: Union[
+ ~azure.ai.ml.entities.PatTokenConfiguration,
+ ~azure.ai.ml.entities.SasTokenConfiguration,
+ ~azure.ai.ml.entities.UsernamePasswordConfiguration,
+ ~azure.ai.ml.entities.ManagedIdentityConfiguration
+ ~azure.ai.ml.entities.ServicePrincipalConfiguration,
+ ~azure.ai.ml.entities.AccessKeyConfiguration,
+ ~azure.ai.ml.entities.ApiKeyConfiguration
+ ~azure.ai.ml.entities.NoneCredentialConfiguration,
+ ~azure.ai.ml.entities.AccountKeyConfiguration,
+ ~azure.ai.ml.entities.AadCredentialConfiguration,
+ ]
+ """
+ return self._credentials
+
+ @property
+ def metadata(self) -> Optional[Dict[str, Any]]:
+ """The connection's metadata dictionary.
+ :return: This connection's metadata.
+ :rtype: Optional[Dict[str, Any]]
+ """
+ return self._metadata if self._metadata is not None else {}
+
+ @metadata.setter
+ def metadata(self, value: Optional[Dict[str, Any]]) -> None:
+ """Set the metadata for the connection. Be warned that setting this will override
+ ALL metadata values, including those implicitly set by certain connection types to manage their
+ extra data. Usually, you should probably access the metadata dictionary, then add or remove values
+ individually as needed.
+ :param value: The new metadata for connection.
+ This completely overwrites the existing metadata dictionary.
+ :type value: Optional[Dict[str, Any]]
+ """
+ if not value:
+ return
+ self._metadata = value
+
+ @property
+ def tags(self) -> Optional[Dict[str, Any]]:
+ """Deprecated. Use metadata instead.
+ :return: This connection's metadata.
+ :rtype: Optional[Dict[str, Any]]
+ """
+ return self._metadata if self._metadata is not None else {}
+
+ @tags.setter
+ def tags(self, value: Optional[Dict[str, Any]]) -> None:
+ """Deprecated use metadata instead
+ :param value: The new metadata for connection.
+ This completely overwrites the existing metadata dictionary.
+ :type value: Optional[Dict[str, Any]]
+ """
+ if not value:
+ return
+ self._metadata = value
+
+ @property
+ def is_shared(self) -> bool:
+ """Get the Boolean describing if this connection is shared amongst its cohort within a hub.
+ Only applicable for connections created within a project.
+
+ :rtype: bool
+ """
+ return self._is_shared
+
+ @is_shared.setter
+ def is_shared(self, value: bool) -> None:
+ """Assign the is_shared property of the connection, determining if it is shared amongst other projects
+ within its parent hub. Only applicable for connections created within a project.
+
+ :param value: The new is_shared value.
+ :type value: bool
+ """
+ if not value:
+ return
+ self._is_shared = value
+
+ def dump(self, dest: Union[str, PathLike, IO[AnyStr]], **kwargs: Any) -> None:
+ """Dump the connection spec into a file in yaml format.
+
+ :param dest: The destination to receive this connection's spec.
+ Must be either a path to a local file, or an already-open file stream.
+ If dest is a file path, a new file will be created,
+ and an exception is raised if the file exists.
+ If dest is an open file, the file will be written to directly,
+ and an exception will be raised if the file is not writable.
+ :type dest: Union[PathLike, str, IO[AnyStr]]
+ """
+ path = kwargs.pop("path", None)
+ yaml_serialized = self._to_dict()
+ dump_yaml_to_file(dest, yaml_serialized, default_flow_style=False, path=path, **kwargs)
+
+ @classmethod
+ def _load(
+ cls,
+ data: Optional[Dict] = None,
+ yaml_path: Optional[Union[PathLike, str]] = None,
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "WorkspaceConnection":
+ data = data or {}
+ params_override = params_override or []
+ context = {
+ BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"),
+ PARAMS_OVERRIDE_KEY: params_override,
+ }
+ return cls._load_from_dict(data=data, context=context, **kwargs)
+
+ @classmethod
+ def _load_from_dict(cls, data: Dict, context: Dict, **kwargs: Any) -> "WorkspaceConnection":
+ conn_type = data["type"] if "type" in data else None
+ schema_class = cls._get_schema_class_from_type(conn_type)
+ loaded_data: WorkspaceConnection = load_from_dict(schema_class, data, context, **kwargs)
+ return loaded_data
+
+ def _to_dict(self) -> Dict:
+ schema_class = WorkspaceConnection._get_schema_class_from_type(self.type)
+ # Not sure what this pylint complaint was about, probably due to the polymorphic
+ # tricks at play. Disabling since testing indicates no issue.
+ res: dict = schema_class(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
+
+ @classmethod
+ def _from_rest_object(cls, rest_obj: RestWorkspaceConnection) -> "WorkspaceConnection":
+ conn_class = cls._get_entity_class_from_rest_obj(rest_obj)
+
+ popped_metadata = conn_class._get_required_metadata_fields()
+
+ rest_kwargs = cls._extract_kwargs_from_rest_obj(rest_obj=rest_obj, popped_metadata=popped_metadata)
+ # Check for alternative name for custom connection type (added for client clarity).
+ if rest_kwargs["type"].lower() == camel_to_snake(ConnectionCategory.CUSTOM_KEYS).lower():
+ rest_kwargs["type"] = ConnectionTypes.CUSTOM
+ if rest_kwargs["type"].lower() == camel_to_snake(ConnectionCategory.ADLS_GEN2).lower():
+ rest_kwargs["type"] = ConnectionTypes.AZURE_DATA_LAKE_GEN_2
+ target = rest_kwargs.get("target", "")
+ # This dumb code accomplishes 2 things.
+ # It ensures that sub-classes properly input their target, regardless of which
+ # arbitrary name they replace it with, while also still allowing our official
+ # client specs to list those inputs as 'required'
+ for target_name in CONNECTION_ALTERNATE_TARGET_NAMES:
+ rest_kwargs[target_name] = target
+ if rest_obj.properties.category == ConnectionCategory.AZURE_ONE_LAKE:
+ # The microsoft one lake connection uniquely has client-only inputs
+ # that aren't just an alternate name for the target.
+ # This sets those inputs, that way the initializer can still
+ # required those fields for users.
+ rest_kwargs["artifact"] = ""
+ rest_kwargs["one_lake_workspace_name"] = ""
+ if rest_obj.properties.category == ConnectionTypes.AI_SERVICES_REST_PLACEHOLDER:
+ # AI Services renames it's metadata field when surfaced to users and inputted
+ # into it's initializer for clarity. ResourceId doesn't really tell much on its own.
+ # No default in pop, this should fail if we somehow don't get a resource ID
+ rest_kwargs["ai_services_resource_id"] = rest_kwargs.pop(camel_to_snake(CONNECTION_RESOURCE_ID_KEY))
+ connection = conn_class(**rest_kwargs)
+ return cast(WorkspaceConnection, connection)
+
+ def _validate(self) -> str:
+ return str(self.name)
+
+ def _to_rest_object(self) -> RestWorkspaceConnection:
+ connection_properties_class: Any = NoneAuthTypeWorkspaceConnectionProperties
+ if self._credentials:
+ connection_properties_class = self._credentials._get_rest_properties_class()
+ # Convert from human readable type to corresponding api enum if needed.
+ conn_type = self.type
+ if conn_type == ConnectionTypes.CUSTOM:
+ conn_type = ConnectionCategory.CUSTOM_KEYS
+ elif conn_type == ConnectionTypes.AZURE_DATA_LAKE_GEN_2:
+ conn_type = ConnectionCategory.ADLS_GEN2
+ elif conn_type in {
+ ConnectionTypes.AZURE_CONTENT_SAFETY,
+ ConnectionTypes.AZURE_SPEECH_SERVICES,
+ }:
+ conn_type = ConnectionCategory.COGNITIVE_SERVICE
+ elif conn_type == ConnectionTypes.AZURE_SEARCH:
+ conn_type = ConnectionCategory.COGNITIVE_SEARCH
+ elif conn_type == ConnectionTypes.AZURE_AI_SERVICES:
+ # ConnectionCategory.AI_SERVICES category accidentally unpublished
+ conn_type = ConnectionTypes.AI_SERVICES_REST_PLACEHOLDER
+ # Some credential property bags have no credential input.
+ if connection_properties_class in {
+ NoneAuthTypeWorkspaceConnectionProperties,
+ AADAuthTypeWorkspaceConnectionProperties,
+ }:
+ properties = connection_properties_class(
+ target=self.target,
+ metadata=self.metadata,
+ category=_snake_to_camel(conn_type),
+ is_shared_to_all=self.is_shared,
+ )
+ else:
+ properties = connection_properties_class(
+ target=self.target,
+ credentials=self.credentials._to_workspace_connection_rest_object() if self._credentials else None,
+ metadata=self.metadata,
+ category=_snake_to_camel(conn_type),
+ is_shared_to_all=self.is_shared,
+ )
+
+ return RestWorkspaceConnection(properties=properties)
+
+ @classmethod
+ def _extract_kwargs_from_rest_obj(
+ cls, rest_obj: RestWorkspaceConnection, popped_metadata: List[str]
+ ) -> Dict[str, str]:
+ """Internal helper function with extracts all the fields needed to initialize a connection object
+ from its associated restful object. Pulls extra fields based on the supplied `popped_metadata` input.
+ Returns all the fields as a dictionary, which is expected to then be supplied to a
+ connection initializer as kwargs.
+
+ :param rest_obj: The rest object representation of a connection
+ :type rest_obj: RestWorkspaceConnection
+ :param popped_metadata: Key names that should be pulled from the rest object's metadata and
+ injected as top-level fields into the client connection's initializer.
+ This is needed for subclasses that require extra inputs compared to the base Connection class.
+ :type popped_metadata: List[str]
+
+ :return: A dictionary containing all kwargs needed to construct a connection.
+ :rtype: Dict[str, str]
+ """
+ properties = rest_obj.properties
+ credentials: Any = NoneCredentialConfiguration()
+
+ credentials_class = _BaseIdentityConfiguration._get_credential_class_from_rest_type(properties.auth_type)
+ # None and AAD auth types have a property bag class, but no credentials inside that.
+ # Thankfully they both have no inputs.
+
+ if credentials_class is AadCredentialConfiguration:
+ credentials = AadCredentialConfiguration()
+ elif credentials_class is not NoneCredentialConfiguration:
+ credentials = credentials_class._from_workspace_connection_rest_object(properties.credentials)
+
+ metadata = properties.metadata if hasattr(properties, "metadata") else {}
+ rest_kwargs = {
+ "id": rest_obj.id,
+ "name": rest_obj.name,
+ "target": properties.target,
+ "creation_context": SystemData._from_rest_object(rest_obj.system_data) if rest_obj.system_data else None,
+ "type": camel_to_snake(properties.category),
+ "credentials": credentials,
+ "metadata": metadata,
+ "is_shared": properties.is_shared_to_all if hasattr(properties, "is_shared_to_all") else True,
+ }
+
+ for name in popped_metadata:
+ if name in metadata:
+ rest_kwargs[camel_to_snake(name)] = metadata[name]
+ return rest_kwargs
+
+ @classmethod
+ def _get_entity_class_from_type(cls, type: str) -> Type:
+ """Helper function that derives the correct connection class given the client or server type.
+ Differs slightly from the rest object version in that it doesn't need to account for
+ rest object metadata.
+
+ This reason there are two functions at all is due to certain API connection types that
+ are obfuscated with different names when presented to the client. These types are
+ accounted for in the ConnectionTypes class in the constants file.
+
+ :param type: The type string describing the connection.
+ :type type: str
+
+ :return: Theconnection class the conn_type corresponds to.
+ :rtype: Type
+ """
+ from .connection_subtypes import (
+ AzureBlobStoreConnection,
+ MicrosoftOneLakeConnection,
+ AzureOpenAIConnection,
+ AzureAIServicesConnection,
+ AzureAISearchConnection,
+ AzureContentSafetyConnection,
+ AzureSpeechServicesConnection,
+ APIKeyConnection,
+ OpenAIConnection,
+ SerpConnection,
+ ServerlessConnection,
+ )
+
+ conn_type = _snake_to_camel(type).lower()
+ if conn_type is None:
+ return WorkspaceConnection
+
+ # Connection categories don't perfectly follow perfect camel casing, so lower
+ # case everything to avoid problems.
+ CONNECTION_CATEGORY_TO_SUBCLASS_MAP = {
+ ConnectionCategory.AZURE_OPEN_AI.lower(): AzureOpenAIConnection,
+ ConnectionCategory.AZURE_BLOB.lower(): AzureBlobStoreConnection,
+ ConnectionCategory.AZURE_ONE_LAKE.lower(): MicrosoftOneLakeConnection,
+ ConnectionCategory.API_KEY.lower(): APIKeyConnection,
+ ConnectionCategory.OPEN_AI.lower(): OpenAIConnection,
+ ConnectionCategory.SERP.lower(): SerpConnection,
+ ConnectionCategory.SERVERLESS.lower(): ServerlessConnection,
+ _snake_to_camel(ConnectionTypes.AZURE_CONTENT_SAFETY).lower(): AzureContentSafetyConnection,
+ _snake_to_camel(ConnectionTypes.AZURE_SPEECH_SERVICES).lower(): AzureSpeechServicesConnection,
+ ConnectionCategory.COGNITIVE_SEARCH.lower(): AzureAISearchConnection,
+ _snake_to_camel(ConnectionTypes.AZURE_SEARCH).lower(): AzureAISearchConnection,
+ _snake_to_camel(ConnectionTypes.AZURE_AI_SERVICES).lower(): AzureAIServicesConnection,
+ ConnectionTypes.AI_SERVICES_REST_PLACEHOLDER.lower(): AzureAIServicesConnection,
+ }
+ return CONNECTION_CATEGORY_TO_SUBCLASS_MAP.get(conn_type, WorkspaceConnection)
+
+ @classmethod
+ def _get_entity_class_from_rest_obj(cls, rest_obj: RestWorkspaceConnection) -> Type:
+ """Helper function that converts a restful connection into the associated
+ connection class or subclass. Accounts for potential snake/camel case and
+ capitalization differences in the type, and sub-typing derived from metadata.
+
+ :param rest_obj: The rest object representation of the connection to derive a class from.
+ :type rest_obj: RestWorkspaceConnection
+
+ :return: The connection class the conn_type corresponds to.
+ :rtype: Type
+ """
+ conn_type = rest_obj.properties.category
+ conn_type = _snake_to_camel(conn_type).lower()
+ if conn_type is None:
+ return WorkspaceConnection
+
+ # Imports are done here to avoid circular imports on load.
+ from .connection_subtypes import (
+ AzureContentSafetyConnection,
+ AzureSpeechServicesConnection,
+ )
+
+ # Cognitive search connections have further subdivisions based on the kind of service.
+ if (
+ conn_type == ConnectionCategory.COGNITIVE_SERVICE.lower()
+ and hasattr(rest_obj.properties, "metadata")
+ and rest_obj.properties.metadata is not None
+ ):
+ kind = rest_obj.properties.metadata.get(CONNECTION_KIND_KEY, "").lower()
+ if kind == CognitiveServiceKinds.CONTENT_SAFETY.lower():
+ return AzureContentSafetyConnection
+ if kind == CognitiveServiceKinds.SPEECH.lower():
+ return AzureSpeechServicesConnection
+ return WorkspaceConnection
+
+ return cls._get_entity_class_from_type(type=conn_type)
+
+ @classmethod
+ def _get_schema_class_from_type(cls, conn_type: Optional[str]) -> Type:
+ """Helper function that converts a rest client connection category into the associated
+ connection schema class or subclass. Accounts for potential snake/camel case and
+ capitalization differences.
+
+ :param conn_type: The connection type.
+ :type conn_type: str
+
+ :return: The connection schema class the conn_type corresponds to.
+ :rtype: Type
+ """
+ if conn_type is None:
+ return WorkspaceConnectionSchema
+ entity_class = cls._get_entity_class_from_type(conn_type)
+ return entity_class._get_schema_class()
+
+ @classmethod
+ def _get_required_metadata_fields(cls) -> List[str]:
+ """Helper function that returns the required metadata fields for specific
+ connection type. This parent function returns nothing, but needs to be overwritten by child
+ classes, which are created under the expectation that they have extra fields that need to be
+ accounted for.
+
+ :return: A list of the required metadata fields for the specific connection type.
+ :rtype: List[str]
+ """
+ return []
+
+ @classmethod
+ def _get_schema_class(cls) -> Type:
+ """Helper function that maps this class to its associated schema class. Needs to be overridden by
+ child classes to allow the base class to be polymorphic in its schema reading.
+
+ :return: The appropriate schema class to use with this entity class.
+ :rtype: Type
+ """
+ return WorkspaceConnectionSchema
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/customer_managed_key.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/customer_managed_key.py
new file mode 100644
index 00000000..88474dab
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/customer_managed_key.py
@@ -0,0 +1,48 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+
+from typing import Optional
+
+
+class CustomerManagedKey:
+ """Key vault details for encrypting data with customer-managed keys.
+
+ :param key_vault: Key vault that is holding the customer-managed key.
+ :type key_vault: str
+ :param key_uri: URI for the customer-managed key.
+ :type key_uri: str
+ :param cosmosdb_id: ARM id of bring-your-own cosmosdb account that customer brings
+ to store customer's data with encryption.
+ :type cosmosdb_id: str
+ :param storage_id: ARM id of bring-your-own storage account that customer brings
+ to store customer's data with encryption.
+ :type storage_id: str
+ :param search_id: ARM id of bring-your-own search account that customer brings
+ to store customer's data with encryption.
+ :type search_id: str
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_workspace.py
+ :start-after: [START customermanagedkey]
+ :end-before: [END customermanagedkey]
+ :language: python
+ :dedent: 8
+ :caption: Creating a CustomerManagedKey object.
+ """
+
+ def __init__(
+ self,
+ key_vault: Optional[str] = None,
+ key_uri: Optional[str] = None,
+ cosmosdb_id: Optional[str] = None,
+ storage_id: Optional[str] = None,
+ search_id: Optional[str] = None,
+ ):
+ self.key_vault = key_vault
+ self.key_uri = key_uri
+ self.cosmosdb_id = cosmosdb_id or ""
+ self.storage_id = storage_id or ""
+ self.search_id = search_id or ""
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/diagnose.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/diagnose.py
new file mode 100644
index 00000000..fa923dc4
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/diagnose.py
@@ -0,0 +1,214 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import json
+from typing import Any, Dict, List, Optional
+
+from azure.ai.ml._restclient.v2024_10_01_preview.models import (
+ DiagnoseRequestProperties as RestDiagnoseRequestProperties,
+)
+from azure.ai.ml._restclient.v2024_10_01_preview.models import DiagnoseResponseResult as RestDiagnoseResponseResult
+from azure.ai.ml._restclient.v2024_10_01_preview.models import (
+ DiagnoseResponseResultValue as RestDiagnoseResponseResultValue,
+)
+from azure.ai.ml._restclient.v2024_10_01_preview.models import DiagnoseResult as RestDiagnoseResult
+from azure.ai.ml._restclient.v2024_10_01_preview.models import (
+ DiagnoseWorkspaceParameters as RestDiagnoseWorkspaceParameters,
+)
+
+
+class DiagnoseRequestProperties:
+ """DiagnoseRequestProperties."""
+
+ def __init__(
+ self,
+ *,
+ udr: Optional[Dict[str, Any]] = None,
+ nsg: Optional[Dict[str, Any]] = None,
+ resource_lock: Optional[Dict[str, Any]] = None,
+ dns_resolution: Optional[Dict[str, Any]] = None,
+ storage_account: Optional[Dict[str, Any]] = None,
+ key_vault: Optional[Dict[str, Any]] = None,
+ container_registry: Optional[Dict[str, Any]] = None,
+ application_insights: Optional[Dict[str, Any]] = None,
+ others: Optional[Dict[str, Any]] = None,
+ ):
+ self.udr = udr
+ self.nsg = nsg
+ self.resource_lock = resource_lock
+ self.dns_resolution = dns_resolution
+ self.storage_account = storage_account
+ self.key_vault = key_vault
+ self.container_registry = container_registry
+ self.application_insights = application_insights
+ self.others = others
+
+ @classmethod
+ def _from_rest_object(cls, rest_obj: RestDiagnoseRequestProperties) -> "DiagnoseRequestProperties":
+ return cls(
+ udr=rest_obj.udr,
+ nsg=rest_obj.nsg,
+ resource_lock=rest_obj.resource_lock,
+ dns_resolution=rest_obj.dns_resolution,
+ storage_account=rest_obj.storage_account,
+ key_vault=rest_obj.key_vault,
+ container_registry=rest_obj.container_registry,
+ application_insights=rest_obj.application_insights,
+ others=rest_obj.others,
+ )
+
+ def _to_rest_object(self) -> RestDiagnoseRequestProperties:
+ return RestDiagnoseRequestProperties(
+ udr=self.udr,
+ nsg=self.nsg,
+ resource_lock=self.resource_lock,
+ dns_resolution=self.dns_resolution,
+ storage_account=self.storage_account,
+ key_vault=self.key_vault,
+ container_registry=self.container_registry,
+ application_insights=self.application_insights,
+ others=self.others,
+ )
+
+
+class DiagnoseResponseResult:
+ """DiagnoseResponseResult."""
+
+ def __init__(
+ self,
+ *,
+ value: Optional["DiagnoseResponseResultValue"] = None,
+ ):
+ self.value = value
+
+ @classmethod
+ def _from_rest_object(cls, rest_obj: RestDiagnoseResponseResult) -> "DiagnoseResponseResult":
+ val = None
+ if rest_obj and rest_obj.value and isinstance(rest_obj.value, RestDiagnoseResponseResultValue):
+ # pylint: disable=protected-access
+ val = DiagnoseResponseResultValue._from_rest_object(rest_obj.value)
+ return cls(value=val)
+
+ def _to_rest_object(self) -> RestDiagnoseResponseResult:
+ return RestDiagnoseResponseResult(value=self.value)
+
+
+class DiagnoseResponseResultValue:
+ """DiagnoseResponseResultValue."""
+
+ def __init__(
+ self,
+ *,
+ user_defined_route_results: Optional[List["DiagnoseResult"]] = None,
+ network_security_rule_results: Optional[List["DiagnoseResult"]] = None,
+ resource_lock_results: Optional[List["DiagnoseResult"]] = None,
+ dns_resolution_results: Optional[List["DiagnoseResult"]] = None,
+ storage_account_results: Optional[List["DiagnoseResult"]] = None,
+ key_vault_results: Optional[List["DiagnoseResult"]] = None,
+ container_registry_results: Optional[List["DiagnoseResult"]] = None,
+ application_insights_results: Optional[List["DiagnoseResult"]] = None,
+ other_results: Optional[List["DiagnoseResult"]] = None,
+ ):
+ self.user_defined_route_results = user_defined_route_results
+ self.network_security_rule_results = network_security_rule_results
+ self.resource_lock_results = resource_lock_results
+ self.dns_resolution_results = dns_resolution_results
+ self.storage_account_results = storage_account_results
+ self.key_vault_results = key_vault_results
+ self.container_registry_results = container_registry_results
+ self.application_insights_results = application_insights_results
+ self.other_results = other_results
+
+ @classmethod
+ def _from_rest_object(cls, rest_obj: RestDiagnoseResponseResultValue) -> "DiagnoseResponseResultValue":
+ return cls(
+ user_defined_route_results=rest_obj.user_defined_route_results,
+ network_security_rule_results=rest_obj.network_security_rule_results,
+ resource_lock_results=rest_obj.resource_lock_results,
+ dns_resolution_results=rest_obj.dns_resolution_results,
+ storage_account_results=rest_obj.storage_account_results,
+ key_vault_results=rest_obj.key_vault_results,
+ container_registry_results=rest_obj.container_registry_results,
+ application_insights_results=rest_obj.application_insights_results,
+ other_results=rest_obj.other_results,
+ )
+
+ def _to_rest_object(self) -> RestDiagnoseResponseResultValue:
+ return RestDiagnoseResponseResultValue(
+ user_defined_route_results=self.user_defined_route_results,
+ network_security_rule_results=self.network_security_rule_results,
+ resource_lock_results=self.resource_lock_results,
+ dns_resolution_results=self.dns_resolution_results,
+ storage_account_results=self.storage_account_results,
+ key_vault_results=self.key_vault_results,
+ container_registry_results=self.container_registry_results,
+ application_insights_results=self.application_insights_results,
+ other_results=self.other_results,
+ )
+
+ def __json__(self):
+ results = self.__dict__.copy()
+ for k, v in results.items():
+ results[k] = [item.__dict__ for item in v]
+ return results
+
+ def __str__(self) -> str:
+ return json.dumps(self, default=lambda o: o.__json__(), indent=2)
+
+
+class DiagnoseResult:
+ """Result of Diagnose."""
+
+ def __init__(
+ self,
+ *,
+ code: Optional[str] = None,
+ level: Optional[str] = None,
+ message: Optional[str] = None,
+ ):
+ self.code = code
+ self.level = level
+ self.message = message
+
+ @classmethod
+ def _from_rest_object(cls, rest_obj: RestDiagnoseResult) -> "DiagnoseResult":
+ return cls(
+ code=rest_obj.code,
+ level=rest_obj.level,
+ message=rest_obj.message,
+ )
+
+ def _to_rest_object(self) -> RestDiagnoseResult:
+ return RestDiagnoseResult(
+ code=self.code,
+ level=self.level,
+ message=self.message,
+ )
+
+
+class DiagnoseWorkspaceParameters:
+ """Parameters to diagnose a workspace."""
+
+ def __init__(
+ self,
+ *,
+ value: Optional["DiagnoseRequestProperties"] = None,
+ ):
+ self.value = value
+
+ @classmethod
+ def _from_rest_object(cls, rest_obj: RestDiagnoseWorkspaceParameters) -> "DiagnoseWorkspaceParameters":
+ val = None
+ if rest_obj.value and isinstance(rest_obj.value, DiagnoseRequestProperties):
+ # TODO: Bug Item number: 2883283
+ # pylint: disable=protected-access
+ val = rest_obj.value._from_rest_object() # type: ignore
+ return cls(value=val)
+
+ def _to_rest_object(self) -> RestDiagnoseWorkspaceParameters:
+ val = None
+ if self.value and isinstance(self.value, DiagnoseRequestProperties):
+ # pylint: disable=protected-access
+ val = self.value._to_rest_object()
+ return RestDiagnoseWorkspaceParameters(value=val)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/feature_store_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/feature_store_settings.py
new file mode 100644
index 00000000..8c264db0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/feature_store_settings.py
@@ -0,0 +1,61 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from typing import Optional
+
+from azure.ai.ml._restclient.v2024_10_01_preview.models import FeatureStoreSettings as RestFeatureStoreSettings
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+from .compute_runtime import ComputeRuntime
+
+
+class FeatureStoreSettings(RestTranslatableMixin):
+ """Feature Store Settings
+
+ :param compute_runtime: The spark compute runtime settings. defaults to None.
+ :type compute_runtime: Optional[~compute_runtime.ComputeRuntime]
+ :param offline_store_connection_name: The offline store connection name. Defaults to None.
+ :type offline_store_connection_name: Optional[str]
+ :param online_store_connection_name: The online store connection name. Defaults to None.
+ :type online_store_connection_name: Optional[str]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_featurestore.py
+ :start-after: [START configure_feature_store_settings]
+ :end-before: [END configure_feature_store_settings]
+ :language: python
+ :dedent: 8
+ :caption: Instantiating FeatureStoreSettings
+ """
+
+ def __init__(
+ self,
+ *,
+ compute_runtime: Optional[ComputeRuntime] = None,
+ offline_store_connection_name: Optional[str] = None,
+ online_store_connection_name: Optional[str] = None,
+ ) -> None:
+ self.compute_runtime = compute_runtime if compute_runtime else ComputeRuntime(spark_runtime_version="3.4.0")
+ self.offline_store_connection_name = offline_store_connection_name
+ self.online_store_connection_name = online_store_connection_name
+
+ def _to_rest_object(self) -> RestFeatureStoreSettings:
+ return RestFeatureStoreSettings(
+ compute_runtime=ComputeRuntime._to_rest_object(self.compute_runtime),
+ offline_store_connection_name=self.offline_store_connection_name,
+ online_store_connection_name=self.online_store_connection_name,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestFeatureStoreSettings) -> Optional["FeatureStoreSettings"]:
+ if not obj:
+ return None
+ return FeatureStoreSettings(
+ compute_runtime=ComputeRuntime._from_rest_object(obj.compute_runtime),
+ offline_store_connection_name=obj.offline_store_connection_name,
+ online_store_connection_name=obj.online_store_connection_name,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/network_acls.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/network_acls.py
new file mode 100644
index 00000000..fbb3b9ef
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/network_acls.py
@@ -0,0 +1,90 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import List, Optional
+
+from azure.ai.ml._restclient.v2024_10_01_preview.models import IPRule as RestIPRule
+from azure.ai.ml._restclient.v2024_10_01_preview.models import NetworkAcls as RestNetworkAcls
+from azure.ai.ml.entities._mixins import RestTranslatableMixin
+
+
+class IPRule(RestTranslatableMixin):
+ """Represents an IP rule with a value.
+
+ :param value: An IPv4 address or range in CIDR notation.
+ :type value: str
+ """
+
+ def __init__(self, value: Optional[str]):
+ self.value = value
+
+ def __repr__(self):
+ return f"IPRule(value={self.value})"
+
+ def _to_rest_object(self) -> RestIPRule:
+ return RestIPRule(value=self.value)
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestIPRule) -> "IPRule":
+ return cls(value=obj.value)
+
+
+class DefaultActionType:
+ """Specifies the default action when no IP rules are matched."""
+
+ DENY = "Deny"
+ ALLOW = "Allow"
+
+
+class NetworkAcls(RestTranslatableMixin):
+ """Network Access Setting for Workspace
+
+ :param default_action: Specifies the default action when no IP rules are matched.
+ :type default_action: str
+ :param ip_rules: Rules governing the accessibility of a resource from a specific IP address or IP range.
+ :type ip_rules: Optional[List[IPRule]]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_workspace.py
+ :start-after: [START workspace_network_access_settings]
+ :end-before: [END workspace_network_access_settings]
+ :language: python
+ :dedent: 8
+ :caption: Configuring one of the three public network access settings.
+ """
+
+ def __init__(
+ self,
+ *,
+ default_action: str = DefaultActionType.ALLOW,
+ ip_rules: Optional[List[IPRule]] = None,
+ ):
+ self.default_action = default_action
+ self.ip_rules = ip_rules if ip_rules is not None else []
+
+ def __repr__(self):
+ ip_rules_repr = ", ".join(repr(ip_rule) for ip_rule in self.ip_rules)
+ return f"NetworkAcls(default_action={self.default_action}, ip_rules=[{ip_rules_repr}])"
+
+ def _to_rest_object(self) -> RestNetworkAcls:
+ return RestNetworkAcls(
+ default_action=self.default_action,
+ ip_rules=(
+ [ip_rule._to_rest_object() for ip_rule in self.ip_rules] # pylint: disable=protected-access
+ if self.ip_rules
+ else None
+ ),
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestNetworkAcls) -> "NetworkAcls":
+ return cls(
+ default_action=obj.default_action,
+ ip_rules=(
+ [IPRule._from_rest_object(ip_rule) for ip_rule in obj.ip_rules] # pylint: disable=protected-access
+ if obj.ip_rules
+ else []
+ ),
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/networking.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/networking.py
new file mode 100644
index 00000000..4576eac9
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/networking.py
@@ -0,0 +1,348 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from abc import ABC
+from typing import Any, Dict, List, Optional
+
+from azure.ai.ml._restclient.v2024_10_01_preview.models import FqdnOutboundRule as RestFqdnOutboundRule
+from azure.ai.ml._restclient.v2024_10_01_preview.models import (
+ ManagedNetworkProvisionStatus as RestManagedNetworkProvisionStatus,
+)
+from azure.ai.ml._restclient.v2024_10_01_preview.models import ManagedNetworkSettings as RestManagedNetwork
+from azure.ai.ml._restclient.v2024_10_01_preview.models import (
+ PrivateEndpointDestination as RestPrivateEndpointOutboundRuleDestination,
+)
+from azure.ai.ml._restclient.v2024_10_01_preview.models import (
+ PrivateEndpointOutboundRule as RestPrivateEndpointOutboundRule,
+)
+from azure.ai.ml._restclient.v2024_10_01_preview.models import (
+ ServiceTagDestination as RestServiceTagOutboundRuleDestination,
+)
+from azure.ai.ml._restclient.v2024_10_01_preview.models import ServiceTagOutboundRule as RestServiceTagOutboundRule
+from azure.ai.ml.constants._workspace import IsolationMode, OutboundRuleCategory, OutboundRuleType
+
+
+class OutboundRule(ABC):
+ """Base class for Outbound Rules, cannot be instantiated directly. Please see FqdnDestination,
+ PrivateEndpointDestination, and ServiceTagDestination objects to create outbound rules.
+
+ :param name: Name of the outbound rule.
+ :type name: str
+ :param type: Type of the outbound rule. Supported types are "FQDN", "PrivateEndpoint", "ServiceTag"
+ :type type: str
+ :ivar type: Type of the outbound rule. Supported types are "FQDN", "PrivateEndpoint", "ServiceTag"
+ :vartype type: str
+ """
+
+ def __init__(
+ self,
+ *,
+ name: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ self.name = name
+ self.parent_rule_names = kwargs.pop("parent_rule_names", None)
+ self.type = kwargs.pop("type", None)
+ self.category = kwargs.pop("category", OutboundRuleCategory.USER_DEFINED)
+ self.status = kwargs.pop("status", None)
+
+ @classmethod
+ def _from_rest_object(cls, rest_obj: Any, name: str) -> Optional["OutboundRule"]:
+ if isinstance(rest_obj, RestFqdnOutboundRule):
+ rule_fqdnDestination = FqdnDestination(destination=rest_obj.destination, name=name)
+ rule_fqdnDestination.category = rest_obj.category
+ rule_fqdnDestination.status = rest_obj.status
+ return rule_fqdnDestination
+ if isinstance(rest_obj, RestPrivateEndpointOutboundRule):
+ rule_privateEndpointDestination = PrivateEndpointDestination(
+ service_resource_id=rest_obj.destination.service_resource_id,
+ subresource_target=rest_obj.destination.subresource_target,
+ spark_enabled=rest_obj.destination.spark_enabled,
+ fqdns=rest_obj.fqdns,
+ name=name,
+ )
+ rule_privateEndpointDestination.category = rest_obj.category
+ rule_privateEndpointDestination.status = rest_obj.status
+ return rule_privateEndpointDestination
+ if isinstance(rest_obj, RestServiceTagOutboundRule):
+ rule = ServiceTagDestination(
+ service_tag=rest_obj.destination.service_tag,
+ protocol=rest_obj.destination.protocol,
+ port_ranges=rest_obj.destination.port_ranges,
+ address_prefixes=rest_obj.destination.address_prefixes,
+ name=name,
+ )
+ rule.category = rest_obj.category
+ rule.status = rest_obj.status
+ return rule
+
+ return None
+
+
+class FqdnDestination(OutboundRule):
+ """Class representing a FQDN outbound rule.
+
+ :param name: Name of the outbound rule.
+ :type name: str
+ :param destination: Fully qualified domain name to which outbound connections are allowed.
+ For example: “xxxxxx.contoso.com”.
+ :type destination: str
+ :ivar type: Type of the outbound rule. Set to "FQDN" for this class.
+ :vartype type: str
+
+ .. literalinclude:: ../samples/ml_samples_workspace.py
+ :start-after: [START fqdn_outboundrule]
+ :end-before: [END fqdn_outboundrule]
+ :language: python
+ :dedent: 8
+ :caption: Creating a FqdnDestination outbound rule object.
+ """
+
+ def __init__(self, *, name: str, destination: str, **kwargs: Any) -> None:
+ self.destination = destination
+ OutboundRule.__init__(self, type=OutboundRuleType.FQDN, name=name, **kwargs)
+
+ def _to_rest_object(self) -> RestFqdnOutboundRule:
+ return RestFqdnOutboundRule(type=self.type, category=self.category, destination=self.destination)
+
+ def _to_dict(self) -> Dict:
+ return {
+ "name": self.name,
+ "type": OutboundRuleType.FQDN,
+ "category": self.category,
+ "destination": self.destination,
+ "status": self.status,
+ }
+
+
+class PrivateEndpointDestination(OutboundRule):
+ """Class representing a Private Endpoint outbound rule.
+
+ :param name: Name of the outbound rule.
+ :type name: str
+ :param service_resource_id: The resource URI of the root service that supports creation of the private link.
+ :type service_resource_id: str
+ :param subresource_target: The target endpoint of the subresource of the service.
+ :type subresource_target: str
+ :param spark_enabled: Indicates if the private endpoint can be used for Spark jobs, default is “false”.
+ :type spark_enabled: bool
+ :param fqdns: String list of FQDNs particular to the Private Endpoint resource creation. For application
+ gateway Private Endpoints, this is the FQDN which will resolve to the private IP of the application
+ gateway PE inside the workspace's managed network.
+ :type fqdns: List[str]
+ :ivar type: Type of the outbound rule. Set to "PrivateEndpoint" for this class.
+ :vartype type: str
+
+ .. literalinclude:: ../samples/ml_samples_workspace.py
+ :start-after: [START private_endpoint_outboundrule]
+ :end-before: [END private_endpoint_outboundrule]
+ :language: python
+ :dedent: 8
+ :caption: Creating a PrivateEndpointDestination outbound rule object.
+ """
+
+ def __init__(
+ self,
+ *,
+ name: str,
+ service_resource_id: str,
+ subresource_target: str,
+ spark_enabled: bool = False,
+ fqdns: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> None:
+ self.service_resource_id = service_resource_id
+ self.subresource_target = subresource_target
+ self.spark_enabled = spark_enabled
+ self.fqdns = fqdns
+ OutboundRule.__init__(self, type=OutboundRuleType.PRIVATE_ENDPOINT, name=name, **kwargs)
+
+ def _to_rest_object(self) -> RestPrivateEndpointOutboundRule:
+ return RestPrivateEndpointOutboundRule(
+ type=self.type,
+ category=self.category,
+ destination=RestPrivateEndpointOutboundRuleDestination(
+ service_resource_id=self.service_resource_id,
+ subresource_target=self.subresource_target,
+ spark_enabled=self.spark_enabled,
+ ),
+ fqdns=self.fqdns,
+ )
+
+ def _to_dict(self) -> Dict:
+ return {
+ "name": self.name,
+ "type": OutboundRuleType.PRIVATE_ENDPOINT,
+ "category": self.category,
+ "destination": {
+ "service_resource_id": self.service_resource_id,
+ "subresource_target": self.subresource_target,
+ "spark_enabled": self.spark_enabled,
+ },
+ "fqdns": self.fqdns,
+ "status": self.status,
+ }
+
+
+class ServiceTagDestination(OutboundRule):
+ """Class representing a Service Tag outbound rule.
+
+ :param name: Name of the outbound rule.
+ :type name: str
+ :param service_tag: Service Tag of an Azure service, maps to predefined IP addresses for its service endpoints.
+ :type service_tag: str
+ :param protocol: Allowed transport protocol, can be "TCP", "UDP", "ICMP" or "*" for all supported protocols.
+ :type protocol: str
+ :param port_ranges: A comma-separated list of single ports and/or range of ports, such as "80,1024-65535".
+ Traffics should be allowed to these port ranges.
+ :type port_ranges: str
+ :param address_prefixes: Optional list of CIDR prefixes or IP ranges, when provided, service_tag argument will
+ be ignored and address_prefixes will be used instead.
+ :type address_prefixes: List[str]
+ :ivar type: Type of the outbound rule. Set to "ServiceTag" for this class.
+ :vartype type: str
+
+ .. literalinclude:: ../samples/ml_samples_workspace.py
+ :start-after: [START service_tag_outboundrule]
+ :end-before: [END service_tag_outboundrule]
+ :language: python
+ :dedent: 8
+ :caption: Creating a ServiceTagDestination outbound rule object.
+ """
+
+ def __init__(
+ self,
+ *,
+ name: str,
+ protocol: str,
+ port_ranges: str,
+ service_tag: Optional[str] = None,
+ address_prefixes: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> None:
+ self.service_tag = service_tag
+ self.protocol = protocol
+ self.port_ranges = port_ranges
+ self.address_prefixes = address_prefixes
+ OutboundRule.__init__(self, type=OutboundRuleType.SERVICE_TAG, name=name, **kwargs)
+
+ def _to_rest_object(self) -> RestServiceTagOutboundRule:
+ return RestServiceTagOutboundRule(
+ type=self.type,
+ category=self.category,
+ destination=RestServiceTagOutboundRuleDestination(
+ service_tag=self.service_tag,
+ protocol=self.protocol,
+ port_ranges=self.port_ranges,
+ address_prefixes=self.address_prefixes,
+ ),
+ )
+
+ def _to_dict(self) -> Dict:
+ return {
+ "name": self.name,
+ "type": OutboundRuleType.SERVICE_TAG,
+ "category": self.category,
+ "destination": {
+ "service_tag": self.service_tag,
+ "protocol": self.protocol,
+ "port_ranges": self.port_ranges,
+ "address_prefixes": self.address_prefixes,
+ },
+ "status": self.status,
+ }
+
+
+class ManagedNetwork:
+ """Managed Network settings for a workspace.
+
+ :param isolation_mode: Isolation of the managed network, defaults to Disabled.
+ :type isolation_mode: str
+ :param firewall_sku: Firewall Sku for FQDN rules in AllowOnlyApprovedOutbound..
+ :type firewall_sku: str
+ :param outbound_rules: List of outbound rules for the managed network.
+ :type outbound_rules: List[~azure.ai.ml.entities.OutboundRule]
+ :param network_id: Network id for the managed network, not meant to be set by user.
+ :type network_id: str
+
+ .. literalinclude:: ../samples/ml_samples_workspace.py
+ :start-after: [START workspace_managed_network]
+ :end-before: [END workspace_managed_network]
+ :language: python
+ :dedent: 8
+ :caption: Creating a ManagedNetwork object with one of each rule type.
+ """
+
+ def __init__(
+ self,
+ *,
+ isolation_mode: str = IsolationMode.DISABLED,
+ outbound_rules: Optional[List[OutboundRule]] = None,
+ firewall_sku: Optional[str] = None,
+ network_id: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ self.isolation_mode = isolation_mode
+ self.firewall_sku = firewall_sku
+ self.network_id = network_id
+ self.outbound_rules = outbound_rules
+ self.status = kwargs.pop("status", None)
+
+ def _to_rest_object(self) -> RestManagedNetwork:
+ rest_outbound_rules = (
+ {
+ # pylint: disable=protected-access
+ outbound_rule.name: outbound_rule._to_rest_object() # type: ignore[attr-defined]
+ for outbound_rule in self.outbound_rules
+ }
+ if self.outbound_rules
+ else {}
+ )
+ return RestManagedNetwork(
+ isolation_mode=self.isolation_mode, outbound_rules=rest_outbound_rules, firewall_sku=self.firewall_sku
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestManagedNetwork) -> "ManagedNetwork":
+ from_rest_outbound_rules = (
+ [
+ OutboundRule._from_rest_object(obj.outbound_rules[name], name=name) # pylint: disable=protected-access
+ for name in obj.outbound_rules
+ ]
+ if obj.outbound_rules
+ else {}
+ )
+ return ManagedNetwork(
+ isolation_mode=obj.isolation_mode,
+ outbound_rules=from_rest_outbound_rules, # type: ignore[arg-type]
+ network_id=obj.network_id,
+ status=obj.status,
+ firewall_sku=obj.firewall_sku,
+ )
+
+
+class ManagedNetworkProvisionStatus:
+ """ManagedNetworkProvisionStatus.
+
+ :param status: Status for managed network provision.
+ :type status: str
+ :param spark_ready: Bool value indicating if managed network is spark ready
+ :type spark_ready: bool
+ """
+
+ def __init__(
+ self,
+ *,
+ status: Optional[str] = None,
+ spark_ready: Optional[bool] = None,
+ ):
+ self.status = status
+ self.spark_ready = spark_ready
+
+ @classmethod
+ def _from_rest_object(cls, rest_obj: RestManagedNetworkProvisionStatus) -> "ManagedNetworkProvisionStatus":
+ return cls(
+ status=rest_obj.status,
+ spark_ready=rest_obj.spark_ready,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/private_endpoint.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/private_endpoint.py
new file mode 100644
index 00000000..c9e8882e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/private_endpoint.py
@@ -0,0 +1,53 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Dict, Optional
+
+
+class EndpointConnection:
+ """Private Endpoint Connection related to a workspace private endpoint.
+
+ :param subscription_id: Subscription id of the connection.
+ :type subscription_id: str
+ :param resource_group: Resource group of the connection.
+ :type resource_group: str
+ :param vnet_name: Name of the virtual network of the connection.
+ :type vnet_name: str
+ :param subnet_name: Name of the subnet of the connection.
+ :type subnet_name: str
+ :param location: Location of the connection.
+ :type location: str
+ """
+
+ def __init__(
+ self,
+ subscription_id: str,
+ resource_group: str,
+ vnet_name: str,
+ subnet_name: str,
+ location: Optional[str] = None,
+ ):
+ self.subscription_id = subscription_id
+ self.resource_group = resource_group
+ self.location = location
+ self.vnet_name = vnet_name
+ self.subnet_name = subnet_name
+
+
+class PrivateEndpoint:
+ """Private Endpoint of a workspace.
+
+ :param approval_type: Approval type of the private endpoint.
+ :type approval_type: str
+ :param connections: List of private endpoint connections.
+ :type connections: List[~azure.ai.ml.entities.EndpointConnection]
+ """
+
+ def __init__(
+ self,
+ approval_type: Optional[str] = None,
+ connections: Optional[Dict[str, EndpointConnection]] = None,
+ ):
+ self.approval_type = approval_type
+ self.connections = connections
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/serverless_compute.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/serverless_compute.py
new file mode 100644
index 00000000..b78ede06
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/serverless_compute.py
@@ -0,0 +1,52 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from typing import Optional, Union
+
+from marshmallow.exceptions import ValidationError
+
+from azure.ai.ml._restclient.v2024_10_01_preview.models import (
+ ServerlessComputeSettings as RestServerlessComputeSettings,
+)
+from azure.ai.ml._schema._utils.utils import ArmId
+
+
+class ServerlessComputeSettings:
+ custom_subnet: Optional[ArmId]
+ no_public_ip: bool = False
+
+ def __init__(self, *, custom_subnet: Optional[Union[str, ArmId]] = None, no_public_ip: bool = False) -> None:
+ """Settings regarding serverless compute(s) in an Azure ML workspace.
+
+ :keyword custom_subnet: The ARM ID of the subnet to use for serverless compute(s).
+ :paramtype custom_subnet: Optional[Union[str, ArmId]]
+ :keyword no_public_ip: Whether or not to disable public IP addresses for serverless compute(s).
+ Defaults to False.
+ :paramtype no_public_ip: bool
+ :raises ValidationError: If the custom_subnet is not formatted as an ARM ID.
+ """
+ if isinstance(custom_subnet, str):
+ self.custom_subnet = ArmId(custom_subnet)
+ elif isinstance(custom_subnet, ArmId) or custom_subnet is None:
+ self.custom_subnet = custom_subnet
+ else:
+ raise ValidationError("custom_subnet must be a string, ArmId, or None.")
+ self.no_public_ip = no_public_ip
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, ServerlessComputeSettings):
+ return NotImplemented
+ return self.custom_subnet == other.custom_subnet and self.no_public_ip == other.no_public_ip
+
+ def _to_rest_object(self) -> RestServerlessComputeSettings:
+ return RestServerlessComputeSettings(
+ serverless_compute_custom_subnet=self.custom_subnet,
+ serverless_compute_no_public_ip=self.no_public_ip,
+ )
+
+ @classmethod
+ def _from_rest_object(cls, obj: RestServerlessComputeSettings) -> "ServerlessComputeSettings":
+ return cls(
+ custom_subnet=obj.serverless_compute_custom_subnet,
+ no_public_ip=obj.serverless_compute_no_public_ip,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/workspace.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/workspace.py
new file mode 100644
index 00000000..495e00b0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/workspace.py
@@ -0,0 +1,491 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=too-many-instance-attributes
+
+from os import PathLike
+from pathlib import Path
+from typing import IO, Any, AnyStr, Dict, List, Optional, Tuple, Type, Union
+
+from azure.ai.ml._restclient.v2024_10_01_preview.models import FeatureStoreSettings as RestFeatureStoreSettings
+from azure.ai.ml._restclient.v2024_10_01_preview.models import ManagedNetworkSettings as RestManagedNetwork
+from azure.ai.ml._restclient.v2024_10_01_preview.models import ManagedServiceIdentity as RestManagedServiceIdentity
+from azure.ai.ml._restclient.v2024_10_01_preview.models import NetworkAcls as RestNetworkAcls
+from azure.ai.ml._restclient.v2024_10_01_preview.models import (
+ ServerlessComputeSettings as RestServerlessComputeSettings,
+)
+from azure.ai.ml._restclient.v2024_10_01_preview.models import Workspace as RestWorkspace
+from azure.ai.ml._schema.workspace.workspace import WorkspaceSchema
+from azure.ai.ml._utils.utils import dump_yaml_to_file
+from azure.ai.ml.constants._common import (
+ BASE_PATH_CONTEXT_KEY,
+ PARAMS_OVERRIDE_KEY,
+ CommonYamlFields,
+ WorkspaceKind,
+ WorkspaceResourceConstants,
+)
+from azure.ai.ml.entities._credentials import IdentityConfiguration
+from azure.ai.ml.entities._resource import Resource
+from azure.ai.ml.entities._util import find_field_in_override, load_from_dict
+from azure.ai.ml.entities._workspace.serverless_compute import ServerlessComputeSettings
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+
+from .customer_managed_key import CustomerManagedKey
+from .feature_store_settings import FeatureStoreSettings
+from .network_acls import NetworkAcls
+from .networking import ManagedNetwork
+
+
+class Workspace(Resource):
+ """Azure ML workspace.
+
+ :param name: Name of the workspace.
+ :type name: str
+ :param description: Description of the workspace.
+ :type description: str
+ :param tags: Tags of the workspace.
+ :type tags: dict
+ :param display_name: Display name for the workspace. This is non-unique within the resource group.
+ :type display_name: str
+ :param location: The location to create the workspace in.
+ If not specified, the same location as the resource group will be used.
+ :type location: str
+ :param resource_group: Name of resource group to create the workspace in.
+ :type resource_group: str
+ :param hbi_workspace: Whether the customer data is of high business impact (HBI),
+ containing sensitive business information.
+ For more information, see
+ https://learn.microsoft.com/azure/machine-learning/concept-data-encryption#encryption-at-rest.
+ :type hbi_workspace: bool
+ :param storage_account: The resource ID of an existing storage account to use instead of creating a new one.
+ :type storage_account: str
+ :param container_registry: The resource ID of an existing container registry
+ to use instead of creating a new one.
+ :type container_registry: str
+ :param key_vault: The resource ID of an existing key vault to use instead of creating a new one.
+ :type key_vault: str
+ :param application_insights: The resource ID of an existing application insights
+ to use instead of creating a new one.
+ :type application_insights: str
+ :param customer_managed_key: Key vault details for encrypting data with customer-managed keys.
+ If not specified, Microsoft-managed keys will be used by default.
+ :type customer_managed_key: ~azure.ai.ml.entities.CustomerManagedKey
+ :param image_build_compute: The name of the compute target to use for building environment
+ Docker images with the container registry is behind a VNet.
+ :type image_build_compute: str
+ :param public_network_access: Whether to allow public endpoint connectivity
+ when a workspace is private link enabled.
+ :type public_network_access: str
+ :param network_acls: The network access control list (ACL) settings of the workspace.
+ :type network_acls: ~azure.ai.ml.entities.NetworkAcls
+ :param identity: workspace's Managed Identity (user assigned, or system assigned)
+ :type identity: ~azure.ai.ml.entities.IdentityConfiguration
+ :param primary_user_assigned_identity: The workspace's primary user assigned identity
+ :type primary_user_assigned_identity: str
+ :param managed_network: workspace's Managed Network configuration
+ :type managed_network: ~azure.ai.ml.entities.ManagedNetwork
+ :param provision_network_now: Set to trigger the provisioning of the managed vnet with the default options when
+ creating a workspace with the managed vnet enable, or else it does nothing
+ :type provision_network_now: Optional[bool]
+ :param system_datastores_auth_mode: The authentication mode for system datastores.
+ :type system_datastores_auth_mode: str
+ :param enable_data_isolation: A flag to determine if workspace has data isolation enabled.
+ The flag can only be set at the creation phase, it can't be updated.
+ :type enable_data_isolation: bool
+ :param allow_roleassignment_on_rg: Determine whether allow workspace role assignment on resource group level.
+ :type allow_roleassignment_on_rg: Optional[bool]
+ :param serverless_compute: The serverless compute settings for the workspace.
+ :type: ~azure.ai.ml.entities.ServerlessComputeSettings
+ :param workspace_hub: Deprecated resource ID of an existing workspace hub to help create project workspace.
+ Use the Project class instead now.
+ :type workspace_hub: Optional[str]
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: dict
+
+ .. literalinclude:: ../samples/ml_samples_workspace.py
+ :start-after: [START workspace]
+ :end-before: [END workspace]
+ :language: python
+ :dedent: 8
+ :caption: Creating a Workspace object.
+ """
+
+ def __init__(
+ self,
+ *,
+ name: str,
+ description: Optional[str] = None,
+ tags: Optional[Dict[str, str]] = None,
+ display_name: Optional[str] = None,
+ location: Optional[str] = None,
+ resource_group: Optional[str] = None,
+ hbi_workspace: bool = False,
+ storage_account: Optional[str] = None,
+ container_registry: Optional[str] = None,
+ key_vault: Optional[str] = None,
+ application_insights: Optional[str] = None,
+ customer_managed_key: Optional[CustomerManagedKey] = None,
+ image_build_compute: Optional[str] = None,
+ public_network_access: Optional[str] = None,
+ network_acls: Optional[NetworkAcls] = None,
+ identity: Optional[IdentityConfiguration] = None,
+ primary_user_assigned_identity: Optional[str] = None,
+ managed_network: Optional[ManagedNetwork] = None,
+ provision_network_now: Optional[bool] = None,
+ system_datastores_auth_mode: Optional[str] = None,
+ enable_data_isolation: bool = False,
+ allow_roleassignment_on_rg: Optional[bool] = None,
+ hub_id: Optional[str] = None, # Hidden input, surfaced by Project
+ workspace_hub: Optional[str] = None, # Deprecated input maintained for backwards compat.
+ serverless_compute: Optional[ServerlessComputeSettings] = None,
+ **kwargs: Any,
+ ):
+ # Workspaces have subclasses that are differentiated by the 'kind' field in the REST API.
+ # Now that this value is occasionally surfaced (for sub-class YAML specifications)
+ # We've switched to using 'type' in the SDK for consistency's sake with other polymorphic classes.
+ # That said, the code below but quietly supports 'kind' as an input
+ # to maintain backwards compatibility with internal systems that I suspect still use 'kind' somewhere.
+ # 'type' takes precedence over 'kind' if they're both set, and this defaults to a normal workspace's type
+ # if nothing is set.
+ # pylint: disable=too-many-locals
+ self._kind = kwargs.pop("kind", None)
+ if self._kind is None:
+ self._kind = WorkspaceKind.DEFAULT
+
+ self.print_as_yaml = True
+ self._discovery_url: Optional[str] = kwargs.pop("discovery_url", None)
+ self._mlflow_tracking_uri: Optional[str] = kwargs.pop("mlflow_tracking_uri", None)
+ self._workspace_id = kwargs.pop("workspace_id", None)
+ self._feature_store_settings: Optional[FeatureStoreSettings] = kwargs.pop("feature_store_settings", None)
+ super().__init__(name=name, description=description, tags=tags, **kwargs)
+
+ self.display_name = display_name
+ self.location = location
+ self.resource_group = resource_group
+ self.hbi_workspace = hbi_workspace
+ self.storage_account = storage_account
+ self.container_registry = container_registry
+ self.key_vault = key_vault
+ self.application_insights = application_insights
+ self.customer_managed_key = customer_managed_key
+ self.image_build_compute = image_build_compute
+ self.public_network_access = public_network_access
+ self.identity = identity
+ self.primary_user_assigned_identity = primary_user_assigned_identity
+ self.managed_network = managed_network
+ self.provision_network_now = provision_network_now
+ self.system_datastores_auth_mode = system_datastores_auth_mode
+ self.enable_data_isolation = enable_data_isolation
+ self.allow_roleassignment_on_rg = allow_roleassignment_on_rg
+ if workspace_hub and not hub_id:
+ hub_id = workspace_hub
+ self.__hub_id = hub_id
+ # Overwrite kind if hub_id is provided. Technically not needed anymore,
+ # but kept for backwards if people try to just use a normal workspace like
+ # a project.
+ if hub_id:
+ self._kind = WorkspaceKind.PROJECT
+ self.serverless_compute: Optional[ServerlessComputeSettings] = serverless_compute
+ self.network_acls: Optional[NetworkAcls] = network_acls
+
+ @property
+ def discovery_url(self) -> Optional[str]:
+ """Backend service base URLs for the workspace.
+
+ :return: Backend service URLs of the workspace
+ :rtype: str
+ """
+ return self._discovery_url
+
+ # Exists to appease tox's mypy rules.
+ @property
+ def _hub_id(self) -> Optional[str]:
+ """The UID of the hub parent of the project. This is an internal property
+ that's surfaced by the Project sub-class, but exists here for backwards-compatibility
+ reasons.
+
+ :return: Resource ID of the parent hub.
+ :rtype: str
+ """
+ return self.__hub_id
+
+ # Exists to appease tox's mypy rules.
+ @_hub_id.setter
+ def _hub_id(self, value: str):
+ """Set the hub of the project. This is an internal property
+ that's surfaced by the Project sub-class, but exists here for backwards-compatibility
+ reasons.
+
+
+ :param value: The hub id to assign to the project.
+ Note: cannot be reassigned after creation.
+ :type value: str
+ """
+ if not value:
+ return
+ self.__hub_id = value
+
+ @property
+ def mlflow_tracking_uri(self) -> Optional[str]:
+ """MLflow tracking uri for the workspace.
+
+ :return: Returns mlflow tracking uri of the workspace.
+ :rtype: str
+ """
+ return self._mlflow_tracking_uri
+
+ def dump(self, dest: Union[str, PathLike, IO[AnyStr]], **kwargs: Any) -> None:
+ """Dump the workspace spec into a file in yaml format.
+
+ :param dest: The destination to receive this workspace's spec.
+ Must be either a path to a local file, or an already-open file stream.
+ If dest is a file path, a new file will be created,
+ and an exception is raised if the file exists.
+ If dest is an open file, the file will be written to directly,
+ and an exception will be raised if the file is not writable.
+ :type dest: Union[PathLike, str, IO[AnyStr]]
+ """
+ path = kwargs.pop("path", None)
+ yaml_serialized = self._to_dict()
+ dump_yaml_to_file(dest, yaml_serialized, default_flow_style=False, path=path, **kwargs)
+
+ def _to_dict(self) -> Dict:
+ res: dict = self._get_schema_class()(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+ return res
+
+ @classmethod
+ def _resolve_sub_cls_and_kind(
+ cls, data: Dict, params_override: Optional[List[Dict]] = None
+ ) -> Tuple[Type["Workspace"], str]:
+ """Given a workspace data dictionary, determine the appropriate workspace class and type string.
+ Allows for easier polymorphism between the workspace class and its children.
+ Adapted from similar code in the Job class.
+
+ :param data: A dictionary of values describing the workspace.
+ :type data: Dict
+ :param params_override: Override values from alternative sources (ex: CLI input).
+ :type params_override: Optional[List[Dict]]
+ :return: A tuple containing the workspace class and type string.
+ :rtype: Tuple[Type["Workspace"], str]
+ """
+ from azure.ai.ml.entities import Hub, Project
+
+ workspace_class: Optional[Type["Workspace"]] = None
+ type_in_override = find_field_in_override(CommonYamlFields.KIND, params_override)
+ type_str = type_in_override or data.get(CommonYamlFields.KIND, WorkspaceKind.DEFAULT)
+ if type_str is not None:
+ type_str = type_str.lower()
+ if type_str == WorkspaceKind.HUB:
+ workspace_class = Hub
+ elif type_str == WorkspaceKind.PROJECT:
+ workspace_class = Project
+ elif type_str == WorkspaceKind.DEFAULT:
+ workspace_class = Workspace
+ else:
+ msg = f"Unsupported workspace type: {type_str}."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.WORKSPACE,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ return workspace_class, type_str
+
+ @classmethod
+ def _load(
+ cls,
+ data: Optional[Dict] = None,
+ yaml_path: Optional[Union[PathLike, str]] = None,
+ params_override: Optional[list] = None,
+ **kwargs: Any,
+ ) -> "Workspace":
+ # This _load function is polymorphic and can return child classes.
+ # It was adapted from the Job class's similar function.
+ data = data or {}
+ params_override = params_override or []
+ context = {
+ BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"),
+ PARAMS_OVERRIDE_KEY: params_override,
+ }
+ workspace_class, type_str = cls._resolve_sub_cls_and_kind(data, params_override)
+ schema_type = workspace_class._get_schema_class() # pylint: disable=protected-access
+ loaded_schema = load_from_dict(
+ schema_type,
+ data=data,
+ context=context,
+ additional_message=f"If you are trying to configure a workspace that is not of type {type_str},"
+ f" please specify the correct job type in the 'type' property.",
+ **kwargs,
+ )
+ result = workspace_class(**loaded_schema)
+ if yaml_path:
+ result._source_path = yaml_path # pylint: disable=protected-access
+ return result
+
+ @classmethod
+ def _from_rest_object(
+ cls, rest_obj: RestWorkspace, v2_service_context: Optional[object] = None
+ ) -> Optional["Workspace"]:
+
+ if not rest_obj:
+ return None
+ customer_managed_key = (
+ CustomerManagedKey(
+ key_vault=rest_obj.encryption.key_vault_properties.key_vault_arm_id,
+ key_uri=rest_obj.encryption.key_vault_properties.key_identifier,
+ )
+ if rest_obj.encryption
+ and rest_obj.encryption.status == WorkspaceResourceConstants.ENCRYPTION_STATUS_ENABLED
+ else None
+ )
+
+ # TODO: Remove attribute check once Oct API version is out
+ mlflow_tracking_uri = None
+
+ if hasattr(rest_obj, "ml_flow_tracking_uri"):
+ try:
+ if v2_service_context:
+ # v2_service_context is required (not None) in get_mlflow_tracking_uri_v2
+ from azureml.mlflow import get_mlflow_tracking_uri_v2
+
+ mlflow_tracking_uri = get_mlflow_tracking_uri_v2(rest_obj, v2_service_context)
+ else:
+ mlflow_tracking_uri = rest_obj.ml_flow_tracking_uri
+ except ImportError:
+ mlflow_tracking_uri = rest_obj.ml_flow_tracking_uri
+
+ # TODO: Remove once Online Endpoints updates API version to at least 2023-08-01
+ allow_roleassignment_on_rg = None
+ if hasattr(rest_obj, "allow_role_assignment_on_rg"):
+ allow_roleassignment_on_rg = rest_obj.allow_role_assignment_on_rg
+ system_datastores_auth_mode = None
+ if hasattr(rest_obj, "system_datastores_auth_mode"):
+ system_datastores_auth_mode = rest_obj.system_datastores_auth_mode
+
+ # TODO: remove this once it is included in API response
+ managed_network = None
+ if hasattr(rest_obj, "managed_network"):
+ if rest_obj.managed_network and isinstance(rest_obj.managed_network, RestManagedNetwork):
+ managed_network = ManagedNetwork._from_rest_object( # pylint: disable=protected-access
+ rest_obj.managed_network
+ )
+
+ # TODO: Remove once it's included in response
+ provision_network_now = None
+ if hasattr(rest_obj, "provision_network_now"):
+ provision_network_now = rest_obj.provision_network_now
+
+ armid_parts = str(rest_obj.id).split("/")
+ group = None if len(armid_parts) < 4 else armid_parts[4]
+ identity = None
+ if rest_obj.identity and isinstance(rest_obj.identity, RestManagedServiceIdentity):
+ identity = IdentityConfiguration._from_workspace_rest_object( # pylint: disable=protected-access
+ rest_obj.identity
+ )
+ feature_store_settings = None
+ if rest_obj.feature_store_settings and isinstance(rest_obj.feature_store_settings, RestFeatureStoreSettings):
+ feature_store_settings = FeatureStoreSettings._from_rest_object( # pylint: disable=protected-access
+ rest_obj.feature_store_settings
+ )
+ serverless_compute = None
+ # TODO: Remove attribute check once serverless_compute_settings is in API response contract
+ if hasattr(rest_obj, "serverless_compute_settings"):
+ if rest_obj.serverless_compute_settings and isinstance(
+ rest_obj.serverless_compute_settings, RestServerlessComputeSettings
+ ):
+ serverless_compute = ServerlessComputeSettings._from_rest_object( # pylint: disable=protected-access
+ rest_obj.serverless_compute_settings
+ )
+ network_acls = None
+ if hasattr(rest_obj, "network_acls"):
+ if rest_obj.network_acls and isinstance(rest_obj.network_acls, RestNetworkAcls):
+ network_acls = NetworkAcls._from_rest_object(rest_obj.network_acls) # pylint: disable=protected-access
+
+ return cls(
+ name=rest_obj.name,
+ id=rest_obj.id,
+ description=rest_obj.description,
+ kind=rest_obj.kind.lower() if rest_obj.kind else None,
+ tags=rest_obj.tags,
+ location=rest_obj.location,
+ resource_group=group,
+ display_name=rest_obj.friendly_name,
+ discovery_url=rest_obj.discovery_url,
+ hbi_workspace=rest_obj.hbi_workspace,
+ storage_account=rest_obj.storage_account,
+ container_registry=rest_obj.container_registry,
+ key_vault=rest_obj.key_vault,
+ application_insights=rest_obj.application_insights,
+ customer_managed_key=customer_managed_key,
+ image_build_compute=rest_obj.image_build_compute,
+ public_network_access=rest_obj.public_network_access,
+ network_acls=network_acls,
+ mlflow_tracking_uri=mlflow_tracking_uri,
+ identity=identity,
+ primary_user_assigned_identity=rest_obj.primary_user_assigned_identity,
+ managed_network=managed_network,
+ provision_network_now=provision_network_now,
+ system_datastores_auth_mode=system_datastores_auth_mode,
+ feature_store_settings=feature_store_settings,
+ enable_data_isolation=rest_obj.enable_data_isolation,
+ allow_roleassignment_on_rg=allow_roleassignment_on_rg,
+ hub_id=rest_obj.hub_resource_id,
+ workspace_id=rest_obj.workspace_id,
+ serverless_compute=serverless_compute,
+ )
+
+ def _to_rest_object(self) -> RestWorkspace:
+ """Note: Unlike most entities, the create operation for workspaces does NOTE use this function,
+ and instead relies on its own internal conversion process to produce a valid ARM template.
+
+ :return: The REST API object-equivalent of this workspace.
+ :rtype: RestWorkspace
+ """
+ feature_store_settings = None
+ if self._feature_store_settings:
+ feature_store_settings = self._feature_store_settings._to_rest_object() # pylint: disable=protected-access
+
+ serverless_compute_settings = None
+ if self.serverless_compute:
+ serverless_compute_settings = self.serverless_compute._to_rest_object() # pylint: disable=protected-access
+
+ return RestWorkspace(
+ name=self.name,
+ identity=(
+ self.identity._to_workspace_rest_object() if self.identity else None # pylint: disable=protected-access
+ ),
+ location=self.location,
+ tags=self.tags,
+ description=self.description,
+ kind=self._kind,
+ friendly_name=self.display_name,
+ key_vault=self.key_vault,
+ application_insights=self.application_insights,
+ container_registry=self.container_registry,
+ storage_account=self.storage_account,
+ discovery_url=self.discovery_url,
+ hbi_workspace=self.hbi_workspace,
+ image_build_compute=self.image_build_compute,
+ public_network_access=self.public_network_access,
+ primary_user_assigned_identity=self.primary_user_assigned_identity,
+ managed_network=(
+ self.managed_network._to_rest_object() # pylint: disable=protected-access
+ if self.managed_network
+ else None
+ ),
+ provision_network_now=self.provision_network_now,
+ system_datastores_auth_mode=self.system_datastores_auth_mode,
+ feature_store_settings=feature_store_settings,
+ enable_data_isolation=self.enable_data_isolation,
+ allow_role_assignment_on_rg=self.allow_roleassignment_on_rg, # diff due to swagger restclient casing diff
+ hub_resource_id=self._hub_id,
+ serverless_compute_settings=serverless_compute_settings,
+ )
+
+ # Helper for sub-class polymorphism. Needs to be overwritten by child classes
+ # If they don't want to redefine things like _to_dict.
+ @classmethod
+ def _get_schema_class(cls) -> Type[WorkspaceSchema]:
+ return WorkspaceSchema
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/workspace_keys.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/workspace_keys.py
new file mode 100644
index 00000000..4213b419
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_workspace/workspace_keys.py
@@ -0,0 +1,100 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import List, Optional
+
+from azure.ai.ml._restclient.v2024_10_01_preview.models import ListWorkspaceKeysResult
+
+
+class ContainerRegistryCredential:
+ """Key for ACR associated with given workspace.
+
+ :param location: Location of the ACR
+ :type location: str
+ :param username: Username of the ACR
+ :type username: str
+ :param passwords: Passwords to access the ACR
+ :type passwords: List[str]
+ """
+
+ def __init__(
+ self, *, location: Optional[str] = None, username: Optional[str] = None, passwords: Optional[List[str]] = None
+ ):
+ self.location = location
+ self.username = username
+ self.passwords = passwords
+
+
+class NotebookAccessKeys:
+ """Key for notebook resource associated with given workspace.
+
+ :param primary_access_key: Primary access key of notebook resource
+ :type primary_access_key: str
+ :param secondary_access_key: Secondary access key of notebook resource
+ :type secondary_access_key: str
+ """
+
+ def __init__(self, *, primary_access_key: Optional[str] = None, secondary_access_key: Optional[str] = None):
+ self.primary_access_key = primary_access_key
+ self.secondary_access_key = secondary_access_key
+
+
+class WorkspaceKeys:
+ """Workspace Keys.
+
+ :param user_storage_key: Key for storage account associated with given workspace
+ :type user_storage_key: str
+ :param user_storage_resource_id: Resource id of storage account associated with given workspace
+ :type user_storage_resource_id: str
+ :param app_insights_instrumentation_key: Key for app insights associated with given workspace
+ :type app_insights_instrumentation_key: str
+ :param container_registry_credentials: Key for ACR associated with given workspace
+ :type container_registry_credentials: ContainerRegistryCredential
+ :param notebook_access_keys: Key for notebook resource associated with given workspace
+ :type notebook_access_keys: NotebookAccessKeys
+ """
+
+ def __init__(
+ self,
+ *,
+ user_storage_key: Optional[str] = None,
+ user_storage_resource_id: Optional[str] = None,
+ app_insights_instrumentation_key: Optional[str] = None,
+ container_registry_credentials: Optional[ContainerRegistryCredential] = None,
+ notebook_access_keys: Optional[NotebookAccessKeys] = None
+ ):
+ self.user_storage_key = user_storage_key
+ self.user_storage_resource_id = user_storage_resource_id
+ self.app_insights_instrumentation_key = app_insights_instrumentation_key
+ self.container_registry_credentials = container_registry_credentials
+ self.notebook_access_keys = notebook_access_keys
+
+ @classmethod
+ def _from_rest_object(cls, rest_obj: ListWorkspaceKeysResult) -> Optional["WorkspaceKeys"]:
+ if not rest_obj:
+ return None
+
+ container_registry_credentials = None
+ notebook_access_keys = None
+
+ if hasattr(rest_obj, "container_registry_credentials") and rest_obj.container_registry_credentials is not None:
+ container_registry_credentials = ContainerRegistryCredential(
+ location=rest_obj.container_registry_credentials.location,
+ username=rest_obj.container_registry_credentials.username,
+ passwords=rest_obj.container_registry_credentials.passwords,
+ )
+
+ if hasattr(rest_obj, "notebook_access_keys") and rest_obj.notebook_access_keys is not None:
+ notebook_access_keys = NotebookAccessKeys(
+ primary_access_key=rest_obj.notebook_access_keys.primary_access_key,
+ secondary_access_key=rest_obj.notebook_access_keys.secondary_access_key,
+ )
+
+ return WorkspaceKeys(
+ user_storage_key=rest_obj.user_storage_key,
+ user_storage_resource_id=rest_obj.user_storage_arm_id,
+ app_insights_instrumentation_key=rest_obj.app_insights_instrumentation_key,
+ container_registry_credentials=container_registry_credentials,
+ notebook_access_keys=notebook_access_keys,
+ )