diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/node.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/node.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/node.py | 75 |
1 files changed, 75 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/node.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/node.py new file mode 100644 index 00000000..6dbadcd3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/node.py @@ -0,0 +1,75 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import INCLUDE, fields, post_load, pre_dump + +from ..._schema import ArmVersionedStr, NestedField, RegistryStr, UnionField +from ..._schema.core.fields import DumpableEnumField +from ..._schema.pipeline.component_job import BaseNodeSchema, _resolve_inputs_outputs +from ...constants._common import AzureMLResourceType +from .component import InternalComponentSchema, NodeType + + +class InternalBaseNodeSchema(BaseNodeSchema): + class Meta: + unknown = INCLUDE + + component = UnionField( + [ + # for registry type assets + RegistryStr(azureml_type=AzureMLResourceType.ENVIRONMENT), + # existing component + ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True), + # inline component or component file reference starting with FILE prefix + NestedField(InternalComponentSchema, unknown=INCLUDE), + ], + required=True, + ) + type = DumpableEnumField( + allowed_values=NodeType.all_values(), + ) + + @post_load + def make(self, data, **kwargs): # pylint: disable=unused-argument + from ...entities._builders import parse_inputs_outputs + + # parse inputs/outputs + data = parse_inputs_outputs(data) + + # dict to node object + from ...entities._job.pipeline._load_component import pipeline_node_factory + + return pipeline_node_factory.load_from_dict(data=data) + + @pre_dump + def resolve_inputs_outputs(self, job, **kwargs): # pylint: disable=unused-argument + return _resolve_inputs_outputs(job) + + +class ScopeSchema(InternalBaseNodeSchema): + type = DumpableEnumField(allowed_values=[NodeType.SCOPE]) + adla_account_name = fields.Str(required=True) + scope_param = fields.Str() + custom_job_name_suffix = fields.Str() + priority = fields.Int() + auto_token = fields.Int() + tokens = fields.Int() + vcp = fields.Float() + + +class HDInsightSchema(InternalBaseNodeSchema): + type = DumpableEnumField(allowed_values=[NodeType.HDI]) + + compute_name = fields.Str() + queue = fields.Str() + driver_memory = fields.Str() + driver_cores = fields.Int() + executor_memory = fields.Str() + executor_cores = fields.Int() + number_executors = fields.Int() + conf = UnionField( + # dictionary or json string + union_fields=[fields.Dict(keys=fields.Str()), fields.Str()], + ) + hdinsight_spark_job_name = fields.Str() |