about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/spark_component.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/spark_component.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/spark_component.py79
1 files changed, 79 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/spark_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/spark_component.py
new file mode 100644
index 00000000..445481ec
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/spark_component.py
@@ -0,0 +1,79 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument,protected-access
+
+from copy import deepcopy
+
+import yaml
+from marshmallow import INCLUDE, fields, post_dump, post_load
+
+from azure.ai.ml._schema.assets.asset import AnonymousAssetSchema
+from azure.ai.ml._schema.component.component import ComponentSchema
+from azure.ai.ml._schema.core.fields import FileRefField, StringTransformedEnum
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
+from azure.ai.ml.constants._component import ComponentSource, NodeType
+
+from ..job.parameterized_spark import ParameterizedSparkSchema
+
+
+class SparkComponentSchema(ComponentSchema, ParameterizedSparkSchema):
+    type = StringTransformedEnum(allowed_values=[NodeType.SPARK])
+    additional_includes = fields.List(fields.Str())
+
+    @post_dump
+    def remove_unnecessary_fields(self, component_schema_dict, **kwargs):
+        if (
+            component_schema_dict.get("additional_includes") is not None
+            and len(component_schema_dict["additional_includes"]) == 0
+        ):
+            component_schema_dict.pop("additional_includes")
+        return component_schema_dict
+
+
+class RestSparkComponentSchema(SparkComponentSchema):
+    """When component load from rest, won't validate on name since there might
+    be existing component with invalid name."""
+
+    name = fields.Str(required=True)
+
+
+class AnonymousSparkComponentSchema(AnonymousAssetSchema, SparkComponentSchema):
+    """Anonymous spark component schema.
+
+    Note inheritance follows order: AnonymousAssetSchema,
+    SparkComponentSchema because we need name and version to be
+    dump_only(marshmallow collects fields follows method resolution
+    order).
+    """
+
+    @post_load
+    def make(self, data, **kwargs):
+        from azure.ai.ml.entities._component.spark_component import SparkComponent
+
+        # Inline component will have source=YAML.JOB
+        # As we only regard full separate component file as YAML.COMPONENT
+        return SparkComponent(
+            base_path=self.context[BASE_PATH_CONTEXT_KEY],
+            _source=kwargs.pop("_source", ComponentSource.YAML_JOB),
+            **data,
+        )
+
+
+class SparkComponentFileRefField(FileRefField):
+    def _deserialize(self, value, attr, data, **kwargs):
+        # Get component info from component yaml file.
+        data = super()._deserialize(value, attr, data, **kwargs)
+        component_dict = yaml.safe_load(data)
+        source_path = self.context[BASE_PATH_CONTEXT_KEY] / value
+
+        # Update base_path to parent path of component file.
+        component_schema_context = deepcopy(self.context)
+        component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent
+        component = AnonymousSparkComponentSchema(context=component_schema_context).load(
+            component_dict, unknown=INCLUDE
+        )
+        component._source_path = source_path
+        component._source = ComponentSource.YAML_COMPONENT
+        return component