aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/__init__.py3
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/command.py37
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/component.py232
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/environment.py21
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/input_output.py113
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/node.py75
6 files changed, 481 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/__init__.py
new file mode 100644
index 00000000..d540fd20
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/__init__.py
@@ -0,0 +1,3 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/command.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/command.py
new file mode 100644
index 00000000..2dddf02b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/command.py
@@ -0,0 +1,37 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from marshmallow import fields
+
+from ..._schema import NestedField
+from ..._schema.core.fields import DumpableEnumField, EnvironmentField
+from ..._schema.job import ParameterizedCommandSchema, ParameterizedParallelSchema
+from ..._schema.job.job_limits import CommandJobLimitsSchema
+from .._schema.node import InternalBaseNodeSchema, NodeType
+
+
+class CommandSchema(InternalBaseNodeSchema, ParameterizedCommandSchema):
+ class Meta:
+ exclude = ["code", "distribution"] # internal command doesn't have code & distribution
+
+ environment = EnvironmentField()
+ type = DumpableEnumField(allowed_values=[NodeType.COMMAND])
+ limits = NestedField(CommandJobLimitsSchema)
+
+
+class DistributedSchema(CommandSchema):
+ class Meta:
+ exclude = ["code"] # need to enable distribution comparing to CommandSchema
+
+ type = DumpableEnumField(allowed_values=[NodeType.DISTRIBUTED])
+
+
+class ParallelSchema(InternalBaseNodeSchema, ParameterizedParallelSchema):
+ class Meta:
+ # partition_keys can still be used with unknown warning, but need to do dump before setting
+ exclude = ["task", "input_data", "mini_batch_error_threshold", "partition_keys"]
+
+ type = DumpableEnumField(allowed_values=[NodeType.PARALLEL])
+ compute = fields.Str()
+ environment = fields.Str()
+ limits = NestedField(CommandJobLimitsSchema)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/component.py
new file mode 100644
index 00000000..11d4bb56
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/component.py
@@ -0,0 +1,232 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+import os.path
+
+import pydash
+from marshmallow import EXCLUDE, INCLUDE, fields, post_dump, pre_load
+
+from ..._schema import NestedField, StringTransformedEnum, UnionField
+from ..._schema.component.component import ComponentSchema
+from ..._schema.core.fields import ArmVersionedStr, CodeField, EnvironmentField, RegistryStr
+from ..._schema.job.parameterized_spark import SparkEntryClassSchema, SparkEntryFileSchema
+from ..._utils._arm_id_utils import parse_name_label
+from ..._utils.utils import get_valid_dot_keys_with_wildcard
+from ...constants._common import (
+ LABELLED_RESOURCE_NAME,
+ SOURCE_PATH_CONTEXT_KEY,
+ AzureMLResourceType,
+ DefaultOpenEncoding,
+)
+from ...constants._component import NodeType as PublicNodeType
+from .._utils import yaml_safe_load_with_base_resolver
+from .environment import InternalEnvironmentSchema
+from .input_output import (
+ InternalEnumParameterSchema,
+ InternalInputPortSchema,
+ InternalOutputPortSchema,
+ InternalParameterSchema,
+ InternalPrimitiveOutputSchema,
+ InternalSparkParameterSchema,
+)
+
+
+class NodeType:
+ COMMAND = "CommandComponent"
+ DATA_TRANSFER = "DataTransferComponent"
+ DISTRIBUTED = "DistributedComponent"
+ HDI = "HDInsightComponent"
+ SCOPE_V2 = "scope"
+ HDI_V2 = "hdinsight"
+ HEMERA_V2 = "hemera"
+ STARLITE_V2 = "starlite"
+ AE365EXEPOOL_V2 = "ae365exepool"
+ AETHER_BRIDGE_V2 = "aetherbridge"
+ PARALLEL = "ParallelComponent"
+ SCOPE = "ScopeComponent"
+ STARLITE = "StarliteComponent"
+ SWEEP = "SweepComponent"
+ PIPELINE = "PipelineComponent"
+ HEMERA = "HemeraComponent"
+ AE365EXEPOOL = "AE365ExePoolComponent"
+ IPP = "IntellectualPropertyProtectedComponent"
+ # internal spake component got a type value conflict with spark component
+ # this enum is used to identify its create_function in factories
+ SPARK = "DummySpark"
+ AETHER_BRIDGE = "AetherBridgeComponent"
+
+ @classmethod
+ def all_values(cls):
+ all_values = []
+ for key, value in vars(cls).items():
+ if not key.startswith("_") and isinstance(value, str):
+ all_values.append(value)
+ return all_values
+
+
+class InternalComponentSchema(ComponentSchema):
+ class Meta:
+ unknown = INCLUDE
+
+ # override name as 1p components allow . in name, which is not allowed in v2 components
+ name = fields.Str()
+
+ # override to allow empty properties
+ tags = fields.Dict(keys=fields.Str())
+
+ # override inputs & outputs to support 1P inputs & outputs, may need to do strict validation later
+ # no need to check io type match since server will do that
+ inputs = fields.Dict(
+ keys=fields.Str(),
+ values=UnionField(
+ [
+ NestedField(InternalParameterSchema),
+ NestedField(InternalEnumParameterSchema),
+ NestedField(InternalInputPortSchema),
+ ]
+ ),
+ )
+ # support primitive output for all internal components for now
+ outputs = fields.Dict(
+ keys=fields.Str(),
+ values=UnionField(
+ [
+ NestedField(InternalPrimitiveOutputSchema, unknown=EXCLUDE),
+ NestedField(InternalOutputPortSchema, unknown=EXCLUDE),
+ ]
+ ),
+ )
+
+ # type field is required for registration
+ type = StringTransformedEnum(
+ allowed_values=NodeType.all_values(),
+ casing_transform=lambda x: parse_name_label(x)[0],
+ pass_original=True,
+ )
+
+ # need to resolve as it can be a local field
+ code = CodeField()
+
+ environment = UnionField(
+ [
+ RegistryStr(azureml_type=AzureMLResourceType.ENVIRONMENT),
+ ArmVersionedStr(azureml_type=AzureMLResourceType.ENVIRONMENT),
+ NestedField(InternalEnvironmentSchema),
+ ]
+ )
+
+ def get_skip_fields(self):
+ return ["properties"]
+
+ def _serialize(self, obj, *, many: bool = False):
+ if many and obj is not None:
+ return super(InternalComponentSchema, self)._serialize(obj, many=many)
+ ret = super(InternalComponentSchema, self)._serialize(obj)
+ for attr_name in obj.__dict__.keys():
+ if (
+ not attr_name.startswith("_")
+ and attr_name not in self.get_skip_fields()
+ and attr_name not in self.dump_fields
+ ):
+ ret[attr_name] = self.get_attribute(obj, attr_name, None)
+ return ret
+
+ # override param_override to ensure that param override happens after reloading the yaml
+ @pre_load
+ def add_param_overrides(self, data, **kwargs):
+ source_path = self.context.pop(SOURCE_PATH_CONTEXT_KEY, None)
+ if isinstance(data, dict) and source_path and os.path.isfile(source_path):
+
+ def should_node_overwritten(_root, _parts):
+ parts = _parts.copy()
+ parts.pop()
+ parts.append("type")
+ _input_type = pydash.get(_root, parts, None)
+ return isinstance(_input_type, str) and _input_type.lower() not in ["boolean"]
+
+ # do override here
+ with open(source_path, "r", encoding=DefaultOpenEncoding.READ) as f:
+ origin_data = yaml_safe_load_with_base_resolver(f)
+ for dot_key_wildcard, condition_func in [
+ ("version", None),
+ ("inputs.*.default", should_node_overwritten),
+ ("inputs.*.enum", should_node_overwritten),
+ ]:
+ for dot_key in get_valid_dot_keys_with_wildcard(
+ origin_data, dot_key_wildcard, validate_func=condition_func
+ ):
+ pydash.set_(data, dot_key, pydash.get(origin_data, dot_key))
+ return super().add_param_overrides(data, **kwargs)
+
+ @post_dump(pass_original=True)
+ def simplify_input_output_port(self, data, original, **kwargs): # pylint:disable=unused-argument
+ # remove None in input & output
+ for io_ports in [data["inputs"], data["outputs"]]:
+ for port_name, port_definition in io_ports.items():
+ io_ports[port_name] = dict(filter(lambda item: item[1] is not None, port_definition.items()))
+
+ # hack, to match current serialization match expectation
+ for port_name, port_definition in data["inputs"].items():
+ if "mode" in port_definition:
+ del port_definition["mode"]
+
+ return data
+
+ @post_dump(pass_original=True)
+ def add_back_type_label(self, data, original, **kwargs): # pylint:disable=unused-argument
+ type_label = original._type_label # pylint:disable=protected-access
+ if type_label:
+ data["type"] = LABELLED_RESOURCE_NAME.format(data["type"], type_label)
+ return data
+
+
+class InternalSparkComponentSchema(InternalComponentSchema):
+ # type field is required for registration
+ type = StringTransformedEnum(
+ allowed_values=PublicNodeType.SPARK,
+ casing_transform=lambda x: parse_name_label(x)[0].lower(),
+ pass_original=True,
+ )
+
+ # override inputs:
+ # https://componentsdk.azurewebsites.net/components/spark_component.html#differences-with-other-component-types
+ inputs = fields.Dict(
+ keys=fields.Str(),
+ values=UnionField(
+ [
+ NestedField(InternalSparkParameterSchema),
+ NestedField(InternalInputPortSchema),
+ ]
+ ),
+ )
+
+ environment = EnvironmentField(
+ extra_fields=[NestedField(InternalEnvironmentSchema)],
+ allow_none=True,
+ )
+
+ jars = UnionField(
+ [
+ fields.List(fields.Str()),
+ fields.Str(),
+ ],
+ )
+ py_files = UnionField(
+ [
+ fields.List(fields.Str()),
+ fields.Str(),
+ ],
+ data_key="pyFiles",
+ attribute="py_files",
+ )
+
+ entry = UnionField(
+ [NestedField(SparkEntryFileSchema), NestedField(SparkEntryClassSchema)],
+ required=True,
+ metadata={"description": "Entry."},
+ )
+
+ files = fields.List(fields.Str(required=True))
+ archives = fields.List(fields.Str(required=True))
+ conf = fields.Dict(keys=fields.Str(), values=fields.Raw())
+ args = fields.Str(metadata={"description": "Command Line arguments."})
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/environment.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/environment.py
new file mode 100644
index 00000000..f7c20228
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/environment.py
@@ -0,0 +1,21 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields
+
+from ..._schema import PathAwareSchema
+from ..._schema.core.fields import DumpableEnumField, VersionField
+
+
+class InternalEnvironmentSchema(PathAwareSchema):
+ docker = fields.Dict()
+ conda = fields.Dict()
+ os = DumpableEnumField(
+ # add enum instead of use string transformer here to avoid changing the value
+ allowed_values=["Linux", "Windows", "linux", "windows"],
+ required=False,
+ )
+ name = fields.Str()
+ version = VersionField()
+ python = fields.Dict()
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/input_output.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/input_output.py
new file mode 100644
index 00000000..b1fe2188
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/input_output.py
@@ -0,0 +1,113 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields, post_dump, post_load
+
+from ..._schema import PatchedSchemaMeta, StringTransformedEnum, UnionField
+from ..._schema.component.input_output import InputPortSchema, ParameterSchema
+from ..._schema.core.fields import DumpableEnumField, PrimitiveValueField
+
+SUPPORTED_INTERNAL_PARAM_TYPES = [
+ "integer",
+ "Integer",
+ "boolean",
+ "Boolean",
+ "string",
+ "String",
+ "float",
+ "Float",
+ "double",
+ "Double",
+]
+
+
+SUPPORTED_INTERNAL_SPARK_PARAM_TYPES = [
+ "integer",
+ "Integer",
+ "boolean",
+ "Boolean",
+ "string",
+ "String",
+ "double",
+ "Double",
+ # remove float and add number
+ "number",
+]
+
+
+class InternalInputPortSchema(InputPortSchema):
+ # skip client-side validate for type enum & support list
+ type = UnionField(
+ [
+ fields.Str(),
+ fields.List(fields.Str()),
+ ],
+ required=True,
+ data_key="type",
+ )
+ is_resource = fields.Bool()
+ datastore_mode = fields.Str()
+
+ @post_dump(pass_original=True)
+ def resolve_list_type(self, data, original_data, **kwargs): # pylint: disable=unused-argument
+ if isinstance(original_data.type, list):
+ data["type"] = original_data.type
+ return data
+
+
+class InternalOutputPortSchema(metaclass=PatchedSchemaMeta):
+ # skip client-side validate for type enum
+ type = fields.Str(
+ required=True,
+ data_key="type",
+ )
+ description = fields.Str()
+ is_link_mode = fields.Bool()
+ datastore_mode = fields.Str()
+
+
+class InternalPrimitiveOutputSchema(metaclass=PatchedSchemaMeta):
+ type = DumpableEnumField(
+ allowed_values=SUPPORTED_INTERNAL_PARAM_TYPES,
+ required=True,
+ )
+ description = fields.Str()
+
+
+class InternalParameterSchema(ParameterSchema):
+ type = DumpableEnumField(
+ allowed_values=SUPPORTED_INTERNAL_PARAM_TYPES,
+ required=True,
+ data_key="type",
+ )
+
+
+class InternalSparkParameterSchema(ParameterSchema):
+ type = DumpableEnumField(
+ allowed_values=SUPPORTED_INTERNAL_SPARK_PARAM_TYPES,
+ required=True,
+ data_key="type",
+ )
+
+
+class InternalEnumParameterSchema(ParameterSchema):
+ type = StringTransformedEnum(
+ allowed_values=["enum"],
+ required=True,
+ data_key="type",
+ )
+ default = PrimitiveValueField()
+ enum = fields.List(
+ PrimitiveValueField(),
+ required=True,
+ )
+
+ @post_dump
+ @post_load
+ def enum_value_to_string(self, data, **kwargs): # pylint: disable=unused-argument
+ if "enum" in data:
+ data["enum"] = list(map(str, data["enum"]))
+ if "default" in data and data["default"] is not None:
+ data["default"] = str(data["default"])
+ return data
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/node.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/node.py
new file mode 100644
index 00000000..6dbadcd3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/node.py
@@ -0,0 +1,75 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import INCLUDE, fields, post_load, pre_dump
+
+from ..._schema import ArmVersionedStr, NestedField, RegistryStr, UnionField
+from ..._schema.core.fields import DumpableEnumField
+from ..._schema.pipeline.component_job import BaseNodeSchema, _resolve_inputs_outputs
+from ...constants._common import AzureMLResourceType
+from .component import InternalComponentSchema, NodeType
+
+
+class InternalBaseNodeSchema(BaseNodeSchema):
+ class Meta:
+ unknown = INCLUDE
+
+ component = UnionField(
+ [
+ # for registry type assets
+ RegistryStr(azureml_type=AzureMLResourceType.ENVIRONMENT),
+ # existing component
+ ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True),
+ # inline component or component file reference starting with FILE prefix
+ NestedField(InternalComponentSchema, unknown=INCLUDE),
+ ],
+ required=True,
+ )
+ type = DumpableEnumField(
+ allowed_values=NodeType.all_values(),
+ )
+
+ @post_load
+ def make(self, data, **kwargs): # pylint: disable=unused-argument
+ from ...entities._builders import parse_inputs_outputs
+
+ # parse inputs/outputs
+ data = parse_inputs_outputs(data)
+
+ # dict to node object
+ from ...entities._job.pipeline._load_component import pipeline_node_factory
+
+ return pipeline_node_factory.load_from_dict(data=data)
+
+ @pre_dump
+ def resolve_inputs_outputs(self, job, **kwargs): # pylint: disable=unused-argument
+ return _resolve_inputs_outputs(job)
+
+
+class ScopeSchema(InternalBaseNodeSchema):
+ type = DumpableEnumField(allowed_values=[NodeType.SCOPE])
+ adla_account_name = fields.Str(required=True)
+ scope_param = fields.Str()
+ custom_job_name_suffix = fields.Str()
+ priority = fields.Int()
+ auto_token = fields.Int()
+ tokens = fields.Int()
+ vcp = fields.Float()
+
+
+class HDInsightSchema(InternalBaseNodeSchema):
+ type = DumpableEnumField(allowed_values=[NodeType.HDI])
+
+ compute_name = fields.Str()
+ queue = fields.Str()
+ driver_memory = fields.Str()
+ driver_cores = fields.Int()
+ executor_memory = fields.Str()
+ executor_cores = fields.Int()
+ number_executors = fields.Int()
+ conf = UnionField(
+ # dictionary or json string
+ union_fields=[fields.Dict(keys=fields.Str()), fields.Str()],
+ )
+ hdinsight_spark_job_name = fields.Str()