1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
|
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
import abc
import logging
import sys
from contextlib import contextmanager
from types import CodeType, FrameType, FunctionType, MethodType
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from azure.ai.ml._utils.utils import is_bytecode_optimization_enabled
logger = logging.getLogger(__name__)
class PersistentLocalsFunctionBuilder(abc.ABC):
errors = {
"not_callable": "func must be a function or a callable object",
"conflict_argument": "Injected param name __self conflicts with function args {args}",
"not_all_template_separators_used": "Not all template separators are used, "
"please switch to a compatible version of Python.",
"invalid_template": "Provided template functions are invalid in current environment, "
"please switch to a compatible version (3.9 e.g.) of Python "
"and/or check template functions.",
}
injected_param = "__self"
@classmethod
def make_error(cls, error_name: str, **kwargs) -> str:
"""Make error message with error_name and kwargs.
:param error_name: A key from :attr:`~PersistentLocalsFunctionBuilder.errors`
:type error_name: str
:return: Formatted error message
:rtype: str
"""
return cls.errors[error_name].format(**kwargs)
@abc.abstractmethod
def _call(self, func, _all_kwargs) -> Tuple[Any, dict]:
raise NotImplementedError()
def call(self, func, _all_kwargs) -> Tuple[Any, dict]:
"""Get outputs and locals in calling func with _all_kwargs. Locals will be used to update node variable names.
:param func: The function to execute.
:type func: Union[FunctionType, MethodType]
:param _all_kwargs: All kwargs to call self.func.
:type _all_kwargs: typing.Dict[str, typing.Any]
:return: A tuple of outputs and locals.
:rtype: typing.Tuple[typing.Any, typing.Dict]
"""
if isinstance(func, (FunctionType, MethodType)):
pass
elif hasattr(func, "__call__"):
func = func.__call__
else:
raise TypeError(self.make_error("not_callable"))
if self.injected_param in func.__code__.co_varnames:
raise ValueError(self.make_error("conflict_argument", args=list(func.__code__.co_varnames)))
return self._call(func, _all_kwargs)
class PersistentLocalsFunctionProfilerBuilder(PersistentLocalsFunctionBuilder):
@staticmethod
@contextmanager
# pylint: disable-next=docstring-missing-return,docstring-missing-rtype
def _replace_sys_profiler(profiler: Callable[[FrameType, str, Any], None]) -> Iterable[None]:
"""A context manager which replaces sys profiler to given profiler.
:param profiler: The profile function.
See https://docs.python.org/3/library/sys.html#sys.setprofile for more information
:type profiler: Callable[[FrameType, str, Any], None]
"""
original_profiler = sys.getprofile()
sys.setprofile(profiler)
try:
yield
finally:
sys.setprofile(original_profiler)
@staticmethod
def _get_func_variable_tracer(
_locals_data: Dict[str, Any], func_code: CodeType
) -> Callable[[FrameType, str, Any], None]:
"""Get a tracer to trace variable names in dsl.pipeline function.
:param _locals_data: A dict to save locals data.
:type _locals_data: dict
:param func_code: An code object to compare if current frame is inside user function.
:type func_code: CodeType
:return: A tracing function
:rtype: Callable[[FrameType, str, Any], None]
"""
def tracer(frame: FrameType, event: str, arg: Any) -> None: # pylint: disable=unused-argument
if frame.f_code == func_code and event == "return":
# Copy the locals of user's dsl function when it returns.
_locals_data.update(frame.f_locals.copy())
return tracer
def _call(self, func, _all_kwargs):
_locals = {}
func_variable_profiler = self._get_func_variable_tracer(_locals, func.__code__)
with self._replace_sys_profiler(func_variable_profiler):
outputs = func(**_all_kwargs)
return outputs, _locals
class PersistentLocalsFunction(object):
def __init__(
self,
_func,
*,
_self: Optional[Any] = None,
skip_locals: Optional[List[str]] = None,
):
"""
:param _func: The function to be wrapped.
:param _self: If original func is a method, _self should be provided, which is the instance of the method.
:param skip_locals: A list of local variables to skip when saving the locals.
"""
self._locals = {}
self._self = _self
# make function an instance method
self._func = MethodType(_func, self)
self._skip_locals = skip_locals
def __call__(__self, *args, **kwargs): # pylint: disable=no-self-argument
# Use __self in case self is also passed as a named argument in kwargs
__self._locals.clear()
try:
if __self._self:
return __self._func(__self._self, *args, **kwargs) # pylint: disable=not-callable
return __self._func(*args, **kwargs) # pylint: disable=not-callable
finally:
# always pop skip locals even if exception is raised in user code
if __self._skip_locals is not None:
for skip_local in __self._skip_locals:
__self._locals.pop(skip_local, None)
@property
def locals(self):
return self._locals
def _source_template_func(mock_arg):
return mock_arg
def _target_template_func(__self, mock_arg):
try:
return mock_arg
finally:
__self._locals = locals().copy() # pylint: disable=protected-access
try:
from bytecode import Bytecode, Instr, Label
class PersistentLocalsFunctionBytecodeBuilder(PersistentLocalsFunctionBuilder):
_template_separators = []
_template_separators_before_body = []
_template_separators_after_body = []
_template_body = []
_template_tail = None
__initialized = False
@classmethod
def _split(cls, instructions, separator, n=-1):
cur_start, index, result = 0, 0, []
while index < len(instructions) - len(separator) + 1:
if cls.is_instr_equal(instructions[index], separator[0]):
for i, template_body_instruction in enumerate(separator):
if not cls.is_instr_equal(instructions[index + i], template_body_instruction):
break
else:
result.append(instructions[cur_start:index])
cur_start = index + len(separator)
index += len(separator)
if len(result) == n:
break
continue
index += 1
result.append(instructions[cur_start:])
if n != -1 and len(result) != n:
msg = "can't split instructions into {} pieces with provided separators".format(n)
raise ValueError(msg)
return result
@classmethod
def _class_init_impl(cls):
"""Override this method to implement different template matching algorithm."""
cls._template_separators_before_body, cls._template_separators_after_body = cls._split(
cls.get_instructions(_source_template_func),
separator=cls._get_mock_body_instructions(),
n=2,
)
# use None to indicate the body
cls._template_separators = (
cls._template_separators_before_body + [None] + cls._template_separators_after_body
)
cls._template_body = cls._split_instructions_based_on_template(
cls.get_instructions(_target_template_func),
remove_mock_body=True,
)
cls._template_tail = cls._template_body.pop()
if len(cls._template_body) != len(cls._template_body):
raise ValueError(cls.make_error("invalid_template"))
@classmethod
def __class_init(cls):
if cls.__initialized:
return
cls._class_init_impl()
cls.__initialized = True
def __init__(self):
self.__class_init()
# region methods depending on package bytecode
@classmethod
def get_instructions(cls, func):
return list(Bytecode.from_code(func.__code__))
@classmethod
def is_instr_equal(cls, instr1: Instr, instr2: Instr) -> bool:
if instr1 is None and instr2 is None:
return True
if instr1 is None or instr2 is None:
return False
if instr1.__class__ != instr2.__class__:
return False
if isinstance(instr1, Instr):
if isinstance(instr1.arg, Label) and isinstance(instr2.arg, Label):
return True
return instr1.opcode == instr2.opcode and instr1.arg == instr2.arg
# objects like Label and TryBegin
return True
@classmethod
def is_instructions_equal(cls, instructions1: List[Instr], instructions2: List[Instr]) -> bool:
if len(instructions1) != len(instructions2):
return False
for instr1, instr2 in zip(instructions1, instructions2):
if not cls.is_instr_equal(instr1, instr2):
return False
return True
def _create_code(self, instructions: List[Instr], base_func: Union[FunctionType, MethodType]) -> CodeType:
"""Create the base bytecode for the function to be generated.
Will keep information of the function, such as name, globals, etc., but skip all instructions.
:param instructions: The list of instructions. Used to replace the instructions in base_func
:type instructions: List[Instr]
:param base_func: A function that provides base metadata (name, globals, etc...). Instructions will not
be kept
:type base_func: Union[FunctionType, MethodType]
:return: Generated code
:rtype: CodeType
"""
fn_code = Bytecode.from_code(base_func.__code__)
fn_code.clear()
fn_code.extend(instructions)
fn_code.argcount += 1
fn_code.argnames.insert(0, self.injected_param)
return fn_code.to_code()
@classmethod
def _get_mock_body_instructions(cls):
return [Instr("LOAD_FAST", "mock_arg")]
# endregion
@classmethod
def _get_pieces(cls, instructions: List[Instr], separators: List[Instr]) -> List[List[Instr]]:
"""Split the instructions into pieces by the separators.
Note that separators is a list of instructions. For example,
instructions: [I3, I1, I2, I3, I1, I3, I1, I2, I3]
separators: [I1, I2]
result: [[I3], [I3, I1, I3], [I3]]
:param instructions: The list of instructions to split
:type instructions: List[instr]
:param separators: The sequence of Instr to use as a delimiter
:type separators: List[Instr]
:return: A sublists of instructions that were delimited by separators
:rtype: List[List[Instr]]
"""
separator_iter = iter(separators)
def get_next_separator():
try:
while True:
separator = next(separator_iter)
if separator is not None:
return separator
except StopIteration:
return None
pieces = []
last_piece = []
cur_separator = get_next_separator()
for instr in instructions:
if cls.is_instr_equal(instr, cur_separator):
# skip the separator
pieces.append(last_piece)
cur_separator = get_next_separator()
last_piece = []
else:
last_piece.append(instr)
pieces.append(last_piece)
if cur_separator is not None:
raise ValueError(cls.make_error("not_all_template_separators_used"))
return pieces
@classmethod
def _split_instructions_based_on_template(
cls,
instructions: List[Instr],
*,
remove_mock_body: bool = False,
) -> List[List[Instr]]:
"""Split instructions into several pieces by separators.
For example, in Python 3.11, the template source instructions will be:
.. code-block:: python
[
Instr('RESUME', 0), # initial instruction shared by all functions
Instr('LOAD_FAST', 'mock_arg'), # the body execution instruction
Instr('RETURN_VALUE'), # the return instruction shared by all functions
]
Then the separators before body will be:
.. code-block:: python
[
Instr('RESUME', 0),
]
And the separators after body will be:
.. code-block:: python
[
Instr('RETURN_VALUE'),
]
For passed in instructions, we will split them with separators from beginning (the first RESUME) and
with reversed_separators from end (the last RETURN_VALUE).
:param instructions: The instructions to split
:type instructions: List[instr]
:keyword remove_mock_body: Whether to remove the mock body. Defaults to False
:paramtype remove_mock_body: bool
:return: The split instructions
:rtype: List[List[Instr]]
"""
if remove_mock_body:
# this parameter should be set as True only when processing the template target function,
# when we should ignore the mock body
pieces = cls._get_pieces(
instructions, cls._template_separators_before_body + cls._get_mock_body_instructions()
)
else:
pieces = cls._get_pieces(instructions, cls._template_separators_before_body)
reversed_pieces = cls._get_pieces(reversed(pieces.pop()), reversed(cls._template_separators_after_body))
while reversed_pieces:
pieces.append(list(reversed(reversed_pieces.pop())))
return pieces
def _build_instructions(self, func: Union[FunctionType, MethodType]) -> List[Instr]:
generated_instructions = []
for template_piece, input_piece, separator in zip(
self._template_body,
self._split_instructions_based_on_template(self.get_instructions(func)),
self._template_separators,
):
generated_instructions.extend(template_piece)
generated_instructions.extend(input_piece)
if separator is not None:
generated_instructions.append(separator)
generated_instructions.extend(self._template_tail)
return generated_instructions
def _build_func(self, func: Union[FunctionType, MethodType]) -> PersistentLocalsFunction:
"""Build a persistent locals function from the given function. Use bytecode injection to add try...finally
statement around code to persistent the locals in the function.
It will change the func bytecode in this way:
.. code-block:: python
def func(__self, *func_args):
try:
the func code...
finally:
__self.locals = locals().copy()
You can get the locals in func by this code:
.. code-block:: python
builder = PersistentLocalsFunctionBuilder()
persistent_locals_func = builder.build(your_func)
# Execute your func
result = persistent_locals_func(*args)
# Get the locals in the func.
func_locals = persistent_locals_func.locals
:param func: The function to modify
:type func: Union[FunctionType, MethodType]
:return: The built persistent locals function
:rtype: PersistentLocalsFunction
"""
generated_func = FunctionType(
self._create_code(self._build_instructions(func), func),
func.__globals__,
func.__name__,
func.__defaults__,
func.__closure__,
)
return PersistentLocalsFunction(
generated_func,
_self=func.__self__ if isinstance(func, MethodType) else None,
skip_locals=[self.injected_param],
)
def _call(self, func, _all_kwargs) -> Tuple[Any, dict]:
persistent_func = self._build_func(func)
outputs = persistent_func(**_all_kwargs)
return outputs, persistent_func.locals
except ImportError:
# Fall back to the profiler implementation
class PersistentLocalsFunctionBytecodeBuilder(PersistentLocalsFunctionProfilerBuilder):
pass
def _get_persistent_locals_builder() -> PersistentLocalsFunctionBuilder:
if is_bytecode_optimization_enabled():
return PersistentLocalsFunctionBytecodeBuilder()
return PersistentLocalsFunctionProfilerBuilder()
def get_outputs_and_locals(func, _all_kwargs):
"""Get outputs and locals from self.func. Locals will be used to update node variable names.
:param func: The function to execute.
:type func: Union[FunctionType, MethodType]
:param _all_kwargs: All kwargs to call self.func.
:type _all_kwargs: typing.Dict[str, typing.Any]
:return: A tuple of outputs and locals.
:rtype: typing.Tuple[typing.Dict, typing.Dict]
"""
return _get_persistent_locals_builder().call(func, _all_kwargs)
|