aboutsummaryrefslogtreecommitdiff
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

import io
import json
import logging
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Dict, Union
from urllib.parse import urljoin, urlparse

import yaml
from jsonschema import Draft7Validator, ValidationError
from jsonschema.exceptions import best_match

from azure.ai.ml._artifacts._artifact_utilities import get_datastore_info, get_storage_client
from azure.ai.ml._artifacts._constants import INVALID_MLTABLE_METADATA_SCHEMA_ERROR, INVALID_MLTABLE_METADATA_SCHEMA_MSG
from azure.ai.ml.constants._common import DefaultOpenEncoding
from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
from azure.ai.ml.operations._datastore_operations import DatastoreOperations

from ._http_utils import HttpPipeline
from ._storage_utils import AzureMLDatastorePathUri
from .utils import load_yaml

module_logger = logging.getLogger(__name__)


def download_mltable_metadata_schema(mltable_schema_url: str, requests_pipeline: HttpPipeline):
    response = requests_pipeline.get(mltable_schema_url)
    return response.json()


def read_local_mltable_metadata_contents(*, path: str) -> Dict:
    metadata_path = str(Path(path, "MLTable"))
    return load_yaml(metadata_path)


def read_remote_mltable_metadata_contents(
    *,
    base_uri: str,
    datastore_operations: DatastoreOperations,
    requests_pipeline: HttpPipeline,
) -> Union[Dict, None]:
    scheme = urlparse(base_uri).scheme
    if scheme == "https":
        response = requests_pipeline.get(urljoin(base_uri, "MLTable"))
        yaml_file = io.BytesIO(response.content)
        return yaml.safe_load(yaml_file)
    if scheme == "azureml":
        datastore_path_uri = AzureMLDatastorePathUri(base_uri)
        datastore_info = get_datastore_info(datastore_operations, datastore_path_uri.datastore)
        storage_client = get_storage_client(**datastore_info)
        with TemporaryDirectory() as tmp_dir:
            starts_with = datastore_path_uri.path.rstrip("/")
            storage_client.download(f"{starts_with}/MLTable", tmp_dir)
            downloaded_mltable_path = Path(tmp_dir, "MLTable")
            with open(downloaded_mltable_path, "r", encoding=DefaultOpenEncoding.READ) as f:
                return yaml.safe_load(f)
    return None


def validate_mltable_metadata(*, mltable_metadata_dict: Dict, mltable_schema: Dict):
    # use json-schema to validate dict
    error: Union[ValidationError, None] = best_match(Draft7Validator(mltable_schema).iter_errors(mltable_metadata_dict))
    if error:
        err_path = ".".join(error.path)
        err_path = f"{err_path}: " if err_path != "" else ""
        msg = INVALID_MLTABLE_METADATA_SCHEMA_ERROR.format(
            jsonSchemaErrorPath=err_path,
            jsonSchemaMessage=error.message,
            invalidMLTableMsg=INVALID_MLTABLE_METADATA_SCHEMA_MSG,
            invalidSchemaSnippet=json.dumps(error.schema, indent=2),
        )
        raise ValidationException(
            message=msg,
            no_personal_data_message=INVALID_MLTABLE_METADATA_SCHEMA_MSG,
            error_type=ValidationErrorType.INVALID_VALUE,
            target=ErrorTarget.DATA,
            error_category=ErrorCategory.USER_ERROR,
        )