# --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- # pylint: disable=unused-argument import logging from marshmallow import ValidationError, fields, post_load, pre_dump, pre_load from azure.ai.ml._restclient.v2022_05_01.models import ( InferenceContainerProperties, OperatingSystemType, Route, ) from azure.ai.ml._schema.core.fields import ExperimentalField, NestedField, UnionField, LocalPathField from azure.ai.ml._schema.core.intellectual_property import IntellectualPropertySchema from azure.ai.ml._schema.core.schema import PatchedSchemaMeta from azure.ai.ml.constants._common import ( ANONYMOUS_ENV_NAME, BASE_PATH_CONTEXT_KEY, CREATE_ENVIRONMENT_ERROR_MESSAGE, AzureMLResourceType, YAMLRefDocLinks, ) from ..core.fields import ArmStr, RegistryStr, StringTransformedEnum, VersionField from .asset import AnonymousAssetSchema, AssetSchema module_logger = logging.getLogger(__name__) class BuildContextSchema(metaclass=PatchedSchemaMeta): dockerfile_path = fields.Str() path = UnionField( [ LocalPathField(), # build context also support http url fields.URL(), ] ) @post_load def make(self, data, **kwargs): from azure.ai.ml.entities._assets.environment import BuildContext return BuildContext(**data) class RouteSchema(metaclass=PatchedSchemaMeta): port = fields.Int(required=True) path = fields.Str(required=True) @post_load def make(self, data, **kwargs): return Route(**data) class InferenceConfigSchema(metaclass=PatchedSchemaMeta): liveness_route = NestedField(RouteSchema, required=True) scoring_route = NestedField(RouteSchema, required=True) readiness_route = NestedField(RouteSchema, required=True) @post_load def make(self, data, **kwargs): return InferenceContainerProperties(**data) class _BaseEnvironmentSchema(AssetSchema): id = UnionField( [ RegistryStr(dump_only=True), ArmStr(azureml_type=AzureMLResourceType.ENVIRONMENT, dump_only=True), ] ) build = NestedField( BuildContextSchema, metadata={"description": "Docker build context to create the environment. Mutually exclusive with image"}, ) image = fields.Str() conda_file = UnionField([fields.Raw(), fields.Str()]) inference_config = NestedField(InferenceConfigSchema) os_type = StringTransformedEnum( allowed_values=[OperatingSystemType.Linux, OperatingSystemType.Windows], required=False, ) datastore = fields.Str( metadata={ "description": "Name of the datastore to upload to.", "arm_type": AzureMLResourceType.DATASTORE, }, required=False, ) intellectual_property = ExperimentalField(NestedField(IntellectualPropertySchema), dump_only=True) @pre_load def pre_load(self, data, **kwargs): if isinstance(data, str): raise ValidationError("Environment schema data cannot be a string") # validates that "channels" and "dependencies" are not included in the data creation. # These properties should only be on environment conda files not in the environment creation file if "channels" in data or "dependencies" in data: environmentMessage = CREATE_ENVIRONMENT_ERROR_MESSAGE.format(YAMLRefDocLinks.ENVIRONMENT) raise ValidationError(environmentMessage) return data @pre_dump def validate(self, data, **kwargs): from azure.ai.ml.entities._assets import Environment if isinstance(data, Environment): if data._intellectual_property: # pylint: disable=protected-access ipp_field = data._intellectual_property # pylint: disable=protected-access if ipp_field: setattr(data, "intellectual_property", ipp_field) return data if data is None or not hasattr(data, "get"): raise ValidationError("Environment cannot be None") return data @post_load def make(self, data, **kwargs): from azure.ai.ml.entities._assets import Environment try: obj = Environment(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data) except FileNotFoundError as e: # Environment.__init__() will raise FileNotFoundError if build.path is not found when trying to calculate # the hash for anonymous. Raise ValidationError instead to collect all errors in schema validation. raise ValidationError("Environment file not found: {}".format(e)) from e return obj class EnvironmentSchema(_BaseEnvironmentSchema): name = fields.Str(required=True) version = VersionField() class AnonymousEnvironmentSchema(_BaseEnvironmentSchema, AnonymousAssetSchema): @pre_load # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype def trim_dump_only(self, data, **kwargs): """trim_dump_only in PathAwareSchema removes all properties which are dump only. By the time we reach this schema name and version properties are removed so no warning is shown. This method overrides trim_dump_only in PathAwareSchema to check for name and version and raise warning if present. And then calls the it """ if isinstance(data, str) or data is None: return data name = data.pop("name", None) data.pop("version", None) # CliV2AnonymousEnvironment is a default name for anonymous environment if name is not None and name != ANONYMOUS_ENV_NAME: module_logger.warning( "Warning: the provided asset name '%s' will not be used for anonymous registration", name, ) return super(AnonymousEnvironmentSchema, self).trim_dump_only(data, **kwargs)