diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema')
6 files changed, 481 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/__init__.py new file mode 100644 index 00000000..d540fd20 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/__init__.py @@ -0,0 +1,3 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/command.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/command.py new file mode 100644 index 00000000..2dddf02b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/command.py @@ -0,0 +1,37 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from marshmallow import fields + +from ..._schema import NestedField +from ..._schema.core.fields import DumpableEnumField, EnvironmentField +from ..._schema.job import ParameterizedCommandSchema, ParameterizedParallelSchema +from ..._schema.job.job_limits import CommandJobLimitsSchema +from .._schema.node import InternalBaseNodeSchema, NodeType + + +class CommandSchema(InternalBaseNodeSchema, ParameterizedCommandSchema): + class Meta: + exclude = ["code", "distribution"] # internal command doesn't have code & distribution + + environment = EnvironmentField() + type = DumpableEnumField(allowed_values=[NodeType.COMMAND]) + limits = NestedField(CommandJobLimitsSchema) + + +class DistributedSchema(CommandSchema): + class Meta: + exclude = ["code"] # need to enable distribution comparing to CommandSchema + + type = DumpableEnumField(allowed_values=[NodeType.DISTRIBUTED]) + + +class ParallelSchema(InternalBaseNodeSchema, ParameterizedParallelSchema): + class Meta: + # partition_keys can still be used with unknown warning, but need to do dump before setting + exclude = ["task", "input_data", "mini_batch_error_threshold", "partition_keys"] + + type = DumpableEnumField(allowed_values=[NodeType.PARALLEL]) + compute = fields.Str() + environment = fields.Str() + limits = NestedField(CommandJobLimitsSchema) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/component.py new file mode 100644 index 00000000..11d4bb56 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/component.py @@ -0,0 +1,232 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import os.path + +import pydash +from marshmallow import EXCLUDE, INCLUDE, fields, post_dump, pre_load + +from ..._schema import NestedField, StringTransformedEnum, UnionField +from ..._schema.component.component import ComponentSchema +from ..._schema.core.fields import ArmVersionedStr, CodeField, EnvironmentField, RegistryStr +from ..._schema.job.parameterized_spark import SparkEntryClassSchema, SparkEntryFileSchema +from ..._utils._arm_id_utils import parse_name_label +from ..._utils.utils import get_valid_dot_keys_with_wildcard +from ...constants._common import ( + LABELLED_RESOURCE_NAME, + SOURCE_PATH_CONTEXT_KEY, + AzureMLResourceType, + DefaultOpenEncoding, +) +from ...constants._component import NodeType as PublicNodeType +from .._utils import yaml_safe_load_with_base_resolver +from .environment import InternalEnvironmentSchema +from .input_output import ( + InternalEnumParameterSchema, + InternalInputPortSchema, + InternalOutputPortSchema, + InternalParameterSchema, + InternalPrimitiveOutputSchema, + InternalSparkParameterSchema, +) + + +class NodeType: + COMMAND = "CommandComponent" + DATA_TRANSFER = "DataTransferComponent" + DISTRIBUTED = "DistributedComponent" + HDI = "HDInsightComponent" + SCOPE_V2 = "scope" + HDI_V2 = "hdinsight" + HEMERA_V2 = "hemera" + STARLITE_V2 = "starlite" + AE365EXEPOOL_V2 = "ae365exepool" + AETHER_BRIDGE_V2 = "aetherbridge" + PARALLEL = "ParallelComponent" + SCOPE = "ScopeComponent" + STARLITE = "StarliteComponent" + SWEEP = "SweepComponent" + PIPELINE = "PipelineComponent" + HEMERA = "HemeraComponent" + AE365EXEPOOL = "AE365ExePoolComponent" + IPP = "IntellectualPropertyProtectedComponent" + # internal spake component got a type value conflict with spark component + # this enum is used to identify its create_function in factories + SPARK = "DummySpark" + AETHER_BRIDGE = "AetherBridgeComponent" + + @classmethod + def all_values(cls): + all_values = [] + for key, value in vars(cls).items(): + if not key.startswith("_") and isinstance(value, str): + all_values.append(value) + return all_values + + +class InternalComponentSchema(ComponentSchema): + class Meta: + unknown = INCLUDE + + # override name as 1p components allow . in name, which is not allowed in v2 components + name = fields.Str() + + # override to allow empty properties + tags = fields.Dict(keys=fields.Str()) + + # override inputs & outputs to support 1P inputs & outputs, may need to do strict validation later + # no need to check io type match since server will do that + inputs = fields.Dict( + keys=fields.Str(), + values=UnionField( + [ + NestedField(InternalParameterSchema), + NestedField(InternalEnumParameterSchema), + NestedField(InternalInputPortSchema), + ] + ), + ) + # support primitive output for all internal components for now + outputs = fields.Dict( + keys=fields.Str(), + values=UnionField( + [ + NestedField(InternalPrimitiveOutputSchema, unknown=EXCLUDE), + NestedField(InternalOutputPortSchema, unknown=EXCLUDE), + ] + ), + ) + + # type field is required for registration + type = StringTransformedEnum( + allowed_values=NodeType.all_values(), + casing_transform=lambda x: parse_name_label(x)[0], + pass_original=True, + ) + + # need to resolve as it can be a local field + code = CodeField() + + environment = UnionField( + [ + RegistryStr(azureml_type=AzureMLResourceType.ENVIRONMENT), + ArmVersionedStr(azureml_type=AzureMLResourceType.ENVIRONMENT), + NestedField(InternalEnvironmentSchema), + ] + ) + + def get_skip_fields(self): + return ["properties"] + + def _serialize(self, obj, *, many: bool = False): + if many and obj is not None: + return super(InternalComponentSchema, self)._serialize(obj, many=many) + ret = super(InternalComponentSchema, self)._serialize(obj) + for attr_name in obj.__dict__.keys(): + if ( + not attr_name.startswith("_") + and attr_name not in self.get_skip_fields() + and attr_name not in self.dump_fields + ): + ret[attr_name] = self.get_attribute(obj, attr_name, None) + return ret + + # override param_override to ensure that param override happens after reloading the yaml + @pre_load + def add_param_overrides(self, data, **kwargs): + source_path = self.context.pop(SOURCE_PATH_CONTEXT_KEY, None) + if isinstance(data, dict) and source_path and os.path.isfile(source_path): + + def should_node_overwritten(_root, _parts): + parts = _parts.copy() + parts.pop() + parts.append("type") + _input_type = pydash.get(_root, parts, None) + return isinstance(_input_type, str) and _input_type.lower() not in ["boolean"] + + # do override here + with open(source_path, "r", encoding=DefaultOpenEncoding.READ) as f: + origin_data = yaml_safe_load_with_base_resolver(f) + for dot_key_wildcard, condition_func in [ + ("version", None), + ("inputs.*.default", should_node_overwritten), + ("inputs.*.enum", should_node_overwritten), + ]: + for dot_key in get_valid_dot_keys_with_wildcard( + origin_data, dot_key_wildcard, validate_func=condition_func + ): + pydash.set_(data, dot_key, pydash.get(origin_data, dot_key)) + return super().add_param_overrides(data, **kwargs) + + @post_dump(pass_original=True) + def simplify_input_output_port(self, data, original, **kwargs): # pylint:disable=unused-argument + # remove None in input & output + for io_ports in [data["inputs"], data["outputs"]]: + for port_name, port_definition in io_ports.items(): + io_ports[port_name] = dict(filter(lambda item: item[1] is not None, port_definition.items())) + + # hack, to match current serialization match expectation + for port_name, port_definition in data["inputs"].items(): + if "mode" in port_definition: + del port_definition["mode"] + + return data + + @post_dump(pass_original=True) + def add_back_type_label(self, data, original, **kwargs): # pylint:disable=unused-argument + type_label = original._type_label # pylint:disable=protected-access + if type_label: + data["type"] = LABELLED_RESOURCE_NAME.format(data["type"], type_label) + return data + + +class InternalSparkComponentSchema(InternalComponentSchema): + # type field is required for registration + type = StringTransformedEnum( + allowed_values=PublicNodeType.SPARK, + casing_transform=lambda x: parse_name_label(x)[0].lower(), + pass_original=True, + ) + + # override inputs: + # https://componentsdk.azurewebsites.net/components/spark_component.html#differences-with-other-component-types + inputs = fields.Dict( + keys=fields.Str(), + values=UnionField( + [ + NestedField(InternalSparkParameterSchema), + NestedField(InternalInputPortSchema), + ] + ), + ) + + environment = EnvironmentField( + extra_fields=[NestedField(InternalEnvironmentSchema)], + allow_none=True, + ) + + jars = UnionField( + [ + fields.List(fields.Str()), + fields.Str(), + ], + ) + py_files = UnionField( + [ + fields.List(fields.Str()), + fields.Str(), + ], + data_key="pyFiles", + attribute="py_files", + ) + + entry = UnionField( + [NestedField(SparkEntryFileSchema), NestedField(SparkEntryClassSchema)], + required=True, + metadata={"description": "Entry."}, + ) + + files = fields.List(fields.Str(required=True)) + archives = fields.List(fields.Str(required=True)) + conf = fields.Dict(keys=fields.Str(), values=fields.Raw()) + args = fields.Str(metadata={"description": "Command Line arguments."}) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/environment.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/environment.py new file mode 100644 index 00000000..f7c20228 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/environment.py @@ -0,0 +1,21 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields + +from ..._schema import PathAwareSchema +from ..._schema.core.fields import DumpableEnumField, VersionField + + +class InternalEnvironmentSchema(PathAwareSchema): + docker = fields.Dict() + conda = fields.Dict() + os = DumpableEnumField( + # add enum instead of use string transformer here to avoid changing the value + allowed_values=["Linux", "Windows", "linux", "windows"], + required=False, + ) + name = fields.Str() + version = VersionField() + python = fields.Dict() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/input_output.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/input_output.py new file mode 100644 index 00000000..b1fe2188 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/input_output.py @@ -0,0 +1,113 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields, post_dump, post_load + +from ..._schema import PatchedSchemaMeta, StringTransformedEnum, UnionField +from ..._schema.component.input_output import InputPortSchema, ParameterSchema +from ..._schema.core.fields import DumpableEnumField, PrimitiveValueField + +SUPPORTED_INTERNAL_PARAM_TYPES = [ + "integer", + "Integer", + "boolean", + "Boolean", + "string", + "String", + "float", + "Float", + "double", + "Double", +] + + +SUPPORTED_INTERNAL_SPARK_PARAM_TYPES = [ + "integer", + "Integer", + "boolean", + "Boolean", + "string", + "String", + "double", + "Double", + # remove float and add number + "number", +] + + +class InternalInputPortSchema(InputPortSchema): + # skip client-side validate for type enum & support list + type = UnionField( + [ + fields.Str(), + fields.List(fields.Str()), + ], + required=True, + data_key="type", + ) + is_resource = fields.Bool() + datastore_mode = fields.Str() + + @post_dump(pass_original=True) + def resolve_list_type(self, data, original_data, **kwargs): # pylint: disable=unused-argument + if isinstance(original_data.type, list): + data["type"] = original_data.type + return data + + +class InternalOutputPortSchema(metaclass=PatchedSchemaMeta): + # skip client-side validate for type enum + type = fields.Str( + required=True, + data_key="type", + ) + description = fields.Str() + is_link_mode = fields.Bool() + datastore_mode = fields.Str() + + +class InternalPrimitiveOutputSchema(metaclass=PatchedSchemaMeta): + type = DumpableEnumField( + allowed_values=SUPPORTED_INTERNAL_PARAM_TYPES, + required=True, + ) + description = fields.Str() + + +class InternalParameterSchema(ParameterSchema): + type = DumpableEnumField( + allowed_values=SUPPORTED_INTERNAL_PARAM_TYPES, + required=True, + data_key="type", + ) + + +class InternalSparkParameterSchema(ParameterSchema): + type = DumpableEnumField( + allowed_values=SUPPORTED_INTERNAL_SPARK_PARAM_TYPES, + required=True, + data_key="type", + ) + + +class InternalEnumParameterSchema(ParameterSchema): + type = StringTransformedEnum( + allowed_values=["enum"], + required=True, + data_key="type", + ) + default = PrimitiveValueField() + enum = fields.List( + PrimitiveValueField(), + required=True, + ) + + @post_dump + @post_load + def enum_value_to_string(self, data, **kwargs): # pylint: disable=unused-argument + if "enum" in data: + data["enum"] = list(map(str, data["enum"])) + if "default" in data and data["default"] is not None: + data["default"] = str(data["default"]) + return data diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/node.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/node.py new file mode 100644 index 00000000..6dbadcd3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/node.py @@ -0,0 +1,75 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import INCLUDE, fields, post_load, pre_dump + +from ..._schema import ArmVersionedStr, NestedField, RegistryStr, UnionField +from ..._schema.core.fields import DumpableEnumField +from ..._schema.pipeline.component_job import BaseNodeSchema, _resolve_inputs_outputs +from ...constants._common import AzureMLResourceType +from .component import InternalComponentSchema, NodeType + + +class InternalBaseNodeSchema(BaseNodeSchema): + class Meta: + unknown = INCLUDE + + component = UnionField( + [ + # for registry type assets + RegistryStr(azureml_type=AzureMLResourceType.ENVIRONMENT), + # existing component + ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True), + # inline component or component file reference starting with FILE prefix + NestedField(InternalComponentSchema, unknown=INCLUDE), + ], + required=True, + ) + type = DumpableEnumField( + allowed_values=NodeType.all_values(), + ) + + @post_load + def make(self, data, **kwargs): # pylint: disable=unused-argument + from ...entities._builders import parse_inputs_outputs + + # parse inputs/outputs + data = parse_inputs_outputs(data) + + # dict to node object + from ...entities._job.pipeline._load_component import pipeline_node_factory + + return pipeline_node_factory.load_from_dict(data=data) + + @pre_dump + def resolve_inputs_outputs(self, job, **kwargs): # pylint: disable=unused-argument + return _resolve_inputs_outputs(job) + + +class ScopeSchema(InternalBaseNodeSchema): + type = DumpableEnumField(allowed_values=[NodeType.SCOPE]) + adla_account_name = fields.Str(required=True) + scope_param = fields.Str() + custom_job_name_suffix = fields.Str() + priority = fields.Int() + auto_token = fields.Int() + tokens = fields.Int() + vcp = fields.Float() + + +class HDInsightSchema(InternalBaseNodeSchema): + type = DumpableEnumField(allowed_values=[NodeType.HDI]) + + compute_name = fields.Str() + queue = fields.Str() + driver_memory = fields.Str() + driver_cores = fields.Int() + executor_memory = fields.Str() + executor_cores = fields.Int() + number_executors = fields.Int() + conf = UnionField( + # dictionary or json string + union_fields=[fields.Dict(keys=fields.Str()), fields.Str()], + ) + hdinsight_spark_job_name = fields.Str() |