aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/dsl/_settings.py
blob: 784aaece23f75af311773fd08a2456ac73946857 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
import logging
from collections import deque
from typing import Any, Deque, Dict, Optional, Union

from azure.ai.ml.entities._builders import BaseNode
from azure.ai.ml.exceptions import UserErrorException

module_logger = logging.getLogger(__name__)


class _DSLSettingsStack:
    def __init__(self) -> None:
        self._stack: Deque["_DSLSettings"] = deque()

    def push(self) -> None:
        self._stack.append(_DSLSettings())

    def pop(self) -> "_DSLSettings":
        self._check_stack()
        return self._stack.pop()

    def top(self) -> "_DSLSettings":
        self._check_stack()
        return self._stack[-1]

    def _check_stack(self) -> None:
        # as there is a separate stack push & pop operation management, if stack is empty,
        # it should be the scenario that user call `set_pipeline_settings` out of `pipeline` decorator,
        # then directly raise user error.
        if len(self._stack) == 0:
            error_message = "Please call `set_pipeline_settings` inside a `pipeline` decorated function."
            raise UserErrorException(
                message=error_message,
                no_personal_data_message=error_message,
            )


_dsl_settings_stack = _DSLSettingsStack()


class _DSLSettings:
    """Initialization & finalization job settings for DSL pipeline job.

    Store settings from `dsl.set_pipeline_settings` during pipeline definition.
    """

    def __init__(self) -> None:
        self._init_job: Any = None
        self._finalize_job: Any = None

    @property
    def init_job(self) -> Union[BaseNode, str]:
        return self._init_job

    @init_job.setter
    def init_job(self, value: Optional[Union[BaseNode, str]]) -> None:
        # pylint: disable=logging-fstring-interpolation
        if isinstance(value, (BaseNode, str)):
            self._init_job = value
        else:
            module_logger.warning(f"Initialization job setting is ignored as input parameter type is {type(value)!r}.")

    @property
    def init_job_set(self) -> bool:
        # note: need to use `BaseNode is not None` as `bool(BaseNode)` will return False
        return self._init_job is not None

    def init_job_name(self, jobs: Dict[str, BaseNode]) -> Optional[str]:
        if isinstance(self._init_job, str):
            return self._init_job
        for name, job in jobs.items():
            if id(self.init_job) == id(job):
                return name
        module_logger.warning("Initialization job setting is ignored as cannot find corresponding job node.")
        return None

    @property
    def finalize_job(self) -> Union[BaseNode, str]:
        return self._finalize_job

    @finalize_job.setter
    def finalize_job(self, value: Optional[Union[BaseNode, str]]) -> None:
        # pylint: disable=logging-fstring-interpolation
        if isinstance(value, (BaseNode, str)):
            self._finalize_job = value
        else:
            module_logger.warning(f"Finalization job setting is ignored as input parameter type is {type(value)!r}.")

    @property
    def finalize_job_set(self) -> bool:
        # note: need to use `BaseNode is not None` as `bool(BaseNode)` will return False
        return self._finalize_job is not None

    def finalize_job_name(self, jobs: Dict[str, BaseNode]) -> Optional[str]:
        if isinstance(self._finalize_job, str):
            return self._finalize_job
        for name, job in jobs.items():
            if id(self.finalize_job) == id(job):
                return name
        module_logger.warning("Finalization job setting is ignored as cannot find corresponding job node.")
        return None


def set_pipeline_settings(
    *,
    on_init: Optional[Union[BaseNode, str]] = None,
    on_finalize: Optional[Union[BaseNode, str]] = None,
) -> None:
    """Set pipeline settings for current `dsl.pipeline` definition.

    This function should be called inside a `dsl.pipeline` decorated function, otherwise will raise exception.

    :keyword on_init: On_init job node or name. On init job will be executed before all other jobs, \
                it should not have data connections to regular nodes.
    :paramtype on_init: Union[BaseNode, str]
    :keyword on_finalize: On_finalize job node or name. On finalize job will be executed after pipeline run \
                finishes (completed/failed/canceled), it should not have data connections to regular nodes.
    :paramtype on_finalize: Union[BaseNode, str]
    :return:
    """
    dsl_settings = _dsl_settings_stack.top()
    if on_init is not None:
        dsl_settings.init_job = on_init
    if on_finalize is not None:
        dsl_settings.finalize_job = on_finalize