# --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- import os.path import pydash from marshmallow import EXCLUDE, INCLUDE, fields, post_dump, pre_load from ..._schema import NestedField, StringTransformedEnum, UnionField from ..._schema.component.component import ComponentSchema from ..._schema.core.fields import ArmVersionedStr, CodeField, EnvironmentField, RegistryStr from ..._schema.job.parameterized_spark import SparkEntryClassSchema, SparkEntryFileSchema from ..._utils._arm_id_utils import parse_name_label from ..._utils.utils import get_valid_dot_keys_with_wildcard from ...constants._common import ( LABELLED_RESOURCE_NAME, SOURCE_PATH_CONTEXT_KEY, AzureMLResourceType, DefaultOpenEncoding, ) from ...constants._component import NodeType as PublicNodeType from .._utils import yaml_safe_load_with_base_resolver from .environment import InternalEnvironmentSchema from .input_output import ( InternalEnumParameterSchema, InternalInputPortSchema, InternalOutputPortSchema, InternalParameterSchema, InternalPrimitiveOutputSchema, InternalSparkParameterSchema, ) class NodeType: COMMAND = "CommandComponent" DATA_TRANSFER = "DataTransferComponent" DISTRIBUTED = "DistributedComponent" HDI = "HDInsightComponent" SCOPE_V2 = "scope" HDI_V2 = "hdinsight" HEMERA_V2 = "hemera" STARLITE_V2 = "starlite" AE365EXEPOOL_V2 = "ae365exepool" AETHER_BRIDGE_V2 = "aetherbridge" PARALLEL = "ParallelComponent" SCOPE = "ScopeComponent" STARLITE = "StarliteComponent" SWEEP = "SweepComponent" PIPELINE = "PipelineComponent" HEMERA = "HemeraComponent" AE365EXEPOOL = "AE365ExePoolComponent" IPP = "IntellectualPropertyProtectedComponent" # internal spake component got a type value conflict with spark component # this enum is used to identify its create_function in factories SPARK = "DummySpark" AETHER_BRIDGE = "AetherBridgeComponent" @classmethod def all_values(cls): all_values = [] for key, value in vars(cls).items(): if not key.startswith("_") and isinstance(value, str): all_values.append(value) return all_values class InternalComponentSchema(ComponentSchema): class Meta: unknown = INCLUDE # override name as 1p components allow . in name, which is not allowed in v2 components name = fields.Str() # override to allow empty properties tags = fields.Dict(keys=fields.Str()) # override inputs & outputs to support 1P inputs & outputs, may need to do strict validation later # no need to check io type match since server will do that inputs = fields.Dict( keys=fields.Str(), values=UnionField( [ NestedField(InternalParameterSchema), NestedField(InternalEnumParameterSchema), NestedField(InternalInputPortSchema), ] ), ) # support primitive output for all internal components for now outputs = fields.Dict( keys=fields.Str(), values=UnionField( [ NestedField(InternalPrimitiveOutputSchema, unknown=EXCLUDE), NestedField(InternalOutputPortSchema, unknown=EXCLUDE), ] ), ) # type field is required for registration type = StringTransformedEnum( allowed_values=NodeType.all_values(), casing_transform=lambda x: parse_name_label(x)[0], pass_original=True, ) # need to resolve as it can be a local field code = CodeField() environment = UnionField( [ RegistryStr(azureml_type=AzureMLResourceType.ENVIRONMENT), ArmVersionedStr(azureml_type=AzureMLResourceType.ENVIRONMENT), NestedField(InternalEnvironmentSchema), ] ) def get_skip_fields(self): return ["properties"] def _serialize(self, obj, *, many: bool = False): if many and obj is not None: return super(InternalComponentSchema, self)._serialize(obj, many=many) ret = super(InternalComponentSchema, self)._serialize(obj) for attr_name in obj.__dict__.keys(): if ( not attr_name.startswith("_") and attr_name not in self.get_skip_fields() and attr_name not in self.dump_fields ): ret[attr_name] = self.get_attribute(obj, attr_name, None) return ret # override param_override to ensure that param override happens after reloading the yaml @pre_load def add_param_overrides(self, data, **kwargs): source_path = self.context.pop(SOURCE_PATH_CONTEXT_KEY, None) if isinstance(data, dict) and source_path and os.path.isfile(source_path): def should_node_overwritten(_root, _parts): parts = _parts.copy() parts.pop() parts.append("type") _input_type = pydash.get(_root, parts, None) return isinstance(_input_type, str) and _input_type.lower() not in ["boolean"] # do override here with open(source_path, "r", encoding=DefaultOpenEncoding.READ) as f: origin_data = yaml_safe_load_with_base_resolver(f) for dot_key_wildcard, condition_func in [ ("version", None), ("inputs.*.default", should_node_overwritten), ("inputs.*.enum", should_node_overwritten), ]: for dot_key in get_valid_dot_keys_with_wildcard( origin_data, dot_key_wildcard, validate_func=condition_func ): pydash.set_(data, dot_key, pydash.get(origin_data, dot_key)) return super().add_param_overrides(data, **kwargs) @post_dump(pass_original=True) def simplify_input_output_port(self, data, original, **kwargs): # pylint:disable=unused-argument # remove None in input & output for io_ports in [data["inputs"], data["outputs"]]: for port_name, port_definition in io_ports.items(): io_ports[port_name] = dict(filter(lambda item: item[1] is not None, port_definition.items())) # hack, to match current serialization match expectation for port_name, port_definition in data["inputs"].items(): if "mode" in port_definition: del port_definition["mode"] return data @post_dump(pass_original=True) def add_back_type_label(self, data, original, **kwargs): # pylint:disable=unused-argument type_label = original._type_label # pylint:disable=protected-access if type_label: data["type"] = LABELLED_RESOURCE_NAME.format(data["type"], type_label) return data class InternalSparkComponentSchema(InternalComponentSchema): # type field is required for registration type = StringTransformedEnum( allowed_values=PublicNodeType.SPARK, casing_transform=lambda x: parse_name_label(x)[0].lower(), pass_original=True, ) # override inputs: # https://componentsdk.azurewebsites.net/components/spark_component.html#differences-with-other-component-types inputs = fields.Dict( keys=fields.Str(), values=UnionField( [ NestedField(InternalSparkParameterSchema), NestedField(InternalInputPortSchema), ] ), ) environment = EnvironmentField( extra_fields=[NestedField(InternalEnvironmentSchema)], allow_none=True, ) jars = UnionField( [ fields.List(fields.Str()), fields.Str(), ], ) py_files = UnionField( [ fields.List(fields.Str()), fields.Str(), ], data_key="pyFiles", attribute="py_files", ) entry = UnionField( [NestedField(SparkEntryFileSchema), NestedField(SparkEntryClassSchema)], required=True, metadata={"description": "Entry."}, ) files = fields.List(fields.Str(required=True)) archives = fields.List(fields.Str(required=True)) conf = fields.Dict(keys=fields.Str(), values=fields.Raw()) args = fields.Str(metadata={"description": "Command Line arguments."})