# --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- # pylint: disable=protected-access,too-many-lines import copy import decimal import hashlib import json import logging import os import random import re import string import sys import tempfile import time import warnings from collections import OrderedDict from contextlib import contextmanager, nullcontext from datetime import timedelta from functools import singledispatch, wraps from os import PathLike from pathlib import Path, PureWindowsPath from typing import IO, Any, AnyStr, Callable, Dict, Iterable, List, Optional, Tuple, Union from urllib.parse import urlparse from uuid import UUID import isodate import pydash import yaml from azure.ai.ml._restclient.v2022_05_01.models import ListViewType, ManagedServiceIdentity from azure.ai.ml._scope_dependent_operations import OperationScope from azure.ai.ml._utils._http_utils import HttpPipeline from azure.ai.ml.constants._common import ( AZUREML_DISABLE_CONCURRENT_COMPONENT_REGISTRATION, AZUREML_DISABLE_ON_DISK_CACHE_ENV_VAR, AZUREML_INTERNAL_COMPONENTS_ENV_VAR, AZUREML_INTERNAL_COMPONENTS_SCHEMA_PREFIX, AZUREML_PRIVATE_FEATURES_ENV_VAR, CommonYamlFields, DefaultOpenEncoding, WorkspaceDiscoveryUrlKey, ) from azure.ai.ml.exceptions import MlException from azure.core.pipeline.policies import RetryPolicy module_logger = logging.getLogger(__name__) DEVELOPER_URL_MFE_ENV_VAR = "AZUREML_DEV_URL_MFE" # Prefix used when hitting MFE skipping ARM MFE_PATH_PREFIX = "/mferp/managementfrontend" def _get_mfe_url_override() -> Optional[str]: return os.getenv(DEVELOPER_URL_MFE_ENV_VAR) def _is_https_url(url: str) -> Union[bool, str]: if url: return url.lower().startswith("https") return False def _csv_parser(text: Optional[str], convert: Callable) -> Optional[str]: if not text: return None if "," in text: return ",".join(convert(t.strip()) for t in text.split(",")) return convert(text) def _snake_to_pascal_convert(text: str) -> str: return string.capwords(text.replace("_", " ")).replace(" ", "") def snake_to_pascal(text: Optional[str]) -> str: """Convert snake name to pascal. :param text: String to convert :type text: Optional[str] :return: * None if text is None * Converted text from snake_case to PascalCase :rtype: Optional[str] """ return _csv_parser(text, _snake_to_pascal_convert) def snake_to_kebab(text: Optional[str]) -> Optional[str]: """Convert snake name to kebab. :param text: String to convert :type text: Optional[str] :return: * None if text is None * Converted text from snake_case to kebab-case :rtype: Optional[str] """ if text: return re.sub("_", "-", text) return None # https://stackoverflow.com/questions/1175208 # This works for pascal to snake as well def _camel_to_snake_convert(text: str) -> str: text = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", text) return re.sub("([a-z0-9])([A-Z])", r"\1_\2", text).lower() def camel_to_snake(text: str) -> Optional[str]: """Convert camel name to snake. :param text: String to convert :type text: str :return: * None if text is None * Converted text from camelCase to snake_case :rtype: Optional[str] """ return _csv_parser(text, _camel_to_snake_convert) # This is snake to camel back which is different from snake to camel # https://stackoverflow.com/questions/19053707 def snake_to_camel(text: Optional[str]) -> Optional[str]: """Convert snake name to camel. :param text: String to convert :type text: Optional[str] :return: * None if text is None * Converted text from snake_case to camelCase :rtype: Optional[str] """ if text: return re.sub("_([a-zA-Z0-9])", lambda m: m.group(1).upper(), text) return None # This is real snake to camel def _snake_to_camel(name): return re.sub(r"(?:^|_)([a-z])", lambda x: x.group(1).upper(), name) def float_to_str(f): """Convert a float to a string without scientific notation. :param f: Float to convert :type f: float :return: String representation of the float :rtype: str """ with decimal.localcontext() as ctx: ctx.prec = 20 # Support up to 20 significant figures. float_as_dec = ctx.create_decimal(repr(f)) return format(float_as_dec, "f") def create_requests_pipeline_with_retry(*, requests_pipeline: HttpPipeline, retries: int = 3) -> HttpPipeline: """Creates an HttpPipeline that reuses the same configuration as the supplied pipeline (including the transport), but overwrites the retry policy. :keyword requests_pipeline: Pipeline to base new one off of. :paramtype requests_pipeline: HttpPipeline :keyword retries: Number of retries. Defaults to 3. :paramtype retries: int :return: Pipeline identical to provided one, except with a new retry policy :rtype: HttpPipeline """ return requests_pipeline.with_policies(retry_policy=get_retry_policy(num_retry=retries)) def get_retry_policy(num_retry: int = 3) -> RetryPolicy: """Retrieves a retry policy to use in an azure.core.pipeline.Pipeline :param num_retry: The number of retries :type num_retry: int :return: Returns the msrest or requests REST client retry policy. :rtype: RetryPolicy """ status_forcelist = [413, 429, 500, 502, 503, 504] backoff_factor = 0.4 return RetryPolicy( retry_total=num_retry, retry_read=num_retry, retry_connect=num_retry, retry_backoff_factor=backoff_factor, retry_on_status_codes=status_forcelist, ) def download_text_from_url( source_uri: str, requests_pipeline: HttpPipeline, timeout: Optional[Union[float, Tuple[float, float]]] = None, ) -> str: """Downloads the content from an URL. :param source_uri: URI to download :type source_uri: str :param requests_pipeline: Used to send the request :type requests_pipeline: HttpPipeline :param timeout: One of * float that specifies the connect and read time outs * a 2-tuple that specifies the connect and read time out in that order :type timeout: Union[float, Tuple[float, float]] :return: The Response text :rtype: str """ if not timeout: timeout_params = {} else: connect_timeout, read_timeout = timeout if isinstance(timeout, tuple) else (timeout, timeout) timeout_params = {"read_timeout": read_timeout, "connection_timeout": connect_timeout} response = requests_pipeline.get(source_uri, **timeout_params) # Match old behavior from execution service's status API. if response.status_code == 404: return "" # _raise_request_error(response, "Retrieving content from " + uri) return response.text() def load_file(file_path: str) -> str: """Load a local file. :param file_path: The relative or absolute path to the local file. :type file_path: str :raises ~azure.ai.ml.exceptions.ValidationException: Raised if file or folder cannot be found. :return: A string representation of the local file's contents. :rtype: str """ from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException # These imports can't be placed in at top file level because it will cause a circular import in # exceptions.py via _get_mfe_url_override try: with open(file_path, "r", encoding=DefaultOpenEncoding.READ) as f: cfg = f.read() except OSError as e: # FileNotFoundError introduced in Python 3 msg = "No such file or directory: {}" raise ValidationException( message=msg.format(file_path), no_personal_data_message=msg.format("[file_path]"), error_category=ErrorCategory.USER_ERROR, target=ErrorTarget.GENERAL, error_type=ValidationErrorType.FILE_OR_FOLDER_NOT_FOUND, ) from e return cfg def load_json(file_path: Optional[Union[str, os.PathLike]]) -> Dict: """Load a local json file. :param file_path: The relative or absolute path to the local file. :type file_path: Union[str, os.PathLike] :raises ~azure.ai.ml.exceptions.ValidationException: Raised if file or folder cannot be found. :return: A dictionary representation of the local file's contents. :rtype: Dict """ from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException # These imports can't be placed in at top file level because it will cause a circular import in # exceptions.py via _get_mfe_url_override try: with open(file_path, "r", encoding=DefaultOpenEncoding.READ) as f: cfg = json.load(f) except OSError as e: # FileNotFoundError introduced in Python 3 msg = "No such file or directory: {}" raise ValidationException( message=msg.format(file_path), no_personal_data_message=msg.format("[file_path]"), error_category=ErrorCategory.USER_ERROR, target=ErrorTarget.GENERAL, error_type=ValidationErrorType.FILE_OR_FOLDER_NOT_FOUND, ) from e return cfg def load_yaml(source: Optional[Union[AnyStr, PathLike, IO]]) -> Dict: # null check - just return an empty dict. # Certain CLI commands rely on this behavior to produce a resource # via CLI, which is then populated through CLArgs. """Load a local YAML file. :param source: Either * The relative or absolute path to the local file. * A readable File-like object :type source: Optional[Union[AnyStr, PathLike, IO]] :raises ~azure.ai.ml.exceptions.ValidationException: Raised if file or folder cannot be successfully loaded. Details will be provided in the error message. :return: A dictionary representation of the local file's contents. :rtype: Dict """ from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException # These imports can't be placed in at top file level because it will cause a circular import in # exceptions.py via _get_mfe_url_override if source is None: return {} if isinstance(source, (str, os.PathLike)): try: cm = open(source, "r", encoding=DefaultOpenEncoding.READ) except OSError as e: msg = "No such file or directory: {}" raise ValidationException( message=msg.format(source), no_personal_data_message=msg.format("[file_path]"), error_category=ErrorCategory.USER_ERROR, target=ErrorTarget.GENERAL, error_type=ValidationErrorType.FILE_OR_FOLDER_NOT_FOUND, ) from e else: # source is a subclass of IO if not source.readable(): msg = "File Permissions Error: The already-open \n\n inputted file is not readable." raise ValidationException( message=msg, no_personal_data_message="File Permissions error", error_category=ErrorCategory.USER_ERROR, target=ErrorTarget.GENERAL, error_type=ValidationErrorType.INVALID_VALUE, ) cm = nullcontext(enter_result=source) with cm as f: try: return yaml.safe_load(f) except yaml.YAMLError as e: msg = f"Error while parsing yaml file: {source} \n\n {str(e)}" raise ValidationException( message=msg, no_personal_data_message="Error while parsing yaml file", error_category=ErrorCategory.USER_ERROR, target=ErrorTarget.GENERAL, error_type=ValidationErrorType.CANNOT_PARSE, ) from e # pylint: disable-next=docstring-missing-param def dump_yaml(*args, **kwargs): """A thin wrapper over yaml.dump which forces `OrderedDict`s to be serialized as mappings. Otherwise behaves identically to yaml.dump :return: The yaml object :rtype: Any """ class OrderedDumper(yaml.Dumper): """A modified yaml serializer that forces pyyaml to represent an OrderedDict as a mapping instead of a sequence.""" OrderedDumper.add_representer(OrderedDict, yaml.representer.SafeRepresenter.represent_dict) return yaml.dump(*args, Dumper=OrderedDumper, **kwargs) def dump_yaml_to_file( dest: Optional[Union[AnyStr, PathLike, IO]], data_dict: Union[OrderedDict, Dict], default_flow_style=False, args=None, # pylint: disable=unused-argument **kwargs, ) -> None: """Dump dictionary to a local YAML file. :param dest: The relative or absolute path where the YAML dictionary will be dumped. :type dest: Optional[Union[AnyStr, PathLike, IO]] :param data_dict: Dictionary representing a YAML object :type data_dict: Union[OrderedDict, Dict] :param default_flow_style: Use flow style for formatting nested YAML collections instead of block style. Defaults to False. :type default_flow_style: bool :param path: Deprecated. Use 'dest' param instead. :type path: Optional[Union[AnyStr, PathLike]] :param args: Deprecated. :type: Any :raises ~azure.ai.ml.exceptions.ValidationException: Raised if object cannot be successfully dumped. Details will be provided in the error message. """ from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException # These imports can't be placed in at top file level because it will cause a circular import in # exceptions.py via _get_mfe_url_override # Check for deprecated path input, either named or as first unnamed input path = kwargs.pop("path", None) if dest is None: if path is not None: dest = path warnings.warn( "the 'path' input for dump functions is deprecated. Please use 'dest' instead.", DeprecationWarning ) else: msg = "No dump destination provided." raise ValidationException( message=msg, no_personal_data_message="No dump destination Provided", error_category=ErrorCategory.USER_ERROR, target=ErrorTarget.GENERAL, error_type=ValidationErrorType.MISSING_FIELD, ) if isinstance(dest, (str, os.PathLike)): try: cm = open(dest, "w", encoding=DefaultOpenEncoding.WRITE) except OSError as e: # FileNotFoundError introduced in Python 3 msg = "No such parent folder path or not a file path: {}" raise ValidationException( message=msg.format(dest), no_personal_data_message=msg.format("[file_path]"), error_category=ErrorCategory.USER_ERROR, target=ErrorTarget.GENERAL, error_type=ValidationErrorType.FILE_OR_FOLDER_NOT_FOUND, ) from e else: # dest is a subclass of IO if not dest.writable(): # dest is misformatted stream or file msg = "File Permissions Error: The already-open \n\n inputted file is not writable." raise ValidationException( message=msg, no_personal_data_message="File Permissions error", error_category=ErrorCategory.USER_ERROR, target=ErrorTarget.GENERAL, error_type=ValidationErrorType.CANNOT_PARSE, ) cm = nullcontext(enter_result=dest) with cm as f: try: dump_yaml(data_dict, f, default_flow_style=default_flow_style) except yaml.YAMLError as e: msg = f"Error while parsing yaml file \n\n {str(e)}" raise ValidationException( message=msg, no_personal_data_message="Error while parsing yaml file", error_category=ErrorCategory.USER_ERROR, target=ErrorTarget.GENERAL, error_type=ValidationErrorType.CANNOT_PARSE, ) from e def dict_eq(dict1: Dict[str, Any], dict2: Dict[str, Any]) -> bool: """Compare two dictionaries. :param dict1: The first dictionary :type dict1: Dict[str, Any] :param dict2: The second dictionary :type dict2: Dict[str, Any] :return: True if the two dictionaries are equal, False otherwise :rtype: bool """ if not dict1 and not dict2: return True return dict1 == dict2 def xor(a: Any, b: Any) -> bool: """XOR two values. :param a: The first value :type a: Any :param b: The second value :type b: Any :return: False if the two values are both True or both False, True otherwise :rtype: bool """ return bool(a) != bool(b) def is_url(value: Union[PathLike, str]) -> bool: """Check if a string is a valid URL. :param value: The string to check :type value: Union[PathLike, str] :return: True if the string is a valid URL, False otherwise :rtype: bool """ try: result = urlparse(str(value)) return all([result.scheme, result.netloc]) except ValueError: return False # Resolve an URL to long form if it is an azureml short from datastore URL, otherwise return the same value def resolve_short_datastore_url(value: Union[PathLike, str], workspace: OperationScope) -> str: """Resolve an URL to long form if it is an azureml short from datastore URL, otherwise return the same value. :param value: The URL to resolve :type value: Union[PathLike, str] :param workspace: The workspace :type workspace: OperationScope :return: The resolved URL :rtype: str """ from azure.ai.ml.exceptions import ValidationException # These imports can't be placed in at top file level because it will cause a circular import in # exceptions.py via _get_mfe_url_override try: # Check if the URL is an azureml URL if urlparse(str(value)).scheme == "azureml": from azure.ai.ml._utils._storage_utils import AzureMLDatastorePathUri data_store_path_uri = AzureMLDatastorePathUri(value) if data_store_path_uri.uri_type == "Datastore": return AzureMLDatastorePathUri(value).to_long_uri( subscription_id=workspace.subscription_id, resource_group_name=workspace.resource_group_name, workspace_name=workspace.workspace_name, ) except (ValueError, ValidationException): pass # If the URL is not an valid URL (e.g. a local path) or not an azureml URL # (e.g. a http URL), just return the same value return value def is_mlflow_uri(value: Union[PathLike, str]) -> bool: """Check if a string is a valid mlflow uri. :param value: The string to check :type value: Union[PathLike, str] :return: True if the string is a valid mlflow uri, False otherwise :rtype: bool """ try: return urlparse(str(value)).scheme == "runs" except ValueError: return False def validate_ml_flow_folder(path: str, model_type: string) -> None: """Validate that the path is a valid ml flow folder. :param path: The path to validate :type path: str :param model_type: The model type :type model_type: str :return: No return value :rtype: None :raises ~azure.ai.ml.exceptions.ValidationException: Raised if the path is not a valid ml flow folder. """ from azure.ai.ml.exceptions import ErrorTarget, ValidationErrorType, ValidationException # These imports can't be placed in at top file level because it will cause a circular import in # exceptions.py via _get_mfe_url_override if not isinstance(path, str): path = path.as_posix() path_array = path.split("/") if model_type != "mlflow_model" or "." not in path_array[-1]: return msg = "Error with path {}. Model of type mlflow_model cannot register a file." raise ValidationException( message=msg.format(path), no_personal_data_message=msg.format("[path]"), target=ErrorTarget.MODEL, error_type=ValidationErrorType.INVALID_VALUE, ) # modified from: https://stackoverflow.com/a/33245493/8093897 def is_valid_uuid(test_uuid: str) -> bool: """Check if a string is a valid UUID. :param test_uuid: The string to check :type test_uuid: str :return: True if the string is a valid UUID, False otherwise :rtype: bool """ try: uuid_obj = UUID(test_uuid, version=4) except ValueError: return False return str(uuid_obj) == test_uuid @singledispatch def from_iso_duration_format(duration: Optional[Any] = None) -> int: # pylint: disable=unused-argument """Convert ISO duration format to seconds. :param duration: The duration to convert :type duration: Optional[Any] :return: The converted duration :rtype: int """ return None @from_iso_duration_format.register(str) def _(duration: str) -> int: return int(isodate.parse_duration(duration).total_seconds()) @from_iso_duration_format.register(timedelta) def _(duration: timedelta) -> int: return int(duration.total_seconds()) def to_iso_duration_format_mins(time_in_mins: Optional[Union[int, float]]) -> str: """Convert minutes to ISO duration format. :param time_in_mins: The time in minutes to convert :type time_in_mins: Optional[Union[int, float]] :return: The converted time in ISO duration format :rtype: str """ return isodate.duration_isoformat(timedelta(minutes=time_in_mins)) if time_in_mins else None def from_iso_duration_format_mins(duration: Optional[str]) -> int: """Convert ISO duration format to minutes. :param duration: The duration to convert :type duration: Optional[str] :return: The converted duration :rtype: int """ return int(from_iso_duration_format(duration) / 60) if duration else None def to_iso_duration_format(time_in_seconds: Optional[Union[int, float]]) -> str: """Convert seconds to ISO duration format. :param time_in_seconds: The time in seconds to convert :type time_in_seconds: Optional[Union[int, float]] :return: The converted time in ISO duration format :rtype: str """ return isodate.duration_isoformat(timedelta(seconds=time_in_seconds)) if time_in_seconds else None def to_iso_duration_format_ms(time_in_ms: Optional[Union[int, float]]) -> str: """Convert milliseconds to ISO duration format. :param time_in_ms: The time in milliseconds to convert :type time_in_ms: Optional[Union[int, float]] :return: The converted time in ISO duration format :rtype: str """ return isodate.duration_isoformat(timedelta(milliseconds=time_in_ms)) if time_in_ms else None def from_iso_duration_format_ms(duration: Optional[str]) -> int: """Convert ISO duration format to milliseconds. :param duration: The duration to convert :type duration: Optional[str] :return: The converted duration :rtype: int """ return from_iso_duration_format(duration) * 1000 if duration else None def to_iso_duration_format_days(time_in_days: Optional[int]) -> str: """Convert days to ISO duration format. :param time_in_days: The time in days to convert :type time_in_days: Optional[int] :return: The converted time in ISO duration format :rtype: str """ return isodate.duration_isoformat(timedelta(days=time_in_days)) if time_in_days else None @singledispatch def from_iso_duration_format_days(duration: Optional[Any] = None) -> int: # pylint: disable=unused-argument return None @from_iso_duration_format_days.register(str) def _(duration: str) -> int: return int(isodate.parse_duration(duration).days) @from_iso_duration_format_days.register(timedelta) def _(duration: timedelta) -> int: return int(duration.days) def _get_base_urls_from_discovery_service( workspace_operations: "WorkspaceOperations", workspace_name: str, requests_pipeline: HttpPipeline ) -> Dict[WorkspaceDiscoveryUrlKey, str]: """Fetch base urls for a workspace from the discovery service. :param WorkspaceOperations workspace_operations: :param str workspace_name: The name of the workspace :param HttpPipeline requests_pipeline: An HTTP pipeline to make requests with :returns: A dictionary mapping url types to base urls :rtype: Dict[WorkspaceDiscoveryUrlKey,str] """ discovery_url = workspace_operations.get(workspace_name).discovery_url return json.loads( download_text_from_url( discovery_url, create_requests_pipeline_with_retry(requests_pipeline=requests_pipeline), ) ) def _get_mfe_base_url_from_discovery_service( workspace_operations: Any, workspace_name: str, requests_pipeline: HttpPipeline ) -> str: all_urls = _get_base_urls_from_discovery_service(workspace_operations, workspace_name, requests_pipeline) return f"{all_urls[WorkspaceDiscoveryUrlKey.API]}{MFE_PATH_PREFIX}" def _get_mfe_base_url_from_registry_discovery_service( workspace_operations: Any, workspace_name: str, requests_pipeline: HttpPipeline ) -> str: all_urls = _get_base_urls_from_discovery_service(workspace_operations, workspace_name, requests_pipeline) return all_urls[WorkspaceDiscoveryUrlKey.API] def _get_workspace_base_url(workspace_operations: Any, workspace_name: str, requests_pipeline: HttpPipeline) -> str: all_urls = _get_base_urls_from_discovery_service(workspace_operations, workspace_name, requests_pipeline) return all_urls[WorkspaceDiscoveryUrlKey.API] def _get_mfe_base_url_from_batch_endpoint(endpoint: "BatchEndpoint") -> str: return endpoint.scoring_uri.split("/subscriptions/")[0] # Allows to use a modified client with a provided url @contextmanager def modified_operation_client(operation_to_modify, url_to_use): """Modify the operation client to use a different url. :param operation_to_modify: The operation to modify :type operation_to_modify: Any :param url_to_use: The url to use :type url_to_use: str :return: The modified operation :rtype: Any """ original_api_base_url = None try: # Modify the operation if url_to_use: original_api_base_url = operation_to_modify._client._base_url operation_to_modify._client._base_url = url_to_use yield finally: # Undo the modification if original_api_base_url: operation_to_modify._client._base_url = original_api_base_url def from_iso_duration_format_min_sec(duration: Optional[str]) -> str: """Convert ISO duration format to min:sec format. :param duration: The duration to convert :type duration: Optional[str] :return: The converted duration :rtype: str """ return duration.split(".")[0].replace("PT", "").replace("M", "m ") + "s" def hash_dict(items: Dict[str, Any], keys_to_omit: Optional[Iterable[str]] = None) -> str: """Return hash GUID of a dictionary except keys_to_omit. :param items: The dict to hash :type items: Dict[str, Any] :param keys_to_omit: Keys to omit before hashing :type keys_to_omit: Optional[Iterable[str]] :return: The hash GUID of the dictionary :rtype: str """ if keys_to_omit is None: keys_to_omit = [] items = pydash.omit(items, keys_to_omit) # serialize dict with order so same dict will have same content serialized_component_interface = json.dumps(items, sort_keys=True) object_hash = hashlib.md5() # nosec object_hash.update(serialized_component_interface.encode("utf-8")) return str(UUID(object_hash.hexdigest())) def convert_identity_dict( identity: Optional[ManagedServiceIdentity] = None, ) -> ManagedServiceIdentity: """Convert identity to the right format. :param identity: The identity to convert :type identity: Optional[ManagedServiceIdentity] :return: The converted identity :rtype: ManagedServiceIdentity """ if identity: if identity.type.lower() in ("system_assigned", "none"): identity = ManagedServiceIdentity(type="SystemAssigned") else: if identity.user_assigned_identities: if isinstance(identity.user_assigned_identities, dict): # if the identity is already in right format return identity ids = {} for id in identity.user_assigned_identities: # pylint: disable=redefined-builtin ids[id["resource_id"]] = {} identity.user_assigned_identities = ids identity.type = snake_to_camel(identity.type) else: identity = ManagedServiceIdentity(type="SystemAssigned") return identity def strip_double_curly(io_binding_val: str) -> str: """Strip double curly brackets from a string. :param io_binding_val: The string to strip :type io_binding_val: str :return: The string with double curly brackets stripped :rtype: str """ return io_binding_val.replace("${{", "").replace("}}", "") def append_double_curly(io_binding_val: str) -> str: """Append double curly brackets to a string. :param io_binding_val: The string to append to :type io_binding_val: str :return: The string with double curly brackets appended :rtype: str """ return f"${{{{{io_binding_val}}}}}" def map_single_brackets_and_warn(command: str) -> str: """Map single brackets to double brackets and warn if found. :param command: The command to map :type command: str :return: The mapped command :rtype: str """ def _check_for_parameter(param_prefix: str, command_string: str) -> Tuple[bool, str]: template_prefix = r"(?<!\{)\{" template_suffix = r"\.([^}]*)\}(?!\})" template = template_prefix + param_prefix + template_suffix should_warn = False if bool(re.search(template, command_string)): should_warn = True command_string = re.sub(template, r"${{" + param_prefix + r".\g<1>}}", command_string) return (should_warn, command_string) input_warn, command = _check_for_parameter("inputs", command) output_warn, command = _check_for_parameter("outputs", command) sweep_warn, command = _check_for_parameter("search_space", command) if input_warn or output_warn or sweep_warn: module_logger.warning("Use of {} for parameters is deprecated, instead use ${{}}.") return command def transform_dict_keys(data: Dict[str, Any], casing_transform: Callable[[str], str]) -> Dict[str, Any]: """Convert all keys of a nested dictionary according to the passed casing_transform function. :param data: The data to transform :type data: Dict[str, Any] :param casing_transform: A callable applied to all keys in data :type casing_transform: Callable[[str], str] :return: A dictionary with transformed keys :rtype: dict """ return { casing_transform(key): transform_dict_keys(val, casing_transform) if isinstance(val, dict) else val for key, val in data.items() } def merge_dict(origin, delta, dep=0) -> dict: """Merge two dicts recursively. Note that the function will return a copy of the origin dict if the depth of the recursion is 0. :param origin: The original dictionary :type origin: dict :param delta: The delta dictionary :type delta: dict :param dep: The depth of the recursion :type dep: int :return: The merged dictionary :rtype: dict """ result = copy.deepcopy(origin) if dep == 0 else origin for key, val in delta.items(): origin_val = origin.get(key) # Merge delta dict with original dict if isinstance(origin_val, dict) and isinstance(val, dict): result[key] = merge_dict(origin_val, val, dep + 1) continue result[key] = copy.deepcopy(val) return result def retry( exceptions: Union[Tuple[Exception], Exception], failure_msg: str, logger: Any, max_attempts: int = 1, delay_multiplier: int = 0.25, ) -> Callable: """Retry a function if it fails. :param exceptions: Exceptions to retry on. :type exceptions: Union[Tuple[Exception], Exception] :param failure_msg: Message to log on failure. :type failure_msg: str :param logger: Logger to use. :type logger: Any :param max_attempts: Maximum number of attempts. :type max_attempts: int :param delay_multiplier: Multiplier for delay between attempts. :type delay_multiplier: int :return: Decorated function. :rtype: Callable """ def retry_decorator(f): @wraps(f) def func_with_retries(*args, **kwargs): # pylint: disable=inconsistent-return-statements tries = max_attempts + 1 counter = 1 while tries > 1: delay = delay_multiplier * 2**counter + random.uniform(0, 1) try: return f(*args, **kwargs) except exceptions as e: tries -= 1 counter += 1 if tries == 1: logger.warning(failure_msg) raise e logger.info(f"Operation failed. Retrying in {delay} seconds.") time.sleep(delay) return func_with_retries return retry_decorator def get_list_view_type(include_archived: bool, archived_only: bool) -> ListViewType: """Get the list view type based on the include_archived and archived_only flags. :param include_archived: Whether to include archived items. :type include_archived: bool :param archived_only: Whether to only include archived items. :type archived_only: bool :return: The list view type. :rtype: ListViewType """ if include_archived and archived_only: msg = "Cannot provide both archived-only and include-archived." raise MlException(message=msg, no_personal_data_message=msg) if include_archived: return ListViewType.ALL if archived_only: return ListViewType.ARCHIVED_ONLY return ListViewType.ACTIVE_ONLY def is_data_binding_expression( value: str, binding_prefix: Union[str, List[str]] = "", is_singular: bool = True ) -> bool: """Check if a value is a data-binding expression with specific binding target(prefix). Note that the function will return False if the value is not a str. For example, if binding_prefix is ["parent", "jobs"], then input_value is a data-binding expression only if the binding target starts with "parent.jobs", like "${{parent.jobs.xxx}}" if is_singular is False, return True even if input_value includes non-binding part or multiple binding targets, like "${{parent.jobs.xxx}}_extra" and "${{parent.jobs.xxx}}_{{parent.jobs.xxx}}". :param value: Value to check. :type value: str :param binding_prefix: Prefix to check for. :type binding_prefix: Union[str, List[str]] :param is_singular: should the value be a singular data-binding expression, like "${{parent.jobs.xxx}}". :type is_singular: bool :return: True if the value is a data-binding expression, False otherwise. :rtype: bool """ return len(get_all_data_binding_expressions(value, binding_prefix, is_singular)) > 0 def get_all_data_binding_expressions( value: str, binding_prefix: Union[str, List[str]] = "", is_singular: bool = True ) -> List[str]: """Get all data-binding expressions in a value with specific binding target(prefix). Note that the function will return an empty list if the value is not a str. :param value: Value to extract. :type value: str :param binding_prefix: Prefix to filter. :type binding_prefix: Union[str, List[str]] :param is_singular: should the value be a singular data-binding expression, like "${{parent.jobs.xxx}}". :type is_singular: bool :return: list of data-binding expressions. :rtype: List[str] """ if isinstance(binding_prefix, str): binding_prefix = [binding_prefix] if isinstance(value, str): target_regex = r"\$\{\{\s*(" + "\\.".join(binding_prefix) + r"\S*?)\s*\}\}" if is_singular: target_regex = "^" + target_regex + "$" return re.findall(target_regex, value) return [] def is_private_preview_enabled(): """Check if private preview features are enabled. :return: True if private preview features are enabled, False otherwise. :rtype: bool """ return os.getenv(AZUREML_PRIVATE_FEATURES_ENV_VAR) in ["True", "true", True] def is_bytecode_optimization_enabled(): """Check if bytecode optimization is enabled: 1) bytecode package is installed 2) private preview is enabled 3) python version is between 3.6 and 3.11 :return: True if bytecode optimization is enabled, False otherwise. :rtype: bool """ try: import bytecode # pylint: disable=unused-import return is_private_preview_enabled() and (3, 6) < sys.version_info < (3, 12) except ImportError: return False def is_on_disk_cache_enabled(): """Check if on-disk cache for component registrations in pipeline submission is enabled. :return: True if on-disk cache is enabled, False otherwise. :rtype: bool """ return os.getenv(AZUREML_DISABLE_ON_DISK_CACHE_ENV_VAR) not in ["True", "true", True] def is_concurrent_component_registration_enabled(): # pylint: disable=name-too-long """Check if concurrent component registrations in pipeline submission is enabled. :return: True if concurrent component registration is enabled, False otherwise. :rtype: bool """ return os.getenv(AZUREML_DISABLE_CONCURRENT_COMPONENT_REGISTRATION) not in ["True", "true", True] def _is_internal_components_enabled(): return os.getenv(AZUREML_INTERNAL_COMPONENTS_ENV_VAR) in ["True", "true", True] def try_enable_internal_components(*, force=False) -> bool: """Try to enable internal components for the current process. This is the only function outside _internal that references _internal. :keyword force: Force enable internal components even if enabled before. :type force: bool :return: True if internal components are enabled, False otherwise. :rtype: bool """ if _is_internal_components_enabled(): from azure.ai.ml._internal import enable_internal_components_in_pipeline enable_internal_components_in_pipeline(force=force) return True return False def is_internal_component_data(data: Dict[str, Any], *, raise_if_not_enabled: bool = False) -> bool: """Check if the data is an internal component data by checking schema url prefix. :param data: The data to check. :type data: Dict[str, Any] :keyword raise_if_not_enabled: Raise exception if the data is an internal component data but internal components is not enabled. :type raise_if_not_enabled: bool :return: True if the data is an internal component data, False otherwise. :rtype: bool :raises ~azure.ai.ml.exceptions.ValidationException: Raised if the data is an internal component data but internal components is not enabled. """ from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException # These imports can't be placed in at top file level because it will cause a circular import in # exceptions.py via _get_mfe_url_override schema = data.get(CommonYamlFields.SCHEMA, None) if schema is None or not isinstance(schema, str): return False if not schema.startswith(AZUREML_INTERNAL_COMPONENTS_SCHEMA_PREFIX): return False if not _is_internal_components_enabled() and raise_if_not_enabled: no_personal_data_message = ( f"Internal components is a private feature in v2, please set environment variable " f"{AZUREML_INTERNAL_COMPONENTS_ENV_VAR} to true to use it." ) msg = f"Detected schema url {schema}. {no_personal_data_message}" raise ValidationException( message=msg, target=ErrorTarget.COMPONENT, error_type=ValidationErrorType.INVALID_VALUE, no_personal_data_message=no_personal_data_message, error_category=ErrorCategory.USER_ERROR, ) return True def is_valid_node_name(name: str) -> bool: """Check if `name` is a valid node name :param name: A node name :type name: str :return: Return True if the string is a valid Python identifier in lower ASCII range, False otherwise. :rtype: bool """ return isinstance(name, str) and name.isidentifier() and re.fullmatch(r"^[a-z_][a-z\d_]*", name) is not None def parse_args_description_from_docstring(docstring: str) -> Dict[str, str]: """Return arg descriptions in docstring with google style. e.g. docstring = ''' A pipeline with detailed docstring, including descriptions for inputs and outputs. In this pipeline docstring, there are descriptions for inputs and outputs Input/Output descriptions can infer from descriptions here. Args: job_in_path: a path parameter job_in_number: a number parameter with multi-line description job_in_int (int): a int parameter Other docstring xxxxxx random_key: random_value ''' return dict: args = { 'job_in_path': 'a path parameter', 'job_in_number': 'a number parameter with multi-line description', 'job_in_int': 'a int parameter' } :param docstring: A Google-style docstring :type docstring: str :return: A map of parameter names to parameter descriptions :rtype: Dict[str, str] """ args = {} if not isinstance(docstring, str): return args lines = [line.strip() for line in docstring.splitlines()] for index, line in enumerate(lines): if line.lower() == "args:": args_region = lines[index + 1 :] args_line_end = args_region.index("") if "" in args_region else len(args_region) args_region = args_region[0:args_line_end] while len(args_region) > 0 and ":" in args_region[0]: arg_line = args_region[0] colon_index = arg_line.index(":") arg, description = ( arg_line[0:colon_index].strip(), arg_line[colon_index + 1 :].strip(), ) # handle case like "param (float) : xxx" if "(" in arg: arg = arg[0 : arg.index("(")].strip() args[arg] = description args_region.pop(0) # handle multi-line description, assuming description has no colon inside. while len(args_region) > 0 and ":" not in args_region[0]: args[arg] += " " + args_region[0] args_region.pop(0) return args def convert_windows_path_to_unix(path: Union[str, PathLike]) -> str: """Convert a Windows path to a Unix path. :param path: A Windows path :type path: Union[str, os.PathLike] :return: A Unix path :rtype: str """ return PureWindowsPath(path).as_posix() def _is_user_error_from_status_code(http_status_code): return 400 <= http_status_code < 500 def _str_to_bool(s: str) -> bool: """Converts a string to a boolean Can be used as a type for argument in argparse, return argument's boolean value according to it's literal value. :param s: The string to convert :type s: str :return: True if s is "true" (case-insensitive), otherwise returns False. :rtype: bool """ if not isinstance(s, str): return False return s.lower() == "true" def _is_user_error_from_exception_type(e: Optional[Exception]) -> bool: """Determine whether if an exception is user error from it's exception type. :param e: An exception :type e: Optional[Exception] :return: True if exception is a user error :rtype: bool """ # Connection error happens on user's network failure, should be user error. # For OSError/IOError with error no 28: "No space left on device" should be sdk user error return isinstance(e, (ConnectionError, KeyboardInterrupt)) or (isinstance(e, (IOError, OSError)) and e.errno == 28) class DockerProxy: """A proxy class for docker module. It will raise a more user-friendly error message if docker module is not installed. """ def __getattribute__(self, name: str) -> Any: try: import docker return getattr(docker, name) except ModuleNotFoundError as e: msg = "Please install docker in the current python environment with `pip install docker` and try again." raise MlException(message=msg, no_personal_data_message=msg) from e def get_all_enum_values_iter(enum_type: type) -> Iterable[Any]: """Get all values of an enum type. :param enum_type: An "enum" (not necessary enum.Enum) :type enum_type: Type :return: An iterable of all of the attributes of `enum_type` :rtype: Iterable[Any] """ for key in dir(enum_type): if not key.startswith("_"): yield getattr(enum_type, key) def write_to_shared_file(file_path: Union[str, PathLike], content: str): """Open file with specific mode and return the file object. :param file_path: Path to the file. :type file_path: Union[str, os.PathLike] :param content: Content to write to the file. :type content: str """ with open(file_path, "w", encoding=DefaultOpenEncoding.WRITE) as f: f.write(content) # share_mode means read/write for owner, group and others share_mode, mode_mask = 0o666, 0o777 if os.stat(file_path).st_mode & mode_mask != share_mode: try: os.chmod(file_path, share_mode) except PermissionError: pass def _get_valid_dot_keys_with_wildcard_impl( left_reversed_parts, root, *, validate_func=None, cur_node=None, processed_parts=None ) -> List[str]: if len(left_reversed_parts) == 0: if validate_func is None or validate_func(root, processed_parts): return [".".join(processed_parts)] return [] if cur_node is None: cur_node = root if not isinstance(cur_node, dict): return [] if processed_parts is None: processed_parts = [] key: str = left_reversed_parts.pop() result = [] if key == "*": for next_key in cur_node: if not isinstance(next_key, str): continue processed_parts.append(next_key) result.extend( _get_valid_dot_keys_with_wildcard_impl( left_reversed_parts, root, validate_func=validate_func, cur_node=cur_node[next_key], processed_parts=processed_parts, ) ) processed_parts.pop() elif key in cur_node: processed_parts.append(key) result = _get_valid_dot_keys_with_wildcard_impl( left_reversed_parts, root, validate_func=validate_func, cur_node=cur_node[key], processed_parts=processed_parts, ) processed_parts.pop() left_reversed_parts.append(key) return result def get_valid_dot_keys_with_wildcard( root: Dict[str, Any], dot_key_wildcard: str, *, validate_func: Optional[Callable[[List[str], Dict[str, Any]], bool]] = None, ) -> List[str]: """Get all valid dot keys with wildcard. Only "x.*.x" and "x.*" is supported for now. A valid dot key should satisfy the following conditions: 1) It should be a valid dot key in the root node. 2) It should satisfy the validation function. :param root: Root node. :type root: Dict[str, Any] :param dot_key_wildcard: Dot key with wildcard, e.g. "a.*.c". :type dot_key_wildcard: str :keyword validate_func: Validation function. It takes two parameters: the root node and the dot key parts. If None, no validation will be performed. :paramtype validate_func: Optional[Callable[[List[str], Dict[str, Any]], bool]] :return: List of valid dot keys. :rtype: List[str] """ left_reversed_parts = dot_key_wildcard.split(".")[::-1] return _get_valid_dot_keys_with_wildcard_impl(left_reversed_parts, root, validate_func=validate_func) def get_base_directory_for_cache() -> Path: """Get the base directory for cache files. :return: The base directory for cache files. :rtype: Path """ return Path(tempfile.gettempdir()).joinpath("azure-ai-ml") def get_versioned_base_directory_for_cache() -> Path: """Get the base directory for cache files of current version of azure-ai-ml. Cache files of different versions will be stored in different directories. :return: The base directory for cache files of current version of azure-ai-ml. :rtype: Path """ # import here to avoid circular import from azure.ai.ml._version import VERSION return get_base_directory_for_cache().joinpath(VERSION) # pylint: disable-next=name-too-long def get_resource_and_group_name_from_resource_id(armstr: str) -> str: if armstr.find("/") == -1: return armstr, None return armstr.split("/")[-1], armstr.split("/")[-5] # pylint: disable-next=name-too-long def get_resource_group_name_from_resource_group_id(armstr: str) -> str: if armstr.find("/") == -1: return armstr return armstr.split("/")[-1] def extract_name_and_version(azureml_id: str) -> Dict[str, str]: """Extract name and version from azureml id. :param azureml_id: AzureML id. :type azureml_id: str :return: A dict of name and version. :rtype: Dict[str, str] """ if not isinstance(azureml_id, str): raise ValueError("azureml_id should be a string but got {}: {}.".format(type(azureml_id), azureml_id)) if azureml_id.count(":") != 1: raise ValueError("azureml_id should be in the format of name:version but got {}.".format(azureml_id)) name, version = azureml_id.split(":") return { "name": name, "version": version, } def _get_evaluator_properties(): return {"is-promptflow": "true", "is-evaluator": "true"} def _is_evaluator(properties: Dict[str, str]) -> bool: return properties.get("is-evaluator") == "true" and properties.get("is-promptflow") == "true"