diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation')
5 files changed, 920 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/__init__.py new file mode 100644 index 00000000..29ba05c5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/__init__.py @@ -0,0 +1,18 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +from .core import MutableValidationResult, ValidationResult, ValidationResultBuilder +from .path_aware_schema import PathAwareSchemaValidatableMixin +from .remote import RemoteValidatableMixin +from .schema import SchemaValidatableMixin + +__all__ = [ + "SchemaValidatableMixin", + "PathAwareSchemaValidatableMixin", + "RemoteValidatableMixin", + "MutableValidationResult", + "ValidationResult", + "ValidationResultBuilder", +] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/core.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/core.py new file mode 100644 index 00000000..a7516c1d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/core.py @@ -0,0 +1,531 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import copy +import json +import logging +import os.path +import typing +from os import PathLike +from pathlib import Path +from typing import IO, Any, AnyStr, Dict, List, Optional, Tuple, Union, cast + +import pydash +import strictyaml +from marshmallow import ValidationError + +module_logger = logging.getLogger(__name__) + + +class _ValidationStatus: + """Validation status class. + + Validation status is used to indicate the status of an validation result. It can be one of the following values: + Succeeded, Failed. + """ + + SUCCEEDED = "Succeeded" + """Succeeded.""" + FAILED = "Failed" + """Failed.""" + + +class Diagnostic(object): + """Represents a diagnostic of an asset validation error with the location info.""" + + def __init__(self, yaml_path: str, message: Optional[str], error_code: Optional[str]) -> None: + """Init Diagnostic. + + :keyword yaml_path: A dash path from root to the target element of the diagnostic. jobs.job_a.inputs.input_str + :paramtype yaml_path: str + :keyword message: Error message of diagnostic. + :paramtype message: str + :keyword error_code: Error code of diagnostic. + :paramtype error_code: str + """ + self.yaml_path = yaml_path + self.message = message + self.error_code = error_code + self.local_path, self.value = None, None + + def __repr__(self) -> str: + """The asset friendly name and error message. + + :return: The formatted diagnostic + :rtype: str + """ + return "{}: {}".format(self.yaml_path, self.message) + + @classmethod + def create_instance( + cls, + yaml_path: str, + message: Optional[str] = None, + error_code: Optional[str] = None, + ) -> "Diagnostic": + """Create a diagnostic instance. + + :param yaml_path: A dash path from root to the target element of the diagnostic. jobs.job_a.inputs.input_str + :type yaml_path: str + :param message: Error message of diagnostic. + :type message: str + :param error_code: Error code of diagnostic. + :type error_code: str + :return: The created instance + :rtype: Diagnostic + """ + return cls( + yaml_path=yaml_path, + message=message, + error_code=error_code, + ) + + +class ValidationResult(object): + """Represents the result of job/asset validation. + + This class is used to organize and parse diagnostics from both client & server side before expose them. The result + is immutable. + """ + + def __init__(self) -> None: + self._target_obj: Optional[Dict] = None + self._errors: List = [] + self._warnings: List = [] + + @property + def error_messages(self) -> Dict: + """ + Return all messages of errors in the validation result. + + :return: A dictionary of error messages. The key is the yaml path of the error, and the value is the error + message. + :rtype: dict + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START validation_result] + :end-before: [END validation_result] + :language: markdown + :dedent: 8 + """ + messages = {} + for diagnostic in self._errors: + if diagnostic.yaml_path not in messages: + messages[diagnostic.yaml_path] = diagnostic.message + else: + messages[diagnostic.yaml_path] += "; " + diagnostic.message + return messages + + @property + def passed(self) -> bool: + """Returns boolean indicating whether any errors were found. + + :return: True if the validation passed, False otherwise. + :rtype: bool + """ + return not self._errors + + def _to_dict(self) -> typing.Dict[str, typing.Any]: + result: Dict = { + "result": _ValidationStatus.SUCCEEDED if self.passed else _ValidationStatus.FAILED, + } + for diagnostic_type, diagnostics in [ + ("errors", self._errors), + ("warnings", self._warnings), + ]: + messages = [] + for diagnostic in diagnostics: + message = { + "message": diagnostic.message, + "path": diagnostic.yaml_path, + "value": pydash.get(self._target_obj, diagnostic.yaml_path, diagnostic.value), + } + if diagnostic.local_path: + message["location"] = str(diagnostic.local_path) + messages.append(message) + if messages: + result[diagnostic_type] = messages + return result + + def __repr__(self) -> str: + """Get the string representation of the validation result. + + :return: The string representation + :rtype: str + """ + return json.dumps(self._to_dict(), indent=2) + + +class MutableValidationResult(ValidationResult): + """Used by the client side to construct a validation result. + + The result is mutable and should not be exposed to the user. + """ + + def __init__(self, target_obj: Optional[Dict] = None): + super().__init__() + self._target_obj = target_obj + + def merge_with( + self, + target: ValidationResult, + field_name: Optional[str] = None, + condition_skip: Optional[typing.Callable] = None, + overwrite: bool = False, + ) -> "MutableValidationResult": + """Merge errors & warnings in another validation results into current one. + + Will update current validation result. + If field_name is not None, then yaml_path in the other validation result will be updated accordingly. + * => field_name, jobs.job_a => field_name.jobs.job_a e.g.. If None, then no update. + + :param target: Validation result to merge. + :type target: ValidationResult + :param field_name: The base field name for the target to merge. + :type field_name: str + :param condition_skip: A function to determine whether to skip the merge of a diagnostic in the target. + :type condition_skip: typing.Callable + :param overwrite: Whether to overwrite the current validation result. If False, all diagnostics will be kept; + if True, current diagnostics with the same yaml_path will be dropped. + :type overwrite: bool + :return: The current validation result. + :rtype: MutableValidationResult + """ + for source_diagnostics, target_diagnostics in [ + (target._errors, self._errors), + (target._warnings, self._warnings), + ]: + if overwrite: + keys_to_remove = set(map(lambda x: x.yaml_path, source_diagnostics)) + target_diagnostics[:] = [ + diagnostic for diagnostic in target_diagnostics if diagnostic.yaml_path not in keys_to_remove + ] + for diagnostic in source_diagnostics: + if condition_skip and condition_skip(diagnostic): + continue + new_diagnostic = copy.deepcopy(diagnostic) + if field_name: + if new_diagnostic.yaml_path == "*": + new_diagnostic.yaml_path = field_name + else: + new_diagnostic.yaml_path = field_name + "." + new_diagnostic.yaml_path + target_diagnostics.append(new_diagnostic) + return self + + def try_raise( + self, + raise_error: Optional[bool] = True, + *, + error_func: Optional[typing.Callable[[str, str], Exception]] = None, + ) -> "MutableValidationResult": + """Try to raise an error from the validation result. + + If the validation is passed or raise_error is False, this method + will return the validation result. + + :param raise_error: Whether to raise the error. + :type raise_error: bool + :keyword error_func: A function to create the error. If None, a marshmallow.ValidationError will be created. + The first parameter of the function is the string representation of the validation result, + and the second parameter is the error message without personal data. + :type error_func: typing.Callable[[str, str], Exception] + :return: The current validation result. + :rtype: MutableValidationResult + """ + # pylint: disable=logging-not-lazy + if raise_error is False: + return self + + if self._warnings: + module_logger.warning("Warnings: %s" % str(self._warnings)) + + if not self.passed: + if error_func is None: + + def error_func(msg: Union[str, list, dict], _: Any) -> ValidationError: + return ValidationError(message=msg) + + raise error_func( + self.__repr__(), + "validation failed on the following fields: " + ", ".join(self.error_messages), + ) + return self + + def append_error( + self, + yaml_path: str = "*", + message: Optional[str] = None, + error_code: Optional[str] = None, + ) -> "MutableValidationResult": + """Append an error to the validation result. + + :param yaml_path: The yaml path of the error. + :type yaml_path: str + :param message: The message of the error. + :type message: str + :param error_code: The error code of the error. + :type error_code: str + :return: The current validation result. + :rtype: MutableValidationResult + """ + self._errors.append( + Diagnostic.create_instance( + yaml_path=yaml_path, + message=message, + error_code=error_code, + ) + ) + return self + + def resolve_location_for_diagnostics(self, source_path: str, resolve_value: bool = False) -> None: + """Resolve location/value for diagnostics based on the source path where the validatable object is loaded. + + Location includes local path of the exact file (can be different from the source path) & line number of the + invalid field. Value of a diagnostic is resolved from the validatable object in transfering to a dict by + default; however, when the validatable object is not available for the validation result, validation result is + created from marshmallow.ValidationError.messages e.g., it can be resolved from the source path. + + :param source_path: The path of the source file. + :type source_path: str + :param resolve_value: Whether to resolve the value of the invalid field from source file. + :type resolve_value: bool + """ + resolver = _YamlLocationResolver(source_path) + for diagnostic in self._errors + self._warnings: + res = resolver.resolve(diagnostic.yaml_path) + if res is not None: + diagnostic.local_path, value = res + if value is not None and resolve_value: + diagnostic.value = value + + def append_warning( + self, + yaml_path: str = "*", + message: Optional[str] = None, + error_code: Optional[str] = None, + ) -> "MutableValidationResult": + """Append a warning to the validation result. + + :param yaml_path: The yaml path of the warning. + :type yaml_path: str + :param message: The message of the warning. + :type message: str + :param error_code: The error code of the warning. + :type error_code: str + :return: The current validation result. + :rtype: MutableValidationResult + """ + self._warnings.append( + Diagnostic.create_instance( + yaml_path=yaml_path, + message=message, + error_code=error_code, + ) + ) + return self + + +class ValidationResultBuilder: + """A helper class to create a validation result.""" + + UNKNOWN_MESSAGE = "Unknown field." + + def __init__(self) -> None: + pass + + @classmethod + def success(cls) -> MutableValidationResult: + """Create a validation result with success status. + + :return: A validation result + :rtype: MutableValidationResult + """ + return MutableValidationResult() + + @classmethod + def from_single_message( + cls, singular_error_message: Optional[str] = None, yaml_path: str = "*", data: Optional[dict] = None + ) -> MutableValidationResult: + """Create a validation result with only 1 diagnostic. + + :param singular_error_message: diagnostic.message. + :type singular_error_message: Optional[str] + :param yaml_path: diagnostic.yaml_path. + :type yaml_path: str + :param data: serializedvalidation target. + :type data: Optional[Dict] + :return: The validation result + :rtype: MutableValidationResult + """ + obj = MutableValidationResult(target_obj=data) + if singular_error_message: + obj.append_error(message=singular_error_message, yaml_path=yaml_path) + return obj + + @classmethod + def from_validation_error( + cls, + error: ValidationError, + *, + source_path: Optional[Union[str, PathLike, IO[AnyStr]]] = None, + error_on_unknown_field: bool = False, + ) -> MutableValidationResult: + """Create a validation result from a ValidationError, which will be raised in marshmallow.Schema.load. Please + use this function only for exception in loading file. + + :param error: ValidationError raised by marshmallow.Schema.load. + :type error: ValidationError + :keyword source_path: The path to the source file. + :paramtype source_path: Optional[Union[str, PathLike, IO[AnyStr]]] + :keyword error_on_unknown_field: whether to raise error if there are unknown field diagnostics. + :paramtype error_on_unknown_field: bool + :return: The validation result + :rtype: MutableValidationResult + """ + obj = cls.from_validation_messages( + error.messages, data=error.data, error_on_unknown_field=error_on_unknown_field + ) + if source_path: + obj.resolve_location_for_diagnostics(cast(str, source_path), resolve_value=True) + return obj + + @classmethod + def from_validation_messages( + cls, errors: typing.Dict, data: typing.Dict, *, error_on_unknown_field: bool = False + ) -> MutableValidationResult: + """Create a validation result from error messages, which will be returned by marshmallow.Schema.validate. + + :param errors: error message returned by marshmallow.Schema.validate. + :type errors: dict + :param data: serialized data to validate + :type data: dict + :keyword error_on_unknown_field: whether to raise error if there are unknown field diagnostics. + :paramtype error_on_unknown_field: bool + :return: The validation result + :rtype: MutableValidationResult + """ + instance = MutableValidationResult(target_obj=data) + errors = copy.deepcopy(errors) + cls._from_validation_messages_recursively(errors, [], instance, error_on_unknown_field=error_on_unknown_field) + return instance + + @classmethod + def _from_validation_messages_recursively( + cls, + errors: typing.Union[typing.Dict, typing.List, str], + path_stack: typing.List[str], + instance: MutableValidationResult, + error_on_unknown_field: bool, + ) -> None: + cur_path = ".".join(path_stack) if path_stack else "*" + # single error message + if isinstance(errors, dict) and "_schema" in errors: + instance.append_error( + message=";".join(errors["_schema"]), + yaml_path=cur_path, + ) + # errors on attributes + elif isinstance(errors, dict): + for field, msgs in errors.items(): + # fields.Dict + if field in ["key", "value"]: + cls._from_validation_messages_recursively(msgs, path_stack, instance, error_on_unknown_field) + else: + # Todo: Add hack logic here to deal with error message in nested TypeSensitiveUnionField in + # DataTransfer: will be a nested dict with None field as dictionary key. + # open a item to track: https://msdata.visualstudio.com/Vienna/_workitems/edit/2244262/ + if field is None: + cls._from_validation_messages_recursively(msgs, path_stack, instance, error_on_unknown_field) + else: + path_stack.append(field) + cls._from_validation_messages_recursively(msgs, path_stack, instance, error_on_unknown_field) + path_stack.pop() + + # detailed error message + elif isinstance(errors, list) and all(isinstance(msg, str) for msg in errors): + if cls.UNKNOWN_MESSAGE in errors and not error_on_unknown_field: + # Unknown field is not a real error, so we should remove it and append a warning. + errors.remove(cls.UNKNOWN_MESSAGE) + instance.append_warning(message=cls.UNKNOWN_MESSAGE, yaml_path=cur_path) + if errors: + instance.append_error(message=";".join(errors), yaml_path=cur_path) + # union field + elif isinstance(errors, list): + + def msg2str(msg: Any) -> Any: + if isinstance(msg, str): + return msg + if isinstance(msg, dict) and len(msg) == 1 and "_schema" in msg and len(msg["_schema"]) == 1: + return str(msg["_schema"][0]) + + return str(msg) + + instance.append_error(message="; ".join([msg2str(x) for x in errors]), yaml_path=cur_path) + # unknown error + else: + instance.append_error(message=str(errors), yaml_path=cur_path) + + +class _YamlLocationResolver: + def __init__(self, source_path: str): + self._source_path = source_path + + def resolve(self, yaml_path: str, source_path: Optional[str] = None) -> Optional[Tuple]: + """Resolve the location & value of a yaml path starting from source_path. + + :param yaml_path: yaml path. + :type yaml_path: str + :param source_path: source path. + :type source_path: str + :return: the location & value of the yaml path based on source_path. + :rtype: Tuple[str, str] + """ + source_path = source_path or self._source_path + if source_path is None or not os.path.isfile(source_path): + return None, None + if yaml_path is None or yaml_path == "*": + return source_path, None + + attrs = yaml_path.split(".") + attrs.reverse() + + res: Optional[Tuple] = self._resolve_recursively(attrs, Path(source_path)) + return res + + def _resolve_recursively(self, attrs: List[str], source_path: Path) -> Optional[Tuple]: + with open(source_path, encoding="utf-8") as f: + try: + loaded_yaml = strictyaml.load(f.read()) + except Exception as e: # pylint: disable=W0718 + msg = "Can't load source file %s as a strict yaml:\n%s" % (source_path, str(e)) + module_logger.debug(msg) + return None, None + + while attrs: + attr = attrs[-1] + if loaded_yaml.is_mapping() and attr in loaded_yaml: + loaded_yaml = loaded_yaml.get(attr) + attrs.pop() + else: + try: + # if current object is a path of a valid yaml file, try to resolve location in new source file + next_path = Path(loaded_yaml.value) + if not next_path.is_absolute(): + next_path = source_path.parent / next_path + if next_path.is_file(): + return self._resolve_recursively(attrs, source_path=next_path) + except OSError: + pass + except TypeError: + pass + # if not, return current section + break + return ( + f"{source_path.resolve().absolute()}#line {loaded_yaml.start_line}", + None if attrs else loaded_yaml.value, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/path_aware_schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/path_aware_schema.py new file mode 100644 index 00000000..959de310 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/path_aware_schema.py @@ -0,0 +1,53 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import typing +from os import PathLike +from pathlib import Path + +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY + +from ..._schema import PathAwareSchema +from .._job.pipeline._attr_dict import try_get_non_arbitrary_attr +from .._util import convert_ordered_dict_to_dict +from .schema import SchemaValidatableMixin + + +class PathAwareSchemaValidatableMixin(SchemaValidatableMixin): + """The mixin class for schema validation. Entity classes inheriting from this class should have a base path + and a schema of PathAwareSchema. + """ + + @property + def __base_path_for_validation(self) -> typing.Union[str, PathLike]: + """Get the base path of the resource. + + It will try to return self.base_path, then self._base_path, then Path.cwd() if above attrs are non-existent or + `None. + + :return: The base path of the resource + :rtype: typing.Union[str, os.PathLike] + """ + return ( + try_get_non_arbitrary_attr(self, BASE_PATH_CONTEXT_KEY) + or try_get_non_arbitrary_attr(self, f"_{BASE_PATH_CONTEXT_KEY}") + or Path.cwd() + ) + + def _default_context(self) -> dict: + # Note that, although context can be passed, nested.schema will be initialized only once + # base_path works well because it's fixed after loaded + return {BASE_PATH_CONTEXT_KEY: self.__base_path_for_validation} + + @classmethod + def _create_schema_for_validation(cls, context: typing.Any) -> PathAwareSchema: + raise NotImplementedError() + + @classmethod + def _create_validation_error(cls, message: str, no_personal_data_message: str) -> Exception: + raise NotImplementedError() + + def _dump_for_validation(self) -> typing.Dict: + # this is not a necessary step but to keep the same behavior as before + # empty items will be removed when converting to dict + return typing.cast(dict, convert_ordered_dict_to_dict(super()._dump_for_validation())) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/remote.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/remote.py new file mode 100644 index 00000000..06f022a0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/remote.py @@ -0,0 +1,162 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import logging +import typing + +import msrest + +from azure.ai.ml._vendor.azure_resources.models import ( + Deployment, + DeploymentProperties, + DeploymentValidateResult, + ErrorResponse, +) +from azure.ai.ml.entities._mixins import RestTranslatableMixin + +from .core import MutableValidationResult, ValidationResultBuilder + +module_logger = logging.getLogger(__name__) + + +class PreflightResource(msrest.serialization.Model): + """Specified resource. + + Variables are only populated by the server, and will be ignored when sending a request. + + :ivar id: Resource ID. + :vartype id: str + :ivar name: Resource name. + :vartype name: str + :ivar type: Resource type. + :vartype type: str + :param location: Resource location. + :type location: str + :param tags: A set of tags. Resource tags. + :type tags: dict[str, str] + """ + + _attribute_map = { + "type": {"key": "type", "type": "str"}, + "name": {"key": "name", "type": "str"}, + "location": {"key": "location", "type": "str"}, + "api_version": {"key": "apiversion", "type": "str"}, + "properties": {"key": "properties", "type": "object"}, + } + + def __init__(self, **kwargs: typing.Any): + super(PreflightResource, self).__init__(**kwargs) + self.name = kwargs.get("name", None) + self.type = kwargs.get("type", None) + self.location = kwargs.get("location", None) + self.properties = kwargs.get("properties", None) + self.api_version = kwargs.get("api_version", None) + + +class ValidationTemplateRequest(msrest.serialization.Model): + """Export resource group template request parameters. + + :param resources: The rest objects to be validated. + :type resources: list[_models.Resource] + :param options: The export template options. A CSV-formatted list containing zero or more of + the following: 'IncludeParameterDefaultValue', 'IncludeComments', + 'SkipResourceNameParameterization', 'SkipAllParameterization'. + :type options: str + """ + + _attribute_map = { + "resources": {"key": "resources", "type": "[PreflightResource]"}, + "content_version": {"key": "contentVersion", "type": "str"}, + "parameters": {"key": "parameters", "type": "object"}, + "_schema": { + "key": "$schema", + "type": "str", + "default": "https://schema.management.azure.com/schemas/2019-04-01/deploymentTemplate.json#", + }, + } + + def __init__(self, **kwargs: typing.Any): + super(ValidationTemplateRequest, self).__init__(**kwargs) + self._schema = kwargs.get("_schema", None) + self.content_version = kwargs.get("content_version", None) + self.parameters = kwargs.get("parameters", None) + self.resources = kwargs.get("resources", None) + + +class RemoteValidatableMixin(RestTranslatableMixin): + @classmethod + def _get_resource_type(cls) -> str: + """Return resource type to be used in remote validation. + + Should be overridden by subclass. + + :return: The resource type + :rtype: str + """ + raise NotImplementedError() + + def _get_resource_name_version(self) -> typing.Tuple: + """Return resource name and version to be used in remote validation. + + Should be overridden by subclass. + + :return: The name and version + :rtype: typing.Tuple[str, str] + """ + raise NotImplementedError() + + def _to_preflight_resource(self, location: str, workspace_name: str) -> PreflightResource: + """Return the preflight resource to be used in remote validation. + + :param location: The location of the resource. + :type location: str + :param workspace_name: The workspace name + :type workspace_name: str + :return: The preflight resource + :rtype: PreflightResource + """ + name, version = self._get_resource_name_version() + return PreflightResource( + type=self._get_resource_type(), + name=f"{workspace_name}/{name}/{version}", + location=location, + properties=self._to_rest_object().properties, + api_version="2023-03-01-preview", + ) + + def _build_rest_object_for_remote_validation(self, location: str, workspace_name: str) -> Deployment: + return Deployment( + properties=DeploymentProperties( + mode="Incremental", + template=ValidationTemplateRequest( + _schema="https://schema.management.azure.com/schemas/2019-04-01/deploymentTemplate.json#", + content_version="1.0.0.0", + parameters={}, + resources=[self._to_preflight_resource(location=location, workspace_name=workspace_name)], + ), + ) + ) + + @classmethod + def _build_validation_result_from_rest_object(cls, rest_obj: DeploymentValidateResult) -> MutableValidationResult: + """Create a validation result from a rest object. Note that the created validation result does not have + target_obj so should only be used for merging. + + :param rest_obj: The Deployment Validate REST obj + :type rest_obj: DeploymentValidateResult + :return: The validation result created from rest_obj + :rtype: MutableValidationResult + """ + if not rest_obj.error or not rest_obj.error.details: + return ValidationResultBuilder.success() + result = MutableValidationResult(target_obj=None) + details: typing.List[ErrorResponse] = rest_obj.error.details + for detail in details: + result.append_error( + message=detail.message, + yaml_path=detail.target.replace("/", "."), + error_code=detail.code, + # will always be UserError for now, not sure if innerError can be passed back + ) + return result diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/schema.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/schema.py new file mode 100644 index 00000000..9e34173d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_validation/schema.py @@ -0,0 +1,156 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import json +import logging +import typing + +from marshmallow import Schema, ValidationError + +from .core import MutableValidationResult, ValidationResultBuilder + +module_logger = logging.getLogger(__name__) + + +class SchemaValidatableMixin: + """The mixin class for schema validation.""" + + @classmethod + def _create_empty_validation_result(cls) -> MutableValidationResult: + """Simply create an empty validation result + + To reduce _ValidationResultBuilder importing, which is a private class. + + :return: An empty validation result + :rtype: MutableValidationResult + """ + return ValidationResultBuilder.success() + + @classmethod + def _load_with_schema( + cls, data: typing.Any, *, context: typing.Any, raise_original_exception: bool = False, **kwargs: typing.Any + ) -> typing.Any: + schema = cls._create_schema_for_validation(context=context) + + try: + return schema.load(data, **kwargs) + except ValidationError as e: + if raise_original_exception: + raise e + msg = "Trying to load data with schema failed. Data:\n%s\nError: %s" % ( + json.dumps(data, indent=4) if isinstance(data, dict) else data, + json.dumps(e.messages, indent=4), + ) + raise cls._create_validation_error( + message=msg, + no_personal_data_message=str(e), + ) from e + + @classmethod + # pylint: disable-next=docstring-missing-param + def _create_schema_for_validation(cls, context: typing.Any) -> Schema: + """Create a schema of the resource with specific context. Should be overridden by subclass. + + :return: The schema of the resource. + :rtype: Schema. + """ + raise NotImplementedError() + + def _default_context(self) -> dict: + """Get the default context for schema validation. Should be overridden by subclass. + + :return: The default context for schema validation + :rtype: dict + """ + raise NotImplementedError() + + @property + def _schema_for_validation(self) -> Schema: + """Return the schema of this Resource with default context. Do not override this method. + Override _create_schema_for_validation instead. + + :return: The schema of the resource. + :rtype: Schema. + """ + return self._create_schema_for_validation(context=self._default_context()) + + def _dump_for_validation(self) -> typing.Dict: + """Convert the resource to a dictionary. + + :return: Converted dictionary + :rtype: typing.Dict + """ + res: dict = self._schema_for_validation.dump(self) + return res + + @classmethod + def _create_validation_error(cls, message: str, no_personal_data_message: str) -> Exception: + """The function to create the validation exception to raise in _try_raise and _validate when + raise_error is True. + + Should be overridden by subclass. + + :param message: The error message containing detailed information + :type message: str + :param no_personal_data_message: The error message without personal data + :type no_personal_data_message: str + :return: The validation exception to raise + :rtype: Exception + """ + raise NotImplementedError() + + @classmethod + def _try_raise( + cls, validation_result: MutableValidationResult, *, raise_error: typing.Optional[bool] = True + ) -> MutableValidationResult: + return validation_result.try_raise(raise_error=raise_error, error_func=cls._create_validation_error) + + def _validate(self, raise_error: typing.Optional[bool] = False) -> MutableValidationResult: + """Validate the resource. If raise_error is True, raise ValidationError if validation fails and log warnings if + applicable; Else, return the validation result. + + :param raise_error: Whether to raise ValidationError if validation fails. + :type raise_error: bool + :return: The validation result + :rtype: MutableValidationResult + """ + result = self.__schema_validate() + result.merge_with(self._customized_validate()) + return self._try_raise(result, raise_error=raise_error) + + def _customized_validate(self) -> MutableValidationResult: + """Validate the resource with customized logic. + + Override this method to add customized validation logic. + + :return: The customized validation result + :rtype: MutableValidationResult + """ + return self._create_empty_validation_result() + + @classmethod + def _get_skip_fields_in_schema_validation( + cls, + ) -> typing.List[str]: + """Get the fields that should be skipped in schema validation. + + Override this method to add customized validation logic. + + :return: The fields to skip in schema validation + :rtype: typing.List[str] + """ + return [] + + def __schema_validate(self) -> MutableValidationResult: + """Validate the resource with the schema. + + :return: The validation result + :rtype: MutableValidationResult + """ + data = self._dump_for_validation() + messages = self._schema_for_validation.validate(data) + for skip_field in self._get_skip_fields_in_schema_validation(): + if skip_field in messages: + del messages[skip_field] + return ValidationResultBuilder.from_validation_messages(messages, data=data) |