aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/node.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/node.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/node.py75
1 files changed, 75 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/node.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/node.py
new file mode 100644
index 00000000..6dbadcd3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/node.py
@@ -0,0 +1,75 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import INCLUDE, fields, post_load, pre_dump
+
+from ..._schema import ArmVersionedStr, NestedField, RegistryStr, UnionField
+from ..._schema.core.fields import DumpableEnumField
+from ..._schema.pipeline.component_job import BaseNodeSchema, _resolve_inputs_outputs
+from ...constants._common import AzureMLResourceType
+from .component import InternalComponentSchema, NodeType
+
+
+class InternalBaseNodeSchema(BaseNodeSchema):
+ class Meta:
+ unknown = INCLUDE
+
+ component = UnionField(
+ [
+ # for registry type assets
+ RegistryStr(azureml_type=AzureMLResourceType.ENVIRONMENT),
+ # existing component
+ ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True),
+ # inline component or component file reference starting with FILE prefix
+ NestedField(InternalComponentSchema, unknown=INCLUDE),
+ ],
+ required=True,
+ )
+ type = DumpableEnumField(
+ allowed_values=NodeType.all_values(),
+ )
+
+ @post_load
+ def make(self, data, **kwargs): # pylint: disable=unused-argument
+ from ...entities._builders import parse_inputs_outputs
+
+ # parse inputs/outputs
+ data = parse_inputs_outputs(data)
+
+ # dict to node object
+ from ...entities._job.pipeline._load_component import pipeline_node_factory
+
+ return pipeline_node_factory.load_from_dict(data=data)
+
+ @pre_dump
+ def resolve_inputs_outputs(self, job, **kwargs): # pylint: disable=unused-argument
+ return _resolve_inputs_outputs(job)
+
+
+class ScopeSchema(InternalBaseNodeSchema):
+ type = DumpableEnumField(allowed_values=[NodeType.SCOPE])
+ adla_account_name = fields.Str(required=True)
+ scope_param = fields.Str()
+ custom_job_name_suffix = fields.Str()
+ priority = fields.Int()
+ auto_token = fields.Int()
+ tokens = fields.Int()
+ vcp = fields.Float()
+
+
+class HDInsightSchema(InternalBaseNodeSchema):
+ type = DumpableEnumField(allowed_values=[NodeType.HDI])
+
+ compute_name = fields.Str()
+ queue = fields.Str()
+ driver_memory = fields.Str()
+ driver_cores = fields.Int()
+ executor_memory = fields.Str()
+ executor_cores = fields.Int()
+ number_executors = fields.Int()
+ conf = UnionField(
+ # dictionary or json string
+ union_fields=[fields.Dict(keys=fields.Str()), fields.Str()],
+ )
+ hdinsight_spark_job_name = fields.Str()