diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils')
3 files changed, 187 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/__init__.py new file mode 100644 index 00000000..29a4fcd3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/data_binding_expression.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/data_binding_expression.py new file mode 100644 index 00000000..611c80a2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/data_binding_expression.py @@ -0,0 +1,88 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from typing import Union + +from marshmallow import Schema, fields + +from azure.ai.ml._schema.core.fields import DataBindingStr, ExperimentalField, NestedField, UnionField +from azure.ai.ml._schema.core.schema import PathAwareSchema + +DATA_BINDING_SUPPORTED_KEY = "_data_binding_supported" + + +def _is_literal(field): + return not isinstance(field, (NestedField, fields.List, fields.Dict, UnionField)) + + +def _add_data_binding_to_field(field, attrs_to_skip, schema_stack): + if hasattr(field, DATA_BINDING_SUPPORTED_KEY) and getattr(field, DATA_BINDING_SUPPORTED_KEY): + return field + data_binding_field = DataBindingStr() + if isinstance(field, UnionField): + for field_obj in field.union_fields: + if not _is_literal(field_obj): + _add_data_binding_to_field(field_obj, attrs_to_skip, schema_stack=schema_stack) + field.insert_union_field(data_binding_field) + elif isinstance(field, fields.Dict): + # handle dict, dict value can be None + if field.value_field is not None: + field.value_field = _add_data_binding_to_field(field.value_field, attrs_to_skip, schema_stack=schema_stack) + elif isinstance(field, fields.List): + # handle list + field.inner = _add_data_binding_to_field(field.inner, attrs_to_skip, schema_stack=schema_stack) + elif isinstance(field, ExperimentalField): + field = ExperimentalField( + _add_data_binding_to_field(field.experimental_field, attrs_to_skip, schema_stack=schema_stack), + data_key=field.data_key, + attribute=field.attribute, + dump_only=field.dump_only, + required=field.required, + allow_none=field.allow_none, + ) + elif isinstance(field, NestedField): + # handle nested field + support_data_binding_expression_for_fields(field.schema, attrs_to_skip, schema_stack=schema_stack) + else: + # change basic fields to union + field = UnionField( + [data_binding_field, field], + data_key=field.data_key, + attribute=field.attribute, + dump_only=field.dump_only, + required=field.required, + allow_none=field.allow_none, + ) + + setattr(field, DATA_BINDING_SUPPORTED_KEY, True) + return field + + +# pylint: disable-next=docstring-missing-param +def support_data_binding_expression_for_fields( # pylint: disable=name-too-long + schema: Union[PathAwareSchema, Schema], attrs_to_skip=None, schema_stack=None +): + """Update fields inside schema to support data binding string. + + Only first layer of recursive schema is supported now. + """ + if hasattr(schema, DATA_BINDING_SUPPORTED_KEY) and getattr(schema, DATA_BINDING_SUPPORTED_KEY): + return + + setattr(schema, DATA_BINDING_SUPPORTED_KEY, True) + + if attrs_to_skip is None: + attrs_to_skip = [] + if schema_stack is None: + schema_stack = [] + schema_type_name = type(schema).__name__ + if schema_type_name in schema_stack: + return + schema_stack.append(schema_type_name) + for attr, field_obj in schema.load_fields.items(): + if attr not in attrs_to_skip: + schema.load_fields[attr] = _add_data_binding_to_field(field_obj, attrs_to_skip, schema_stack=schema_stack) + for attr, field_obj in schema.dump_fields.items(): + if attr not in attrs_to_skip: + schema.dump_fields[attr] = _add_data_binding_to_field(field_obj, attrs_to_skip, schema_stack=schema_stack) + schema_stack.pop() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/utils.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/utils.py new file mode 100644 index 00000000..c1ee3568 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/utils.py @@ -0,0 +1,94 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import copy +import logging +import re +from collections import OrderedDict +from typing import Any, Dict, Optional, Union + +from marshmallow.exceptions import ValidationError + +module_logger = logging.getLogger(__name__) + + +class ArmId(str): + def __new__(cls, content): + validate_arm_str(content) + return str.__new__(cls, content) + + +def validate_arm_str(arm_str: Union[ArmId, str]) -> bool: + """Validate whether the given string is in fact in the format of an ARM ID. + + :param arm_str: The string to validate. + :type arm_str: Either a string (in case of incorrect formatting) or ArmID (in case of correct formatting). + :returns: True if the string is correctly formatted, False otherwise. + :rtype: bool + """ + reg_str = ( + r"/subscriptions/[0-9a-f]{8}-([0-9a-f]{4}-){3}[0-9a-f]{12}?/resourcegroups/.*/providers/[a-z.a-z]*/[a-z]*/.*" + ) + lowered = arm_str.lower() + match = re.match(reg_str, lowered) + if match and match.group() == lowered: + return True + raise ValidationError(f"ARM string {arm_str} is not formatted correctly.") + + +def get_subnet_str(vnet_name: str, subnet: str, sub_id: Optional[str] = None, rg: Optional[str] = None) -> str: + if vnet_name and not subnet: + raise ValidationError("Subnet is required when vnet name is specified.") + try: + validate_arm_str(subnet) + return subnet + except ValidationError: + return ( + f"/subscriptions/{sub_id}/resourceGroups/{rg}/" + f"providers/Microsoft.Network/virtualNetworks/{vnet_name}/subnets/{subnet}" + ) + + +def replace_key_in_odict(odict: OrderedDict, old_key: Any, new_key: Any): + if not odict or old_key not in odict: + return odict + return OrderedDict([(new_key, v) if k == old_key else (k, v) for k, v in odict.items()]) + + +# This is temporary until deployments(batch/K8S) support registry references +def exit_if_registry_assets(data: Dict, caller: str) -> None: + startswith = "azureml://registries/" + if ( + "environment" in data + and data["environment"] + and isinstance(data["environment"], str) + and data["environment"].startswith(startswith) + ): + raise ValidationError(f"Registry reference for environments is not supported for {caller}") + if "model" in data and data["model"] and isinstance(data["model"], str) and data["model"].startswith(startswith): + raise ValidationError(f"Registry reference for models is not supported for {caller}") + if ( + "code_configuration" in data + and data["code_configuration"].code + and isinstance(data["code_configuration"].code, str) + and data["code_configuration"].code.startswith(startswith) + ): + raise ValidationError(f"Registry reference for code_configuration.code is not supported for {caller}") + + +def _resolve_group_inputs_for_component(component, **kwargs): # pylint: disable=unused-argument + # Try resolve object's inputs & outputs and return a resolved new object + from azure.ai.ml.entities._inputs_outputs import GroupInput + + result = copy.copy(component) + + flatten_inputs = {} + for key, val in result.inputs.items(): + if isinstance(val, GroupInput): + flatten_inputs.update(val.flatten(group_parameter_name=key)) + continue + flatten_inputs[key] = val + + # Flatten group inputs + result._inputs = flatten_inputs # pylint: disable=protected-access + return result |