1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
|
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
import hashlib
import logging
import os.path
import threading
import time
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, Union
from azure.ai.ml._utils._asset_utils import get_object_hash
from azure.ai.ml._utils.utils import (
get_versioned_base_directory_for_cache,
is_concurrent_component_registration_enabled,
is_on_disk_cache_enabled,
is_private_preview_enabled,
write_to_shared_file,
)
from azure.ai.ml.constants._common import (
AZUREML_COMPONENT_REGISTRATION_MAX_WORKERS,
AzureMLResourceType,
DefaultOpenEncoding,
)
from azure.ai.ml.entities import Component
from azure.ai.ml.entities._builders import BaseNode
from azure.ai.ml.entities._component.code import ComponentCodeMixin
from azure.ai.ml.operations._operation_orchestrator import _AssetResolver
logger = logging.getLogger(__name__)
_ANONYMOUS_HASH_PREFIX = "anonymous-component-"
_YAML_SOURCE_PREFIX = "yaml-source-"
_CODE_INVOLVED_PREFIX = "code-involved-"
EXPIRE_TIME_IN_SECONDS = 60 * 60 * 24 * 7 # 7 days
_node_resolution_lock = defaultdict(threading.Lock)
@dataclass
class _CacheContent:
component_ref: Component
# in-memory hash assume that the code folders are not changed during the run and
# use the hash of code path instead of code content to simplify the calculation
in_memory_hash: str
# on-disk hash will be calculated base on code content if applicable,
# so it will work even if the code folders are changed among runs
on_disk_hash: Optional[str] = None
arm_id: Optional[str] = None
def update_on_disk_hash(self):
self.on_disk_hash = CachedNodeResolver.calc_on_disk_hash_for_component(self.component_ref, self.in_memory_hash)
class CachedNodeResolver(object):
"""Class to resolve component in nodes with cached component resolution results.
This class is thread-safe if:
1) self._resolve_nodes is not called concurrently. We guarantee this with a lock in self.resolve_nodes.
a) self._resolve_nodes won't be called recursively as all nodes will be skipped on
calling self.register_node_for_lazy_resolution.
b) it can't be called concurrently as node resolution involves filling back and will change the
state of nodes, e.g., hash of its inner component.
2) self._resolve_component is only called concurrently on independent components
a) we have used an in-memory component hash to deduplicate components to resolve first;
b) dependent components have been resolved before registered as nodes are registered & resolved
layer by layer;
c) dependent code will never be an instance, so it won't cause cache hit issue described in d;
d) resolution of potential shared dependencies (1 instance used in 2 components) other than components
are thread-safe as they do not involve further dependency resolution. However, it's still a good practice to
resolve them before calling self.register_node_for_lazy_resolution as it will impact cache hit rate.
For example, if:
node1.component, node2.component = Component(environment=env1, ...), Component(environment=env1, ...)
root
| \
subgraph node2
|
node1
when registering node1, its component will be:
{
"name": "component1",
"environment": {
...
}
...
}
Its in-memory hash will be `hash_a` on registration.
Then when registering node2, the component will be:
{
"name": "component1",
"environment": "/subscriptions/.../environments/...",
...
}
Its in-memory hash will be `hash_b`, which will be a cache miss.
"""
def __init__(
self,
resolver: Callable[[Union[Component, str]], str],
client_key: str,
):
self._resolver = resolver
self._cache: Dict[str, _CacheContent] = {}
self._nodes_to_resolve: List[BaseNode] = []
hash_obj = hashlib.sha256()
hash_obj.update(client_key.encode("utf-8"))
self._client_hash = hash_obj.hexdigest()
# the same client share 1 lock
self._lock = _node_resolution_lock[self._client_hash]
@staticmethod
def _get_component_registration_max_workers() -> int:
"""Get the max workers for component registration.
Before Python 3.8, the default max_worker is the number of processors multiplied by 5.
It may send a large number of the uploading snapshot requests that will occur remote refuses requests.
In order to avoid retrying the upload requests, max_worker will use the default value in Python 3.8,
min(32, os.cpu_count + 4).
1 risk is that, asset_utils will create a new thread pool to upload files in subprocesses, which may cause
the number of threads exceed the max_worker.
:return: The number of workers to use for component registration
:rtype: int
"""
default_max_workers = min(32, (os.cpu_count() or 1) + 4)
try:
max_workers = int(os.environ.get(AZUREML_COMPONENT_REGISTRATION_MAX_WORKERS, default_max_workers))
except ValueError:
logger.info(
"Environment variable %s with value %s set but failed to parse. "
"Use the default max_worker %s as registration thread pool max_worker."
"Please reset the value to an integer.",
AZUREML_COMPONENT_REGISTRATION_MAX_WORKERS,
os.environ.get(AZUREML_COMPONENT_REGISTRATION_MAX_WORKERS),
default_max_workers,
)
max_workers = default_max_workers
return max_workers
@staticmethod
def _get_in_memory_hash_for_component(component: Component) -> str:
"""Get a hash for a component.
This function assumes that there is no change in code folder among hash calculations, which is true during
resolution of 1 root pipeline component/job.
:param component: The component
:type component: Component
:return: The hash of the component
:rtype: str
"""
if not isinstance(component, Component):
# this shouldn't happen; handle it in case invalid call is made outside this class
raise ValueError(f"Component {component} is not a Component object.")
# For components with code, its code will be an absolute path before uploaded to blob,
# so we can use a mixture of its anonymous hash and its source path as its hash, in case
# there are 2 components with same code but different ignore files
# Here we can check if the component has a source path instead of check if it has code, as
# there is no harm to add a source path to the hash even if the component doesn't have code
# Note that here we assume that the content of code folder won't change during the submission
if component._source_path: # pylint: disable=protected-access
object_hash = hashlib.sha256()
object_hash.update(component._get_anonymous_hash().encode("utf-8")) # pylint: disable=protected-access
object_hash.update(component._source_path.encode("utf-8")) # pylint: disable=protected-access
return _YAML_SOURCE_PREFIX + object_hash.hexdigest()
# For components without code, like pipeline component, their dependencies have already
# been resolved before calling this function, so we can use their anonymous hash directly
return _ANONYMOUS_HASH_PREFIX + component._get_anonymous_hash() # pylint: disable=protected-access
@staticmethod
def calc_on_disk_hash_for_component(component: Component, in_memory_hash: str) -> str:
"""Get a hash for a component.
This function will calculate the hash based on the component's code folder if the component has code, so it's
unique even if code folder is changed.
:param component: The component to hash
:type component: Component
:param in_memory_hash: :attr:`_CacheNodeResolver.in_memory_hash`
:type in_memory_hash: str
:return: The hash of the component
:rtype: str
"""
if not isinstance(component, Component):
# this shouldn't happen; handle it in case invalid call is made outside this class
raise ValueError(f"Component {component} is not a Component object.")
# TODO: calculate hash without resolving additional includes (copy code to temp folder)
# note that it's still thread-safe with current implementation, as only read operations are
# done on the original code folder
if not (
isinstance(component, ComponentCodeMixin)
and component._with_local_code() # pylint: disable=protected-access
):
return in_memory_hash
with component._build_code() as code: # pylint: disable=protected-access
if hasattr(code, "_upload_hash"):
content_hash = code._upload_hash # pylint: disable=protected-access
else:
code_path = code.path if os.path.isabs(code.path) else os.path.join(code.base_path, code.path)
if os.path.exists(code_path):
content_hash = get_object_hash(code_path)
else:
# this will be gated by schema validation, so it shouldn't happen except for mock tests
return in_memory_hash
object_hash = hashlib.sha256()
object_hash.update(in_memory_hash.encode("utf-8"))
object_hash.update(content_hash.encode("utf-8"))
return _CODE_INVOLVED_PREFIX + object_hash.hexdigest()
@property
def _on_disk_cache_dir(self) -> Path:
"""Get the base path for on disk cache.
:return: The base path for the on disk cache
:rtype: Path
"""
return get_versioned_base_directory_for_cache().joinpath(
"components",
self._client_hash,
)
def _get_on_disk_cache_path(self, on_disk_hash: str) -> Path:
"""Get the on disk cache path for a component.
:param on_disk_hash: The hash of the component
:type on_disk_hash: str
:return: The path to the disk cache
:rtype: Path
"""
return self._on_disk_cache_dir.joinpath(on_disk_hash)
def _load_from_on_disk_cache(self, on_disk_hash: str) -> Optional[str]:
"""Load component arm id from on disk cache.
:param on_disk_hash: The hash of the component
:type on_disk_hash: str
:return: The cached component arm id if reading was successful, None otherwise
:rtype: Optional[str]
"""
# on-disk cache will expire in a new SDK version
on_disk_cache_path = self._get_on_disk_cache_path(on_disk_hash)
if on_disk_cache_path.is_file() and time.time() - on_disk_cache_path.stat().st_ctime < EXPIRE_TIME_IN_SECONDS:
try:
return on_disk_cache_path.read_text(encoding=DefaultOpenEncoding.READ).strip()
except (OSError, PermissionError) as e:
logger.warning(
"Failed to read on-disk cache for component due to %s. "
"Please check if the file %s is in use or current user doesn't have the permission.",
type(e).__name__,
on_disk_cache_path.as_posix(),
)
return None
def _save_to_on_disk_cache(self, on_disk_hash: str, arm_id: str) -> None:
"""Save component arm id to on disk cache.
:param on_disk_hash: The on disk hash of the component
:type on_disk_hash: str
:param arm_id: The component ARM ID
:type arm_id: str
"""
# this shouldn't happen in real case, but in case of current mock tests and potential future changes
if not isinstance(arm_id, str):
return
on_disk_cache_path = self._get_on_disk_cache_path(on_disk_hash)
on_disk_cache_path.parent.mkdir(parents=True, exist_ok=True)
try:
write_to_shared_file(on_disk_cache_path, arm_id)
except PermissionError:
logger.warning(
"Failed to save on-disk cache for component due to permission error. "
"Please check if the file %s is in use or current user doesn't have the permission.",
on_disk_cache_path.as_posix(),
)
def _resolve_cache_contents(self, cache_contents_to_resolve: List[_CacheContent], resolver: _AssetResolver):
"""Resolve all components to resolve and save the results in cache.
:param cache_contents_to_resolve: The cache contents to resolve
:type cache_contents_to_resolve: List[_CacheContent]
:param resolver: The resolver function
:type resolver: _AssetResolver
"""
def _map_func(_cache_content: _CacheContent):
_cache_content.arm_id = resolver(_cache_content.component_ref, azureml_type=AzureMLResourceType.COMPONENT)
if is_on_disk_cache_enabled() and is_private_preview_enabled():
self._save_to_on_disk_cache(_cache_content.on_disk_hash, _cache_content.arm_id)
if (
len(cache_contents_to_resolve) > 1
and is_concurrent_component_registration_enabled()
and is_private_preview_enabled()
):
# given deduplication has already been done, we can safely assume that there is no
# conflict in concurrent local cache access
with ThreadPoolExecutor(max_workers=self._get_component_registration_max_workers()) as executor:
list(executor.map(_map_func, cache_contents_to_resolve))
else:
list(map(_map_func, cache_contents_to_resolve))
def _prepare_items_to_resolve(self) -> Tuple[Dict[str, List[BaseNode]], List[_CacheContent]]:
"""Pop all nodes in self._nodes_to_resolve to prepare cache contents to resolve and nodes to resolve. Nodes in
self._nodes_to_resolve will be grouped by component hash and saved to a dict of list. Distinct dependent
components not in current cache will be saved to a list.
:return: a tuple of (dict of nodes to resolve, list of cache contents to resolve)
:rtype: Tuple[Dict[str, List[BaseNode]], List[_CacheContent]]
"""
_components = list(map(lambda x: x._component, self._nodes_to_resolve)) # pylint: disable=protected-access
# we can do concurrent component in-memory hash calculation here
in_memory_component_hashes = map(self._get_in_memory_hash_for_component, _components)
dict_of_nodes_to_resolve = defaultdict(list)
cache_contents_to_resolve: List[_CacheContent] = []
for node, component_hash in zip(self._nodes_to_resolve, in_memory_component_hashes):
dict_of_nodes_to_resolve[component_hash].append(node)
if component_hash not in self._cache:
cache_content = _CacheContent(
component_ref=node._component, # pylint: disable=protected-access
in_memory_hash=component_hash,
)
self._cache[component_hash] = cache_content
cache_contents_to_resolve.append(cache_content)
self._nodes_to_resolve.clear()
return dict_of_nodes_to_resolve, cache_contents_to_resolve
def _resolve_cache_contents_from_disk(self, cache_contents_to_resolve: List[_CacheContent]) -> List[_CacheContent]:
"""Check on-disk cache to resolve cache contents in cache_contents_to_resolve and return unresolved cache
contents.
:param cache_contents_to_resolve: The cache contents to resolve
:type cache_contents_to_resolve: List[_CacheContent]
:return: Unresolved cache contents
:rtype: List[_CacheContent]
"""
# Note that we should recalculate the hash based on code for local cache, as
# we can't assume that the code folder won't change among dependency
# On-disk hash calculation can be slow as it involved data copying and artifact downloading.
# It is thread-safe given:
# 1. artifact downloading is thread-safe as we have a lock in ArtifactCache
# 2. data copying is thread-safe as there is only read operation on source folder
# and target folder is unique for each thread
if (
len(cache_contents_to_resolve) > 1
and is_concurrent_component_registration_enabled()
and is_private_preview_enabled()
):
with ThreadPoolExecutor(max_workers=self._get_component_registration_max_workers()) as executor:
executor.map(_CacheContent.update_on_disk_hash, cache_contents_to_resolve)
else:
list(map(_CacheContent.update_on_disk_hash, cache_contents_to_resolve))
left_cache_contents_to_resolve = []
# need to deduplicate disk hash first if concurrent resolution is enabled
for cache_content in cache_contents_to_resolve:
cache_content.arm_id = self._load_from_on_disk_cache(cache_content.on_disk_hash)
if not cache_content.arm_id:
left_cache_contents_to_resolve.append(cache_content)
return left_cache_contents_to_resolve
def _fill_back_component_to_nodes(self, dict_of_nodes_to_resolve: Dict[str, List[BaseNode]]):
"""Fill back resolved component to nodes.
:param dict_of_nodes_to_resolve: The nodes to resolve
:type dict_of_nodes_to_resolve: Dict[str, List[BaseNode]]
"""
for component_hash, nodes in dict_of_nodes_to_resolve.items():
cache_content = self._cache[component_hash]
for node in nodes:
node._component = cache_content.arm_id # pylint: disable=protected-access
def _resolve_nodes(self):
"""Processing logic of self.resolve_nodes.
Should not be called in subgraph creation.
"""
dict_of_nodes_to_resolve, cache_contents_to_resolve = self._prepare_items_to_resolve()
if is_on_disk_cache_enabled() and is_private_preview_enabled():
cache_contents_to_resolve = self._resolve_cache_contents_from_disk(cache_contents_to_resolve)
self._resolve_cache_contents(cache_contents_to_resolve, resolver=self._resolver)
self._fill_back_component_to_nodes(dict_of_nodes_to_resolve)
def register_node_for_lazy_resolution(self, node: BaseNode):
"""Register a node with its component to resolve.
:param node: The node
:type node: BaseNode
"""
component = node._component # pylint: disable=protected-access
# directly resolve node and skip registration if the resolution involves no remote call
# so that all node will be skipped when resolving a subgraph recursively
if isinstance(component, str):
node._component = self._resolver( # pylint: disable=protected-access
component, azureml_type=AzureMLResourceType.COMPONENT
)
return
if component.id is not None:
node._component = component.id # pylint: disable=protected-access
return
self._nodes_to_resolve.append(node)
def resolve_nodes(self):
"""Resolve all dependent components with resolver and set resolved component arm id back to newly registered
nodes.
Registered nodes will be cleared after resolution.
"""
if not self._nodes_to_resolve:
return
# Lock here as node resolution involves filling back and will change the
# state of nodes, e.g. hash of its inner component.
# This will happen only on concurrent external calls; In 1 external call, all nodes in
# subgraph will be skipped on register_node_for_lazy_resolution when resolving subgraph
self._lock.acquire()
try:
self._resolve_nodes()
finally:
# release lock even if exception happens
self._lock.release()
|