aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/component.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/component.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/component.py143
1 files changed, 143 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/component.py
new file mode 100644
index 00000000..5772a607
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/component.py
@@ -0,0 +1,143 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from pathlib import Path
+
+from marshmallow import ValidationError, fields, post_dump, pre_dump, pre_load
+from marshmallow.fields import Field
+
+from azure.ai.ml._schema.component.input_output import InputPortSchema, OutputPortSchema, ParameterSchema
+from azure.ai.ml._schema.core.fields import (
+ ArmVersionedStr,
+ ExperimentalField,
+ NestedField,
+ PythonFuncNameStr,
+ UnionField,
+)
+from azure.ai.ml._schema.core.intellectual_property import IntellectualPropertySchema
+from azure.ai.ml._utils.utils import is_private_preview_enabled, load_yaml
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, AzureMLResourceType
+
+from .._utils.utils import _resolve_group_inputs_for_component
+from ..assets.asset import AssetSchema
+from ..core.fields import RegistryStr
+
+
+class ComponentNameStr(PythonFuncNameStr):
+ def _get_field_name(self):
+ return "Component"
+
+
+class ComponentYamlRefField(Field):
+ """Allows you to nest a :class:`Schema <marshmallow.Schema>`
+ inside a yaml ref field.
+ """
+
+ def _jsonschema_type_mapping(self):
+ schema = {"type": "string"}
+ if self.name is not None:
+ schema["title"] = self.name
+ if self.dump_only:
+ schema["readonly"] = True
+ return schema
+
+ def _deserialize(self, value, attr, data, **kwargs):
+ if not isinstance(value, str):
+ raise ValidationError(f"Nested yaml ref field expected a string but got {type(value)}.")
+
+ base_path = Path(self.context[BASE_PATH_CONTEXT_KEY])
+
+ source_path = Path(value)
+ # raise if the string is not a valid path, like "azureml:xxx"
+ try:
+ source_path.resolve()
+ except OSError as ex:
+ raise ValidationError(f"Nested file ref field expected a local path but got {value}.") from ex
+
+ if not source_path.is_absolute():
+ source_path = base_path / source_path
+
+ if not source_path.is_file():
+ raise ValidationError(
+ f"Nested yaml ref field expected a local path but can't find {value} based on {base_path.as_posix()}."
+ )
+
+ loaded_value = load_yaml(source_path)
+
+ # local import to avoid circular import
+ from azure.ai.ml.entities import Component
+
+ component = Component._load(data=loaded_value, yaml_path=source_path) # pylint: disable=protected-access
+ return component
+
+ def _serialize(self, value, attr, obj, **kwargs):
+ raise ValidationError("Serialize on RefField is not supported.")
+
+
+class ComponentSchema(AssetSchema):
+ schema = fields.Str(data_key="$schema", attribute="_schema")
+ name = ComponentNameStr(required=True)
+ id = UnionField(
+ [
+ RegistryStr(dump_only=True),
+ ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, dump_only=True),
+ ]
+ )
+ display_name = fields.Str()
+ description = fields.Str()
+ tags = fields.Dict(keys=fields.Str(), values=fields.Str())
+ is_deterministic = fields.Bool()
+ inputs = fields.Dict(
+ keys=fields.Str(),
+ values=UnionField(
+ [
+ NestedField(ParameterSchema),
+ NestedField(InputPortSchema),
+ ]
+ ),
+ )
+ outputs = fields.Dict(
+ keys=fields.Str(),
+ values=NestedField(OutputPortSchema),
+ )
+ # hide in private preview
+ if is_private_preview_enabled():
+ intellectual_property = ExperimentalField(NestedField(IntellectualPropertySchema))
+
+ def __init__(self, *args, **kwargs):
+ # Remove schema_ignored to enable serialize and deserialize schema.
+ self._declared_fields.pop("schema_ignored", None)
+ super().__init__(*args, **kwargs)
+
+ @pre_load
+ def convert_version_to_str(self, data, **kwargs): # pylint: disable=unused-argument
+ if isinstance(data, dict) and data.get("version", None):
+ data["version"] = str(data["version"])
+ return data
+
+ @pre_dump
+ def add_private_fields_to_dump(self, data, **kwargs): # pylint: disable=unused-argument
+ # The ipp field is set on the component object as "_intellectual_property".
+ # We need to set it as "intellectual_property" before dumping so that Marshmallow
+ # can pick up the field correctly on dump and show it back to the user.
+ ipp_field = data._intellectual_property # pylint: disable=protected-access
+ if ipp_field:
+ setattr(data, "intellectual_property", ipp_field)
+ return data
+
+ @post_dump
+ def convert_input_value_to_str(self, data, **kwargs): # pylint:disable=unused-argument
+ if isinstance(data, dict) and data.get("inputs", None):
+ input_dict = data["inputs"]
+ for input_value in input_dict.values():
+ input_type = input_value.get("type", None)
+ if isinstance(input_type, str) and input_type.lower() == "float":
+ # Convert number to string to avoid precision issue
+ for key in ["default", "min", "max"]:
+ if input_value.get(key, None) is not None:
+ input_value[key] = str(input_value[key])
+ return data
+
+ @pre_dump
+ def flatten_group_inputs(self, data, **kwargs): # pylint: disable=unused-argument
+ return _resolve_group_inputs_for_component(data)