aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_cache_utils.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_cache_utils.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_cache_utils.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_cache_utils.py437
1 files changed, 437 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_cache_utils.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_cache_utils.py
new file mode 100644
index 00000000..a9947933
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_cache_utils.py
@@ -0,0 +1,437 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+import hashlib
+import logging
+import os.path
+import threading
+import time
+from collections import defaultdict
+from concurrent.futures import ThreadPoolExecutor
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+from azure.ai.ml._utils._asset_utils import get_object_hash
+from azure.ai.ml._utils.utils import (
+ get_versioned_base_directory_for_cache,
+ is_concurrent_component_registration_enabled,
+ is_on_disk_cache_enabled,
+ is_private_preview_enabled,
+ write_to_shared_file,
+)
+from azure.ai.ml.constants._common import (
+ AZUREML_COMPONENT_REGISTRATION_MAX_WORKERS,
+ AzureMLResourceType,
+ DefaultOpenEncoding,
+)
+from azure.ai.ml.entities import Component
+from azure.ai.ml.entities._builders import BaseNode
+from azure.ai.ml.entities._component.code import ComponentCodeMixin
+from azure.ai.ml.operations._operation_orchestrator import _AssetResolver
+
+logger = logging.getLogger(__name__)
+
+_ANONYMOUS_HASH_PREFIX = "anonymous-component-"
+_YAML_SOURCE_PREFIX = "yaml-source-"
+_CODE_INVOLVED_PREFIX = "code-involved-"
+EXPIRE_TIME_IN_SECONDS = 60 * 60 * 24 * 7 # 7 days
+
+_node_resolution_lock = defaultdict(threading.Lock)
+
+
+@dataclass
+class _CacheContent:
+ component_ref: Component
+ # in-memory hash assume that the code folders are not changed during the run and
+ # use the hash of code path instead of code content to simplify the calculation
+ in_memory_hash: str
+ # on-disk hash will be calculated base on code content if applicable,
+ # so it will work even if the code folders are changed among runs
+ on_disk_hash: Optional[str] = None
+ arm_id: Optional[str] = None
+
+ def update_on_disk_hash(self):
+ self.on_disk_hash = CachedNodeResolver.calc_on_disk_hash_for_component(self.component_ref, self.in_memory_hash)
+
+
+class CachedNodeResolver(object):
+ """Class to resolve component in nodes with cached component resolution results.
+
+ This class is thread-safe if:
+ 1) self._resolve_nodes is not called concurrently. We guarantee this with a lock in self.resolve_nodes.
+ a) self._resolve_nodes won't be called recursively as all nodes will be skipped on
+ calling self.register_node_for_lazy_resolution.
+ b) it can't be called concurrently as node resolution involves filling back and will change the
+ state of nodes, e.g., hash of its inner component.
+ 2) self._resolve_component is only called concurrently on independent components
+ a) we have used an in-memory component hash to deduplicate components to resolve first;
+ b) dependent components have been resolved before registered as nodes are registered & resolved
+ layer by layer;
+ c) dependent code will never be an instance, so it won't cause cache hit issue described in d;
+ d) resolution of potential shared dependencies (1 instance used in 2 components) other than components
+ are thread-safe as they do not involve further dependency resolution. However, it's still a good practice to
+ resolve them before calling self.register_node_for_lazy_resolution as it will impact cache hit rate.
+ For example, if:
+ node1.component, node2.component = Component(environment=env1, ...), Component(environment=env1, ...)
+ root
+ | \
+ subgraph node2
+ |
+ node1
+ when registering node1, its component will be:
+ {
+ "name": "component1",
+ "environment": {
+ ...
+ }
+ ...
+ }
+ Its in-memory hash will be `hash_a` on registration.
+ Then when registering node2, the component will be:
+ {
+ "name": "component1",
+ "environment": "/subscriptions/.../environments/...",
+ ...
+ }
+ Its in-memory hash will be `hash_b`, which will be a cache miss.
+ """
+
+ def __init__(
+ self,
+ resolver: Callable[[Union[Component, str]], str],
+ client_key: str,
+ ):
+ self._resolver = resolver
+ self._cache: Dict[str, _CacheContent] = {}
+ self._nodes_to_resolve: List[BaseNode] = []
+
+ hash_obj = hashlib.sha256()
+ hash_obj.update(client_key.encode("utf-8"))
+ self._client_hash = hash_obj.hexdigest()
+ # the same client share 1 lock
+ self._lock = _node_resolution_lock[self._client_hash]
+
+ @staticmethod
+ def _get_component_registration_max_workers() -> int:
+ """Get the max workers for component registration.
+
+ Before Python 3.8, the default max_worker is the number of processors multiplied by 5.
+ It may send a large number of the uploading snapshot requests that will occur remote refuses requests.
+ In order to avoid retrying the upload requests, max_worker will use the default value in Python 3.8,
+ min(32, os.cpu_count + 4).
+
+ 1 risk is that, asset_utils will create a new thread pool to upload files in subprocesses, which may cause
+ the number of threads exceed the max_worker.
+
+ :return: The number of workers to use for component registration
+ :rtype: int
+ """
+ default_max_workers = min(32, (os.cpu_count() or 1) + 4)
+ try:
+ max_workers = int(os.environ.get(AZUREML_COMPONENT_REGISTRATION_MAX_WORKERS, default_max_workers))
+ except ValueError:
+ logger.info(
+ "Environment variable %s with value %s set but failed to parse. "
+ "Use the default max_worker %s as registration thread pool max_worker."
+ "Please reset the value to an integer.",
+ AZUREML_COMPONENT_REGISTRATION_MAX_WORKERS,
+ os.environ.get(AZUREML_COMPONENT_REGISTRATION_MAX_WORKERS),
+ default_max_workers,
+ )
+ max_workers = default_max_workers
+ return max_workers
+
+ @staticmethod
+ def _get_in_memory_hash_for_component(component: Component) -> str:
+ """Get a hash for a component.
+
+ This function assumes that there is no change in code folder among hash calculations, which is true during
+ resolution of 1 root pipeline component/job.
+
+ :param component: The component
+ :type component: Component
+ :return: The hash of the component
+ :rtype: str
+ """
+ if not isinstance(component, Component):
+ # this shouldn't happen; handle it in case invalid call is made outside this class
+ raise ValueError(f"Component {component} is not a Component object.")
+
+ # For components with code, its code will be an absolute path before uploaded to blob,
+ # so we can use a mixture of its anonymous hash and its source path as its hash, in case
+ # there are 2 components with same code but different ignore files
+ # Here we can check if the component has a source path instead of check if it has code, as
+ # there is no harm to add a source path to the hash even if the component doesn't have code
+ # Note that here we assume that the content of code folder won't change during the submission
+ if component._source_path: # pylint: disable=protected-access
+ object_hash = hashlib.sha256()
+ object_hash.update(component._get_anonymous_hash().encode("utf-8")) # pylint: disable=protected-access
+ object_hash.update(component._source_path.encode("utf-8")) # pylint: disable=protected-access
+ return _YAML_SOURCE_PREFIX + object_hash.hexdigest()
+ # For components without code, like pipeline component, their dependencies have already
+ # been resolved before calling this function, so we can use their anonymous hash directly
+ return _ANONYMOUS_HASH_PREFIX + component._get_anonymous_hash() # pylint: disable=protected-access
+
+ @staticmethod
+ def calc_on_disk_hash_for_component(component: Component, in_memory_hash: str) -> str:
+ """Get a hash for a component.
+
+ This function will calculate the hash based on the component's code folder if the component has code, so it's
+ unique even if code folder is changed.
+
+ :param component: The component to hash
+ :type component: Component
+ :param in_memory_hash: :attr:`_CacheNodeResolver.in_memory_hash`
+ :type in_memory_hash: str
+ :return: The hash of the component
+ :rtype: str
+ """
+ if not isinstance(component, Component):
+ # this shouldn't happen; handle it in case invalid call is made outside this class
+ raise ValueError(f"Component {component} is not a Component object.")
+
+ # TODO: calculate hash without resolving additional includes (copy code to temp folder)
+ # note that it's still thread-safe with current implementation, as only read operations are
+ # done on the original code folder
+ if not (
+ isinstance(component, ComponentCodeMixin)
+ and component._with_local_code() # pylint: disable=protected-access
+ ):
+ return in_memory_hash
+
+ with component._build_code() as code: # pylint: disable=protected-access
+ if hasattr(code, "_upload_hash"):
+ content_hash = code._upload_hash # pylint: disable=protected-access
+ else:
+ code_path = code.path if os.path.isabs(code.path) else os.path.join(code.base_path, code.path)
+ if os.path.exists(code_path):
+ content_hash = get_object_hash(code_path)
+ else:
+ # this will be gated by schema validation, so it shouldn't happen except for mock tests
+ return in_memory_hash
+
+ object_hash = hashlib.sha256()
+ object_hash.update(in_memory_hash.encode("utf-8"))
+
+ object_hash.update(content_hash.encode("utf-8"))
+ return _CODE_INVOLVED_PREFIX + object_hash.hexdigest()
+
+ @property
+ def _on_disk_cache_dir(self) -> Path:
+ """Get the base path for on disk cache.
+
+ :return: The base path for the on disk cache
+ :rtype: Path
+ """
+ return get_versioned_base_directory_for_cache().joinpath(
+ "components",
+ self._client_hash,
+ )
+
+ def _get_on_disk_cache_path(self, on_disk_hash: str) -> Path:
+ """Get the on disk cache path for a component.
+
+ :param on_disk_hash: The hash of the component
+ :type on_disk_hash: str
+ :return: The path to the disk cache
+ :rtype: Path
+ """
+ return self._on_disk_cache_dir.joinpath(on_disk_hash)
+
+ def _load_from_on_disk_cache(self, on_disk_hash: str) -> Optional[str]:
+ """Load component arm id from on disk cache.
+
+ :param on_disk_hash: The hash of the component
+ :type on_disk_hash: str
+ :return: The cached component arm id if reading was successful, None otherwise
+ :rtype: Optional[str]
+ """
+ # on-disk cache will expire in a new SDK version
+ on_disk_cache_path = self._get_on_disk_cache_path(on_disk_hash)
+ if on_disk_cache_path.is_file() and time.time() - on_disk_cache_path.stat().st_ctime < EXPIRE_TIME_IN_SECONDS:
+ try:
+ return on_disk_cache_path.read_text(encoding=DefaultOpenEncoding.READ).strip()
+ except (OSError, PermissionError) as e:
+ logger.warning(
+ "Failed to read on-disk cache for component due to %s. "
+ "Please check if the file %s is in use or current user doesn't have the permission.",
+ type(e).__name__,
+ on_disk_cache_path.as_posix(),
+ )
+ return None
+
+ def _save_to_on_disk_cache(self, on_disk_hash: str, arm_id: str) -> None:
+ """Save component arm id to on disk cache.
+
+ :param on_disk_hash: The on disk hash of the component
+ :type on_disk_hash: str
+ :param arm_id: The component ARM ID
+ :type arm_id: str
+ """
+ # this shouldn't happen in real case, but in case of current mock tests and potential future changes
+ if not isinstance(arm_id, str):
+ return
+ on_disk_cache_path = self._get_on_disk_cache_path(on_disk_hash)
+ on_disk_cache_path.parent.mkdir(parents=True, exist_ok=True)
+ try:
+ write_to_shared_file(on_disk_cache_path, arm_id)
+ except PermissionError:
+ logger.warning(
+ "Failed to save on-disk cache for component due to permission error. "
+ "Please check if the file %s is in use or current user doesn't have the permission.",
+ on_disk_cache_path.as_posix(),
+ )
+
+ def _resolve_cache_contents(self, cache_contents_to_resolve: List[_CacheContent], resolver: _AssetResolver):
+ """Resolve all components to resolve and save the results in cache.
+
+ :param cache_contents_to_resolve: The cache contents to resolve
+ :type cache_contents_to_resolve: List[_CacheContent]
+ :param resolver: The resolver function
+ :type resolver: _AssetResolver
+ """
+
+ def _map_func(_cache_content: _CacheContent):
+ _cache_content.arm_id = resolver(_cache_content.component_ref, azureml_type=AzureMLResourceType.COMPONENT)
+ if is_on_disk_cache_enabled() and is_private_preview_enabled():
+ self._save_to_on_disk_cache(_cache_content.on_disk_hash, _cache_content.arm_id)
+
+ if (
+ len(cache_contents_to_resolve) > 1
+ and is_concurrent_component_registration_enabled()
+ and is_private_preview_enabled()
+ ):
+ # given deduplication has already been done, we can safely assume that there is no
+ # conflict in concurrent local cache access
+ with ThreadPoolExecutor(max_workers=self._get_component_registration_max_workers()) as executor:
+ list(executor.map(_map_func, cache_contents_to_resolve))
+ else:
+ list(map(_map_func, cache_contents_to_resolve))
+
+ def _prepare_items_to_resolve(self) -> Tuple[Dict[str, List[BaseNode]], List[_CacheContent]]:
+ """Pop all nodes in self._nodes_to_resolve to prepare cache contents to resolve and nodes to resolve. Nodes in
+ self._nodes_to_resolve will be grouped by component hash and saved to a dict of list. Distinct dependent
+ components not in current cache will be saved to a list.
+
+ :return: a tuple of (dict of nodes to resolve, list of cache contents to resolve)
+ :rtype: Tuple[Dict[str, List[BaseNode]], List[_CacheContent]]
+ """
+ _components = list(map(lambda x: x._component, self._nodes_to_resolve)) # pylint: disable=protected-access
+ # we can do concurrent component in-memory hash calculation here
+ in_memory_component_hashes = map(self._get_in_memory_hash_for_component, _components)
+
+ dict_of_nodes_to_resolve = defaultdict(list)
+ cache_contents_to_resolve: List[_CacheContent] = []
+ for node, component_hash in zip(self._nodes_to_resolve, in_memory_component_hashes):
+ dict_of_nodes_to_resolve[component_hash].append(node)
+ if component_hash not in self._cache:
+ cache_content = _CacheContent(
+ component_ref=node._component, # pylint: disable=protected-access
+ in_memory_hash=component_hash,
+ )
+ self._cache[component_hash] = cache_content
+ cache_contents_to_resolve.append(cache_content)
+ self._nodes_to_resolve.clear()
+ return dict_of_nodes_to_resolve, cache_contents_to_resolve
+
+ def _resolve_cache_contents_from_disk(self, cache_contents_to_resolve: List[_CacheContent]) -> List[_CacheContent]:
+ """Check on-disk cache to resolve cache contents in cache_contents_to_resolve and return unresolved cache
+ contents.
+
+ :param cache_contents_to_resolve: The cache contents to resolve
+ :type cache_contents_to_resolve: List[_CacheContent]
+ :return: Unresolved cache contents
+ :rtype: List[_CacheContent]
+ """
+ # Note that we should recalculate the hash based on code for local cache, as
+ # we can't assume that the code folder won't change among dependency
+ # On-disk hash calculation can be slow as it involved data copying and artifact downloading.
+ # It is thread-safe given:
+ # 1. artifact downloading is thread-safe as we have a lock in ArtifactCache
+ # 2. data copying is thread-safe as there is only read operation on source folder
+ # and target folder is unique for each thread
+ if (
+ len(cache_contents_to_resolve) > 1
+ and is_concurrent_component_registration_enabled()
+ and is_private_preview_enabled()
+ ):
+ with ThreadPoolExecutor(max_workers=self._get_component_registration_max_workers()) as executor:
+ executor.map(_CacheContent.update_on_disk_hash, cache_contents_to_resolve)
+ else:
+ list(map(_CacheContent.update_on_disk_hash, cache_contents_to_resolve))
+
+ left_cache_contents_to_resolve = []
+ # need to deduplicate disk hash first if concurrent resolution is enabled
+ for cache_content in cache_contents_to_resolve:
+ cache_content.arm_id = self._load_from_on_disk_cache(cache_content.on_disk_hash)
+ if not cache_content.arm_id:
+ left_cache_contents_to_resolve.append(cache_content)
+
+ return left_cache_contents_to_resolve
+
+ def _fill_back_component_to_nodes(self, dict_of_nodes_to_resolve: Dict[str, List[BaseNode]]):
+ """Fill back resolved component to nodes.
+
+ :param dict_of_nodes_to_resolve: The nodes to resolve
+ :type dict_of_nodes_to_resolve: Dict[str, List[BaseNode]]
+ """
+ for component_hash, nodes in dict_of_nodes_to_resolve.items():
+ cache_content = self._cache[component_hash]
+ for node in nodes:
+ node._component = cache_content.arm_id # pylint: disable=protected-access
+
+ def _resolve_nodes(self):
+ """Processing logic of self.resolve_nodes.
+
+ Should not be called in subgraph creation.
+ """
+ dict_of_nodes_to_resolve, cache_contents_to_resolve = self._prepare_items_to_resolve()
+
+ if is_on_disk_cache_enabled() and is_private_preview_enabled():
+ cache_contents_to_resolve = self._resolve_cache_contents_from_disk(cache_contents_to_resolve)
+
+ self._resolve_cache_contents(cache_contents_to_resolve, resolver=self._resolver)
+
+ self._fill_back_component_to_nodes(dict_of_nodes_to_resolve)
+
+ def register_node_for_lazy_resolution(self, node: BaseNode):
+ """Register a node with its component to resolve.
+
+ :param node: The node
+ :type node: BaseNode
+ """
+ component = node._component # pylint: disable=protected-access
+
+ # directly resolve node and skip registration if the resolution involves no remote call
+ # so that all node will be skipped when resolving a subgraph recursively
+ if isinstance(component, str):
+ node._component = self._resolver( # pylint: disable=protected-access
+ component, azureml_type=AzureMLResourceType.COMPONENT
+ )
+ return
+ if component.id is not None:
+ node._component = component.id # pylint: disable=protected-access
+ return
+
+ self._nodes_to_resolve.append(node)
+
+ def resolve_nodes(self):
+ """Resolve all dependent components with resolver and set resolved component arm id back to newly registered
+ nodes.
+
+ Registered nodes will be cleared after resolution.
+ """
+ if not self._nodes_to_resolve:
+ return
+
+ # Lock here as node resolution involves filling back and will change the
+ # state of nodes, e.g. hash of its inner component.
+ # This will happen only on concurrent external calls; In 1 external call, all nodes in
+ # subgraph will be skipped on register_node_for_lazy_resolution when resolving subgraph
+ self._lock.acquire()
+ try:
+ self._resolve_nodes()
+ finally:
+ # release lock even if exception happens
+ self._lock.release()