aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/component.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/component.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/component.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/component.py232
1 files changed, 232 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/component.py
new file mode 100644
index 00000000..11d4bb56
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/component.py
@@ -0,0 +1,232 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+import os.path
+
+import pydash
+from marshmallow import EXCLUDE, INCLUDE, fields, post_dump, pre_load
+
+from ..._schema import NestedField, StringTransformedEnum, UnionField
+from ..._schema.component.component import ComponentSchema
+from ..._schema.core.fields import ArmVersionedStr, CodeField, EnvironmentField, RegistryStr
+from ..._schema.job.parameterized_spark import SparkEntryClassSchema, SparkEntryFileSchema
+from ..._utils._arm_id_utils import parse_name_label
+from ..._utils.utils import get_valid_dot_keys_with_wildcard
+from ...constants._common import (
+ LABELLED_RESOURCE_NAME,
+ SOURCE_PATH_CONTEXT_KEY,
+ AzureMLResourceType,
+ DefaultOpenEncoding,
+)
+from ...constants._component import NodeType as PublicNodeType
+from .._utils import yaml_safe_load_with_base_resolver
+from .environment import InternalEnvironmentSchema
+from .input_output import (
+ InternalEnumParameterSchema,
+ InternalInputPortSchema,
+ InternalOutputPortSchema,
+ InternalParameterSchema,
+ InternalPrimitiveOutputSchema,
+ InternalSparkParameterSchema,
+)
+
+
+class NodeType:
+ COMMAND = "CommandComponent"
+ DATA_TRANSFER = "DataTransferComponent"
+ DISTRIBUTED = "DistributedComponent"
+ HDI = "HDInsightComponent"
+ SCOPE_V2 = "scope"
+ HDI_V2 = "hdinsight"
+ HEMERA_V2 = "hemera"
+ STARLITE_V2 = "starlite"
+ AE365EXEPOOL_V2 = "ae365exepool"
+ AETHER_BRIDGE_V2 = "aetherbridge"
+ PARALLEL = "ParallelComponent"
+ SCOPE = "ScopeComponent"
+ STARLITE = "StarliteComponent"
+ SWEEP = "SweepComponent"
+ PIPELINE = "PipelineComponent"
+ HEMERA = "HemeraComponent"
+ AE365EXEPOOL = "AE365ExePoolComponent"
+ IPP = "IntellectualPropertyProtectedComponent"
+ # internal spake component got a type value conflict with spark component
+ # this enum is used to identify its create_function in factories
+ SPARK = "DummySpark"
+ AETHER_BRIDGE = "AetherBridgeComponent"
+
+ @classmethod
+ def all_values(cls):
+ all_values = []
+ for key, value in vars(cls).items():
+ if not key.startswith("_") and isinstance(value, str):
+ all_values.append(value)
+ return all_values
+
+
+class InternalComponentSchema(ComponentSchema):
+ class Meta:
+ unknown = INCLUDE
+
+ # override name as 1p components allow . in name, which is not allowed in v2 components
+ name = fields.Str()
+
+ # override to allow empty properties
+ tags = fields.Dict(keys=fields.Str())
+
+ # override inputs & outputs to support 1P inputs & outputs, may need to do strict validation later
+ # no need to check io type match since server will do that
+ inputs = fields.Dict(
+ keys=fields.Str(),
+ values=UnionField(
+ [
+ NestedField(InternalParameterSchema),
+ NestedField(InternalEnumParameterSchema),
+ NestedField(InternalInputPortSchema),
+ ]
+ ),
+ )
+ # support primitive output for all internal components for now
+ outputs = fields.Dict(
+ keys=fields.Str(),
+ values=UnionField(
+ [
+ NestedField(InternalPrimitiveOutputSchema, unknown=EXCLUDE),
+ NestedField(InternalOutputPortSchema, unknown=EXCLUDE),
+ ]
+ ),
+ )
+
+ # type field is required for registration
+ type = StringTransformedEnum(
+ allowed_values=NodeType.all_values(),
+ casing_transform=lambda x: parse_name_label(x)[0],
+ pass_original=True,
+ )
+
+ # need to resolve as it can be a local field
+ code = CodeField()
+
+ environment = UnionField(
+ [
+ RegistryStr(azureml_type=AzureMLResourceType.ENVIRONMENT),
+ ArmVersionedStr(azureml_type=AzureMLResourceType.ENVIRONMENT),
+ NestedField(InternalEnvironmentSchema),
+ ]
+ )
+
+ def get_skip_fields(self):
+ return ["properties"]
+
+ def _serialize(self, obj, *, many: bool = False):
+ if many and obj is not None:
+ return super(InternalComponentSchema, self)._serialize(obj, many=many)
+ ret = super(InternalComponentSchema, self)._serialize(obj)
+ for attr_name in obj.__dict__.keys():
+ if (
+ not attr_name.startswith("_")
+ and attr_name not in self.get_skip_fields()
+ and attr_name not in self.dump_fields
+ ):
+ ret[attr_name] = self.get_attribute(obj, attr_name, None)
+ return ret
+
+ # override param_override to ensure that param override happens after reloading the yaml
+ @pre_load
+ def add_param_overrides(self, data, **kwargs):
+ source_path = self.context.pop(SOURCE_PATH_CONTEXT_KEY, None)
+ if isinstance(data, dict) and source_path and os.path.isfile(source_path):
+
+ def should_node_overwritten(_root, _parts):
+ parts = _parts.copy()
+ parts.pop()
+ parts.append("type")
+ _input_type = pydash.get(_root, parts, None)
+ return isinstance(_input_type, str) and _input_type.lower() not in ["boolean"]
+
+ # do override here
+ with open(source_path, "r", encoding=DefaultOpenEncoding.READ) as f:
+ origin_data = yaml_safe_load_with_base_resolver(f)
+ for dot_key_wildcard, condition_func in [
+ ("version", None),
+ ("inputs.*.default", should_node_overwritten),
+ ("inputs.*.enum", should_node_overwritten),
+ ]:
+ for dot_key in get_valid_dot_keys_with_wildcard(
+ origin_data, dot_key_wildcard, validate_func=condition_func
+ ):
+ pydash.set_(data, dot_key, pydash.get(origin_data, dot_key))
+ return super().add_param_overrides(data, **kwargs)
+
+ @post_dump(pass_original=True)
+ def simplify_input_output_port(self, data, original, **kwargs): # pylint:disable=unused-argument
+ # remove None in input & output
+ for io_ports in [data["inputs"], data["outputs"]]:
+ for port_name, port_definition in io_ports.items():
+ io_ports[port_name] = dict(filter(lambda item: item[1] is not None, port_definition.items()))
+
+ # hack, to match current serialization match expectation
+ for port_name, port_definition in data["inputs"].items():
+ if "mode" in port_definition:
+ del port_definition["mode"]
+
+ return data
+
+ @post_dump(pass_original=True)
+ def add_back_type_label(self, data, original, **kwargs): # pylint:disable=unused-argument
+ type_label = original._type_label # pylint:disable=protected-access
+ if type_label:
+ data["type"] = LABELLED_RESOURCE_NAME.format(data["type"], type_label)
+ return data
+
+
+class InternalSparkComponentSchema(InternalComponentSchema):
+ # type field is required for registration
+ type = StringTransformedEnum(
+ allowed_values=PublicNodeType.SPARK,
+ casing_transform=lambda x: parse_name_label(x)[0].lower(),
+ pass_original=True,
+ )
+
+ # override inputs:
+ # https://componentsdk.azurewebsites.net/components/spark_component.html#differences-with-other-component-types
+ inputs = fields.Dict(
+ keys=fields.Str(),
+ values=UnionField(
+ [
+ NestedField(InternalSparkParameterSchema),
+ NestedField(InternalInputPortSchema),
+ ]
+ ),
+ )
+
+ environment = EnvironmentField(
+ extra_fields=[NestedField(InternalEnvironmentSchema)],
+ allow_none=True,
+ )
+
+ jars = UnionField(
+ [
+ fields.List(fields.Str()),
+ fields.Str(),
+ ],
+ )
+ py_files = UnionField(
+ [
+ fields.List(fields.Str()),
+ fields.Str(),
+ ],
+ data_key="pyFiles",
+ attribute="py_files",
+ )
+
+ entry = UnionField(
+ [NestedField(SparkEntryFileSchema), NestedField(SparkEntryClassSchema)],
+ required=True,
+ metadata={"description": "Entry."},
+ )
+
+ files = fields.List(fields.Str(required=True))
+ archives = fields.List(fields.Str(required=True))
+ conf = fields.Dict(keys=fields.Str(), values=fields.Raw())
+ args = fields.Str(metadata={"description": "Command Line arguments."})