aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/spark_component.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/spark_component.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/spark_component.py79
1 files changed, 79 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/spark_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/spark_component.py
new file mode 100644
index 00000000..445481ec
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/spark_component.py
@@ -0,0 +1,79 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument,protected-access
+
+from copy import deepcopy
+
+import yaml
+from marshmallow import INCLUDE, fields, post_dump, post_load
+
+from azure.ai.ml._schema.assets.asset import AnonymousAssetSchema
+from azure.ai.ml._schema.component.component import ComponentSchema
+from azure.ai.ml._schema.core.fields import FileRefField, StringTransformedEnum
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+from azure.ai.ml.constants._component import ComponentSource, NodeType
+
+from ..job.parameterized_spark import ParameterizedSparkSchema
+
+
+class SparkComponentSchema(ComponentSchema, ParameterizedSparkSchema):
+ type = StringTransformedEnum(allowed_values=[NodeType.SPARK])
+ additional_includes = fields.List(fields.Str())
+
+ @post_dump
+ def remove_unnecessary_fields(self, component_schema_dict, **kwargs):
+ if (
+ component_schema_dict.get("additional_includes") is not None
+ and len(component_schema_dict["additional_includes"]) == 0
+ ):
+ component_schema_dict.pop("additional_includes")
+ return component_schema_dict
+
+
+class RestSparkComponentSchema(SparkComponentSchema):
+ """When component load from rest, won't validate on name since there might
+ be existing component with invalid name."""
+
+ name = fields.Str(required=True)
+
+
+class AnonymousSparkComponentSchema(AnonymousAssetSchema, SparkComponentSchema):
+ """Anonymous spark component schema.
+
+ Note inheritance follows order: AnonymousAssetSchema,
+ SparkComponentSchema because we need name and version to be
+ dump_only(marshmallow collects fields follows method resolution
+ order).
+ """
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._component.spark_component import SparkComponent
+
+ # Inline component will have source=YAML.JOB
+ # As we only regard full separate component file as YAML.COMPONENT
+ return SparkComponent(
+ base_path=self.context[BASE_PATH_CONTEXT_KEY],
+ _source=kwargs.pop("_source", ComponentSource.YAML_JOB),
+ **data,
+ )
+
+
+class SparkComponentFileRefField(FileRefField):
+ def _deserialize(self, value, attr, data, **kwargs):
+ # Get component info from component yaml file.
+ data = super()._deserialize(value, attr, data, **kwargs)
+ component_dict = yaml.safe_load(data)
+ source_path = self.context[BASE_PATH_CONTEXT_KEY] / value
+
+ # Update base_path to parent path of component file.
+ component_schema_context = deepcopy(self.context)
+ component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent
+ component = AnonymousSparkComponentSchema(context=component_schema_context).load(
+ component_dict, unknown=INCLUDE
+ )
+ component._source_path = source_path
+ component._source = ComponentSource.YAML_COMPONENT
+ return component