aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/data_binding_expression.py88
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/utils.py94
3 files changed, 187 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/__init__.py
new file mode 100644
index 00000000..29a4fcd3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/data_binding_expression.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/data_binding_expression.py
new file mode 100644
index 00000000..611c80a2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/data_binding_expression.py
@@ -0,0 +1,88 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from typing import Union
+
+from marshmallow import Schema, fields
+
+from azure.ai.ml._schema.core.fields import DataBindingStr, ExperimentalField, NestedField, UnionField
+from azure.ai.ml._schema.core.schema import PathAwareSchema
+
+DATA_BINDING_SUPPORTED_KEY = "_data_binding_supported"
+
+
+def _is_literal(field):
+ return not isinstance(field, (NestedField, fields.List, fields.Dict, UnionField))
+
+
+def _add_data_binding_to_field(field, attrs_to_skip, schema_stack):
+ if hasattr(field, DATA_BINDING_SUPPORTED_KEY) and getattr(field, DATA_BINDING_SUPPORTED_KEY):
+ return field
+ data_binding_field = DataBindingStr()
+ if isinstance(field, UnionField):
+ for field_obj in field.union_fields:
+ if not _is_literal(field_obj):
+ _add_data_binding_to_field(field_obj, attrs_to_skip, schema_stack=schema_stack)
+ field.insert_union_field(data_binding_field)
+ elif isinstance(field, fields.Dict):
+ # handle dict, dict value can be None
+ if field.value_field is not None:
+ field.value_field = _add_data_binding_to_field(field.value_field, attrs_to_skip, schema_stack=schema_stack)
+ elif isinstance(field, fields.List):
+ # handle list
+ field.inner = _add_data_binding_to_field(field.inner, attrs_to_skip, schema_stack=schema_stack)
+ elif isinstance(field, ExperimentalField):
+ field = ExperimentalField(
+ _add_data_binding_to_field(field.experimental_field, attrs_to_skip, schema_stack=schema_stack),
+ data_key=field.data_key,
+ attribute=field.attribute,
+ dump_only=field.dump_only,
+ required=field.required,
+ allow_none=field.allow_none,
+ )
+ elif isinstance(field, NestedField):
+ # handle nested field
+ support_data_binding_expression_for_fields(field.schema, attrs_to_skip, schema_stack=schema_stack)
+ else:
+ # change basic fields to union
+ field = UnionField(
+ [data_binding_field, field],
+ data_key=field.data_key,
+ attribute=field.attribute,
+ dump_only=field.dump_only,
+ required=field.required,
+ allow_none=field.allow_none,
+ )
+
+ setattr(field, DATA_BINDING_SUPPORTED_KEY, True)
+ return field
+
+
+# pylint: disable-next=docstring-missing-param
+def support_data_binding_expression_for_fields( # pylint: disable=name-too-long
+ schema: Union[PathAwareSchema, Schema], attrs_to_skip=None, schema_stack=None
+):
+ """Update fields inside schema to support data binding string.
+
+ Only first layer of recursive schema is supported now.
+ """
+ if hasattr(schema, DATA_BINDING_SUPPORTED_KEY) and getattr(schema, DATA_BINDING_SUPPORTED_KEY):
+ return
+
+ setattr(schema, DATA_BINDING_SUPPORTED_KEY, True)
+
+ if attrs_to_skip is None:
+ attrs_to_skip = []
+ if schema_stack is None:
+ schema_stack = []
+ schema_type_name = type(schema).__name__
+ if schema_type_name in schema_stack:
+ return
+ schema_stack.append(schema_type_name)
+ for attr, field_obj in schema.load_fields.items():
+ if attr not in attrs_to_skip:
+ schema.load_fields[attr] = _add_data_binding_to_field(field_obj, attrs_to_skip, schema_stack=schema_stack)
+ for attr, field_obj in schema.dump_fields.items():
+ if attr not in attrs_to_skip:
+ schema.dump_fields[attr] = _add_data_binding_to_field(field_obj, attrs_to_skip, schema_stack=schema_stack)
+ schema_stack.pop()
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/utils.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/utils.py
new file mode 100644
index 00000000..c1ee3568
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/utils.py
@@ -0,0 +1,94 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+import copy
+import logging
+import re
+from collections import OrderedDict
+from typing import Any, Dict, Optional, Union
+
+from marshmallow.exceptions import ValidationError
+
+module_logger = logging.getLogger(__name__)
+
+
+class ArmId(str):
+ def __new__(cls, content):
+ validate_arm_str(content)
+ return str.__new__(cls, content)
+
+
+def validate_arm_str(arm_str: Union[ArmId, str]) -> bool:
+ """Validate whether the given string is in fact in the format of an ARM ID.
+
+ :param arm_str: The string to validate.
+ :type arm_str: Either a string (in case of incorrect formatting) or ArmID (in case of correct formatting).
+ :returns: True if the string is correctly formatted, False otherwise.
+ :rtype: bool
+ """
+ reg_str = (
+ r"/subscriptions/[0-9a-f]{8}-([0-9a-f]{4}-){3}[0-9a-f]{12}?/resourcegroups/.*/providers/[a-z.a-z]*/[a-z]*/.*"
+ )
+ lowered = arm_str.lower()
+ match = re.match(reg_str, lowered)
+ if match and match.group() == lowered:
+ return True
+ raise ValidationError(f"ARM string {arm_str} is not formatted correctly.")
+
+
+def get_subnet_str(vnet_name: str, subnet: str, sub_id: Optional[str] = None, rg: Optional[str] = None) -> str:
+ if vnet_name and not subnet:
+ raise ValidationError("Subnet is required when vnet name is specified.")
+ try:
+ validate_arm_str(subnet)
+ return subnet
+ except ValidationError:
+ return (
+ f"/subscriptions/{sub_id}/resourceGroups/{rg}/"
+ f"providers/Microsoft.Network/virtualNetworks/{vnet_name}/subnets/{subnet}"
+ )
+
+
+def replace_key_in_odict(odict: OrderedDict, old_key: Any, new_key: Any):
+ if not odict or old_key not in odict:
+ return odict
+ return OrderedDict([(new_key, v) if k == old_key else (k, v) for k, v in odict.items()])
+
+
+# This is temporary until deployments(batch/K8S) support registry references
+def exit_if_registry_assets(data: Dict, caller: str) -> None:
+ startswith = "azureml://registries/"
+ if (
+ "environment" in data
+ and data["environment"]
+ and isinstance(data["environment"], str)
+ and data["environment"].startswith(startswith)
+ ):
+ raise ValidationError(f"Registry reference for environments is not supported for {caller}")
+ if "model" in data and data["model"] and isinstance(data["model"], str) and data["model"].startswith(startswith):
+ raise ValidationError(f"Registry reference for models is not supported for {caller}")
+ if (
+ "code_configuration" in data
+ and data["code_configuration"].code
+ and isinstance(data["code_configuration"].code, str)
+ and data["code_configuration"].code.startswith(startswith)
+ ):
+ raise ValidationError(f"Registry reference for code_configuration.code is not supported for {caller}")
+
+
+def _resolve_group_inputs_for_component(component, **kwargs): # pylint: disable=unused-argument
+ # Try resolve object's inputs & outputs and return a resolved new object
+ from azure.ai.ml.entities._inputs_outputs import GroupInput
+
+ result = copy.copy(component)
+
+ flatten_inputs = {}
+ for key, val in result.inputs.items():
+ if isinstance(val, GroupInput):
+ flatten_inputs.update(val.flatten(group_parameter_name=key))
+ continue
+ flatten_inputs[key] = val
+
+ # Flatten group inputs
+ result._inputs = flatten_inputs # pylint: disable=protected-access
+ return result