aboutsummaryrefslogtreecommitdiff
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
from pathlib import Path
from typing import Any, Dict, NoReturn, Optional, Union, cast

from marshmallow import Schema

from azure.ai.ml._schema.component.data_transfer_component import (
    DataTransferCopyComponentSchema,
    DataTransferExportComponentSchema,
    DataTransferImportComponentSchema,
)
from azure.ai.ml._utils._experimental import experimental
from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, COMPONENT_TYPE, AssetTypes
from azure.ai.ml.constants._component import DataTransferTaskType, ExternalDataType, NodeType
from azure.ai.ml.entities._inputs_outputs.external_data import Database, FileSystem
from azure.ai.ml.entities._inputs_outputs.output import Output
from azure.ai.ml.entities._validation.core import MutableValidationResult
from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException

from ..._schema import PathAwareSchema
from .._util import convert_ordered_dict_to_dict, validate_attribute_type
from .component import Component


class DataTransferComponent(Component):
    """DataTransfer component version, used to define a data transfer component.

    :param task: Task type in the data transfer component. Possible values are "copy_data",
                 "import_data", and "export_data".
    :type task: str
    :param inputs: Mapping of input data bindings used in the job.
    :type inputs: dict
    :param outputs: Mapping of output data bindings used in the job.
    :type outputs: dict
    :param kwargs: Additional parameters for the data transfer component.
    :raises ~azure.ai.ml.exceptions.ValidationException: Raised if the component cannot be successfully validated.
        Details will be provided in the error message.
    """

    def __init__(
        self,
        *,
        task: Optional[str] = None,
        inputs: Optional[Dict] = None,
        outputs: Optional[Dict] = None,
        **kwargs: Any,
    ) -> None:
        # validate init params are valid type
        validate_attribute_type(attrs_to_check=locals(), attr_type_map=self._attr_type_map())

        kwargs[COMPONENT_TYPE] = NodeType.DATA_TRANSFER
        # Set default base path
        if BASE_PATH_CONTEXT_KEY not in kwargs:
            kwargs[BASE_PATH_CONTEXT_KEY] = Path(".")

        super().__init__(
            inputs=inputs,
            outputs=outputs,
            **kwargs,
        )
        self._task = task

    @classmethod
    def _attr_type_map(cls) -> dict:
        return {}

    @property
    def task(self) -> Optional[str]:
        """Task type of the component.

        :return: Task type of the component.
        :rtype: str
        """
        return self._task

    def _to_dict(self) -> Dict:
        return cast(
            dict,
            convert_ordered_dict_to_dict({**self._other_parameter, **super(DataTransferComponent, self)._to_dict()}),
        )

    def __str__(self) -> str:
        try:
            _toYaml: str = self._to_yaml()
            return _toYaml
        except BaseException:  # pylint: disable=W0718
            _toStr: str = super(DataTransferComponent, self).__str__()
            return _toStr

    @classmethod
    def _build_source_sink(cls, io_dict: Union[Dict, Database, FileSystem]) -> Union[Database, FileSystem]:
        component_io: Union[Database, FileSystem] = Database()

        if isinstance(io_dict, Database):
            component_io = Database()
        elif isinstance(io_dict, FileSystem):
            component_io = FileSystem()
        else:
            if isinstance(io_dict, dict):
                data_type = io_dict.pop("type", None)
                if data_type == ExternalDataType.DATABASE:
                    component_io = Database()
                elif data_type == ExternalDataType.FILE_SYSTEM:
                    component_io = FileSystem()
                else:
                    msg = "Type in source or sink only support {} and {}, currently got {}."
                    raise ValidationException(
                        message=msg.format(
                            ExternalDataType.DATABASE,
                            ExternalDataType.FILE_SYSTEM,
                            data_type,
                        ),
                        no_personal_data_message=msg.format(
                            ExternalDataType.DATABASE,
                            ExternalDataType.FILE_SYSTEM,
                            "data_type",
                        ),
                        target=ErrorTarget.COMPONENT,
                        error_category=ErrorCategory.USER_ERROR,
                        error_type=ValidationErrorType.INVALID_VALUE,
                    )
            else:
                msg = "Source or sink only support dict, Database and FileSystem"
                raise ValidationException(
                    message=msg,
                    no_personal_data_message=msg,
                    target=ErrorTarget.COMPONENT,
                    error_category=ErrorCategory.USER_ERROR,
                    error_type=ValidationErrorType.INVALID_VALUE,
                )

        return component_io


@experimental
class DataTransferCopyComponent(DataTransferComponent):
    """DataTransfer copy component version, used to define a data transfer copy component.

    :param data_copy_mode: Data copy mode in the copy task.
                           Possible values are "merge_with_overwrite" and "fail_if_conflict".
    :type data_copy_mode: str
    :param inputs: Mapping of input data bindings used in the job.
    :type inputs: dict
    :param outputs: Mapping of output data bindings used in the job.
    :type outputs: dict
    :param kwargs: Additional parameters for the data transfer copy component.
    :raises ~azure.ai.ml.exceptions.ValidationException: Raised if the component cannot be successfully validated.
        Details will be provided in the error message.
    """

    def __init__(
        self,
        *,
        data_copy_mode: Optional[str] = None,
        inputs: Optional[Dict] = None,
        outputs: Optional[Dict] = None,
        **kwargs: Any,
    ) -> None:
        kwargs["task"] = DataTransferTaskType.COPY_DATA
        super().__init__(
            inputs=inputs,
            outputs=outputs,
            **kwargs,
        )

        self._data_copy_mode = data_copy_mode

    @classmethod
    def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]:
        return DataTransferCopyComponentSchema(context=context)

    @property
    def data_copy_mode(self) -> Optional[str]:
        """Data copy mode of the component.

        :return: Data copy mode of the component.
        :rtype: str
        """
        return self._data_copy_mode

    def _customized_validate(self) -> MutableValidationResult:
        validation_result = super(DataTransferCopyComponent, self)._customized_validate()
        validation_result.merge_with(self._validate_input_output_mapping())
        return validation_result

    def _validate_input_output_mapping(self) -> MutableValidationResult:
        validation_result = self._create_empty_validation_result()
        inputs_count = len(self.inputs)
        outputs_count = len(self.outputs)
        if outputs_count != 1:
            msg = "Only support single output in {}, but there're {} outputs."
            validation_result.append_error(
                message=msg.format(DataTransferTaskType.COPY_DATA, outputs_count),
                yaml_path="outputs",
            )
        else:
            input_type = None
            output_type = None
            if inputs_count == 1:
                for _, input_data in self.inputs.items():
                    input_type = input_data.type
                for _, output_data in self.outputs.items():
                    output_type = output_data.type
                if input_type is None or output_type is None or input_type != output_type:
                    msg = "Input type {} doesn't exactly match with output type {} in task {}"
                    validation_result.append_error(
                        message=msg.format(input_type, output_type, DataTransferTaskType.COPY_DATA),
                        yaml_path="outputs",
                    )
            elif inputs_count > 1:
                for _, output_data in self.outputs.items():
                    output_type = output_data.type
                if output_type is None or output_type != AssetTypes.URI_FOLDER:
                    msg = "output type {} need to be {} in task {}"
                    validation_result.append_error(
                        message=msg.format(
                            output_type,
                            AssetTypes.URI_FOLDER,
                            DataTransferTaskType.COPY_DATA,
                        ),
                        yaml_path="outputs",
                    )
            else:
                msg = "Inputs must be set in task {}."
                validation_result.append_error(
                    message=msg.format(DataTransferTaskType.COPY_DATA),
                    yaml_path="inputs",
                )
        return validation_result


@experimental
class DataTransferImportComponent(DataTransferComponent):
    """DataTransfer import component version, used to define a data transfer import component.

    :param source: The data source of the file system or database.
    :type source: dict
    :param outputs: Mapping of output data bindings used in the job.
                    Default value is an output port with the key "sink" and the type "mltable".
    :type outputs: dict
    :param kwargs: Additional parameters for the data transfer import component.
    :raises ~azure.ai.ml.exceptions.ValidationException: Raised if the component cannot be successfully validated.
        Details will be provided in the error message.
    """

    def __init__(
        self,
        *,
        source: Optional[Dict] = None,
        outputs: Optional[Dict] = None,
        **kwargs: Any,
    ) -> None:
        outputs = outputs or {"sink": Output(type=AssetTypes.MLTABLE)}
        kwargs["task"] = DataTransferTaskType.IMPORT_DATA
        super().__init__(
            outputs=outputs,
            **kwargs,
        )

        source = source if source else {}
        self.source = self._build_source_sink(source)

    @classmethod
    def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]:
        return DataTransferImportComponentSchema(context=context)

    # pylint: disable-next=docstring-missing-param
    def __call__(self, *args: Any, **kwargs: Any) -> NoReturn:
        """Call ComponentVersion as a function and get a Component object."""

        msg = "DataTransfer component is not callable for import task."
        raise ValidationException(
            message=msg,
            no_personal_data_message=msg,
            target=ErrorTarget.COMPONENT,
            error_category=ErrorCategory.USER_ERROR,
        )


@experimental
class DataTransferExportComponent(DataTransferComponent):
    """DataTransfer export component version, used to define a data transfer export component.

    :param sink: The sink of external data and databases.
    :type sink: Union[Dict, Database, FileSystem]
    :param inputs: Mapping of input data bindings used in the job.
    :type inputs: dict
    :param kwargs: Additional parameters for the data transfer export component.
    :raises ~azure.ai.ml.exceptions.ValidationException: Raised if the component cannot be successfully validated.
        Details will be provided in the error message.
    """

    def __init__(
        self,
        *,
        inputs: Optional[Dict] = None,
        sink: Optional[Dict] = None,
        **kwargs: Any,
    ) -> None:
        kwargs["task"] = DataTransferTaskType.EXPORT_DATA
        super().__init__(
            inputs=inputs,
            **kwargs,
        )

        sink = sink if sink else {}
        self.sink = self._build_source_sink(sink)

    @classmethod
    def _create_schema_for_validation(cls, context: Any) -> Union[PathAwareSchema, Schema]:
        return DataTransferExportComponentSchema(context=context)

    # pylint: disable-next=docstring-missing-param
    def __call__(self, *args: Any, **kwargs: Any) -> NoReturn:
        """Call ComponentVersion as a function and get a Component object."""

        msg = "DataTransfer component is not callable for export task."
        raise ValidationException(
            message=msg,
            no_personal_data_message=msg,
            target=ErrorTarget.COMPONENT,
            error_category=ErrorCategory.USER_ERROR,
        )