# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
import os.path
import pydash
from marshmallow import EXCLUDE, INCLUDE, fields, post_dump, pre_load
from ..._schema import NestedField, StringTransformedEnum, UnionField
from ..._schema.component.component import ComponentSchema
from ..._schema.core.fields import ArmVersionedStr, CodeField, EnvironmentField, RegistryStr
from ..._schema.job.parameterized_spark import SparkEntryClassSchema, SparkEntryFileSchema
from ..._utils._arm_id_utils import parse_name_label
from ..._utils.utils import get_valid_dot_keys_with_wildcard
from ...constants._common import (
LABELLED_RESOURCE_NAME,
SOURCE_PATH_CONTEXT_KEY,
AzureMLResourceType,
DefaultOpenEncoding,
)
from ...constants._component import NodeType as PublicNodeType
from .._utils import yaml_safe_load_with_base_resolver
from .environment import InternalEnvironmentSchema
from .input_output import (
InternalEnumParameterSchema,
InternalInputPortSchema,
InternalOutputPortSchema,
InternalParameterSchema,
InternalPrimitiveOutputSchema,
InternalSparkParameterSchema,
)
class NodeType:
COMMAND = "CommandComponent"
DATA_TRANSFER = "DataTransferComponent"
DISTRIBUTED = "DistributedComponent"
HDI = "HDInsightComponent"
SCOPE_V2 = "scope"
HDI_V2 = "hdinsight"
HEMERA_V2 = "hemera"
STARLITE_V2 = "starlite"
AE365EXEPOOL_V2 = "ae365exepool"
AETHER_BRIDGE_V2 = "aetherbridge"
PARALLEL = "ParallelComponent"
SCOPE = "ScopeComponent"
STARLITE = "StarliteComponent"
SWEEP = "SweepComponent"
PIPELINE = "PipelineComponent"
HEMERA = "HemeraComponent"
AE365EXEPOOL = "AE365ExePoolComponent"
IPP = "IntellectualPropertyProtectedComponent"
# internal spake component got a type value conflict with spark component
# this enum is used to identify its create_function in factories
SPARK = "DummySpark"
AETHER_BRIDGE = "AetherBridgeComponent"
@classmethod
def all_values(cls):
all_values = []
for key, value in vars(cls).items():
if not key.startswith("_") and isinstance(value, str):
all_values.append(value)
return all_values
class InternalComponentSchema(ComponentSchema):
class Meta:
unknown = INCLUDE
# override name as 1p components allow . in name, which is not allowed in v2 components
name = fields.Str()
# override to allow empty properties
tags = fields.Dict(keys=fields.Str())
# override inputs & outputs to support 1P inputs & outputs, may need to do strict validation later
# no need to check io type match since server will do that
inputs = fields.Dict(
keys=fields.Str(),
values=UnionField(
[
NestedField(InternalParameterSchema),
NestedField(InternalEnumParameterSchema),
NestedField(InternalInputPortSchema),
]
),
)
# support primitive output for all internal components for now
outputs = fields.Dict(
keys=fields.Str(),
values=UnionField(
[
NestedField(InternalPrimitiveOutputSchema, unknown=EXCLUDE),
NestedField(InternalOutputPortSchema, unknown=EXCLUDE),
]
),
)
# type field is required for registration
type = StringTransformedEnum(
allowed_values=NodeType.all_values(),
casing_transform=lambda x: parse_name_label(x)[0],
pass_original=True,
)
# need to resolve as it can be a local field
code = CodeField()
environment = UnionField(
[
RegistryStr(azureml_type=AzureMLResourceType.ENVIRONMENT),
ArmVersionedStr(azureml_type=AzureMLResourceType.ENVIRONMENT),
NestedField(InternalEnvironmentSchema),
]
)
def get_skip_fields(self):
return ["properties"]
def _serialize(self, obj, *, many: bool = False):
if many and obj is not None:
return super(InternalComponentSchema, self)._serialize(obj, many=many)
ret = super(InternalComponentSchema, self)._serialize(obj)
for attr_name in obj.__dict__.keys():
if (
not attr_name.startswith("_")
and attr_name not in self.get_skip_fields()
and attr_name not in self.dump_fields
):
ret[attr_name] = self.get_attribute(obj, attr_name, None)
return ret
# override param_override to ensure that param override happens after reloading the yaml
@pre_load
def add_param_overrides(self, data, **kwargs):
source_path = self.context.pop(SOURCE_PATH_CONTEXT_KEY, None)
if isinstance(data, dict) and source_path and os.path.isfile(source_path):
def should_node_overwritten(_root, _parts):
parts = _parts.copy()
parts.pop()
parts.append("type")
_input_type = pydash.get(_root, parts, None)
return isinstance(_input_type, str) and _input_type.lower() not in ["boolean"]
# do override here
with open(source_path, "r", encoding=DefaultOpenEncoding.READ) as f:
origin_data = yaml_safe_load_with_base_resolver(f)
for dot_key_wildcard, condition_func in [
("version", None),
("inputs.*.default", should_node_overwritten),
("inputs.*.enum", should_node_overwritten),
]:
for dot_key in get_valid_dot_keys_with_wildcard(
origin_data, dot_key_wildcard, validate_func=condition_func
):
pydash.set_(data, dot_key, pydash.get(origin_data, dot_key))
return super().add_param_overrides(data, **kwargs)
@post_dump(pass_original=True)
def simplify_input_output_port(self, data, original, **kwargs): # pylint:disable=unused-argument
# remove None in input & output
for io_ports in [data["inputs"], data["outputs"]]:
for port_name, port_definition in io_ports.items():
io_ports[port_name] = dict(filter(lambda item: item[1] is not None, port_definition.items()))
# hack, to match current serialization match expectation
for port_name, port_definition in data["inputs"].items():
if "mode" in port_definition:
del port_definition["mode"]
return data
@post_dump(pass_original=True)
def add_back_type_label(self, data, original, **kwargs): # pylint:disable=unused-argument
type_label = original._type_label # pylint:disable=protected-access
if type_label:
data["type"] = LABELLED_RESOURCE_NAME.format(data["type"], type_label)
return data
class InternalSparkComponentSchema(InternalComponentSchema):
# type field is required for registration
type = StringTransformedEnum(
allowed_values=PublicNodeType.SPARK,
casing_transform=lambda x: parse_name_label(x)[0].lower(),
pass_original=True,
)
# override inputs:
# https://componentsdk.azurewebsites.net/components/spark_component.html#differences-with-other-component-types
inputs = fields.Dict(
keys=fields.Str(),
values=UnionField(
[
NestedField(InternalSparkParameterSchema),
NestedField(InternalInputPortSchema),
]
),
)
environment = EnvironmentField(
extra_fields=[NestedField(InternalEnvironmentSchema)],
allow_none=True,
)
jars = UnionField(
[
fields.List(fields.Str()),
fields.Str(),
],
)
py_files = UnionField(
[
fields.List(fields.Str()),
fields.Str(),
],
data_key="pyFiles",
attribute="py_files",
)
entry = UnionField(
[NestedField(SparkEntryFileSchema), NestedField(SparkEntryClassSchema)],
required=True,
metadata={"description": "Entry."},
)
files = fields.List(fields.Str(required=True))
archives = fields.List(fields.Str(required=True))
conf = fields.Dict(keys=fields.Str(), values=fields.Raw())
args = fields.Str(metadata={"description": "Command Line arguments."})