about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_cache_utils.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_cache_utils.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_cache_utils.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_cache_utils.py437
1 files changed, 437 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_cache_utils.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_cache_utils.py
new file mode 100644
index 00000000..a9947933
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_cache_utils.py
@@ -0,0 +1,437 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+import hashlib
+import logging
+import os.path
+import threading
+import time
+from collections import defaultdict
+from concurrent.futures import ThreadPoolExecutor
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+from azure.ai.ml._utils._asset_utils import get_object_hash
+from azure.ai.ml._utils.utils import (
+    get_versioned_base_directory_for_cache,
+    is_concurrent_component_registration_enabled,
+    is_on_disk_cache_enabled,
+    is_private_preview_enabled,
+    write_to_shared_file,
+)
+from azure.ai.ml.constants._common import (
+    AZUREML_COMPONENT_REGISTRATION_MAX_WORKERS,
+    AzureMLResourceType,
+    DefaultOpenEncoding,
+)
+from azure.ai.ml.entities import Component
+from azure.ai.ml.entities._builders import BaseNode
+from azure.ai.ml.entities._component.code import ComponentCodeMixin
+from azure.ai.ml.operations._operation_orchestrator import _AssetResolver
+
+logger = logging.getLogger(__name__)
+
+_ANONYMOUS_HASH_PREFIX = "anonymous-component-"
+_YAML_SOURCE_PREFIX = "yaml-source-"
+_CODE_INVOLVED_PREFIX = "code-involved-"
+EXPIRE_TIME_IN_SECONDS = 60 * 60 * 24 * 7  # 7 days
+
+_node_resolution_lock = defaultdict(threading.Lock)
+
+
+@dataclass
+class _CacheContent:
+    component_ref: Component
+    # in-memory hash assume that the code folders are not changed during the run and
+    # use the hash of code path instead of code content to simplify the calculation
+    in_memory_hash: str
+    # on-disk hash will be calculated base on code content if applicable,
+    # so it will work even if the code folders are changed among runs
+    on_disk_hash: Optional[str] = None
+    arm_id: Optional[str] = None
+
+    def update_on_disk_hash(self):
+        self.on_disk_hash = CachedNodeResolver.calc_on_disk_hash_for_component(self.component_ref, self.in_memory_hash)
+
+
+class CachedNodeResolver(object):
+    """Class to resolve component in nodes with cached component resolution results.
+
+    This class is thread-safe if:
+    1) self._resolve_nodes is not called concurrently. We guarantee this with a lock in self.resolve_nodes.
+        a) self._resolve_nodes won't be called recursively as all nodes will be skipped on
+          calling self.register_node_for_lazy_resolution.
+        b) it can't be called concurrently as node resolution involves filling back and will change the
+          state of nodes, e.g., hash of its inner component.
+    2) self._resolve_component is only called concurrently on independent components
+        a) we have used an in-memory component hash to deduplicate components to resolve first;
+        b) dependent components have been resolved before registered as nodes are registered & resolved
+          layer by layer;
+        c) dependent code will never be an instance, so it won't cause cache hit issue described in d;
+        d) resolution of potential shared dependencies (1 instance used in 2 components) other than components
+          are thread-safe as they do not involve further dependency resolution. However, it's still a good practice to
+          resolve them before calling self.register_node_for_lazy_resolution as it will impact cache hit rate.
+          For example, if:
+          node1.component, node2.component = Component(environment=env1, ...), Component(environment=env1, ...)
+          root
+           |        \
+          subgraph  node2
+            |
+          node1
+          when registering node1, its component will be:
+          {
+              "name": "component1",
+              "environment": {
+                  ...
+              }
+              ...
+          }
+          Its in-memory hash will be `hash_a` on registration.
+          Then when registering node2, the component will be:
+          {
+              "name": "component1",
+              "environment": "/subscriptions/.../environments/...",
+              ...
+          }
+          Its in-memory hash will be `hash_b`, which will be a cache miss.
+    """
+
+    def __init__(
+        self,
+        resolver: Callable[[Union[Component, str]], str],
+        client_key: str,
+    ):
+        self._resolver = resolver
+        self._cache: Dict[str, _CacheContent] = {}
+        self._nodes_to_resolve: List[BaseNode] = []
+
+        hash_obj = hashlib.sha256()
+        hash_obj.update(client_key.encode("utf-8"))
+        self._client_hash = hash_obj.hexdigest()
+        # the same client share 1 lock
+        self._lock = _node_resolution_lock[self._client_hash]
+
+    @staticmethod
+    def _get_component_registration_max_workers() -> int:
+        """Get the max workers for component registration.
+
+        Before Python 3.8, the default max_worker is the number of processors multiplied by 5.
+        It may send a large number of the uploading snapshot requests that will occur remote refuses requests.
+        In order to avoid retrying the upload requests, max_worker will use the default value in Python 3.8,
+        min(32, os.cpu_count + 4).
+
+        1 risk is that, asset_utils will create a new thread pool to upload files in subprocesses, which may cause
+        the number of threads exceed the max_worker.
+
+        :return: The number of workers to use for component registration
+        :rtype: int
+        """
+        default_max_workers = min(32, (os.cpu_count() or 1) + 4)
+        try:
+            max_workers = int(os.environ.get(AZUREML_COMPONENT_REGISTRATION_MAX_WORKERS, default_max_workers))
+        except ValueError:
+            logger.info(
+                "Environment variable %s with value %s set but failed to parse. "
+                "Use the default max_worker %s as registration thread pool max_worker."
+                "Please reset the value to an integer.",
+                AZUREML_COMPONENT_REGISTRATION_MAX_WORKERS,
+                os.environ.get(AZUREML_COMPONENT_REGISTRATION_MAX_WORKERS),
+                default_max_workers,
+            )
+            max_workers = default_max_workers
+        return max_workers
+
+    @staticmethod
+    def _get_in_memory_hash_for_component(component: Component) -> str:
+        """Get a hash for a component.
+
+        This function assumes that there is no change in code folder among hash calculations, which is true during
+        resolution of 1 root pipeline component/job.
+
+        :param component: The component
+        :type component: Component
+        :return: The hash of the component
+        :rtype: str
+        """
+        if not isinstance(component, Component):
+            # this shouldn't happen; handle it in case invalid call is made outside this class
+            raise ValueError(f"Component {component} is not a Component object.")
+
+        # For components with code, its code will be an absolute path before uploaded to blob,
+        # so we can use a mixture of its anonymous hash and its source path as its hash, in case
+        # there are 2 components with same code but different ignore files
+        # Here we can check if the component has a source path instead of check if it has code, as
+        # there is no harm to add a source path to the hash even if the component doesn't have code
+        # Note that here we assume that the content of code folder won't change during the submission
+        if component._source_path:  # pylint: disable=protected-access
+            object_hash = hashlib.sha256()
+            object_hash.update(component._get_anonymous_hash().encode("utf-8"))  # pylint: disable=protected-access
+            object_hash.update(component._source_path.encode("utf-8"))  # pylint: disable=protected-access
+            return _YAML_SOURCE_PREFIX + object_hash.hexdigest()
+        # For components without code, like pipeline component, their dependencies have already
+        # been resolved before calling this function, so we can use their anonymous hash directly
+        return _ANONYMOUS_HASH_PREFIX + component._get_anonymous_hash()  # pylint: disable=protected-access
+
+    @staticmethod
+    def calc_on_disk_hash_for_component(component: Component, in_memory_hash: str) -> str:
+        """Get a hash for a component.
+
+        This function will calculate the hash based on the component's code folder if the component has code, so it's
+        unique even if code folder is changed.
+
+        :param component: The component to hash
+        :type component: Component
+        :param in_memory_hash: :attr:`_CacheNodeResolver.in_memory_hash`
+        :type in_memory_hash: str
+        :return: The hash of the component
+        :rtype: str
+        """
+        if not isinstance(component, Component):
+            # this shouldn't happen; handle it in case invalid call is made outside this class
+            raise ValueError(f"Component {component} is not a Component object.")
+
+        # TODO: calculate hash without resolving additional includes (copy code to temp folder)
+        # note that it's still thread-safe with current implementation, as only read operations are
+        # done on the original code folder
+        if not (
+            isinstance(component, ComponentCodeMixin)
+            and component._with_local_code()  # pylint: disable=protected-access
+        ):
+            return in_memory_hash
+
+        with component._build_code() as code:  # pylint: disable=protected-access
+            if hasattr(code, "_upload_hash"):
+                content_hash = code._upload_hash  # pylint: disable=protected-access
+            else:
+                code_path = code.path if os.path.isabs(code.path) else os.path.join(code.base_path, code.path)
+                if os.path.exists(code_path):
+                    content_hash = get_object_hash(code_path)
+                else:
+                    # this will be gated by schema validation, so it shouldn't happen except for mock tests
+                    return in_memory_hash
+
+            object_hash = hashlib.sha256()
+            object_hash.update(in_memory_hash.encode("utf-8"))
+
+            object_hash.update(content_hash.encode("utf-8"))
+            return _CODE_INVOLVED_PREFIX + object_hash.hexdigest()
+
+    @property
+    def _on_disk_cache_dir(self) -> Path:
+        """Get the base path for on disk cache.
+
+        :return: The base path for the on disk cache
+        :rtype: Path
+        """
+        return get_versioned_base_directory_for_cache().joinpath(
+            "components",
+            self._client_hash,
+        )
+
+    def _get_on_disk_cache_path(self, on_disk_hash: str) -> Path:
+        """Get the on disk cache path for a component.
+
+        :param on_disk_hash: The hash of the component
+        :type on_disk_hash: str
+        :return: The path to the disk cache
+        :rtype: Path
+        """
+        return self._on_disk_cache_dir.joinpath(on_disk_hash)
+
+    def _load_from_on_disk_cache(self, on_disk_hash: str) -> Optional[str]:
+        """Load component arm id from on disk cache.
+
+        :param on_disk_hash: The hash of the component
+        :type on_disk_hash: str
+        :return: The cached component arm id if reading was successful, None otherwise
+        :rtype: Optional[str]
+        """
+        # on-disk cache will expire in a new SDK version
+        on_disk_cache_path = self._get_on_disk_cache_path(on_disk_hash)
+        if on_disk_cache_path.is_file() and time.time() - on_disk_cache_path.stat().st_ctime < EXPIRE_TIME_IN_SECONDS:
+            try:
+                return on_disk_cache_path.read_text(encoding=DefaultOpenEncoding.READ).strip()
+            except (OSError, PermissionError) as e:
+                logger.warning(
+                    "Failed to read on-disk cache for component due to %s. "
+                    "Please check if the file %s is in use or current user doesn't have the permission.",
+                    type(e).__name__,
+                    on_disk_cache_path.as_posix(),
+                )
+        return None
+
+    def _save_to_on_disk_cache(self, on_disk_hash: str, arm_id: str) -> None:
+        """Save component arm id to on disk cache.
+
+        :param on_disk_hash: The on disk hash of the component
+        :type on_disk_hash: str
+        :param arm_id: The component ARM ID
+        :type arm_id: str
+        """
+        # this shouldn't happen in real case, but in case of current mock tests and potential future changes
+        if not isinstance(arm_id, str):
+            return
+        on_disk_cache_path = self._get_on_disk_cache_path(on_disk_hash)
+        on_disk_cache_path.parent.mkdir(parents=True, exist_ok=True)
+        try:
+            write_to_shared_file(on_disk_cache_path, arm_id)
+        except PermissionError:
+            logger.warning(
+                "Failed to save on-disk cache for component due to permission error. "
+                "Please check if the file %s is in use or current user doesn't have the permission.",
+                on_disk_cache_path.as_posix(),
+            )
+
+    def _resolve_cache_contents(self, cache_contents_to_resolve: List[_CacheContent], resolver: _AssetResolver):
+        """Resolve all components to resolve and save the results in cache.
+
+        :param cache_contents_to_resolve: The cache contents to resolve
+        :type cache_contents_to_resolve: List[_CacheContent]
+        :param resolver: The resolver function
+        :type resolver: _AssetResolver
+        """
+
+        def _map_func(_cache_content: _CacheContent):
+            _cache_content.arm_id = resolver(_cache_content.component_ref, azureml_type=AzureMLResourceType.COMPONENT)
+            if is_on_disk_cache_enabled() and is_private_preview_enabled():
+                self._save_to_on_disk_cache(_cache_content.on_disk_hash, _cache_content.arm_id)
+
+        if (
+            len(cache_contents_to_resolve) > 1
+            and is_concurrent_component_registration_enabled()
+            and is_private_preview_enabled()
+        ):
+            # given deduplication has already been done, we can safely assume that there is no
+            # conflict in concurrent local cache access
+            with ThreadPoolExecutor(max_workers=self._get_component_registration_max_workers()) as executor:
+                list(executor.map(_map_func, cache_contents_to_resolve))
+        else:
+            list(map(_map_func, cache_contents_to_resolve))
+
+    def _prepare_items_to_resolve(self) -> Tuple[Dict[str, List[BaseNode]], List[_CacheContent]]:
+        """Pop all nodes in self._nodes_to_resolve to prepare cache contents to resolve and nodes to resolve. Nodes in
+        self._nodes_to_resolve will be grouped by component hash and saved to a dict of list. Distinct dependent
+        components not in current cache will be saved to a list.
+
+        :return: a tuple of (dict of nodes to resolve, list of cache contents to resolve)
+        :rtype: Tuple[Dict[str, List[BaseNode]],  List[_CacheContent]]
+        """
+        _components = list(map(lambda x: x._component, self._nodes_to_resolve))  # pylint: disable=protected-access
+        # we can do concurrent component in-memory hash calculation here
+        in_memory_component_hashes = map(self._get_in_memory_hash_for_component, _components)
+
+        dict_of_nodes_to_resolve = defaultdict(list)
+        cache_contents_to_resolve: List[_CacheContent] = []
+        for node, component_hash in zip(self._nodes_to_resolve, in_memory_component_hashes):
+            dict_of_nodes_to_resolve[component_hash].append(node)
+            if component_hash not in self._cache:
+                cache_content = _CacheContent(
+                    component_ref=node._component,  # pylint: disable=protected-access
+                    in_memory_hash=component_hash,
+                )
+                self._cache[component_hash] = cache_content
+                cache_contents_to_resolve.append(cache_content)
+        self._nodes_to_resolve.clear()
+        return dict_of_nodes_to_resolve, cache_contents_to_resolve
+
+    def _resolve_cache_contents_from_disk(self, cache_contents_to_resolve: List[_CacheContent]) -> List[_CacheContent]:
+        """Check on-disk cache to resolve cache contents in cache_contents_to_resolve and return unresolved cache
+        contents.
+
+        :param cache_contents_to_resolve: The cache contents to resolve
+        :type cache_contents_to_resolve: List[_CacheContent]
+        :return: Unresolved cache contents
+        :rtype: List[_CacheContent]
+        """
+        # Note that we should recalculate the hash based on code for local cache, as
+        # we can't assume that the code folder won't change among dependency
+        # On-disk hash calculation can be slow as it involved data copying and artifact downloading.
+        # It is thread-safe given:
+        # 1. artifact downloading is thread-safe as we have a lock in ArtifactCache
+        # 2. data copying is thread-safe as there is only read operation on source folder
+        #    and target folder is unique for each thread
+        if (
+            len(cache_contents_to_resolve) > 1
+            and is_concurrent_component_registration_enabled()
+            and is_private_preview_enabled()
+        ):
+            with ThreadPoolExecutor(max_workers=self._get_component_registration_max_workers()) as executor:
+                executor.map(_CacheContent.update_on_disk_hash, cache_contents_to_resolve)
+        else:
+            list(map(_CacheContent.update_on_disk_hash, cache_contents_to_resolve))
+
+        left_cache_contents_to_resolve = []
+        # need to deduplicate disk hash first if concurrent resolution is enabled
+        for cache_content in cache_contents_to_resolve:
+            cache_content.arm_id = self._load_from_on_disk_cache(cache_content.on_disk_hash)
+            if not cache_content.arm_id:
+                left_cache_contents_to_resolve.append(cache_content)
+
+        return left_cache_contents_to_resolve
+
+    def _fill_back_component_to_nodes(self, dict_of_nodes_to_resolve: Dict[str, List[BaseNode]]):
+        """Fill back resolved component to nodes.
+
+        :param dict_of_nodes_to_resolve: The nodes to resolve
+        :type dict_of_nodes_to_resolve: Dict[str, List[BaseNode]]
+        """
+        for component_hash, nodes in dict_of_nodes_to_resolve.items():
+            cache_content = self._cache[component_hash]
+            for node in nodes:
+                node._component = cache_content.arm_id  # pylint: disable=protected-access
+
+    def _resolve_nodes(self):
+        """Processing logic of self.resolve_nodes.
+
+        Should not be called in subgraph creation.
+        """
+        dict_of_nodes_to_resolve, cache_contents_to_resolve = self._prepare_items_to_resolve()
+
+        if is_on_disk_cache_enabled() and is_private_preview_enabled():
+            cache_contents_to_resolve = self._resolve_cache_contents_from_disk(cache_contents_to_resolve)
+
+        self._resolve_cache_contents(cache_contents_to_resolve, resolver=self._resolver)
+
+        self._fill_back_component_to_nodes(dict_of_nodes_to_resolve)
+
+    def register_node_for_lazy_resolution(self, node: BaseNode):
+        """Register a node with its component to resolve.
+
+        :param node: The node
+        :type node: BaseNode
+        """
+        component = node._component  # pylint: disable=protected-access
+
+        # directly resolve node and skip registration if the resolution involves no remote call
+        # so that all node will be skipped when resolving a subgraph recursively
+        if isinstance(component, str):
+            node._component = self._resolver(  # pylint: disable=protected-access
+                component, azureml_type=AzureMLResourceType.COMPONENT
+            )
+            return
+        if component.id is not None:
+            node._component = component.id  # pylint: disable=protected-access
+            return
+
+        self._nodes_to_resolve.append(node)
+
+    def resolve_nodes(self):
+        """Resolve all dependent components with resolver and set resolved component arm id back to newly registered
+        nodes.
+
+        Registered nodes will be cleared after resolution.
+        """
+        if not self._nodes_to_resolve:
+            return
+
+        # Lock here as node resolution involves filling back and will change the
+        # state of nodes, e.g. hash of its inner component.
+        # This will happen only on concurrent external calls; In 1 external call, all nodes in
+        # subgraph will be skipped on register_node_for_lazy_resolution when resolving subgraph
+        self._lock.acquire()
+        try:
+            self._resolve_nodes()
+        finally:
+            # release lock even if exception happens
+            self._lock.release()