diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/utils.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/utils.py | 94 |
1 files changed, 94 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/utils.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/utils.py new file mode 100644 index 00000000..c1ee3568 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/utils.py @@ -0,0 +1,94 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import copy +import logging +import re +from collections import OrderedDict +from typing import Any, Dict, Optional, Union + +from marshmallow.exceptions import ValidationError + +module_logger = logging.getLogger(__name__) + + +class ArmId(str): + def __new__(cls, content): + validate_arm_str(content) + return str.__new__(cls, content) + + +def validate_arm_str(arm_str: Union[ArmId, str]) -> bool: + """Validate whether the given string is in fact in the format of an ARM ID. + + :param arm_str: The string to validate. + :type arm_str: Either a string (in case of incorrect formatting) or ArmID (in case of correct formatting). + :returns: True if the string is correctly formatted, False otherwise. + :rtype: bool + """ + reg_str = ( + r"/subscriptions/[0-9a-f]{8}-([0-9a-f]{4}-){3}[0-9a-f]{12}?/resourcegroups/.*/providers/[a-z.a-z]*/[a-z]*/.*" + ) + lowered = arm_str.lower() + match = re.match(reg_str, lowered) + if match and match.group() == lowered: + return True + raise ValidationError(f"ARM string {arm_str} is not formatted correctly.") + + +def get_subnet_str(vnet_name: str, subnet: str, sub_id: Optional[str] = None, rg: Optional[str] = None) -> str: + if vnet_name and not subnet: + raise ValidationError("Subnet is required when vnet name is specified.") + try: + validate_arm_str(subnet) + return subnet + except ValidationError: + return ( + f"/subscriptions/{sub_id}/resourceGroups/{rg}/" + f"providers/Microsoft.Network/virtualNetworks/{vnet_name}/subnets/{subnet}" + ) + + +def replace_key_in_odict(odict: OrderedDict, old_key: Any, new_key: Any): + if not odict or old_key not in odict: + return odict + return OrderedDict([(new_key, v) if k == old_key else (k, v) for k, v in odict.items()]) + + +# This is temporary until deployments(batch/K8S) support registry references +def exit_if_registry_assets(data: Dict, caller: str) -> None: + startswith = "azureml://registries/" + if ( + "environment" in data + and data["environment"] + and isinstance(data["environment"], str) + and data["environment"].startswith(startswith) + ): + raise ValidationError(f"Registry reference for environments is not supported for {caller}") + if "model" in data and data["model"] and isinstance(data["model"], str) and data["model"].startswith(startswith): + raise ValidationError(f"Registry reference for models is not supported for {caller}") + if ( + "code_configuration" in data + and data["code_configuration"].code + and isinstance(data["code_configuration"].code, str) + and data["code_configuration"].code.startswith(startswith) + ): + raise ValidationError(f"Registry reference for code_configuration.code is not supported for {caller}") + + +def _resolve_group_inputs_for_component(component, **kwargs): # pylint: disable=unused-argument + # Try resolve object's inputs & outputs and return a resolved new object + from azure.ai.ml.entities._inputs_outputs import GroupInput + + result = copy.copy(component) + + flatten_inputs = {} + for key, val in result.inputs.items(): + if isinstance(val, GroupInput): + flatten_inputs.update(val.flatten(group_parameter_name=key)) + continue + flatten_inputs[key] = val + + # Flatten group inputs + result._inputs = flatten_inputs # pylint: disable=protected-access + return result |