diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch')
11 files changed, 594 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/__init__.py new file mode 100644 index 00000000..fdf8caba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_deployment.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_deployment.py new file mode 100644 index 00000000..7a69176b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_deployment.py @@ -0,0 +1,92 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument,no-else-return + +import logging +from typing import Any + +from marshmallow import fields, post_load +from marshmallow.exceptions import ValidationError +from azure.ai.ml._schema import ( + UnionField, + ArmVersionedStr, + ArmStr, + RegistryStr, +) +from azure.ai.ml._schema._deployment.deployment import DeploymentSchema +from azure.ai.ml._schema.core.fields import ComputeField, NestedField, StringTransformedEnum +from azure.ai.ml._schema.job.creation_context import CreationContextSchema +from azure.ai.ml._schema.pipeline.pipeline_component import PipelineComponentFileRefField +from azure.ai.ml.constants._common import AzureMLResourceType +from azure.ai.ml._schema.job_resource_configuration import JobResourceConfigurationSchema +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY +from azure.ai.ml.constants._deployment import BatchDeploymentOutputAction, BatchDeploymentType + +from .batch_deployment_settings import BatchRetrySettingsSchema + +module_logger = logging.getLogger(__name__) + + +class BatchDeploymentSchema(DeploymentSchema): + compute = ComputeField(required=False) + error_threshold = fields.Int( + metadata={ + "description": """Error threshold, if the error count for the entire input goes above this value,\r\n + the batch inference will be aborted. Range is [-1, int.MaxValue].\r\n + For FileDataset, this value is the count of file failures.\r\n + For TabularDataset, this value is the count of record failures.\r\n + If set to -1 (the lower bound), all failures during batch inference will be ignored.""" + } + ) + retry_settings = NestedField(BatchRetrySettingsSchema) + mini_batch_size = fields.Int() + logging_level = fields.Str( + metadata={ + "description": """A string of the logging level name, which is defined in 'logging'. + Possible values are 'warning', 'info', and 'debug'.""" + } + ) + output_action = StringTransformedEnum( + allowed_values=[ + BatchDeploymentOutputAction.APPEND_ROW, + BatchDeploymentOutputAction.SUMMARY_ONLY, + ], + metadata={"description": "Indicates how batch inferencing will handle output."}, + dump_default=BatchDeploymentOutputAction.APPEND_ROW, + ) + output_file_name = fields.Str(metadata={"description": "Customized output file name for append_row output action."}) + max_concurrency_per_instance = fields.Int( + metadata={"description": "Indicates maximum number of parallelism per instance."} + ) + resources = NestedField(JobResourceConfigurationSchema) + type = StringTransformedEnum( + allowed_values=[BatchDeploymentType.PIPELINE, BatchDeploymentType.MODEL], required=False + ) + + job_definition = ArmStr(azureml_type=AzureMLResourceType.JOB) + component = UnionField( + [ + RegistryStr(azureml_type=AzureMLResourceType.COMPONENT), + ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True), + PipelineComponentFileRefField(), + ] + ) + creation_context = NestedField(CreationContextSchema, dump_only=True) + provisioning_state = fields.Str(dump_only=True) + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: + from azure.ai.ml.entities import BatchDeployment, ModelBatchDeployment, PipelineComponentBatchDeployment + + if "type" not in data: + return BatchDeployment(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data) + elif data["type"] == BatchDeploymentType.PIPELINE: + return PipelineComponentBatchDeployment(**data) + elif data["type"] == BatchDeploymentType.MODEL: + return ModelBatchDeployment(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data) + else: + raise ValidationError( + "Deployment type must be of type " + f"{BatchDeploymentType.PIPELINE} or {BatchDeploymentType.MODEL}." + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_deployment_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_deployment_settings.py new file mode 100644 index 00000000..2a36352c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_deployment_settings.py @@ -0,0 +1,26 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging +from typing import Any + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml.entities._deployment.deployment_settings import BatchRetrySettings + +module_logger = logging.getLogger(__name__) + + +class BatchRetrySettingsSchema(metaclass=PatchedSchemaMeta): + max_retries = fields.Int( + metadata={"description": "The number of maximum tries for a failed or timeout mini batch."}, + ) + timeout = fields.Int(metadata={"description": "The timeout for a mini batch."}) + + @post_load + def make(self, data: Any, **kwargs: Any) -> BatchRetrySettings: + return BatchRetrySettings(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_job.py new file mode 100644 index 00000000..a1496f1e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_job.py @@ -0,0 +1,132 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument,protected-access + +from typing import Any + +from marshmallow import fields +from marshmallow.decorators import post_load + +from azure.ai.ml._restclient.v2020_09_01_dataplanepreview.models import ( + BatchJob, + CustomModelJobInput, + CustomModelJobOutput, + DataVersion, + LiteralJobInput, + MLFlowModelJobInput, + MLFlowModelJobOutput, + MLTableJobInput, + MLTableJobOutput, + TritonModelJobInput, + TritonModelJobOutput, + UriFileJobInput, + UriFileJobOutput, + UriFolderJobInput, + UriFolderJobOutput, +) +from azure.ai.ml._schema.core.fields import ArmStr, NestedField +from azure.ai.ml._schema.core.schema import PathAwareSchema +from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta +from azure.ai.ml.constants import AssetTypes +from azure.ai.ml.constants._common import AzureMLResourceType, InputTypes +from azure.ai.ml.constants._endpoint import EndpointYamlFields +from azure.ai.ml.entities import ComputeConfiguration +from azure.ai.ml.entities._inputs_outputs import Input, Output + +from .batch_deployment_settings import BatchRetrySettingsSchema +from .compute_binding import ComputeBindingSchema + + +class OutputDataSchema(metaclass=PatchedSchemaMeta): + datastore_id = ArmStr(azureml_type=AzureMLResourceType.DATASTORE) + path = fields.Str() + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: + return DataVersion(**data) + + +class BatchJobSchema(PathAwareSchema): + compute = NestedField(ComputeBindingSchema) + dataset = fields.Str() + error_threshold = fields.Int() + input_data = fields.Dict() + mini_batch_size = fields.Int() + name = fields.Str(data_key="job_name") + output_data = fields.Dict() + output_dataset = NestedField(OutputDataSchema) + output_file_name = fields.Str() + retry_settings = NestedField(BatchRetrySettingsSchema) + properties = fields.Dict(data_key="properties") + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: # pylint: disable=too-many-branches + if data.get(EndpointYamlFields.BATCH_JOB_INPUT_DATA, None): + for key, input_data in data[EndpointYamlFields.BATCH_JOB_INPUT_DATA].items(): + if isinstance(input_data, Input): + if input_data.type == AssetTypes.URI_FILE: + data[EndpointYamlFields.BATCH_JOB_INPUT_DATA][key] = UriFileJobInput(uri=input_data.path) + if input_data.type == AssetTypes.URI_FOLDER: + data[EndpointYamlFields.BATCH_JOB_INPUT_DATA][key] = UriFolderJobInput(uri=input_data.path) + if input_data.type == AssetTypes.TRITON_MODEL: + data[EndpointYamlFields.BATCH_JOB_INPUT_DATA][key] = TritonModelJobInput( + mode=input_data.mode, uri=input_data.path + ) + if input_data.type == AssetTypes.MLFLOW_MODEL: + data[EndpointYamlFields.BATCH_JOB_INPUT_DATA][key] = MLFlowModelJobInput( + mode=input_data.mode, uri=input_data.path + ) + if input_data.type == AssetTypes.MLTABLE: + data[EndpointYamlFields.BATCH_JOB_INPUT_DATA][key] = MLTableJobInput( + mode=input_data.mode, uri=input_data.path + ) + if input_data.type == AssetTypes.CUSTOM_MODEL: + data[EndpointYamlFields.BATCH_JOB_INPUT_DATA][key] = CustomModelJobInput( + mode=input_data.mode, uri=input_data.path + ) + if input_data.type in { + InputTypes.INTEGER, + InputTypes.NUMBER, + InputTypes.STRING, + InputTypes.BOOLEAN, + }: + data[EndpointYamlFields.BATCH_JOB_INPUT_DATA][key] = LiteralJobInput(value=input_data.default) + if data.get(EndpointYamlFields.BATCH_JOB_OUTPUT_DATA, None): + for key, output_data in data[EndpointYamlFields.BATCH_JOB_OUTPUT_DATA].items(): + if isinstance(output_data, Output): + if output_data.type == AssetTypes.URI_FILE: + data[EndpointYamlFields.BATCH_JOB_OUTPUT_DATA][key] = UriFileJobOutput( + mode=output_data.mode, uri=output_data.path + ) + if output_data.type == AssetTypes.URI_FOLDER: + data[EndpointYamlFields.BATCH_JOB_OUTPUT_DATA][key] = UriFolderJobOutput( + mode=output_data.mode, uri=output_data.path + ) + if output_data.type == AssetTypes.TRITON_MODEL: + data[EndpointYamlFields.BATCH_JOB_OUTPUT_DATA][key] = TritonModelJobOutput( + mode=output_data.mode, uri=output_data.path + ) + if output_data.type == AssetTypes.MLFLOW_MODEL: + data[EndpointYamlFields.BATCH_JOB_OUTPUT_DATA][key] = MLFlowModelJobOutput( + mode=output_data.mode, uri=output_data.path + ) + if output_data.type == AssetTypes.MLTABLE: + data[EndpointYamlFields.BATCH_JOB_OUTPUT_DATA][key] = MLTableJobOutput( + mode=output_data.mode, uri=output_data.path + ) + if output_data.type == AssetTypes.CUSTOM_MODEL: + data[EndpointYamlFields.BATCH_JOB_OUTPUT_DATA][key] = CustomModelJobOutput( + mode=output_data.mode, uri=output_data.path + ) + + if data.get(EndpointYamlFields.COMPUTE, None): + data[EndpointYamlFields.COMPUTE] = ComputeConfiguration( + **data[EndpointYamlFields.COMPUTE] + )._to_rest_object() + + if data.get(EndpointYamlFields.RETRY_SETTINGS, None): + data[EndpointYamlFields.RETRY_SETTINGS] = data[EndpointYamlFields.RETRY_SETTINGS]._to_rest_object() + + return BatchJob(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_pipeline_component_deployment_configurations_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_pipeline_component_deployment_configurations_schema.py new file mode 100644 index 00000000..f0b22fd7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/batch_pipeline_component_deployment_configurations_schema.py @@ -0,0 +1,52 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +import logging +from typing import Any + +from marshmallow import fields, post_load + +from azure.ai.ml._schema import ( + ArmVersionedStr, + PatchedSchemaMeta, + StringTransformedEnum, + UnionField, + ArmStr, + RegistryStr, +) +from azure.ai.ml._schema.pipeline.pipeline_component import PipelineComponentFileRefField +from azure.ai.ml.constants._common import AzureMLResourceType +from azure.ai.ml.constants._job.job import JobType + +module_logger = logging.getLogger(__name__) + + +# pylint: disable-next=name-too-long +class BatchPipelineComponentDeploymentConfiguarationsSchema(metaclass=PatchedSchemaMeta): + component_id = fields.Str() + job = UnionField( + [ + ArmStr(azureml_type=AzureMLResourceType.JOB), + PipelineComponentFileRefField(), + ] + ) + component = UnionField( + [ + RegistryStr(azureml_type=AzureMLResourceType.COMPONENT), + ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True), + PipelineComponentFileRefField(), + ] + ) + type = StringTransformedEnum(required=True, allowed_values=[JobType.PIPELINE]) + settings = fields.Dict() + name = fields.Str() + description = fields.Str() + tags = fields.Dict() + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: # pylint: disable=unused-argument + from azure.ai.ml.entities._deployment.job_definition import JobDefinition + + return JobDefinition(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/compute_binding.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/compute_binding.py new file mode 100644 index 00000000..2e4b0348 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/compute_binding.py @@ -0,0 +1,36 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging +from typing import Any + +from marshmallow import ValidationError, fields, validates_schema + +from azure.ai.ml._schema.core.fields import ArmStr, StringTransformedEnum, UnionField +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml.constants._common import LOCAL_COMPUTE_TARGET, AzureMLResourceType + +module_logger = logging.getLogger(__name__) + + +class ComputeBindingSchema(metaclass=PatchedSchemaMeta): + target = UnionField( + [ + StringTransformedEnum(allowed_values=[LOCAL_COMPUTE_TARGET]), + ArmStr(azureml_type=AzureMLResourceType.COMPUTE), + # Case for virtual clusters + ArmStr(azureml_type=AzureMLResourceType.VIRTUALCLUSTER), + ] + ) + instance_count = fields.Integer() + instance_type = fields.Str(metadata={"description": "The instance type to make available to this job."}) + location = fields.Str(metadata={"description": "The locations where this job may run."}) + properties = fields.Dict(keys=fields.Str()) + + @validates_schema + def validate(self, data: Any, **kwargs): + if data.get("target") == LOCAL_COMPUTE_TARGET and data.get("instance_count", 1) != 1: + raise ValidationError("Local runs must have node count of 1.") diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/job_definition_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/job_definition_schema.py new file mode 100644 index 00000000..269f1da7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/job_definition_schema.py @@ -0,0 +1,51 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +import logging +from typing import Any + +from marshmallow import fields, post_load + +from azure.ai.ml._schema import ( + ArmVersionedStr, + PatchedSchemaMeta, + StringTransformedEnum, + UnionField, + ArmStr, + RegistryStr, +) +from azure.ai.ml._schema.pipeline.pipeline_component import PipelineComponentFileRefField +from azure.ai.ml.constants._common import AzureMLResourceType +from azure.ai.ml.constants._job.job import JobType + +module_logger = logging.getLogger(__name__) + + +class JobDefinitionSchema(metaclass=PatchedSchemaMeta): + component_id = fields.Str() + job = UnionField( + [ + ArmStr(azureml_type=AzureMLResourceType.JOB), + PipelineComponentFileRefField(), + ] + ) + component = UnionField( + [ + RegistryStr(azureml_type=AzureMLResourceType.COMPONENT), + ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True), + PipelineComponentFileRefField(), + ] + ) + type = StringTransformedEnum(required=True, allowed_values=[JobType.PIPELINE]) + settings = fields.Dict() + name = fields.Str() + description = fields.Str() + tags = fields.Dict() + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: # pylint: disable=unused-argument + from azure.ai.ml.entities._deployment.job_definition import JobDefinition + + return JobDefinition(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/model_batch_deployment.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/model_batch_deployment.py new file mode 100644 index 00000000..0dbd8463 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/model_batch_deployment.py @@ -0,0 +1,46 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging +from typing import Any + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.fields import ComputeField, NestedField, StringTransformedEnum +from azure.ai.ml._schema.job_resource_configuration import JobResourceConfigurationSchema +from azure.ai.ml._schema._deployment.deployment import DeploymentSchema +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY +from azure.ai.ml.constants._deployment import BatchDeploymentType +from azure.ai.ml._schema import ExperimentalField +from .model_batch_deployment_settings import ModelBatchDeploymentSettingsSchema + + +module_logger = logging.getLogger(__name__) + + +class ModelBatchDeploymentSchema(DeploymentSchema): + compute = ComputeField(required=True) + error_threshold = fields.Int( + metadata={ + "description": """Error threshold, if the error count for the entire input goes above this value,\r\n + the batch inference will be aborted. Range is [-1, int.MaxValue].\r\n + For FileDataset, this value is the count of file failures.\r\n + For TabularDataset, this value is the count of record failures.\r\n + If set to -1 (the lower bound), all failures during batch inference will be ignored.""" + } + ) + resources = NestedField(JobResourceConfigurationSchema) + type = StringTransformedEnum( + allowed_values=[BatchDeploymentType.PIPELINE, BatchDeploymentType.MODEL], required=False + ) + + settings = ExperimentalField(NestedField(ModelBatchDeploymentSettingsSchema)) + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: + from azure.ai.ml.entities import ModelBatchDeployment + + return ModelBatchDeployment(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/model_batch_deployment_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/model_batch_deployment_settings.py new file mode 100644 index 00000000..e1945751 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/model_batch_deployment_settings.py @@ -0,0 +1,56 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import logging +from typing import Any + +from marshmallow import fields, post_load + +from azure.ai.ml._schema import PatchedSchemaMeta +from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum +from azure.ai.ml.constants._deployment import BatchDeploymentOutputAction + +from .batch_deployment_settings import BatchRetrySettingsSchema + +module_logger = logging.getLogger(__name__) + + +class ModelBatchDeploymentSettingsSchema(metaclass=PatchedSchemaMeta): + error_threshold = fields.Int( + metadata={ + "description": """Error threshold, if the error count for the entire input goes above this value,\r\n + the batch inference will be aborted. Range is [-1, int.MaxValue].\r\n + For FileDataset, this value is the count of file failures.\r\n + For TabularDataset, this value is the count of record failures.\r\n + If set to -1 (the lower bound), all failures during batch inference will be ignored.""" + } + ) + instance_count = fields.Int() + retry_settings = NestedField(BatchRetrySettingsSchema) + mini_batch_size = fields.Int() + logging_level = fields.Str( + metadata={ + "description": """A string of the logging level name, which is defined in 'logging'. + Possible values are 'warning', 'info', and 'debug'.""" + } + ) + output_action = StringTransformedEnum( + allowed_values=[ + BatchDeploymentOutputAction.APPEND_ROW, + BatchDeploymentOutputAction.SUMMARY_ONLY, + ], + metadata={"description": "Indicates how batch inferencing will handle output."}, + dump_default=BatchDeploymentOutputAction.APPEND_ROW, + ) + output_file_name = fields.Str(metadata={"description": "Customized output file name for append_row output action."}) + max_concurrency_per_instance = fields.Int( + metadata={"description": "Indicates maximum number of parallelism per instance."} + ) + environment_variables = fields.Dict() + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: # pylint: disable=unused-argument + from azure.ai.ml.entities import ModelBatchDeploymentSettings + + return ModelBatchDeploymentSettings(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/pipeline_component_batch_deployment_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/pipeline_component_batch_deployment_schema.py new file mode 100644 index 00000000..4bc884b0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/pipeline_component_batch_deployment_schema.py @@ -0,0 +1,70 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +import logging +from typing import Any + +from marshmallow import INCLUDE, fields, post_load + +from azure.ai.ml._schema import ( + ArmVersionedStr, + ArmStr, + UnionField, + RegistryStr, + NestedField, +) +from azure.ai.ml._schema.core.fields import PipelineNodeNameStr, TypeSensitiveUnionField, PathAwareSchema +from azure.ai.ml._schema.pipeline.pipeline_component import PipelineComponentFileRefField +from azure.ai.ml.constants._common import AzureMLResourceType +from azure.ai.ml.constants._component import NodeType + +module_logger = logging.getLogger(__name__) + + +class PipelineComponentBatchDeploymentSchema(PathAwareSchema): + name = fields.Str() + endpoint_name = fields.Str() + component = UnionField( + [ + RegistryStr(azureml_type=AzureMLResourceType.COMPONENT), + ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True), + PipelineComponentFileRefField(), + ] + ) + settings = fields.Dict() + name = fields.Str() + type = fields.Str() + job_definition = UnionField( + [ + ArmStr(azureml_type=AzureMLResourceType.JOB), + NestedField("PipelineSchema", unknown=INCLUDE), + ] + ) + tags = fields.Dict() + description = fields.Str(metadata={"description": "Description of the endpoint deployment."}) + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: # pylint: disable=unused-argument + from azure.ai.ml.entities._deployment.pipeline_component_batch_deployment import ( + PipelineComponentBatchDeployment, + ) + + return PipelineComponentBatchDeployment(**data) + + +class NodeNameStr(PipelineNodeNameStr): + def _get_field_name(self) -> str: + return "Pipeline node" + + +def PipelineJobsField(): + pipeline_enable_job_type = {NodeType.PIPELINE: [NestedField("PipelineSchema", unknown=INCLUDE)]} + + pipeline_job_field = fields.Dict( + keys=NodeNameStr(), + values=TypeSensitiveUnionField(pipeline_enable_job_type), + ) + + return pipeline_job_field diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/run_settings_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/run_settings_schema.py new file mode 100644 index 00000000..54661ada --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_deployment/batch/run_settings_schema.py @@ -0,0 +1,28 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +import logging +from typing import Any + +from marshmallow import fields, post_load + +from azure.ai.ml._schema import PatchedSchemaMeta + +module_logger = logging.getLogger(__name__) + + +class RunSettingsSchema(metaclass=PatchedSchemaMeta): + name = fields.Str() + display_name = fields.Str() + experiment_name = fields.Str() + description = fields.Str() + tags = fields.Dict() + settings = fields.Dict() + + @post_load + def make(self, data: Any, **kwargs: Any) -> Any: # pylint: disable=unused-argument + from azure.ai.ml.entities._deployment.run_settings import RunSettings + + return RunSettings(**data) |
