aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_func_utils.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_func_utils.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_func_utils.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_func_utils.py471
1 files changed, 471 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_func_utils.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_func_utils.py
new file mode 100644
index 00000000..0e4b465a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_func_utils.py
@@ -0,0 +1,471 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+import abc
+import logging
+import sys
+from contextlib import contextmanager
+from types import CodeType, FrameType, FunctionType, MethodType
+from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
+
+from azure.ai.ml._utils.utils import is_bytecode_optimization_enabled
+
+logger = logging.getLogger(__name__)
+
+
+class PersistentLocalsFunctionBuilder(abc.ABC):
+ errors = {
+ "not_callable": "func must be a function or a callable object",
+ "conflict_argument": "Injected param name __self conflicts with function args {args}",
+ "not_all_template_separators_used": "Not all template separators are used, "
+ "please switch to a compatible version of Python.",
+ "invalid_template": "Provided template functions are invalid in current environment, "
+ "please switch to a compatible version (3.9 e.g.) of Python "
+ "and/or check template functions.",
+ }
+ injected_param = "__self"
+
+ @classmethod
+ def make_error(cls, error_name: str, **kwargs) -> str:
+ """Make error message with error_name and kwargs.
+
+ :param error_name: A key from :attr:`~PersistentLocalsFunctionBuilder.errors`
+ :type error_name: str
+ :return: Formatted error message
+ :rtype: str
+ """
+ return cls.errors[error_name].format(**kwargs)
+
+ @abc.abstractmethod
+ def _call(self, func, _all_kwargs) -> Tuple[Any, dict]:
+ raise NotImplementedError()
+
+ def call(self, func, _all_kwargs) -> Tuple[Any, dict]:
+ """Get outputs and locals in calling func with _all_kwargs. Locals will be used to update node variable names.
+
+ :param func: The function to execute.
+ :type func: Union[FunctionType, MethodType]
+ :param _all_kwargs: All kwargs to call self.func.
+ :type _all_kwargs: typing.Dict[str, typing.Any]
+ :return: A tuple of outputs and locals.
+ :rtype: typing.Tuple[typing.Any, typing.Dict]
+ """
+ if isinstance(func, (FunctionType, MethodType)):
+ pass
+ elif hasattr(func, "__call__"):
+ func = func.__call__
+ else:
+ raise TypeError(self.make_error("not_callable"))
+
+ if self.injected_param in func.__code__.co_varnames:
+ raise ValueError(self.make_error("conflict_argument", args=list(func.__code__.co_varnames)))
+
+ return self._call(func, _all_kwargs)
+
+
+class PersistentLocalsFunctionProfilerBuilder(PersistentLocalsFunctionBuilder):
+ @staticmethod
+ @contextmanager
+ # pylint: disable-next=docstring-missing-return,docstring-missing-rtype
+ def _replace_sys_profiler(profiler: Callable[[FrameType, str, Any], None]) -> Iterable[None]:
+ """A context manager which replaces sys profiler to given profiler.
+
+ :param profiler: The profile function.
+ See https://docs.python.org/3/library/sys.html#sys.setprofile for more information
+ :type profiler: Callable[[FrameType, str, Any], None]
+ """
+ original_profiler = sys.getprofile()
+ sys.setprofile(profiler)
+ try:
+ yield
+ finally:
+ sys.setprofile(original_profiler)
+
+ @staticmethod
+ def _get_func_variable_tracer(
+ _locals_data: Dict[str, Any], func_code: CodeType
+ ) -> Callable[[FrameType, str, Any], None]:
+ """Get a tracer to trace variable names in dsl.pipeline function.
+
+ :param _locals_data: A dict to save locals data.
+ :type _locals_data: dict
+ :param func_code: An code object to compare if current frame is inside user function.
+ :type func_code: CodeType
+ :return: A tracing function
+ :rtype: Callable[[FrameType, str, Any], None]
+ """
+
+ def tracer(frame: FrameType, event: str, arg: Any) -> None: # pylint: disable=unused-argument
+ if frame.f_code == func_code and event == "return":
+ # Copy the locals of user's dsl function when it returns.
+ _locals_data.update(frame.f_locals.copy())
+
+ return tracer
+
+ def _call(self, func, _all_kwargs):
+ _locals = {}
+ func_variable_profiler = self._get_func_variable_tracer(_locals, func.__code__)
+ with self._replace_sys_profiler(func_variable_profiler):
+ outputs = func(**_all_kwargs)
+ return outputs, _locals
+
+
+class PersistentLocalsFunction(object):
+ def __init__(
+ self,
+ _func,
+ *,
+ _self: Optional[Any] = None,
+ skip_locals: Optional[List[str]] = None,
+ ):
+ """
+ :param _func: The function to be wrapped.
+ :param _self: If original func is a method, _self should be provided, which is the instance of the method.
+ :param skip_locals: A list of local variables to skip when saving the locals.
+ """
+ self._locals = {}
+ self._self = _self
+ # make function an instance method
+ self._func = MethodType(_func, self)
+ self._skip_locals = skip_locals
+
+ def __call__(__self, *args, **kwargs): # pylint: disable=no-self-argument
+ # Use __self in case self is also passed as a named argument in kwargs
+ __self._locals.clear()
+ try:
+ if __self._self:
+ return __self._func(__self._self, *args, **kwargs) # pylint: disable=not-callable
+ return __self._func(*args, **kwargs) # pylint: disable=not-callable
+ finally:
+ # always pop skip locals even if exception is raised in user code
+ if __self._skip_locals is not None:
+ for skip_local in __self._skip_locals:
+ __self._locals.pop(skip_local, None)
+
+ @property
+ def locals(self):
+ return self._locals
+
+
+def _source_template_func(mock_arg):
+ return mock_arg
+
+
+def _target_template_func(__self, mock_arg):
+ try:
+ return mock_arg
+ finally:
+ __self._locals = locals().copy() # pylint: disable=protected-access
+
+
+try:
+ from bytecode import Bytecode, Instr, Label
+
+ class PersistentLocalsFunctionBytecodeBuilder(PersistentLocalsFunctionBuilder):
+ _template_separators = []
+ _template_separators_before_body = []
+ _template_separators_after_body = []
+ _template_body = []
+ _template_tail = None
+ __initialized = False
+
+ @classmethod
+ def _split(cls, instructions, separator, n=-1):
+ cur_start, index, result = 0, 0, []
+ while index < len(instructions) - len(separator) + 1:
+ if cls.is_instr_equal(instructions[index], separator[0]):
+ for i, template_body_instruction in enumerate(separator):
+ if not cls.is_instr_equal(instructions[index + i], template_body_instruction):
+ break
+ else:
+ result.append(instructions[cur_start:index])
+ cur_start = index + len(separator)
+ index += len(separator)
+ if len(result) == n:
+ break
+ continue
+ index += 1
+ result.append(instructions[cur_start:])
+ if n != -1 and len(result) != n:
+ msg = "can't split instructions into {} pieces with provided separators".format(n)
+ raise ValueError(msg)
+ return result
+
+ @classmethod
+ def _class_init_impl(cls):
+ """Override this method to implement different template matching algorithm."""
+ cls._template_separators_before_body, cls._template_separators_after_body = cls._split(
+ cls.get_instructions(_source_template_func),
+ separator=cls._get_mock_body_instructions(),
+ n=2,
+ )
+ # use None to indicate the body
+ cls._template_separators = (
+ cls._template_separators_before_body + [None] + cls._template_separators_after_body
+ )
+
+ cls._template_body = cls._split_instructions_based_on_template(
+ cls.get_instructions(_target_template_func),
+ remove_mock_body=True,
+ )
+ cls._template_tail = cls._template_body.pop()
+ if len(cls._template_body) != len(cls._template_body):
+ raise ValueError(cls.make_error("invalid_template"))
+
+ @classmethod
+ def __class_init(cls):
+ if cls.__initialized:
+ return
+
+ cls._class_init_impl()
+
+ cls.__initialized = True
+
+ def __init__(self):
+ self.__class_init()
+
+ # region methods depending on package bytecode
+ @classmethod
+ def get_instructions(cls, func):
+ return list(Bytecode.from_code(func.__code__))
+
+ @classmethod
+ def is_instr_equal(cls, instr1: Instr, instr2: Instr) -> bool:
+ if instr1 is None and instr2 is None:
+ return True
+ if instr1 is None or instr2 is None:
+ return False
+ if instr1.__class__ != instr2.__class__:
+ return False
+ if isinstance(instr1, Instr):
+ if isinstance(instr1.arg, Label) and isinstance(instr2.arg, Label):
+ return True
+ return instr1.opcode == instr2.opcode and instr1.arg == instr2.arg
+ # objects like Label and TryBegin
+ return True
+
+ @classmethod
+ def is_instructions_equal(cls, instructions1: List[Instr], instructions2: List[Instr]) -> bool:
+ if len(instructions1) != len(instructions2):
+ return False
+ for instr1, instr2 in zip(instructions1, instructions2):
+ if not cls.is_instr_equal(instr1, instr2):
+ return False
+ return True
+
+ def _create_code(self, instructions: List[Instr], base_func: Union[FunctionType, MethodType]) -> CodeType:
+ """Create the base bytecode for the function to be generated.
+
+ Will keep information of the function, such as name, globals, etc., but skip all instructions.
+
+ :param instructions: The list of instructions. Used to replace the instructions in base_func
+ :type instructions: List[Instr]
+ :param base_func: A function that provides base metadata (name, globals, etc...). Instructions will not
+ be kept
+ :type base_func: Union[FunctionType, MethodType]
+ :return: Generated code
+ :rtype: CodeType
+ """
+ fn_code = Bytecode.from_code(base_func.__code__)
+ fn_code.clear()
+ fn_code.extend(instructions)
+ fn_code.argcount += 1
+ fn_code.argnames.insert(0, self.injected_param)
+ return fn_code.to_code()
+
+ @classmethod
+ def _get_mock_body_instructions(cls):
+ return [Instr("LOAD_FAST", "mock_arg")]
+
+ # endregion
+
+ @classmethod
+ def _get_pieces(cls, instructions: List[Instr], separators: List[Instr]) -> List[List[Instr]]:
+ """Split the instructions into pieces by the separators.
+ Note that separators is a list of instructions. For example,
+ instructions: [I3, I1, I2, I3, I1, I3, I1, I2, I3]
+ separators: [I1, I2]
+ result: [[I3], [I3, I1, I3], [I3]]
+
+ :param instructions: The list of instructions to split
+ :type instructions: List[instr]
+ :param separators: The sequence of Instr to use as a delimiter
+ :type separators: List[Instr]
+ :return: A sublists of instructions that were delimited by separators
+ :rtype: List[List[Instr]]
+ """
+ separator_iter = iter(separators)
+
+ def get_next_separator():
+ try:
+ while True:
+ separator = next(separator_iter)
+ if separator is not None:
+ return separator
+ except StopIteration:
+ return None
+
+ pieces = []
+ last_piece = []
+ cur_separator = get_next_separator()
+ for instr in instructions:
+ if cls.is_instr_equal(instr, cur_separator):
+ # skip the separator
+ pieces.append(last_piece)
+ cur_separator = get_next_separator()
+ last_piece = []
+ else:
+ last_piece.append(instr)
+ pieces.append(last_piece)
+
+ if cur_separator is not None:
+ raise ValueError(cls.make_error("not_all_template_separators_used"))
+
+ return pieces
+
+ @classmethod
+ def _split_instructions_based_on_template(
+ cls,
+ instructions: List[Instr],
+ *,
+ remove_mock_body: bool = False,
+ ) -> List[List[Instr]]:
+ """Split instructions into several pieces by separators.
+ For example, in Python 3.11, the template source instructions will be:
+
+ .. code-block:: python
+
+ [
+ Instr('RESUME', 0), # initial instruction shared by all functions
+ Instr('LOAD_FAST', 'mock_arg'), # the body execution instruction
+ Instr('RETURN_VALUE'), # the return instruction shared by all functions
+ ]
+
+ Then the separators before body will be:
+
+ .. code-block:: python
+
+ [
+ Instr('RESUME', 0),
+ ]
+
+ And the separators after body will be:
+
+ .. code-block:: python
+
+ [
+ Instr('RETURN_VALUE'),
+ ]
+
+ For passed in instructions, we will split them with separators from beginning (the first RESUME) and
+ with reversed_separators from end (the last RETURN_VALUE).
+
+ :param instructions: The instructions to split
+ :type instructions: List[instr]
+ :keyword remove_mock_body: Whether to remove the mock body. Defaults to False
+ :paramtype remove_mock_body: bool
+ :return: The split instructions
+ :rtype: List[List[Instr]]
+ """
+ if remove_mock_body:
+ # this parameter should be set as True only when processing the template target function,
+ # when we should ignore the mock body
+ pieces = cls._get_pieces(
+ instructions, cls._template_separators_before_body + cls._get_mock_body_instructions()
+ )
+ else:
+ pieces = cls._get_pieces(instructions, cls._template_separators_before_body)
+
+ reversed_pieces = cls._get_pieces(reversed(pieces.pop()), reversed(cls._template_separators_after_body))
+
+ while reversed_pieces:
+ pieces.append(list(reversed(reversed_pieces.pop())))
+
+ return pieces
+
+ def _build_instructions(self, func: Union[FunctionType, MethodType]) -> List[Instr]:
+ generated_instructions = []
+
+ for template_piece, input_piece, separator in zip(
+ self._template_body,
+ self._split_instructions_based_on_template(self.get_instructions(func)),
+ self._template_separators,
+ ):
+ generated_instructions.extend(template_piece)
+ generated_instructions.extend(input_piece)
+ if separator is not None:
+ generated_instructions.append(separator)
+ generated_instructions.extend(self._template_tail)
+ return generated_instructions
+
+ def _build_func(self, func: Union[FunctionType, MethodType]) -> PersistentLocalsFunction:
+ """Build a persistent locals function from the given function. Use bytecode injection to add try...finally
+ statement around code to persistent the locals in the function.
+
+ It will change the func bytecode in this way:
+
+ .. code-block:: python
+
+ def func(__self, *func_args):
+ try:
+ the func code...
+ finally:
+ __self.locals = locals().copy()
+
+ You can get the locals in func by this code:
+
+ .. code-block:: python
+
+ builder = PersistentLocalsFunctionBuilder()
+ persistent_locals_func = builder.build(your_func)
+ # Execute your func
+ result = persistent_locals_func(*args)
+ # Get the locals in the func.
+ func_locals = persistent_locals_func.locals
+
+ :param func: The function to modify
+ :type func: Union[FunctionType, MethodType]
+ :return: The built persistent locals function
+ :rtype: PersistentLocalsFunction
+ """
+ generated_func = FunctionType(
+ self._create_code(self._build_instructions(func), func),
+ func.__globals__,
+ func.__name__,
+ func.__defaults__,
+ func.__closure__,
+ )
+ return PersistentLocalsFunction(
+ generated_func,
+ _self=func.__self__ if isinstance(func, MethodType) else None,
+ skip_locals=[self.injected_param],
+ )
+
+ def _call(self, func, _all_kwargs) -> Tuple[Any, dict]:
+ persistent_func = self._build_func(func)
+ outputs = persistent_func(**_all_kwargs)
+ return outputs, persistent_func.locals
+
+except ImportError:
+ # Fall back to the profiler implementation
+ class PersistentLocalsFunctionBytecodeBuilder(PersistentLocalsFunctionProfilerBuilder):
+ pass
+
+
+def _get_persistent_locals_builder() -> PersistentLocalsFunctionBuilder:
+ if is_bytecode_optimization_enabled():
+ return PersistentLocalsFunctionBytecodeBuilder()
+ return PersistentLocalsFunctionProfilerBuilder()
+
+
+def get_outputs_and_locals(func, _all_kwargs):
+ """Get outputs and locals from self.func. Locals will be used to update node variable names.
+
+ :param func: The function to execute.
+ :type func: Union[FunctionType, MethodType]
+ :param _all_kwargs: All kwargs to call self.func.
+ :type _all_kwargs: typing.Dict[str, typing.Any]
+ :return: A tuple of outputs and locals.
+ :rtype: typing.Tuple[typing.Dict, typing.Dict]
+ """
+ return _get_persistent_locals_builder().call(func, _all_kwargs)