diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/component.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/component.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/component.py | 232 |
1 files changed, 232 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/component.py new file mode 100644 index 00000000..11d4bb56 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/component.py @@ -0,0 +1,232 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import os.path + +import pydash +from marshmallow import EXCLUDE, INCLUDE, fields, post_dump, pre_load + +from ..._schema import NestedField, StringTransformedEnum, UnionField +from ..._schema.component.component import ComponentSchema +from ..._schema.core.fields import ArmVersionedStr, CodeField, EnvironmentField, RegistryStr +from ..._schema.job.parameterized_spark import SparkEntryClassSchema, SparkEntryFileSchema +from ..._utils._arm_id_utils import parse_name_label +from ..._utils.utils import get_valid_dot_keys_with_wildcard +from ...constants._common import ( + LABELLED_RESOURCE_NAME, + SOURCE_PATH_CONTEXT_KEY, + AzureMLResourceType, + DefaultOpenEncoding, +) +from ...constants._component import NodeType as PublicNodeType +from .._utils import yaml_safe_load_with_base_resolver +from .environment import InternalEnvironmentSchema +from .input_output import ( + InternalEnumParameterSchema, + InternalInputPortSchema, + InternalOutputPortSchema, + InternalParameterSchema, + InternalPrimitiveOutputSchema, + InternalSparkParameterSchema, +) + + +class NodeType: + COMMAND = "CommandComponent" + DATA_TRANSFER = "DataTransferComponent" + DISTRIBUTED = "DistributedComponent" + HDI = "HDInsightComponent" + SCOPE_V2 = "scope" + HDI_V2 = "hdinsight" + HEMERA_V2 = "hemera" + STARLITE_V2 = "starlite" + AE365EXEPOOL_V2 = "ae365exepool" + AETHER_BRIDGE_V2 = "aetherbridge" + PARALLEL = "ParallelComponent" + SCOPE = "ScopeComponent" + STARLITE = "StarliteComponent" + SWEEP = "SweepComponent" + PIPELINE = "PipelineComponent" + HEMERA = "HemeraComponent" + AE365EXEPOOL = "AE365ExePoolComponent" + IPP = "IntellectualPropertyProtectedComponent" + # internal spake component got a type value conflict with spark component + # this enum is used to identify its create_function in factories + SPARK = "DummySpark" + AETHER_BRIDGE = "AetherBridgeComponent" + + @classmethod + def all_values(cls): + all_values = [] + for key, value in vars(cls).items(): + if not key.startswith("_") and isinstance(value, str): + all_values.append(value) + return all_values + + +class InternalComponentSchema(ComponentSchema): + class Meta: + unknown = INCLUDE + + # override name as 1p components allow . in name, which is not allowed in v2 components + name = fields.Str() + + # override to allow empty properties + tags = fields.Dict(keys=fields.Str()) + + # override inputs & outputs to support 1P inputs & outputs, may need to do strict validation later + # no need to check io type match since server will do that + inputs = fields.Dict( + keys=fields.Str(), + values=UnionField( + [ + NestedField(InternalParameterSchema), + NestedField(InternalEnumParameterSchema), + NestedField(InternalInputPortSchema), + ] + ), + ) + # support primitive output for all internal components for now + outputs = fields.Dict( + keys=fields.Str(), + values=UnionField( + [ + NestedField(InternalPrimitiveOutputSchema, unknown=EXCLUDE), + NestedField(InternalOutputPortSchema, unknown=EXCLUDE), + ] + ), + ) + + # type field is required for registration + type = StringTransformedEnum( + allowed_values=NodeType.all_values(), + casing_transform=lambda x: parse_name_label(x)[0], + pass_original=True, + ) + + # need to resolve as it can be a local field + code = CodeField() + + environment = UnionField( + [ + RegistryStr(azureml_type=AzureMLResourceType.ENVIRONMENT), + ArmVersionedStr(azureml_type=AzureMLResourceType.ENVIRONMENT), + NestedField(InternalEnvironmentSchema), + ] + ) + + def get_skip_fields(self): + return ["properties"] + + def _serialize(self, obj, *, many: bool = False): + if many and obj is not None: + return super(InternalComponentSchema, self)._serialize(obj, many=many) + ret = super(InternalComponentSchema, self)._serialize(obj) + for attr_name in obj.__dict__.keys(): + if ( + not attr_name.startswith("_") + and attr_name not in self.get_skip_fields() + and attr_name not in self.dump_fields + ): + ret[attr_name] = self.get_attribute(obj, attr_name, None) + return ret + + # override param_override to ensure that param override happens after reloading the yaml + @pre_load + def add_param_overrides(self, data, **kwargs): + source_path = self.context.pop(SOURCE_PATH_CONTEXT_KEY, None) + if isinstance(data, dict) and source_path and os.path.isfile(source_path): + + def should_node_overwritten(_root, _parts): + parts = _parts.copy() + parts.pop() + parts.append("type") + _input_type = pydash.get(_root, parts, None) + return isinstance(_input_type, str) and _input_type.lower() not in ["boolean"] + + # do override here + with open(source_path, "r", encoding=DefaultOpenEncoding.READ) as f: + origin_data = yaml_safe_load_with_base_resolver(f) + for dot_key_wildcard, condition_func in [ + ("version", None), + ("inputs.*.default", should_node_overwritten), + ("inputs.*.enum", should_node_overwritten), + ]: + for dot_key in get_valid_dot_keys_with_wildcard( + origin_data, dot_key_wildcard, validate_func=condition_func + ): + pydash.set_(data, dot_key, pydash.get(origin_data, dot_key)) + return super().add_param_overrides(data, **kwargs) + + @post_dump(pass_original=True) + def simplify_input_output_port(self, data, original, **kwargs): # pylint:disable=unused-argument + # remove None in input & output + for io_ports in [data["inputs"], data["outputs"]]: + for port_name, port_definition in io_ports.items(): + io_ports[port_name] = dict(filter(lambda item: item[1] is not None, port_definition.items())) + + # hack, to match current serialization match expectation + for port_name, port_definition in data["inputs"].items(): + if "mode" in port_definition: + del port_definition["mode"] + + return data + + @post_dump(pass_original=True) + def add_back_type_label(self, data, original, **kwargs): # pylint:disable=unused-argument + type_label = original._type_label # pylint:disable=protected-access + if type_label: + data["type"] = LABELLED_RESOURCE_NAME.format(data["type"], type_label) + return data + + +class InternalSparkComponentSchema(InternalComponentSchema): + # type field is required for registration + type = StringTransformedEnum( + allowed_values=PublicNodeType.SPARK, + casing_transform=lambda x: parse_name_label(x)[0].lower(), + pass_original=True, + ) + + # override inputs: + # https://componentsdk.azurewebsites.net/components/spark_component.html#differences-with-other-component-types + inputs = fields.Dict( + keys=fields.Str(), + values=UnionField( + [ + NestedField(InternalSparkParameterSchema), + NestedField(InternalInputPortSchema), + ] + ), + ) + + environment = EnvironmentField( + extra_fields=[NestedField(InternalEnvironmentSchema)], + allow_none=True, + ) + + jars = UnionField( + [ + fields.List(fields.Str()), + fields.Str(), + ], + ) + py_files = UnionField( + [ + fields.List(fields.Str()), + fields.Str(), + ], + data_key="pyFiles", + attribute="py_files", + ) + + entry = UnionField( + [NestedField(SparkEntryFileSchema), NestedField(SparkEntryClassSchema)], + required=True, + metadata={"description": "Entry."}, + ) + + files = fields.List(fields.Str(required=True)) + archives = fields.List(fields.Str(required=True)) + conf = fields.Dict(keys=fields.Str(), values=fields.Raw()) + args = fields.Str(metadata={"description": "Command Line arguments."}) |