about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/node.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/node.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/node.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/node.py75
1 files changed, 75 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/node.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/node.py
new file mode 100644
index 00000000..6dbadcd3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/node.py
@@ -0,0 +1,75 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import INCLUDE, fields, post_load, pre_dump
+
+from ..._schema import ArmVersionedStr, NestedField, RegistryStr, UnionField
+from ..._schema.core.fields import DumpableEnumField
+from ..._schema.pipeline.component_job import BaseNodeSchema, _resolve_inputs_outputs
+from ...constants._common import AzureMLResourceType
+from .component import InternalComponentSchema, NodeType
+
+
+class InternalBaseNodeSchema(BaseNodeSchema):
+    class Meta:
+        unknown = INCLUDE
+
+    component = UnionField(
+        [
+            # for registry type assets
+            RegistryStr(azureml_type=AzureMLResourceType.ENVIRONMENT),
+            # existing component
+            ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True),
+            # inline component or component file reference starting with FILE prefix
+            NestedField(InternalComponentSchema, unknown=INCLUDE),
+        ],
+        required=True,
+    )
+    type = DumpableEnumField(
+        allowed_values=NodeType.all_values(),
+    )
+
+    @post_load
+    def make(self, data, **kwargs):  # pylint: disable=unused-argument
+        from ...entities._builders import parse_inputs_outputs
+
+        # parse inputs/outputs
+        data = parse_inputs_outputs(data)
+
+        # dict to node object
+        from ...entities._job.pipeline._load_component import pipeline_node_factory
+
+        return pipeline_node_factory.load_from_dict(data=data)
+
+    @pre_dump
+    def resolve_inputs_outputs(self, job, **kwargs):  # pylint: disable=unused-argument
+        return _resolve_inputs_outputs(job)
+
+
+class ScopeSchema(InternalBaseNodeSchema):
+    type = DumpableEnumField(allowed_values=[NodeType.SCOPE])
+    adla_account_name = fields.Str(required=True)
+    scope_param = fields.Str()
+    custom_job_name_suffix = fields.Str()
+    priority = fields.Int()
+    auto_token = fields.Int()
+    tokens = fields.Int()
+    vcp = fields.Float()
+
+
+class HDInsightSchema(InternalBaseNodeSchema):
+    type = DumpableEnumField(allowed_values=[NodeType.HDI])
+
+    compute_name = fields.Str()
+    queue = fields.Str()
+    driver_memory = fields.Str()
+    driver_cores = fields.Int()
+    executor_memory = fields.Str()
+    executor_cores = fields.Int()
+    number_executors = fields.Int()
+    conf = UnionField(
+        # dictionary or json string
+        union_fields=[fields.Dict(keys=fields.Str()), fields.Str()],
+    )
+    hdinsight_spark_job_name = fields.Str()