diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/huggingface_hub/serialization/_torch.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/huggingface_hub/serialization/_torch.py | 1015 |
1 files changed, 1015 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/huggingface_hub/serialization/_torch.py b/.venv/lib/python3.12/site-packages/huggingface_hub/serialization/_torch.py new file mode 100644 index 00000000..ccb9c42b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/huggingface_hub/serialization/_torch.py @@ -0,0 +1,1015 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains pytorch-specific helpers.""" + +import importlib +import json +import os +import re +from collections import defaultdict, namedtuple +from functools import lru_cache +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple, Union + +from packaging import version + +from .. import constants, logging +from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory + + +logger = logging.get_logger(__file__) + +if TYPE_CHECKING: + import torch + +# SAVING + + +def save_torch_model( + model: "torch.nn.Module", + save_directory: Union[str, Path], + *, + filename_pattern: Optional[str] = None, + force_contiguous: bool = True, + max_shard_size: Union[int, str] = MAX_SHARD_SIZE, + metadata: Optional[Dict[str, str]] = None, + safe_serialization: bool = True, + is_main_process: bool = True, + shared_tensors_to_discard: Optional[List[str]] = None, +): + """ + Saves a given torch model to disk, handling sharding and shared tensors issues. + + See also [`save_torch_state_dict`] to save a state dict with more flexibility. + + For more information about tensor sharing, check out [this guide](https://huggingface.co/docs/safetensors/torch_shared_tensors). + + The model state dictionary is split into shards so that each shard is smaller than a given size. The shards are + saved in the `save_directory` with the given `filename_pattern`. If the model is too big to fit in a single shard, + an index file is saved in the `save_directory` to indicate where each tensor is saved. This helper uses + [`split_torch_state_dict_into_shards`] under the hood. If `safe_serialization` is `True`, the shards are saved as + safetensors (the default). Otherwise, the shards are saved as pickle. + + Before saving the model, the `save_directory` is cleaned from any previous shard files. + + <Tip warning={true}> + + If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a + size greater than `max_shard_size`. + + </Tip> + + <Tip warning={true}> + + If your model is a `transformers.PreTrainedModel`, you should pass `model._tied_weights_keys` as `shared_tensors_to_discard` to properly handle shared tensors saving. This ensures the correct duplicate tensors are discarded during saving. + + </Tip> + + Args: + model (`torch.nn.Module`): + The model to save on disk. + save_directory (`str` or `Path`): + The directory in which the model will be saved. + filename_pattern (`str`, *optional*): + The pattern to generate the files names in which the model will be saved. Pattern must be a string that + can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix` + Defaults to `"model{suffix}.safetensors"` or `pytorch_model{suffix}.bin` depending on `safe_serialization` + parameter. + force_contiguous (`boolean`, *optional*): + Forcing the state_dict to be saved as contiguous tensors. This has no effect on the correctness of the + model, but it could potentially change performance if the layout of the tensor was chosen specifically for + that reason. Defaults to `True`. + max_shard_size (`int` or `str`, *optional*): + The maximum size of each shard, in bytes. Defaults to 5GB. + metadata (`Dict[str, str]`, *optional*): + Extra information to save along with the model. Some metadata will be added for each dropped tensors. + This information will not be enough to recover the entire shared structure but might help understanding + things. + safe_serialization (`bool`, *optional*): + Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle. + Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed + in a future version. + is_main_process (`bool`, *optional*): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. Defaults to True. + shared_tensors_to_discard (`List[str]`, *optional*): + List of tensor names to drop when saving shared tensors. If not provided and shared tensors are + detected, it will drop the first name alphabetically. + + Example: + + ```py + >>> from huggingface_hub import save_torch_model + >>> model = ... # A PyTorch model + + # Save state dict to "path/to/folder". The model will be split into shards of 5GB each and saved as safetensors. + >>> save_torch_model(model, "path/to/folder") + + # Load model back + >>> from huggingface_hub import load_torch_model # TODO + >>> load_torch_model(model, "path/to/folder") + >>> + ``` + """ + save_torch_state_dict( + state_dict=model.state_dict(), + filename_pattern=filename_pattern, + force_contiguous=force_contiguous, + max_shard_size=max_shard_size, + metadata=metadata, + safe_serialization=safe_serialization, + save_directory=save_directory, + is_main_process=is_main_process, + shared_tensors_to_discard=shared_tensors_to_discard, + ) + + +def save_torch_state_dict( + state_dict: Dict[str, "torch.Tensor"], + save_directory: Union[str, Path], + *, + filename_pattern: Optional[str] = None, + force_contiguous: bool = True, + max_shard_size: Union[int, str] = MAX_SHARD_SIZE, + metadata: Optional[Dict[str, str]] = None, + safe_serialization: bool = True, + is_main_process: bool = True, + shared_tensors_to_discard: Optional[List[str]] = None, +) -> None: + """ + Save a model state dictionary to the disk, handling sharding and shared tensors issues. + + See also [`save_torch_model`] to directly save a PyTorch model. + + For more information about tensor sharing, check out [this guide](https://huggingface.co/docs/safetensors/torch_shared_tensors). + + The model state dictionary is split into shards so that each shard is smaller than a given size. The shards are + saved in the `save_directory` with the given `filename_pattern`. If the model is too big to fit in a single shard, + an index file is saved in the `save_directory` to indicate where each tensor is saved. This helper uses + [`split_torch_state_dict_into_shards`] under the hood. If `safe_serialization` is `True`, the shards are saved as + safetensors (the default). Otherwise, the shards are saved as pickle. + + Before saving the model, the `save_directory` is cleaned from any previous shard files. + + <Tip warning={true}> + + If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a + size greater than `max_shard_size`. + + </Tip> + + <Tip warning={true}> + + If your model is a `transformers.PreTrainedModel`, you should pass `model._tied_weights_keys` as `shared_tensors_to_discard` to properly handle shared tensors saving. This ensures the correct duplicate tensors are discarded during saving. + + </Tip> + + Args: + state_dict (`Dict[str, torch.Tensor]`): + The state dictionary to save. + save_directory (`str` or `Path`): + The directory in which the model will be saved. + filename_pattern (`str`, *optional*): + The pattern to generate the files names in which the model will be saved. Pattern must be a string that + can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix` + Defaults to `"model{suffix}.safetensors"` or `pytorch_model{suffix}.bin` depending on `safe_serialization` + parameter. + force_contiguous (`boolean`, *optional*): + Forcing the state_dict to be saved as contiguous tensors. This has no effect on the correctness of the + model, but it could potentially change performance if the layout of the tensor was chosen specifically for + that reason. Defaults to `True`. + max_shard_size (`int` or `str`, *optional*): + The maximum size of each shard, in bytes. Defaults to 5GB. + metadata (`Dict[str, str]`, *optional*): + Extra information to save along with the model. Some metadata will be added for each dropped tensors. + This information will not be enough to recover the entire shared structure but might help understanding + things. + safe_serialization (`bool`, *optional*): + Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle. + Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed + in a future version. + is_main_process (`bool`, *optional*): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. Defaults to True. + shared_tensors_to_discard (`List[str]`, *optional*): + List of tensor names to drop when saving shared tensors. If not provided and shared tensors are + detected, it will drop the first name alphabetically. + + Example: + + ```py + >>> from huggingface_hub import save_torch_state_dict + >>> model = ... # A PyTorch model + + # Save state dict to "path/to/folder". The model will be split into shards of 5GB each and saved as safetensors. + >>> state_dict = model_to_save.state_dict() + >>> save_torch_state_dict(state_dict, "path/to/folder") + ``` + """ + save_directory = str(save_directory) + + if filename_pattern is None: + filename_pattern = ( + constants.SAFETENSORS_WEIGHTS_FILE_PATTERN + if safe_serialization + else constants.PYTORCH_WEIGHTS_FILE_PATTERN + ) + + if metadata is None: + metadata = {} + if safe_serialization: + try: + from safetensors.torch import save_file as save_file_fn + except ImportError as e: + raise ImportError( + "Please install `safetensors` to use safe serialization. " + "You can install it with `pip install safetensors`." + ) from e + # Clean state dict for safetensors + state_dict = _clean_state_dict_for_safetensors( + state_dict, + metadata, + force_contiguous=force_contiguous, + shared_tensors_to_discard=shared_tensors_to_discard, + ) + else: + from torch import save as save_file_fn # type: ignore[assignment] + + logger.warning( + "You are using unsafe serialization. Due to security reasons, it is recommended not to load " + "pickled models from untrusted sources. If you intend to share your model, we strongly recommend " + "using safe serialization by installing `safetensors` with `pip install safetensors`." + ) + # Split dict + state_dict_split = split_torch_state_dict_into_shards( + state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size + ) + + # Only main process should clean up existing files to avoid race conditions in distributed environment + if is_main_process: + existing_files_regex = re.compile(filename_pattern.format(suffix=r"(-\d{5}-of-\d{5})?") + r"(\.index\.json)?") + for filename in os.listdir(save_directory): + if existing_files_regex.match(filename): + try: + logger.debug(f"Removing existing file '{filename}' from folder.") + os.remove(os.path.join(save_directory, filename)) + except Exception as e: + logger.warning( + f"Error when trying to remove existing '{filename}' from folder: {e}. Continuing..." + ) + + # Save each shard + per_file_metadata = {"format": "pt"} + if not state_dict_split.is_sharded: + per_file_metadata.update(metadata) + safe_file_kwargs = {"metadata": per_file_metadata} if safe_serialization else {} + for filename, tensors in state_dict_split.filename_to_tensors.items(): + shard = {tensor: state_dict[tensor] for tensor in tensors} + save_file_fn(shard, os.path.join(save_directory, filename), **safe_file_kwargs) + logger.debug(f"Shard saved to {filename}") + + # Save the index (if any) + if state_dict_split.is_sharded: + index_path = filename_pattern.format(suffix="") + ".index.json" + index = { + "metadata": {**state_dict_split.metadata, **metadata}, + "weight_map": state_dict_split.tensor_to_filename, + } + with open(os.path.join(save_directory, index_path), "w") as f: + json.dump(index, f, indent=2) + logger.info( + f"The model is bigger than the maximum size per checkpoint ({max_shard_size}). " + f"Model weighs have been saved in {len(state_dict_split.filename_to_tensors)} checkpoint shards. " + f"You can find where each parameters has been saved in the index located at {index_path}." + ) + + logger.info(f"Model weights successfully saved to {save_directory}!") + + +def split_torch_state_dict_into_shards( + state_dict: Dict[str, "torch.Tensor"], + *, + filename_pattern: str = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN, + max_shard_size: Union[int, str] = MAX_SHARD_SIZE, +) -> StateDictSplit: + """ + Split a model state dictionary in shards so that each shard is smaller than a given size. + + The shards are determined by iterating through the `state_dict` in the order of its keys. There is no optimization + made to make each shard as close as possible to the maximum size passed. For example, if the limit is 10GB and we + have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not + [6+2+2GB], [6+2GB], [6GB]. + + + <Tip> + + To save a model state dictionary to the disk, see [`save_torch_state_dict`]. This helper uses + `split_torch_state_dict_into_shards` under the hood. + + </Tip> + + <Tip warning={true}> + + If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a + size greater than `max_shard_size`. + + </Tip> + + Args: + state_dict (`Dict[str, torch.Tensor]`): + The state dictionary to save. + filename_pattern (`str`, *optional*): + The pattern to generate the files names in which the model will be saved. Pattern must be a string that + can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix` + Defaults to `"model{suffix}.safetensors"`. + max_shard_size (`int` or `str`, *optional*): + The maximum size of each shard, in bytes. Defaults to 5GB. + + Returns: + [`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them. + + Example: + ```py + >>> import json + >>> import os + >>> from safetensors.torch import save_file as safe_save_file + >>> from huggingface_hub import split_torch_state_dict_into_shards + + >>> def save_state_dict(state_dict: Dict[str, torch.Tensor], save_directory: str): + ... state_dict_split = split_torch_state_dict_into_shards(state_dict) + ... for filename, tensors in state_dict_split.filename_to_tensors.items(): + ... shard = {tensor: state_dict[tensor] for tensor in tensors} + ... safe_save_file( + ... shard, + ... os.path.join(save_directory, filename), + ... metadata={"format": "pt"}, + ... ) + ... if state_dict_split.is_sharded: + ... index = { + ... "metadata": state_dict_split.metadata, + ... "weight_map": state_dict_split.tensor_to_filename, + ... } + ... with open(os.path.join(save_directory, "model.safetensors.index.json"), "w") as f: + ... f.write(json.dumps(index, indent=2)) + ``` + """ + return split_state_dict_into_shards_factory( + state_dict, + max_shard_size=max_shard_size, + filename_pattern=filename_pattern, + get_storage_size=get_torch_storage_size, + get_storage_id=get_torch_storage_id, + ) + + +# LOADING + + +def load_torch_model( + model: "torch.nn.Module", + checkpoint_path: Union[str, os.PathLike], + *, + strict: bool = False, + safe: bool = True, + weights_only: bool = False, + map_location: Optional[Union[str, "torch.device"]] = None, + mmap: bool = False, + filename_pattern: Optional[str] = None, +) -> NamedTuple: + """ + Load a checkpoint into a model, handling both sharded and non-sharded checkpoints. + + Args: + model (`torch.nn.Module`): + The model in which to load the checkpoint. + checkpoint_path (`str` or `os.PathLike`): + Path to either the checkpoint file or directory containing the checkpoint(s). + strict (`bool`, *optional*, defaults to `False`): + Whether to strictly enforce that the keys in the model state dict match the keys in the checkpoint. + safe (`bool`, *optional*, defaults to `True`): + If `safe` is True, the safetensors files will be loaded. If `safe` is False, the function + will first attempt to load safetensors files if they are available, otherwise it will fall back to loading + pickle files. `filename_pattern` parameter takes precedence over `safe` parameter. + weights_only (`bool`, *optional*, defaults to `False`): + If True, only loads the model weights without optimizer states and other metadata. + Only supported in PyTorch >= 1.13. + map_location (`str` or `torch.device`, *optional*): + A `torch.device` object, string or a dict specifying how to remap storage locations. It + indicates the location where all tensors should be loaded. + mmap (`bool`, *optional*, defaults to `False`): + Whether to use memory-mapped file loading. Memory mapping can improve loading performance + for large models in PyTorch >= 2.1.0 with zipfile-based checkpoints. + filename_pattern (`str`, *optional*): + The pattern to look for the index file. Pattern must be a string that + can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix` + Defaults to `"model{suffix}.safetensors"`. + Returns: + `NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields. + - `missing_keys` is a list of str containing the missing keys, i.e. keys that are in the model but not in the checkpoint. + - `unexpected_keys` is a list of str containing the unexpected keys, i.e. keys that are in the checkpoint but not in the model. + + Raises: + [`FileNotFoundError`](https://docs.python.org/3/library/exceptions.html#FileNotFoundError) + If the checkpoint file or directory does not exist. + [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError) + If safetensors or torch is not installed when trying to load a .safetensors file or a PyTorch checkpoint respectively. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If the checkpoint path is invalid or if the checkpoint format cannot be determined. + + Example: + ```python + >>> from huggingface_hub import load_torch_model + >>> model = ... # A PyTorch model + >>> load_torch_model(model, "path/to/checkpoint") + ``` + """ + checkpoint_path = Path(checkpoint_path) + + if not checkpoint_path.exists(): + raise ValueError(f"Checkpoint path {checkpoint_path} does not exist") + # 1. Check if checkpoint is a single file + if checkpoint_path.is_file(): + state_dict = load_state_dict_from_file( + checkpoint_file=checkpoint_path, + map_location=map_location, + weights_only=weights_only, + ) + return model.load_state_dict(state_dict, strict=strict) + + # 2. If not, checkpoint_path is a directory + if filename_pattern is None: + filename_pattern = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN + index_path = checkpoint_path / (filename_pattern.format(suffix="") + ".index.json") + # Only fallback to pickle format if safetensors index is not found and safe is False. + if not index_path.is_file() and not safe: + filename_pattern = constants.PYTORCH_WEIGHTS_FILE_PATTERN + + index_path = checkpoint_path / (filename_pattern.format(suffix="") + ".index.json") + + if index_path.is_file(): + return _load_sharded_checkpoint( + model=model, + save_directory=checkpoint_path, + strict=strict, + weights_only=weights_only, + filename_pattern=filename_pattern, + ) + + # Look for single model file + model_files = list(checkpoint_path.glob("*.safetensors" if safe else "*.bin")) + if len(model_files) == 1: + state_dict = load_state_dict_from_file( + checkpoint_file=model_files[0], + map_location=map_location, + weights_only=weights_only, + mmap=mmap, + ) + return model.load_state_dict(state_dict, strict=strict) + + raise ValueError( + f"Directory '{checkpoint_path}' does not contain a valid checkpoint. " + "Expected either a sharded checkpoint with an index file, or a single model file." + ) + + +def _load_sharded_checkpoint( + model: "torch.nn.Module", + save_directory: os.PathLike, + *, + strict: bool = False, + weights_only: bool = False, + filename_pattern: str = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN, +) -> NamedTuple: + """ + Loads a sharded checkpoint into a model. This is the same as + [`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict) + but for a sharded checkpoint. Each shard is loaded one by one and removed from memory after being loaded into the model. + + Args: + model (`torch.nn.Module`): + The model in which to load the checkpoint. + save_directory (`str` or `os.PathLike`): + A path to a folder containing the sharded checkpoint. + strict (`bool`, *optional*, defaults to `False`): + Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint. + weights_only (`bool`, *optional*, defaults to `False`): + If True, only loads the model weights without optimizer states and other metadata. + Only supported in PyTorch >= 1.13. + filename_pattern (`str`, *optional*, defaults to `"model{suffix}.safetensors"`): + The pattern to look for the index file. Pattern must be a string that + can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix` + Defaults to `"model{suffix}.safetensors"`. + + Returns: + `NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields, + - `missing_keys` is a list of str containing the missing keys + - `unexpected_keys` is a list of str containing the unexpected keys + """ + + # 1. Load and validate index file + # The index file contains mapping of parameter names to shard files + index_path = filename_pattern.format(suffix="") + ".index.json" + index_file = os.path.join(save_directory, index_path) + with open(index_file, "r", encoding="utf-8") as f: + index = json.load(f) + + # 2. Validate keys if in strict mode + # This is done before loading any shards to fail fast + if strict: + _validate_keys_for_strict_loading(model, index["weight_map"].keys()) + + # 3. Load each shard using `load_state_dict` + # Get unique shard files (multiple parameters can be in same shard) + shard_files = list(set(index["weight_map"].values())) + for shard_file in shard_files: + # Load shard into memory + shard_path = os.path.join(save_directory, shard_file) + state_dict = load_state_dict_from_file( + shard_path, + map_location="cpu", + weights_only=weights_only, + ) + # Update model with parameters from this shard + model.load_state_dict(state_dict, strict=strict) + # Explicitly remove the state dict from memory + del state_dict + + # 4. Return compatibility info + loaded_keys = set(index["weight_map"].keys()) + model_keys = set(model.state_dict().keys()) + return _IncompatibleKeys( + missing_keys=list(model_keys - loaded_keys), unexpected_keys=list(loaded_keys - model_keys) + ) + + +def load_state_dict_from_file( + checkpoint_file: Union[str, os.PathLike], + map_location: Optional[Union[str, "torch.device"]] = None, + weights_only: bool = False, + mmap: bool = False, +) -> Union[Dict[str, "torch.Tensor"], Any]: + """ + Loads a checkpoint file, handling both safetensors and pickle checkpoint formats. + + Args: + checkpoint_file (`str` or `os.PathLike`): + Path to the checkpoint file to load. Can be either a safetensors or pickle (`.bin`) checkpoint. + map_location (`str` or `torch.device`, *optional*): + A `torch.device` object, string or a dict specifying how to remap storage locations. It + indicates the location where all tensors should be loaded. + weights_only (`bool`, *optional*, defaults to `False`): + If True, only loads the model weights without optimizer states and other metadata. + Only supported for pickle (`.bin`) checkpoints with PyTorch >= 1.13. Has no effect when + loading safetensors files. + mmap (`bool`, *optional*, defaults to `False`): + Whether to use memory-mapped file loading. Memory mapping can improve loading performance + for large models in PyTorch >= 2.1.0 with zipfile-based checkpoints. Has no effect when + loading safetensors files, as the `safetensors` library uses memory mapping by default. + + Returns: + `Union[Dict[str, "torch.Tensor"], Any]`: The loaded checkpoint. + - For safetensors files: always returns a dictionary mapping parameter names to tensors. + - For pickle files: returns any Python object that was pickled (commonly a state dict, but could be + an entire model, optimizer state, or any other Python object). + + Raises: + [`FileNotFoundError`](https://docs.python.org/3/library/exceptions.html#FileNotFoundError) + If the checkpoint file does not exist. + [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError) + If safetensors or torch is not installed when trying to load a .safetensors file or a PyTorch checkpoint respectively. + [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) + If the checkpoint file format is invalid or if git-lfs files are not properly downloaded. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If the checkpoint file path is empty or invalid. + + Example: + ```python + >>> from huggingface_hub import load_state_dict_from_file + + # Load a PyTorch checkpoint + >>> state_dict = load_state_dict_from_file("path/to/model.bin", map_location="cpu") + >>> model.load_state_dict(state_dict) + + # Load a safetensors checkpoint + >>> state_dict = load_state_dict_from_file("path/to/model.safetensors") + >>> model.load_state_dict(state_dict) + ``` + """ + checkpoint_path = Path(checkpoint_file) + + # Check if file exists and is a regular file (not a directory) + if not checkpoint_path.is_file(): + raise FileNotFoundError( + f"No checkpoint file found at '{checkpoint_path}'. Please verify the path is correct and " + "the file has been properly downloaded." + ) + + # Load safetensors checkpoint + if checkpoint_path.suffix == ".safetensors": + try: + from safetensors import safe_open + from safetensors.torch import load_file + except ImportError as e: + raise ImportError( + "Please install `safetensors` to load safetensors checkpoint. " + "You can install it with `pip install safetensors`." + ) from e + + # Check format of the archive + with safe_open(checkpoint_file, framework="pt") as f: # type: ignore[attr-defined] + metadata = f.metadata() + # see comment: https://github.com/huggingface/transformers/blob/3d213b57fe74302e5902d68ed9478c3ad1aaa713/src/transformers/modeling_utils.py#L3966 + if metadata is not None and metadata.get("format") not in ["pt", "mlx"]: + raise OSError( + f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure " + "you save your model with the `save_torch_model` method." + ) + device = str(map_location.type) if map_location is not None and hasattr(map_location, "type") else map_location + # meta device is not supported with safetensors, falling back to CPU + if device == "meta": + logger.warning("Meta device is not supported with safetensors. Falling back to CPU device.") + device = "cpu" + return load_file(checkpoint_file, device=device) # type: ignore[arg-type] + # Otherwise, load from pickle + try: + import torch + from torch import load + except ImportError as e: + raise ImportError( + "Please install `torch` to load torch tensors. You can install it with `pip install torch`." + ) from e + # Add additional kwargs, mmap is only supported in torch >= 2.1.0 + additional_kwargs = {} + if version.parse(torch.__version__) >= version.parse("2.1.0"): + additional_kwargs["mmap"] = mmap + + # weights_only is only supported in torch >= 1.13.0 + if version.parse(torch.__version__) >= version.parse("1.13.0"): + additional_kwargs["weights_only"] = weights_only + + return load( + checkpoint_file, + map_location=map_location, + **additional_kwargs, + ) + + +# HELPERS + + +def _validate_keys_for_strict_loading( + model: "torch.nn.Module", + loaded_keys: Iterable[str], +) -> None: + """ + Validate that model keys match loaded keys when strict loading is enabled. + + Args: + model: The PyTorch model being loaded + loaded_keys: The keys present in the checkpoint + + Raises: + RuntimeError: If there are missing or unexpected keys in strict mode + """ + loaded_keys_set = set(loaded_keys) + model_keys = set(model.state_dict().keys()) + missing_keys = model_keys - loaded_keys_set # Keys in model but not in checkpoint + unexpected_keys = loaded_keys_set - model_keys # Keys in checkpoint but not in model + + if missing_keys or unexpected_keys: + error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}" + if missing_keys: + str_missing_keys = ",".join([f'"{k}"' for k in sorted(missing_keys)]) + error_message += f"\nMissing key(s): {str_missing_keys}." + if unexpected_keys: + str_unexpected_keys = ",".join([f'"{k}"' for k in sorted(unexpected_keys)]) + error_message += f"\nUnexpected key(s): {str_unexpected_keys}." + raise RuntimeError(error_message) + + +def _get_unique_id(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]: + """Returns a unique id for plain tensor + or a (potentially nested) Tuple of unique id for the flattened Tensor + if the input is a wrapper tensor subclass Tensor + """ + + try: + # for torch 2.1 and above we can also handle tensor subclasses + from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + if is_traceable_wrapper_subclass(tensor): + attrs, _ = tensor.__tensor_flatten__() # type: ignore[attr-defined] + return tuple(_get_unique_id(getattr(tensor, attr)) for attr in attrs) + + except ImportError: + # for torch version less than 2.1, we can fallback to original implementation + pass + + if tensor.device.type == "xla" and is_torch_tpu_available(): + # NOTE: xla tensors dont have storage + # use some other unique id to distinguish. + # this is a XLA tensor, it must be created using torch_xla's + # device. So the following import is safe: + import torch_xla # type: ignore[import] + + unique_id = torch_xla._XLAC._xla_get_tensor_id(tensor) + else: + unique_id = storage_ptr(tensor) + + return unique_id + + +def get_torch_storage_id(tensor: "torch.Tensor") -> Optional[Tuple["torch.device", Union[int, Tuple[Any, ...]], int]]: + """ + Return unique identifier to a tensor storage. + + Multiple different tensors can share the same underlying storage. This identifier is + guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with + non-overlapping lifetimes may have the same id. + In the case of meta tensors, we return None since we can't tell if they share the same storage. + + Taken from https://github.com/huggingface/transformers/blob/1ecf5f7c982d761b4daaa96719d162c324187c64/src/transformers/pytorch_utils.py#L278. + """ + if tensor.device.type == "meta": + return None + else: + return tensor.device, _get_unique_id(tensor), get_torch_storage_size(tensor) + + +def get_torch_storage_size(tensor: "torch.Tensor") -> int: + """ + Taken from https://github.com/huggingface/safetensors/blob/08db34094e9e59e2f9218f2df133b7b4aaff5a99/bindings/python/py_src/safetensors/torch.py#L31C1-L41C59 + """ + try: + # for torch 2.1 and above we can also handle tensor subclasses + from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + if is_traceable_wrapper_subclass(tensor): + attrs, _ = tensor.__tensor_flatten__() # type: ignore[attr-defined] + return sum(get_torch_storage_size(getattr(tensor, attr)) for attr in attrs) + except ImportError: + # for torch version less than 2.1, we can fallback to original implementation + pass + + try: + return tensor.untyped_storage().nbytes() + except AttributeError: + # Fallback for torch==1.10 + try: + return tensor.storage().size() * _get_dtype_size(tensor.dtype) + except NotImplementedError: + # Fallback for meta storage + # On torch >=2.0 this is the tensor size + return tensor.nelement() * _get_dtype_size(tensor.dtype) + + +@lru_cache() +def is_torch_tpu_available(check_device=True): + """ + Checks if `torch_xla` is installed and potentially if a TPU is in the environment + + Taken from https://github.com/huggingface/transformers/blob/1ecf5f7c982d761b4daaa96719d162c324187c64/src/transformers/utils/import_utils.py#L463. + """ + if importlib.util.find_spec("torch_xla") is not None: + if check_device: + # We need to check if `xla_device` can be found, will raise a RuntimeError if not + try: + import torch_xla.core.xla_model as xm # type: ignore[import] + + _ = xm.xla_device() + return True + except RuntimeError: + return False + return True + return False + + +def storage_ptr(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]: + """ + Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L11. + """ + try: + # for torch 2.1 and above we can also handle tensor subclasses + from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + if is_traceable_wrapper_subclass(tensor): + return _get_unique_id(tensor) # type: ignore + except ImportError: + # for torch version less than 2.1, we can fallback to original implementation + pass + + try: + return tensor.untyped_storage().data_ptr() + except Exception: + # Fallback for torch==1.10 + try: + return tensor.storage().data_ptr() + except NotImplementedError: + # Fallback for meta storage + return 0 + + +def _clean_state_dict_for_safetensors( + state_dict: Dict[str, "torch.Tensor"], + metadata: Dict[str, str], + force_contiguous: bool = True, + shared_tensors_to_discard: Optional[List[str]] = None, +): + """Remove shared tensors from state_dict and update metadata accordingly (for reloading). + + Warning: `state_dict` and `metadata` are mutated in-place! + + Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L155. + """ + to_removes = _remove_duplicate_names(state_dict, discard_names=shared_tensors_to_discard) + for kept_name, to_remove_group in to_removes.items(): + for to_remove in to_remove_group: + if metadata is None: + metadata = {} + + if to_remove not in metadata: + # Do not override user data + metadata[to_remove] = kept_name + del state_dict[to_remove] + if force_contiguous: + state_dict = {k: v.contiguous() for k, v in state_dict.items()} + return state_dict + + +def _end_ptr(tensor: "torch.Tensor") -> int: + """ + Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L23. + """ + if tensor.nelement(): + stop = tensor.view(-1)[-1].data_ptr() + _get_dtype_size(tensor.dtype) + else: + stop = tensor.data_ptr() + return stop + + +def _filter_shared_not_shared(tensors: List[Set[str]], state_dict: Dict[str, "torch.Tensor"]) -> List[Set[str]]: + """ + Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L44 + """ + filtered_tensors = [] + for shared in tensors: + if len(shared) < 2: + filtered_tensors.append(shared) + continue + + areas = [] + for name in shared: + tensor = state_dict[name] + areas.append((tensor.data_ptr(), _end_ptr(tensor), name)) + areas.sort() + + _, last_stop, last_name = areas[0] + filtered_tensors.append({last_name}) + for start, stop, name in areas[1:]: + if start >= last_stop: + filtered_tensors.append({name}) + else: + filtered_tensors[-1].add(name) + last_stop = stop + + return filtered_tensors + + +def _find_shared_tensors(state_dict: Dict[str, "torch.Tensor"]) -> List[Set[str]]: + """ + Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L69. + """ + import torch + + tensors_dict = defaultdict(set) + for k, v in state_dict.items(): + if v.device != torch.device("meta") and storage_ptr(v) != 0 and get_torch_storage_size(v) != 0: + # Need to add device as key because of multiple GPU. + tensors_dict[(v.device, storage_ptr(v), get_torch_storage_size(v))].add(k) + tensors = list(sorted(tensors_dict.values())) + tensors = _filter_shared_not_shared(tensors, state_dict) + return tensors + + +def _is_complete(tensor: "torch.Tensor") -> bool: + """ + Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L80 + """ + try: + # for torch 2.1 and above we can also handle tensor subclasses + from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + if is_traceable_wrapper_subclass(tensor): + attrs, _ = tensor.__tensor_flatten__() # type: ignore[attr-defined] + return all(_is_complete(getattr(tensor, attr)) for attr in attrs) + except ImportError: + # for torch version less than 2.1, we can fallback to original implementation + pass + + return tensor.data_ptr() == storage_ptr(tensor) and tensor.nelement() * _get_dtype_size( + tensor.dtype + ) == get_torch_storage_size(tensor) + + +def _remove_duplicate_names( + state_dict: Dict[str, "torch.Tensor"], + *, + preferred_names: Optional[List[str]] = None, + discard_names: Optional[List[str]] = None, +) -> Dict[str, List[str]]: + """ + Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L80 + """ + if preferred_names is None: + preferred_names = [] + unique_preferred_names = set(preferred_names) + if discard_names is None: + discard_names = [] + unique_discard_names = set(discard_names) + + shareds = _find_shared_tensors(state_dict) + to_remove = defaultdict(list) + for shared in shareds: + complete_names = set([name for name in shared if _is_complete(state_dict[name])]) + if not complete_names: + raise RuntimeError( + "Error while trying to find names to remove to save state dict, but found no suitable name to keep" + f" for saving amongst: {shared}. None is covering the entire storage. Refusing to save/load the model" + " since you could be storing much more memory than needed. Please refer to" + " https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an" + " issue." + ) + + keep_name = sorted(list(complete_names))[0] + + # Mechanism to preferentially select keys to keep + # coming from the on-disk file to allow + # loading models saved with a different choice + # of keep_name + preferred = complete_names.difference(unique_discard_names) + if preferred: + keep_name = sorted(list(preferred))[0] + + if unique_preferred_names: + preferred = unique_preferred_names.intersection(complete_names) + if preferred: + keep_name = sorted(list(preferred))[0] + for name in sorted(shared): + if name != keep_name: + to_remove[keep_name].append(name) + return to_remove + + +@lru_cache() +def _get_dtype_size(dtype: "torch.dtype") -> int: + """ + Taken from https://github.com/huggingface/safetensors/blob/08db34094e9e59e2f9218f2df133b7b4aaff5a99/bindings/python/py_src/safetensors/torch.py#L344 + """ + import torch + + # torch.float8 formats require 2.1; we do not support these dtypes on earlier versions + _float8_e4m3fn = getattr(torch, "float8_e4m3fn", None) + _float8_e5m2 = getattr(torch, "float8_e5m2", None) + _SIZE = { + torch.int64: 8, + torch.float32: 4, + torch.int32: 4, + torch.bfloat16: 2, + torch.float16: 2, + torch.int16: 2, + torch.uint8: 1, + torch.int8: 1, + torch.bool: 1, + torch.float64: 8, + _float8_e4m3fn: 1, + _float8_e5m2: 1, + } + return _SIZE[dtype] + + +class _IncompatibleKeys(namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"])): + """ + This is used to report missing and unexpected keys in the state dict. + Taken from https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py#L52. + + """ + + def __repr__(self) -> str: + if not self.missing_keys and not self.unexpected_keys: + return "<All keys matched successfully>" + return super().__repr__() + + __str__ = __repr__ |