about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/huggingface_hub/serialization
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/huggingface_hub/serialization
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/huggingface_hub/serialization')
-rw-r--r--.venv/lib/python3.12/site-packages/huggingface_hub/serialization/__init__.py27
-rw-r--r--.venv/lib/python3.12/site-packages/huggingface_hub/serialization/_base.py210
-rw-r--r--.venv/lib/python3.12/site-packages/huggingface_hub/serialization/_dduf.py387
-rw-r--r--.venv/lib/python3.12/site-packages/huggingface_hub/serialization/_tensorflow.py95
-rw-r--r--.venv/lib/python3.12/site-packages/huggingface_hub/serialization/_torch.py1015
5 files changed, 1734 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/huggingface_hub/serialization/__init__.py b/.venv/lib/python3.12/site-packages/huggingface_hub/serialization/__init__.py
new file mode 100644
index 00000000..8949a22a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/huggingface_hub/serialization/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ruff: noqa: F401
+"""Contains helpers to serialize tensors."""
+
+from ._base import StateDictSplit, split_state_dict_into_shards_factory
+from ._tensorflow import get_tf_storage_size, split_tf_state_dict_into_shards
+from ._torch import (
+    get_torch_storage_id,
+    get_torch_storage_size,
+    load_state_dict_from_file,
+    load_torch_model,
+    save_torch_model,
+    save_torch_state_dict,
+    split_torch_state_dict_into_shards,
+)
diff --git a/.venv/lib/python3.12/site-packages/huggingface_hub/serialization/_base.py b/.venv/lib/python3.12/site-packages/huggingface_hub/serialization/_base.py
new file mode 100644
index 00000000..b7b6454a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/huggingface_hub/serialization/_base.py
@@ -0,0 +1,210 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Contains helpers to split tensors into shards."""
+
+from dataclasses import dataclass, field
+from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
+
+from .. import logging
+
+
+TensorT = TypeVar("TensorT")
+TensorSizeFn_T = Callable[[TensorT], int]
+StorageIDFn_T = Callable[[TensorT], Optional[Any]]
+
+MAX_SHARD_SIZE = "5GB"
+SIZE_UNITS = {
+    "TB": 10**12,
+    "GB": 10**9,
+    "MB": 10**6,
+    "KB": 10**3,
+}
+
+
+logger = logging.get_logger(__file__)
+
+
+@dataclass
+class StateDictSplit:
+    is_sharded: bool = field(init=False)
+    metadata: Dict[str, Any]
+    filename_to_tensors: Dict[str, List[str]]
+    tensor_to_filename: Dict[str, str]
+
+    def __post_init__(self):
+        self.is_sharded = len(self.filename_to_tensors) > 1
+
+
+def split_state_dict_into_shards_factory(
+    state_dict: Dict[str, TensorT],
+    *,
+    get_storage_size: TensorSizeFn_T,
+    filename_pattern: str,
+    get_storage_id: StorageIDFn_T = lambda tensor: None,
+    max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
+) -> StateDictSplit:
+    """
+    Split a model state dictionary in shards so that each shard is smaller than a given size.
+
+    The shards are determined by iterating through the `state_dict` in the order of its keys. There is no optimization
+    made to make each shard as close as possible to the maximum size passed. For example, if the limit is 10GB and we
+    have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not
+    [6+2+2GB], [6+2GB], [6GB].
+
+    <Tip warning={true}>
+
+    If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a
+    size greater than `max_shard_size`.
+
+    </Tip>
+
+    Args:
+        state_dict (`Dict[str, Tensor]`):
+            The state dictionary to save.
+        get_storage_size (`Callable[[Tensor], int]`):
+            A function that returns the size of a tensor when saved on disk in bytes.
+        get_storage_id (`Callable[[Tensor], Optional[Any]]`, *optional*):
+            A function that returns a unique identifier to a tensor storage. Multiple different tensors can share the
+            same underlying storage. This identifier is guaranteed to be unique and constant for this tensor's storage
+            during its lifetime. Two tensor storages with non-overlapping lifetimes may have the same id.
+        filename_pattern (`str`, *optional*):
+            The pattern to generate the files names in which the model will be saved. Pattern must be a string that
+            can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
+        max_shard_size (`int` or `str`, *optional*):
+            The maximum size of each shard, in bytes. Defaults to 5GB.
+
+    Returns:
+        [`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them.
+    """
+    storage_id_to_tensors: Dict[Any, List[str]] = {}
+
+    shard_list: List[Dict[str, TensorT]] = []
+    current_shard: Dict[str, TensorT] = {}
+    current_shard_size = 0
+    total_size = 0
+
+    if isinstance(max_shard_size, str):
+        max_shard_size = parse_size_to_int(max_shard_size)
+
+    for key, tensor in state_dict.items():
+        # when bnb serialization is used the weights in the state dict can be strings
+        # check: https://github.com/huggingface/transformers/pull/24416 for more details
+        if isinstance(tensor, str):
+            logger.info("Skipping tensor %s as it is a string (bnb serialization)", key)
+            continue
+
+        # If a `tensor` shares the same underlying storage as another tensor, we put `tensor` in the same `block`
+        storage_id = get_storage_id(tensor)
+        if storage_id is not None:
+            if storage_id in storage_id_to_tensors:
+                # We skip this tensor for now and will reassign to correct shard later
+                storage_id_to_tensors[storage_id].append(key)
+                continue
+            else:
+                # This is the first tensor with this storage_id, we create a new entry
+                # in the storage_id_to_tensors dict => we will assign the shard id later
+                storage_id_to_tensors[storage_id] = [key]
+
+        # Compute tensor size
+        tensor_size = get_storage_size(tensor)
+
+        # If this tensor is bigger than the maximal size, we put it in its own shard
+        if tensor_size > max_shard_size:
+            total_size += tensor_size
+            shard_list.append({key: tensor})
+            continue
+
+        # If this tensor is going to tip up over the maximal size, we split.
+        # Current shard already has some tensors, we add it to the list of shards and create a new one.
+        if current_shard_size + tensor_size > max_shard_size:
+            shard_list.append(current_shard)
+            current_shard = {}
+            current_shard_size = 0
+
+        # Add the tensor to the current shard
+        current_shard[key] = tensor
+        current_shard_size += tensor_size
+        total_size += tensor_size
+
+    # Add the last shard
+    if len(current_shard) > 0:
+        shard_list.append(current_shard)
+    nb_shards = len(shard_list)
+
+    # Loop over the tensors that share the same storage and assign them together
+    for storage_id, keys in storage_id_to_tensors.items():
+        # Let's try to find the shard where the first tensor of this storage is and put all tensors in the same shard
+        for shard in shard_list:
+            if keys[0] in shard:
+                for key in keys:
+                    shard[key] = state_dict[key]
+                break
+
+    # If we only have one shard, we return it => no need to build the index
+    if nb_shards == 1:
+        filename = filename_pattern.format(suffix="")
+        return StateDictSplit(
+            metadata={"total_size": total_size},
+            filename_to_tensors={filename: list(state_dict.keys())},
+            tensor_to_filename={key: filename for key in state_dict.keys()},
+        )
+
+    # Now that each tensor is assigned to a shard, let's assign a filename to each shard
+    tensor_name_to_filename = {}
+    filename_to_tensors = {}
+    for idx, shard in enumerate(shard_list):
+        filename = filename_pattern.format(suffix=f"-{idx + 1:05d}-of-{nb_shards:05d}")
+        for key in shard:
+            tensor_name_to_filename[key] = filename
+        filename_to_tensors[filename] = list(shard.keys())
+
+    # Build the index and return
+    return StateDictSplit(
+        metadata={"total_size": total_size},
+        filename_to_tensors=filename_to_tensors,
+        tensor_to_filename=tensor_name_to_filename,
+    )
+
+
+def parse_size_to_int(size_as_str: str) -> int:
+    """
+    Parse a size expressed as a string with digits and unit (like `"5MB"`) to an integer (in bytes).
+
+    Supported units are "TB", "GB", "MB", "KB".
+
+    Args:
+        size_as_str (`str`): The size to convert. Will be directly returned if an `int`.
+
+    Example:
+
+    ```py
+    >>> parse_size_to_int("5MB")
+    5000000
+    ```
+    """
+    size_as_str = size_as_str.strip()
+
+    # Parse unit
+    unit = size_as_str[-2:].upper()
+    if unit not in SIZE_UNITS:
+        raise ValueError(f"Unit '{unit}' not supported. Supported units are TB, GB, MB, KB. Got '{size_as_str}'.")
+    multiplier = SIZE_UNITS[unit]
+
+    # Parse value
+    try:
+        value = float(size_as_str[:-2].strip())
+    except ValueError as e:
+        raise ValueError(f"Could not parse the size value from '{size_as_str}': {e}") from e
+
+    return int(value * multiplier)
diff --git a/.venv/lib/python3.12/site-packages/huggingface_hub/serialization/_dduf.py b/.venv/lib/python3.12/site-packages/huggingface_hub/serialization/_dduf.py
new file mode 100644
index 00000000..a1debadb
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/huggingface_hub/serialization/_dduf.py
@@ -0,0 +1,387 @@
+import json
+import logging
+import mmap
+import os
+import shutil
+import zipfile
+from contextlib import contextmanager
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import Any, Dict, Generator, Iterable, Tuple, Union
+
+from ..errors import DDUFCorruptedFileError, DDUFExportError, DDUFInvalidEntryNameError
+
+
+logger = logging.getLogger(__name__)
+
+DDUF_ALLOWED_ENTRIES = {
+    # Allowed file extensions in a DDUF file
+    ".json",
+    ".model",
+    ".safetensors",
+    ".txt",
+}
+
+DDUF_FOLDER_REQUIRED_ENTRIES = {
+    # Each folder must contain at least one of these entries
+    "config.json",
+    "tokenizer_config.json",
+    "preprocessor_config.json",
+    "scheduler_config.json",
+}
+
+
+@dataclass
+class DDUFEntry:
+    """Object representing a file entry in a DDUF file.
+
+    See [`read_dduf_file`] for how to read a DDUF file.
+
+    Attributes:
+        filename (str):
+            The name of the file in the DDUF archive.
+        offset (int):
+            The offset of the file in the DDUF archive.
+        length (int):
+            The length of the file in the DDUF archive.
+        dduf_path (str):
+            The path to the DDUF archive (for internal use).
+    """
+
+    filename: str
+    length: int
+    offset: int
+
+    dduf_path: Path = field(repr=False)
+
+    @contextmanager
+    def as_mmap(self) -> Generator[bytes, None, None]:
+        """Open the file as a memory-mapped file.
+
+        Useful to load safetensors directly from the file.
+
+        Example:
+            ```py
+            >>> import safetensors.torch
+            >>> with entry.as_mmap() as mm:
+            ...     tensors = safetensors.torch.load(mm)
+            ```
+        """
+        with self.dduf_path.open("rb") as f:
+            with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) as mm:
+                yield mm[self.offset : self.offset + self.length]
+
+    def read_text(self, encoding: str = "utf-8") -> str:
+        """Read the file as text.
+
+        Useful for '.txt' and '.json' entries.
+
+        Example:
+            ```py
+            >>> import json
+            >>> index = json.loads(entry.read_text())
+            ```
+        """
+        with self.dduf_path.open("rb") as f:
+            f.seek(self.offset)
+            return f.read(self.length).decode(encoding=encoding)
+
+
+def read_dduf_file(dduf_path: Union[os.PathLike, str]) -> Dict[str, DDUFEntry]:
+    """
+    Read a DDUF file and return a dictionary of entries.
+
+    Only the metadata is read, the data is not loaded in memory.
+
+    Args:
+        dduf_path (`str` or `os.PathLike`):
+            The path to the DDUF file to read.
+
+    Returns:
+        `Dict[str, DDUFEntry]`:
+            A dictionary of [`DDUFEntry`] indexed by filename.
+
+    Raises:
+        - [`DDUFCorruptedFileError`]: If the DDUF file is corrupted (i.e. doesn't follow the DDUF format).
+
+    Example:
+        ```python
+        >>> import json
+        >>> import safetensors.torch
+        >>> from huggingface_hub import read_dduf_file
+
+        # Read DDUF metadata
+        >>> dduf_entries = read_dduf_file("FLUX.1-dev.dduf")
+
+        # Returns a mapping filename <> DDUFEntry
+        >>> dduf_entries["model_index.json"]
+        DDUFEntry(filename='model_index.json', offset=66, length=587)
+
+        # Load model index as JSON
+        >>> json.loads(dduf_entries["model_index.json"].read_text())
+        {'_class_name': 'FluxPipeline', '_diffusers_version': '0.32.0.dev0', '_name_or_path': 'black-forest-labs/FLUX.1-dev', ...
+
+        # Load VAE weights using safetensors
+        >>> with dduf_entries["vae/diffusion_pytorch_model.safetensors"].as_mmap() as mm:
+        ...     state_dict = safetensors.torch.load(mm)
+        ```
+    """
+    entries = {}
+    dduf_path = Path(dduf_path)
+    logger.info(f"Reading DDUF file {dduf_path}")
+    with zipfile.ZipFile(str(dduf_path), "r") as zf:
+        for info in zf.infolist():
+            logger.debug(f"Reading entry {info.filename}")
+            if info.compress_type != zipfile.ZIP_STORED:
+                raise DDUFCorruptedFileError("Data must not be compressed in DDUF file.")
+
+            try:
+                _validate_dduf_entry_name(info.filename)
+            except DDUFInvalidEntryNameError as e:
+                raise DDUFCorruptedFileError(f"Invalid entry name in DDUF file: {info.filename}") from e
+
+            offset = _get_data_offset(zf, info)
+
+            entries[info.filename] = DDUFEntry(
+                filename=info.filename, offset=offset, length=info.file_size, dduf_path=dduf_path
+            )
+
+    # Consistency checks on the DDUF file
+    if "model_index.json" not in entries:
+        raise DDUFCorruptedFileError("Missing required 'model_index.json' entry in DDUF file.")
+    index = json.loads(entries["model_index.json"].read_text())
+    _validate_dduf_structure(index, entries.keys())
+
+    logger.info(f"Done reading DDUF file {dduf_path}. Found {len(entries)} entries")
+    return entries
+
+
+def export_entries_as_dduf(
+    dduf_path: Union[str, os.PathLike], entries: Iterable[Tuple[str, Union[str, Path, bytes]]]
+) -> None:
+    """Write a DDUF file from an iterable of entries.
+
+    This is a lower-level helper than [`export_folder_as_dduf`] that allows more flexibility when serializing data.
+    In particular, you don't need to save the data on disk before exporting it in the DDUF file.
+
+    Args:
+        dduf_path (`str` or `os.PathLike`):
+            The path to the DDUF file to write.
+        entries (`Iterable[Tuple[str, Union[str, Path, bytes]]]`):
+            An iterable of entries to write in the DDUF file. Each entry is a tuple with the filename and the content.
+            The filename should be the path to the file in the DDUF archive.
+            The content can be a string or a pathlib.Path representing a path to a file on the local disk or directly the content as bytes.
+
+    Raises:
+        - [`DDUFExportError`]: If anything goes wrong during the export (e.g. invalid entry name, missing 'model_index.json', etc.).
+
+    Example:
+        ```python
+        # Export specific files from the local disk.
+        >>> from huggingface_hub import export_entries_as_dduf
+        >>> export_entries_as_dduf(
+        ...     dduf_path="stable-diffusion-v1-4-FP16.dduf",
+        ...     entries=[ # List entries to add to the DDUF file (here, only FP16 weights)
+        ...         ("model_index.json", "path/to/model_index.json"),
+        ...         ("vae/config.json", "path/to/vae/config.json"),
+        ...         ("vae/diffusion_pytorch_model.fp16.safetensors", "path/to/vae/diffusion_pytorch_model.fp16.safetensors"),
+        ...         ("text_encoder/config.json", "path/to/text_encoder/config.json"),
+        ...         ("text_encoder/model.fp16.safetensors", "path/to/text_encoder/model.fp16.safetensors"),
+        ...         # ... add more entries here
+        ...     ]
+        ... )
+        ```
+
+        ```python
+        # Export state_dicts one by one from a loaded pipeline
+        >>> from diffusers import DiffusionPipeline
+        >>> from typing import Generator, Tuple
+        >>> import safetensors.torch
+        >>> from huggingface_hub import export_entries_as_dduf
+        >>> pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
+        ... # ... do some work with the pipeline
+
+        >>> def as_entries(pipe: DiffusionPipeline) -> Generator[Tuple[str, bytes], None, None]:
+        ...     # Build an generator that yields the entries to add to the DDUF file.
+        ...     # The first element of the tuple is the filename in the DDUF archive (must use UNIX separator!). The second element is the content of the file.
+        ...     # Entries will be evaluated lazily when the DDUF file is created (only 1 entry is loaded in memory at a time)
+        ...     yield "vae/config.json", pipe.vae.to_json_string().encode()
+        ...     yield "vae/diffusion_pytorch_model.safetensors", safetensors.torch.save(pipe.vae.state_dict())
+        ...     yield "text_encoder/config.json", pipe.text_encoder.config.to_json_string().encode()
+        ...     yield "text_encoder/model.safetensors", safetensors.torch.save(pipe.text_encoder.state_dict())
+        ...     # ... add more entries here
+
+        >>> export_entries_as_dduf(dduf_path="stable-diffusion-v1-4.dduf", entries=as_entries(pipe))
+        ```
+    """
+    logger.info(f"Exporting DDUF file '{dduf_path}'")
+    filenames = set()
+    index = None
+    with zipfile.ZipFile(str(dduf_path), "w", zipfile.ZIP_STORED) as archive:
+        for filename, content in entries:
+            if filename in filenames:
+                raise DDUFExportError(f"Can't add duplicate entry: {filename}")
+            filenames.add(filename)
+
+            if filename == "model_index.json":
+                try:
+                    index = json.loads(_load_content(content).decode())
+                except json.JSONDecodeError as e:
+                    raise DDUFExportError("Failed to parse 'model_index.json'.") from e
+
+            try:
+                filename = _validate_dduf_entry_name(filename)
+            except DDUFInvalidEntryNameError as e:
+                raise DDUFExportError(f"Invalid entry name: {filename}") from e
+            logger.debug(f"Adding entry '{filename}' to DDUF file")
+            _dump_content_in_archive(archive, filename, content)
+
+    # Consistency checks on the DDUF file
+    if index is None:
+        raise DDUFExportError("Missing required 'model_index.json' entry in DDUF file.")
+    try:
+        _validate_dduf_structure(index, filenames)
+    except DDUFCorruptedFileError as e:
+        raise DDUFExportError("Invalid DDUF file structure.") from e
+
+    logger.info(f"Done writing DDUF file {dduf_path}")
+
+
+def export_folder_as_dduf(dduf_path: Union[str, os.PathLike], folder_path: Union[str, os.PathLike]) -> None:
+    """
+    Export a folder as a DDUF file.
+
+    AUses [`export_entries_as_dduf`] under the hood.
+
+    Args:
+        dduf_path (`str` or `os.PathLike`):
+            The path to the DDUF file to write.
+        folder_path (`str` or `os.PathLike`):
+            The path to the folder containing the diffusion model.
+
+    Example:
+        ```python
+        >>> from huggingface_hub import export_folder_as_dduf
+        >>> export_folder_as_dduf(dduf_path="FLUX.1-dev.dduf", folder_path="path/to/FLUX.1-dev")
+        ```
+    """
+    folder_path = Path(folder_path)
+
+    def _iterate_over_folder() -> Iterable[Tuple[str, Path]]:
+        for path in Path(folder_path).glob("**/*"):
+            if not path.is_file():
+                continue
+            if path.suffix not in DDUF_ALLOWED_ENTRIES:
+                logger.debug(f"Skipping file '{path}' (file type not allowed)")
+                continue
+            path_in_archive = path.relative_to(folder_path)
+            if len(path_in_archive.parts) >= 3:
+                logger.debug(f"Skipping file '{path}' (nested directories not allowed)")
+                continue
+            yield path_in_archive.as_posix(), path
+
+    export_entries_as_dduf(dduf_path, _iterate_over_folder())
+
+
+def _dump_content_in_archive(archive: zipfile.ZipFile, filename: str, content: Union[str, os.PathLike, bytes]) -> None:
+    with archive.open(filename, "w", force_zip64=True) as archive_fh:
+        if isinstance(content, (str, Path)):
+            content_path = Path(content)
+            with content_path.open("rb") as content_fh:
+                shutil.copyfileobj(content_fh, archive_fh, 1024 * 1024 * 8)  # type: ignore[misc]
+        elif isinstance(content, bytes):
+            archive_fh.write(content)
+        else:
+            raise DDUFExportError(f"Invalid content type for {filename}. Must be str, Path or bytes.")
+
+
+def _load_content(content: Union[str, Path, bytes]) -> bytes:
+    """Load the content of an entry as bytes.
+
+    Used only for small checks (not to dump content into archive).
+    """
+    if isinstance(content, (str, Path)):
+        return Path(content).read_bytes()
+    elif isinstance(content, bytes):
+        return content
+    else:
+        raise DDUFExportError(f"Invalid content type. Must be str, Path or bytes. Got {type(content)}.")
+
+
+def _validate_dduf_entry_name(entry_name: str) -> str:
+    if "." + entry_name.split(".")[-1] not in DDUF_ALLOWED_ENTRIES:
+        raise DDUFInvalidEntryNameError(f"File type not allowed: {entry_name}")
+    if "\\" in entry_name:
+        raise DDUFInvalidEntryNameError(f"Entry names must use UNIX separators ('/'). Got {entry_name}.")
+    entry_name = entry_name.strip("/")
+    if entry_name.count("/") > 1:
+        raise DDUFInvalidEntryNameError(f"DDUF only supports 1 level of directory. Got {entry_name}.")
+    return entry_name
+
+
+def _validate_dduf_structure(index: Any, entry_names: Iterable[str]) -> None:
+    """
+    Consistency checks on the DDUF file structure.
+
+    Rules:
+    - The 'model_index.json' entry is required and must contain a dictionary.
+    - Each folder name must correspond to an entry in 'model_index.json'.
+    - Each folder must contain at least a config file ('config.json', 'tokenizer_config.json', 'preprocessor_config.json', 'scheduler_config.json').
+
+    Args:
+        index (Any):
+            The content of the 'model_index.json' entry.
+        entry_names (Iterable[str]):
+            The list of entry names in the DDUF file.
+
+    Raises:
+        - [`DDUFCorruptedFileError`]: If the DDUF file is corrupted (i.e. doesn't follow the DDUF format).
+    """
+    if not isinstance(index, dict):
+        raise DDUFCorruptedFileError(f"Invalid 'model_index.json' content. Must be a dictionary. Got {type(index)}.")
+
+    dduf_folders = {entry.split("/")[0] for entry in entry_names if "/" in entry}
+    for folder in dduf_folders:
+        if folder not in index:
+            raise DDUFCorruptedFileError(f"Missing required entry '{folder}' in 'model_index.json'.")
+        if not any(f"{folder}/{required_entry}" in entry_names for required_entry in DDUF_FOLDER_REQUIRED_ENTRIES):
+            raise DDUFCorruptedFileError(
+                f"Missing required file in folder '{folder}'. Must contains at least one of {DDUF_FOLDER_REQUIRED_ENTRIES}."
+            )
+
+
+def _get_data_offset(zf: zipfile.ZipFile, info: zipfile.ZipInfo) -> int:
+    """
+    Calculate the data offset for a file in a ZIP archive.
+
+    Args:
+        zf (`zipfile.ZipFile`):
+            The opened ZIP file. Must be opened in read mode.
+        info (`zipfile.ZipInfo`):
+            The file info.
+
+    Returns:
+        int: The offset of the file data in the ZIP archive.
+    """
+    if zf.fp is None:
+        raise DDUFCorruptedFileError("ZipFile object must be opened in read mode.")
+
+    # Step 1: Get the local file header offset
+    header_offset = info.header_offset
+
+    # Step 2: Read the local file header
+    zf.fp.seek(header_offset)
+    local_file_header = zf.fp.read(30)  # Fixed-size part of the local header
+
+    if len(local_file_header) < 30:
+        raise DDUFCorruptedFileError("Incomplete local file header.")
+
+    # Step 3: Parse the header fields to calculate the start of file data
+    # Local file header: https://en.wikipedia.org/wiki/ZIP_(file_format)#File_headers
+    filename_len = int.from_bytes(local_file_header[26:28], "little")
+    extra_field_len = int.from_bytes(local_file_header[28:30], "little")
+
+    # Data offset is after the fixed header, filename, and extra fields
+    data_offset = header_offset + 30 + filename_len + extra_field_len
+
+    return data_offset
diff --git a/.venv/lib/python3.12/site-packages/huggingface_hub/serialization/_tensorflow.py b/.venv/lib/python3.12/site-packages/huggingface_hub/serialization/_tensorflow.py
new file mode 100644
index 00000000..59ed8110
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/huggingface_hub/serialization/_tensorflow.py
@@ -0,0 +1,95 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Contains tensorflow-specific helpers."""
+
+import math
+import re
+from typing import TYPE_CHECKING, Dict, Union
+
+from .. import constants
+from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory
+
+
+if TYPE_CHECKING:
+    import tensorflow as tf
+
+
+def split_tf_state_dict_into_shards(
+    state_dict: Dict[str, "tf.Tensor"],
+    *,
+    filename_pattern: str = constants.TF2_WEIGHTS_FILE_PATTERN,
+    max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
+) -> StateDictSplit:
+    """
+    Split a model state dictionary in shards so that each shard is smaller than a given size.
+
+    The shards are determined by iterating through the `state_dict` in the order of its keys. There is no optimization
+    made to make each shard as close as possible to the maximum size passed. For example, if the limit is 10GB and we
+    have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not
+    [6+2+2GB], [6+2GB], [6GB].
+
+    <Tip warning={true}>
+
+    If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a
+    size greater than `max_shard_size`.
+
+    </Tip>
+
+    Args:
+        state_dict (`Dict[str, Tensor]`):
+            The state dictionary to save.
+        filename_pattern (`str`, *optional*):
+            The pattern to generate the files names in which the model will be saved. Pattern must be a string that
+            can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
+            Defaults to `"tf_model{suffix}.h5"`.
+        max_shard_size (`int` or `str`, *optional*):
+            The maximum size of each shard, in bytes. Defaults to 5GB.
+
+    Returns:
+        [`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them.
+    """
+    return split_state_dict_into_shards_factory(
+        state_dict,
+        max_shard_size=max_shard_size,
+        filename_pattern=filename_pattern,
+        get_storage_size=get_tf_storage_size,
+    )
+
+
+def get_tf_storage_size(tensor: "tf.Tensor") -> int:
+    # Return `math.ceil` since dtype byte size can be a float (e.g., 0.125 for tf.bool).
+    # Better to overestimate than underestimate.
+    return math.ceil(tensor.numpy().size * _dtype_byte_size_tf(tensor.dtype))
+
+
+def _dtype_byte_size_tf(dtype) -> float:
+    """
+    Returns the size (in bytes) occupied by one parameter of type `dtype`.
+    Taken from https://github.com/huggingface/transformers/blob/74d9d0cebb0263a3f8ab9c280569170cc74651d0/src/transformers/modeling_tf_utils.py#L608.
+    NOTE: why not `tensor.numpy().nbytes`?
+    Example:
+    ```py
+    >>> _dtype_byte_size(tf.float32)
+    4
+    ```
+    """
+    import tensorflow as tf
+
+    if dtype == tf.bool:
+        return 1 / 8
+    bit_search = re.search(r"[^\d](\d+)$", dtype.name)
+    if bit_search is None:
+        raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
+    bit_size = int(bit_search.groups()[0])
+    return bit_size // 8
diff --git a/.venv/lib/python3.12/site-packages/huggingface_hub/serialization/_torch.py b/.venv/lib/python3.12/site-packages/huggingface_hub/serialization/_torch.py
new file mode 100644
index 00000000..ccb9c42b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/huggingface_hub/serialization/_torch.py
@@ -0,0 +1,1015 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Contains pytorch-specific helpers."""
+
+import importlib
+import json
+import os
+import re
+from collections import defaultdict, namedtuple
+from functools import lru_cache
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple, Union
+
+from packaging import version
+
+from .. import constants, logging
+from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory
+
+
+logger = logging.get_logger(__file__)
+
+if TYPE_CHECKING:
+    import torch
+
+# SAVING
+
+
+def save_torch_model(
+    model: "torch.nn.Module",
+    save_directory: Union[str, Path],
+    *,
+    filename_pattern: Optional[str] = None,
+    force_contiguous: bool = True,
+    max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
+    metadata: Optional[Dict[str, str]] = None,
+    safe_serialization: bool = True,
+    is_main_process: bool = True,
+    shared_tensors_to_discard: Optional[List[str]] = None,
+):
+    """
+    Saves a given torch model to disk, handling sharding and shared tensors issues.
+
+    See also [`save_torch_state_dict`] to save a state dict with more flexibility.
+
+    For more information about tensor sharing, check out [this guide](https://huggingface.co/docs/safetensors/torch_shared_tensors).
+
+    The model state dictionary is split into shards so that each shard is smaller than a given size. The shards are
+    saved in the `save_directory` with the given `filename_pattern`. If the model is too big to fit in a single shard,
+    an index file is saved in the `save_directory` to indicate where each tensor is saved. This helper uses
+    [`split_torch_state_dict_into_shards`] under the hood. If `safe_serialization` is `True`, the shards are saved as
+    safetensors (the default). Otherwise, the shards are saved as pickle.
+
+    Before saving the model, the `save_directory` is cleaned from any previous shard files.
+
+    <Tip warning={true}>
+
+    If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a
+    size greater than `max_shard_size`.
+
+    </Tip>
+
+    <Tip warning={true}>
+
+    If your model is a `transformers.PreTrainedModel`, you should pass `model._tied_weights_keys` as `shared_tensors_to_discard` to properly handle shared tensors saving. This ensures the correct duplicate tensors are discarded during saving.
+
+    </Tip>
+
+    Args:
+        model (`torch.nn.Module`):
+            The model to save on disk.
+        save_directory (`str` or `Path`):
+            The directory in which the model will be saved.
+        filename_pattern (`str`, *optional*):
+            The pattern to generate the files names in which the model will be saved. Pattern must be a string that
+            can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
+            Defaults to `"model{suffix}.safetensors"` or `pytorch_model{suffix}.bin` depending on `safe_serialization`
+            parameter.
+        force_contiguous (`boolean`, *optional*):
+            Forcing the state_dict to be saved as contiguous tensors. This has no effect on the correctness of the
+            model, but it could potentially change performance if the layout of the tensor was chosen specifically for
+            that reason. Defaults to `True`.
+        max_shard_size (`int` or `str`, *optional*):
+            The maximum size of each shard, in bytes. Defaults to 5GB.
+        metadata (`Dict[str, str]`, *optional*):
+            Extra information to save along with the model. Some metadata will be added for each dropped tensors.
+            This information will not be enough to recover the entire shared structure but might help understanding
+            things.
+        safe_serialization (`bool`, *optional*):
+            Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle.
+            Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed
+            in a future version.
+        is_main_process (`bool`, *optional*):
+            Whether the process calling this is the main process or not. Useful when in distributed training like
+            TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on
+            the main process to avoid race conditions. Defaults to True.
+        shared_tensors_to_discard (`List[str]`, *optional*):
+            List of tensor names to drop when saving shared tensors. If not provided and shared tensors are
+            detected, it will drop the first name alphabetically.
+
+    Example:
+
+    ```py
+    >>> from huggingface_hub import save_torch_model
+    >>> model = ... # A PyTorch model
+
+    # Save state dict to "path/to/folder". The model will be split into shards of 5GB each and saved as safetensors.
+    >>> save_torch_model(model, "path/to/folder")
+
+    # Load model back
+    >>> from huggingface_hub import load_torch_model  # TODO
+    >>> load_torch_model(model, "path/to/folder")
+    >>>
+    ```
+    """
+    save_torch_state_dict(
+        state_dict=model.state_dict(),
+        filename_pattern=filename_pattern,
+        force_contiguous=force_contiguous,
+        max_shard_size=max_shard_size,
+        metadata=metadata,
+        safe_serialization=safe_serialization,
+        save_directory=save_directory,
+        is_main_process=is_main_process,
+        shared_tensors_to_discard=shared_tensors_to_discard,
+    )
+
+
+def save_torch_state_dict(
+    state_dict: Dict[str, "torch.Tensor"],
+    save_directory: Union[str, Path],
+    *,
+    filename_pattern: Optional[str] = None,
+    force_contiguous: bool = True,
+    max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
+    metadata: Optional[Dict[str, str]] = None,
+    safe_serialization: bool = True,
+    is_main_process: bool = True,
+    shared_tensors_to_discard: Optional[List[str]] = None,
+) -> None:
+    """
+    Save a model state dictionary to the disk, handling sharding and shared tensors issues.
+
+    See also [`save_torch_model`] to directly save a PyTorch model.
+
+    For more information about tensor sharing, check out [this guide](https://huggingface.co/docs/safetensors/torch_shared_tensors).
+
+    The model state dictionary is split into shards so that each shard is smaller than a given size. The shards are
+    saved in the `save_directory` with the given `filename_pattern`. If the model is too big to fit in a single shard,
+    an index file is saved in the `save_directory` to indicate where each tensor is saved. This helper uses
+    [`split_torch_state_dict_into_shards`] under the hood. If `safe_serialization` is `True`, the shards are saved as
+    safetensors (the default). Otherwise, the shards are saved as pickle.
+
+    Before saving the model, the `save_directory` is cleaned from any previous shard files.
+
+    <Tip warning={true}>
+
+    If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a
+    size greater than `max_shard_size`.
+
+    </Tip>
+
+    <Tip warning={true}>
+
+    If your model is a `transformers.PreTrainedModel`, you should pass `model._tied_weights_keys` as `shared_tensors_to_discard` to properly handle shared tensors saving. This ensures the correct duplicate tensors are discarded during saving.
+
+    </Tip>
+
+    Args:
+        state_dict (`Dict[str, torch.Tensor]`):
+            The state dictionary to save.
+        save_directory (`str` or `Path`):
+            The directory in which the model will be saved.
+        filename_pattern (`str`, *optional*):
+            The pattern to generate the files names in which the model will be saved. Pattern must be a string that
+            can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
+            Defaults to `"model{suffix}.safetensors"` or `pytorch_model{suffix}.bin` depending on `safe_serialization`
+            parameter.
+        force_contiguous (`boolean`, *optional*):
+            Forcing the state_dict to be saved as contiguous tensors. This has no effect on the correctness of the
+            model, but it could potentially change performance if the layout of the tensor was chosen specifically for
+            that reason. Defaults to `True`.
+        max_shard_size (`int` or `str`, *optional*):
+            The maximum size of each shard, in bytes. Defaults to 5GB.
+        metadata (`Dict[str, str]`, *optional*):
+            Extra information to save along with the model. Some metadata will be added for each dropped tensors.
+            This information will not be enough to recover the entire shared structure but might help understanding
+            things.
+        safe_serialization (`bool`, *optional*):
+            Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle.
+            Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed
+            in a future version.
+        is_main_process (`bool`, *optional*):
+            Whether the process calling this is the main process or not. Useful when in distributed training like
+            TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on
+            the main process to avoid race conditions. Defaults to True.
+        shared_tensors_to_discard (`List[str]`, *optional*):
+            List of tensor names to drop when saving shared tensors. If not provided and shared tensors are
+            detected, it will drop the first name alphabetically.
+
+    Example:
+
+    ```py
+    >>> from huggingface_hub import save_torch_state_dict
+    >>> model = ... # A PyTorch model
+
+    # Save state dict to "path/to/folder". The model will be split into shards of 5GB each and saved as safetensors.
+    >>> state_dict = model_to_save.state_dict()
+    >>> save_torch_state_dict(state_dict, "path/to/folder")
+    ```
+    """
+    save_directory = str(save_directory)
+
+    if filename_pattern is None:
+        filename_pattern = (
+            constants.SAFETENSORS_WEIGHTS_FILE_PATTERN
+            if safe_serialization
+            else constants.PYTORCH_WEIGHTS_FILE_PATTERN
+        )
+
+    if metadata is None:
+        metadata = {}
+    if safe_serialization:
+        try:
+            from safetensors.torch import save_file as save_file_fn
+        except ImportError as e:
+            raise ImportError(
+                "Please install `safetensors` to use safe serialization. "
+                "You can install it with `pip install safetensors`."
+            ) from e
+        # Clean state dict for safetensors
+        state_dict = _clean_state_dict_for_safetensors(
+            state_dict,
+            metadata,
+            force_contiguous=force_contiguous,
+            shared_tensors_to_discard=shared_tensors_to_discard,
+        )
+    else:
+        from torch import save as save_file_fn  # type: ignore[assignment]
+
+        logger.warning(
+            "You are using unsafe serialization. Due to security reasons, it is recommended not to load "
+            "pickled models from untrusted sources. If you intend to share your model, we strongly recommend "
+            "using safe serialization by installing `safetensors` with `pip install safetensors`."
+        )
+    # Split dict
+    state_dict_split = split_torch_state_dict_into_shards(
+        state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
+    )
+
+    # Only main process should clean up existing files to avoid race conditions in distributed environment
+    if is_main_process:
+        existing_files_regex = re.compile(filename_pattern.format(suffix=r"(-\d{5}-of-\d{5})?") + r"(\.index\.json)?")
+        for filename in os.listdir(save_directory):
+            if existing_files_regex.match(filename):
+                try:
+                    logger.debug(f"Removing existing file '{filename}' from folder.")
+                    os.remove(os.path.join(save_directory, filename))
+                except Exception as e:
+                    logger.warning(
+                        f"Error when trying to remove existing '{filename}' from folder: {e}. Continuing..."
+                    )
+
+    # Save each shard
+    per_file_metadata = {"format": "pt"}
+    if not state_dict_split.is_sharded:
+        per_file_metadata.update(metadata)
+    safe_file_kwargs = {"metadata": per_file_metadata} if safe_serialization else {}
+    for filename, tensors in state_dict_split.filename_to_tensors.items():
+        shard = {tensor: state_dict[tensor] for tensor in tensors}
+        save_file_fn(shard, os.path.join(save_directory, filename), **safe_file_kwargs)
+        logger.debug(f"Shard saved to {filename}")
+
+    # Save the index (if any)
+    if state_dict_split.is_sharded:
+        index_path = filename_pattern.format(suffix="") + ".index.json"
+        index = {
+            "metadata": {**state_dict_split.metadata, **metadata},
+            "weight_map": state_dict_split.tensor_to_filename,
+        }
+        with open(os.path.join(save_directory, index_path), "w") as f:
+            json.dump(index, f, indent=2)
+        logger.info(
+            f"The model is bigger than the maximum size per checkpoint ({max_shard_size}). "
+            f"Model weighs have been saved in {len(state_dict_split.filename_to_tensors)} checkpoint shards. "
+            f"You can find where each parameters has been saved in the index located at {index_path}."
+        )
+
+    logger.info(f"Model weights successfully saved to {save_directory}!")
+
+
+def split_torch_state_dict_into_shards(
+    state_dict: Dict[str, "torch.Tensor"],
+    *,
+    filename_pattern: str = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN,
+    max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
+) -> StateDictSplit:
+    """
+    Split a model state dictionary in shards so that each shard is smaller than a given size.
+
+    The shards are determined by iterating through the `state_dict` in the order of its keys. There is no optimization
+    made to make each shard as close as possible to the maximum size passed. For example, if the limit is 10GB and we
+    have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not
+    [6+2+2GB], [6+2GB], [6GB].
+
+
+    <Tip>
+
+    To save a model state dictionary to the disk, see [`save_torch_state_dict`]. This helper uses
+    `split_torch_state_dict_into_shards` under the hood.
+
+    </Tip>
+
+    <Tip warning={true}>
+
+    If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a
+    size greater than `max_shard_size`.
+
+    </Tip>
+
+    Args:
+        state_dict (`Dict[str, torch.Tensor]`):
+            The state dictionary to save.
+        filename_pattern (`str`, *optional*):
+            The pattern to generate the files names in which the model will be saved. Pattern must be a string that
+            can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
+            Defaults to `"model{suffix}.safetensors"`.
+        max_shard_size (`int` or `str`, *optional*):
+            The maximum size of each shard, in bytes. Defaults to 5GB.
+
+    Returns:
+        [`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them.
+
+    Example:
+    ```py
+    >>> import json
+    >>> import os
+    >>> from safetensors.torch import save_file as safe_save_file
+    >>> from huggingface_hub import split_torch_state_dict_into_shards
+
+    >>> def save_state_dict(state_dict: Dict[str, torch.Tensor], save_directory: str):
+    ...     state_dict_split = split_torch_state_dict_into_shards(state_dict)
+    ...     for filename, tensors in state_dict_split.filename_to_tensors.items():
+    ...         shard = {tensor: state_dict[tensor] for tensor in tensors}
+    ...         safe_save_file(
+    ...             shard,
+    ...             os.path.join(save_directory, filename),
+    ...             metadata={"format": "pt"},
+    ...         )
+    ...     if state_dict_split.is_sharded:
+    ...         index = {
+    ...             "metadata": state_dict_split.metadata,
+    ...             "weight_map": state_dict_split.tensor_to_filename,
+    ...         }
+    ...         with open(os.path.join(save_directory, "model.safetensors.index.json"), "w") as f:
+    ...             f.write(json.dumps(index, indent=2))
+    ```
+    """
+    return split_state_dict_into_shards_factory(
+        state_dict,
+        max_shard_size=max_shard_size,
+        filename_pattern=filename_pattern,
+        get_storage_size=get_torch_storage_size,
+        get_storage_id=get_torch_storage_id,
+    )
+
+
+# LOADING
+
+
+def load_torch_model(
+    model: "torch.nn.Module",
+    checkpoint_path: Union[str, os.PathLike],
+    *,
+    strict: bool = False,
+    safe: bool = True,
+    weights_only: bool = False,
+    map_location: Optional[Union[str, "torch.device"]] = None,
+    mmap: bool = False,
+    filename_pattern: Optional[str] = None,
+) -> NamedTuple:
+    """
+    Load a checkpoint into a model, handling both sharded and non-sharded checkpoints.
+
+    Args:
+        model (`torch.nn.Module`):
+            The model in which to load the checkpoint.
+        checkpoint_path (`str` or `os.PathLike`):
+            Path to either the checkpoint file or directory containing the checkpoint(s).
+        strict (`bool`, *optional*, defaults to `False`):
+            Whether to strictly enforce that the keys in the model state dict match the keys in the checkpoint.
+        safe (`bool`, *optional*, defaults to `True`):
+            If `safe` is True, the safetensors files will be loaded. If `safe` is False, the function
+            will first attempt to load safetensors files if they are available, otherwise it will fall back to loading
+            pickle files. `filename_pattern` parameter takes precedence over `safe` parameter.
+        weights_only (`bool`, *optional*, defaults to `False`):
+            If True, only loads the model weights without optimizer states and other metadata.
+            Only supported in PyTorch >= 1.13.
+        map_location (`str` or `torch.device`, *optional*):
+            A `torch.device` object, string or a dict specifying how to remap storage locations. It
+            indicates the location where all tensors should be loaded.
+        mmap (`bool`, *optional*, defaults to `False`):
+            Whether to use memory-mapped file loading. Memory mapping can improve loading performance
+            for large models in PyTorch >= 2.1.0 with zipfile-based checkpoints.
+        filename_pattern (`str`, *optional*):
+            The pattern to look for the index file. Pattern must be a string that
+            can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
+            Defaults to `"model{suffix}.safetensors"`.
+    Returns:
+        `NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields.
+            - `missing_keys` is a list of str containing the missing keys, i.e. keys that are in the model but not in the checkpoint.
+            - `unexpected_keys` is a list of str containing the unexpected keys, i.e. keys that are in the checkpoint but not in the model.
+
+    Raises:
+        [`FileNotFoundError`](https://docs.python.org/3/library/exceptions.html#FileNotFoundError)
+            If the checkpoint file or directory does not exist.
+        [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
+            If safetensors or torch is not installed when trying to load a .safetensors file or a PyTorch checkpoint respectively.
+        [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
+           If the checkpoint path is invalid or if the checkpoint format cannot be determined.
+
+    Example:
+    ```python
+    >>> from huggingface_hub import load_torch_model
+    >>> model = ... # A PyTorch model
+    >>> load_torch_model(model, "path/to/checkpoint")
+    ```
+    """
+    checkpoint_path = Path(checkpoint_path)
+
+    if not checkpoint_path.exists():
+        raise ValueError(f"Checkpoint path {checkpoint_path} does not exist")
+    # 1. Check if checkpoint is a single file
+    if checkpoint_path.is_file():
+        state_dict = load_state_dict_from_file(
+            checkpoint_file=checkpoint_path,
+            map_location=map_location,
+            weights_only=weights_only,
+        )
+        return model.load_state_dict(state_dict, strict=strict)
+
+    # 2. If not, checkpoint_path is a directory
+    if filename_pattern is None:
+        filename_pattern = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN
+        index_path = checkpoint_path / (filename_pattern.format(suffix="") + ".index.json")
+        # Only fallback to pickle format if safetensors index is not found and safe is False.
+        if not index_path.is_file() and not safe:
+            filename_pattern = constants.PYTORCH_WEIGHTS_FILE_PATTERN
+
+    index_path = checkpoint_path / (filename_pattern.format(suffix="") + ".index.json")
+
+    if index_path.is_file():
+        return _load_sharded_checkpoint(
+            model=model,
+            save_directory=checkpoint_path,
+            strict=strict,
+            weights_only=weights_only,
+            filename_pattern=filename_pattern,
+        )
+
+    # Look for single model file
+    model_files = list(checkpoint_path.glob("*.safetensors" if safe else "*.bin"))
+    if len(model_files) == 1:
+        state_dict = load_state_dict_from_file(
+            checkpoint_file=model_files[0],
+            map_location=map_location,
+            weights_only=weights_only,
+            mmap=mmap,
+        )
+        return model.load_state_dict(state_dict, strict=strict)
+
+    raise ValueError(
+        f"Directory '{checkpoint_path}' does not contain a valid checkpoint. "
+        "Expected either a sharded checkpoint with an index file, or a single model file."
+    )
+
+
+def _load_sharded_checkpoint(
+    model: "torch.nn.Module",
+    save_directory: os.PathLike,
+    *,
+    strict: bool = False,
+    weights_only: bool = False,
+    filename_pattern: str = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN,
+) -> NamedTuple:
+    """
+    Loads a sharded checkpoint into a model. This is the same as
+    [`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict)
+    but for a sharded checkpoint. Each shard is loaded one by one and removed from memory after being loaded into the model.
+
+    Args:
+        model (`torch.nn.Module`):
+            The model in which to load the checkpoint.
+        save_directory (`str` or `os.PathLike`):
+            A path to a folder containing the sharded checkpoint.
+        strict (`bool`, *optional*, defaults to `False`):
+            Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
+        weights_only (`bool`, *optional*, defaults to `False`):
+            If True, only loads the model weights without optimizer states and other metadata.
+            Only supported in PyTorch >= 1.13.
+        filename_pattern (`str`, *optional*, defaults to `"model{suffix}.safetensors"`):
+            The pattern to look for the index file. Pattern must be a string that
+            can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
+            Defaults to `"model{suffix}.safetensors"`.
+
+    Returns:
+        `NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields,
+            - `missing_keys` is a list of str containing the missing keys
+            - `unexpected_keys` is a list of str containing the unexpected keys
+    """
+
+    # 1. Load and validate index file
+    # The index file contains mapping of parameter names to shard files
+    index_path = filename_pattern.format(suffix="") + ".index.json"
+    index_file = os.path.join(save_directory, index_path)
+    with open(index_file, "r", encoding="utf-8") as f:
+        index = json.load(f)
+
+    # 2. Validate keys if in strict mode
+    # This is done before loading any shards to fail fast
+    if strict:
+        _validate_keys_for_strict_loading(model, index["weight_map"].keys())
+
+    # 3. Load each shard using `load_state_dict`
+    # Get unique shard files (multiple parameters can be in same shard)
+    shard_files = list(set(index["weight_map"].values()))
+    for shard_file in shard_files:
+        # Load shard into memory
+        shard_path = os.path.join(save_directory, shard_file)
+        state_dict = load_state_dict_from_file(
+            shard_path,
+            map_location="cpu",
+            weights_only=weights_only,
+        )
+        # Update model with parameters from this shard
+        model.load_state_dict(state_dict, strict=strict)
+        # Explicitly remove the state dict from memory
+        del state_dict
+
+    # 4. Return compatibility info
+    loaded_keys = set(index["weight_map"].keys())
+    model_keys = set(model.state_dict().keys())
+    return _IncompatibleKeys(
+        missing_keys=list(model_keys - loaded_keys), unexpected_keys=list(loaded_keys - model_keys)
+    )
+
+
+def load_state_dict_from_file(
+    checkpoint_file: Union[str, os.PathLike],
+    map_location: Optional[Union[str, "torch.device"]] = None,
+    weights_only: bool = False,
+    mmap: bool = False,
+) -> Union[Dict[str, "torch.Tensor"], Any]:
+    """
+    Loads a checkpoint file, handling both safetensors and pickle checkpoint formats.
+
+    Args:
+        checkpoint_file (`str` or `os.PathLike`):
+            Path to the checkpoint file to load. Can be either a safetensors or pickle (`.bin`) checkpoint.
+        map_location (`str` or `torch.device`, *optional*):
+            A `torch.device` object, string or a dict specifying how to remap storage locations. It
+            indicates the location where all tensors should be loaded.
+        weights_only (`bool`, *optional*, defaults to `False`):
+            If True, only loads the model weights without optimizer states and other metadata.
+            Only supported for pickle (`.bin`) checkpoints with PyTorch >= 1.13. Has no effect when
+            loading safetensors files.
+        mmap (`bool`, *optional*, defaults to `False`):
+            Whether to use memory-mapped file loading. Memory mapping can improve loading performance
+            for large models in PyTorch >= 2.1.0 with zipfile-based checkpoints. Has no effect when
+            loading safetensors files, as the `safetensors` library uses memory mapping by default.
+
+    Returns:
+        `Union[Dict[str, "torch.Tensor"], Any]`: The loaded checkpoint.
+            - For safetensors files: always returns a dictionary mapping parameter names to tensors.
+            - For pickle files: returns any Python object that was pickled (commonly a state dict, but could be
+              an entire model, optimizer state, or any other Python object).
+
+    Raises:
+        [`FileNotFoundError`](https://docs.python.org/3/library/exceptions.html#FileNotFoundError)
+            If the checkpoint file does not exist.
+        [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
+            If safetensors or torch is not installed when trying to load a .safetensors file or a PyTorch checkpoint respectively.
+        [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError)
+            If the checkpoint file format is invalid or if git-lfs files are not properly downloaded.
+        [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
+            If the checkpoint file path is empty or invalid.
+
+    Example:
+    ```python
+    >>> from huggingface_hub import load_state_dict_from_file
+
+    # Load a PyTorch checkpoint
+    >>> state_dict = load_state_dict_from_file("path/to/model.bin", map_location="cpu")
+    >>> model.load_state_dict(state_dict)
+
+    # Load a safetensors checkpoint
+    >>> state_dict = load_state_dict_from_file("path/to/model.safetensors")
+    >>> model.load_state_dict(state_dict)
+    ```
+    """
+    checkpoint_path = Path(checkpoint_file)
+
+    # Check if file exists and is a regular file (not a directory)
+    if not checkpoint_path.is_file():
+        raise FileNotFoundError(
+            f"No checkpoint file found at '{checkpoint_path}'. Please verify the path is correct and "
+            "the file has been properly downloaded."
+        )
+
+    # Load safetensors checkpoint
+    if checkpoint_path.suffix == ".safetensors":
+        try:
+            from safetensors import safe_open
+            from safetensors.torch import load_file
+        except ImportError as e:
+            raise ImportError(
+                "Please install `safetensors` to load safetensors checkpoint. "
+                "You can install it with `pip install safetensors`."
+            ) from e
+
+        # Check format of the archive
+        with safe_open(checkpoint_file, framework="pt") as f:  # type: ignore[attr-defined]
+            metadata = f.metadata()
+        # see comment: https://github.com/huggingface/transformers/blob/3d213b57fe74302e5902d68ed9478c3ad1aaa713/src/transformers/modeling_utils.py#L3966
+        if metadata is not None and metadata.get("format") not in ["pt", "mlx"]:
+            raise OSError(
+                f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
+                "you save your model with the `save_torch_model` method."
+            )
+        device = str(map_location.type) if map_location is not None and hasattr(map_location, "type") else map_location
+        # meta device is not supported with safetensors, falling back to CPU
+        if device == "meta":
+            logger.warning("Meta device is not supported with safetensors. Falling back to CPU device.")
+            device = "cpu"
+        return load_file(checkpoint_file, device=device)  # type: ignore[arg-type]
+    # Otherwise, load from pickle
+    try:
+        import torch
+        from torch import load
+    except ImportError as e:
+        raise ImportError(
+            "Please install `torch` to load torch tensors. You can install it with `pip install torch`."
+        ) from e
+    # Add additional kwargs, mmap is only supported in torch >= 2.1.0
+    additional_kwargs = {}
+    if version.parse(torch.__version__) >= version.parse("2.1.0"):
+        additional_kwargs["mmap"] = mmap
+
+    # weights_only is only supported in torch >= 1.13.0
+    if version.parse(torch.__version__) >= version.parse("1.13.0"):
+        additional_kwargs["weights_only"] = weights_only
+
+    return load(
+        checkpoint_file,
+        map_location=map_location,
+        **additional_kwargs,
+    )
+
+
+# HELPERS
+
+
+def _validate_keys_for_strict_loading(
+    model: "torch.nn.Module",
+    loaded_keys: Iterable[str],
+) -> None:
+    """
+    Validate that model keys match loaded keys when strict loading is enabled.
+
+    Args:
+        model: The PyTorch model being loaded
+        loaded_keys: The keys present in the checkpoint
+
+    Raises:
+        RuntimeError: If there are missing or unexpected keys in strict mode
+    """
+    loaded_keys_set = set(loaded_keys)
+    model_keys = set(model.state_dict().keys())
+    missing_keys = model_keys - loaded_keys_set  # Keys in model but not in checkpoint
+    unexpected_keys = loaded_keys_set - model_keys  # Keys in checkpoint but not in model
+
+    if missing_keys or unexpected_keys:
+        error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
+        if missing_keys:
+            str_missing_keys = ",".join([f'"{k}"' for k in sorted(missing_keys)])
+            error_message += f"\nMissing key(s): {str_missing_keys}."
+        if unexpected_keys:
+            str_unexpected_keys = ",".join([f'"{k}"' for k in sorted(unexpected_keys)])
+            error_message += f"\nUnexpected key(s): {str_unexpected_keys}."
+        raise RuntimeError(error_message)
+
+
+def _get_unique_id(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]:
+    """Returns a unique id for plain tensor
+    or a (potentially nested) Tuple of unique id for the flattened Tensor
+    if the input is a wrapper tensor subclass Tensor
+    """
+
+    try:
+        # for torch 2.1 and above we can also handle tensor subclasses
+        from torch.utils._python_dispatch import is_traceable_wrapper_subclass
+
+        if is_traceable_wrapper_subclass(tensor):
+            attrs, _ = tensor.__tensor_flatten__()  # type: ignore[attr-defined]
+            return tuple(_get_unique_id(getattr(tensor, attr)) for attr in attrs)
+
+    except ImportError:
+        # for torch version less than 2.1, we can fallback to original implementation
+        pass
+
+    if tensor.device.type == "xla" and is_torch_tpu_available():
+        # NOTE: xla tensors dont have storage
+        # use some other unique id to distinguish.
+        # this is a XLA tensor, it must be created using torch_xla's
+        # device. So the following import is safe:
+        import torch_xla  # type: ignore[import]
+
+        unique_id = torch_xla._XLAC._xla_get_tensor_id(tensor)
+    else:
+        unique_id = storage_ptr(tensor)
+
+    return unique_id
+
+
+def get_torch_storage_id(tensor: "torch.Tensor") -> Optional[Tuple["torch.device", Union[int, Tuple[Any, ...]], int]]:
+    """
+    Return unique identifier to a tensor storage.
+
+    Multiple different tensors can share the same underlying storage. This identifier is
+    guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
+    non-overlapping lifetimes may have the same id.
+    In the case of meta tensors, we return None since we can't tell if they share the same storage.
+
+    Taken from https://github.com/huggingface/transformers/blob/1ecf5f7c982d761b4daaa96719d162c324187c64/src/transformers/pytorch_utils.py#L278.
+    """
+    if tensor.device.type == "meta":
+        return None
+    else:
+        return tensor.device, _get_unique_id(tensor), get_torch_storage_size(tensor)
+
+
+def get_torch_storage_size(tensor: "torch.Tensor") -> int:
+    """
+    Taken from https://github.com/huggingface/safetensors/blob/08db34094e9e59e2f9218f2df133b7b4aaff5a99/bindings/python/py_src/safetensors/torch.py#L31C1-L41C59
+    """
+    try:
+        # for torch 2.1 and above we can also handle tensor subclasses
+        from torch.utils._python_dispatch import is_traceable_wrapper_subclass
+
+        if is_traceable_wrapper_subclass(tensor):
+            attrs, _ = tensor.__tensor_flatten__()  # type: ignore[attr-defined]
+            return sum(get_torch_storage_size(getattr(tensor, attr)) for attr in attrs)
+    except ImportError:
+        # for torch version less than 2.1, we can fallback to original implementation
+        pass
+
+    try:
+        return tensor.untyped_storage().nbytes()
+    except AttributeError:
+        # Fallback for torch==1.10
+        try:
+            return tensor.storage().size() * _get_dtype_size(tensor.dtype)
+        except NotImplementedError:
+            # Fallback for meta storage
+            # On torch >=2.0 this is the tensor size
+            return tensor.nelement() * _get_dtype_size(tensor.dtype)
+
+
+@lru_cache()
+def is_torch_tpu_available(check_device=True):
+    """
+    Checks if `torch_xla` is installed and potentially if a TPU is in the environment
+
+    Taken from https://github.com/huggingface/transformers/blob/1ecf5f7c982d761b4daaa96719d162c324187c64/src/transformers/utils/import_utils.py#L463.
+    """
+    if importlib.util.find_spec("torch_xla") is not None:
+        if check_device:
+            # We need to check if `xla_device` can be found, will raise a RuntimeError if not
+            try:
+                import torch_xla.core.xla_model as xm  # type: ignore[import]
+
+                _ = xm.xla_device()
+                return True
+            except RuntimeError:
+                return False
+        return True
+    return False
+
+
+def storage_ptr(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]:
+    """
+    Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L11.
+    """
+    try:
+        # for torch 2.1 and above we can also handle tensor subclasses
+        from torch.utils._python_dispatch import is_traceable_wrapper_subclass
+
+        if is_traceable_wrapper_subclass(tensor):
+            return _get_unique_id(tensor)  # type: ignore
+    except ImportError:
+        # for torch version less than 2.1, we can fallback to original implementation
+        pass
+
+    try:
+        return tensor.untyped_storage().data_ptr()
+    except Exception:
+        # Fallback for torch==1.10
+        try:
+            return tensor.storage().data_ptr()
+        except NotImplementedError:
+            # Fallback for meta storage
+            return 0
+
+
+def _clean_state_dict_for_safetensors(
+    state_dict: Dict[str, "torch.Tensor"],
+    metadata: Dict[str, str],
+    force_contiguous: bool = True,
+    shared_tensors_to_discard: Optional[List[str]] = None,
+):
+    """Remove shared tensors from state_dict and update metadata accordingly (for reloading).
+
+    Warning: `state_dict` and `metadata` are mutated in-place!
+
+    Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L155.
+    """
+    to_removes = _remove_duplicate_names(state_dict, discard_names=shared_tensors_to_discard)
+    for kept_name, to_remove_group in to_removes.items():
+        for to_remove in to_remove_group:
+            if metadata is None:
+                metadata = {}
+
+            if to_remove not in metadata:
+                # Do not override user data
+                metadata[to_remove] = kept_name
+            del state_dict[to_remove]
+    if force_contiguous:
+        state_dict = {k: v.contiguous() for k, v in state_dict.items()}
+    return state_dict
+
+
+def _end_ptr(tensor: "torch.Tensor") -> int:
+    """
+    Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L23.
+    """
+    if tensor.nelement():
+        stop = tensor.view(-1)[-1].data_ptr() + _get_dtype_size(tensor.dtype)
+    else:
+        stop = tensor.data_ptr()
+    return stop
+
+
+def _filter_shared_not_shared(tensors: List[Set[str]], state_dict: Dict[str, "torch.Tensor"]) -> List[Set[str]]:
+    """
+    Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L44
+    """
+    filtered_tensors = []
+    for shared in tensors:
+        if len(shared) < 2:
+            filtered_tensors.append(shared)
+            continue
+
+        areas = []
+        for name in shared:
+            tensor = state_dict[name]
+            areas.append((tensor.data_ptr(), _end_ptr(tensor), name))
+        areas.sort()
+
+        _, last_stop, last_name = areas[0]
+        filtered_tensors.append({last_name})
+        for start, stop, name in areas[1:]:
+            if start >= last_stop:
+                filtered_tensors.append({name})
+            else:
+                filtered_tensors[-1].add(name)
+            last_stop = stop
+
+    return filtered_tensors
+
+
+def _find_shared_tensors(state_dict: Dict[str, "torch.Tensor"]) -> List[Set[str]]:
+    """
+    Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L69.
+    """
+    import torch
+
+    tensors_dict = defaultdict(set)
+    for k, v in state_dict.items():
+        if v.device != torch.device("meta") and storage_ptr(v) != 0 and get_torch_storage_size(v) != 0:
+            # Need to add device as key because of multiple GPU.
+            tensors_dict[(v.device, storage_ptr(v), get_torch_storage_size(v))].add(k)
+    tensors = list(sorted(tensors_dict.values()))
+    tensors = _filter_shared_not_shared(tensors, state_dict)
+    return tensors
+
+
+def _is_complete(tensor: "torch.Tensor") -> bool:
+    """
+    Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L80
+    """
+    try:
+        # for torch 2.1 and above we can also handle tensor subclasses
+        from torch.utils._python_dispatch import is_traceable_wrapper_subclass
+
+        if is_traceable_wrapper_subclass(tensor):
+            attrs, _ = tensor.__tensor_flatten__()  # type: ignore[attr-defined]
+            return all(_is_complete(getattr(tensor, attr)) for attr in attrs)
+    except ImportError:
+        # for torch version less than 2.1, we can fallback to original implementation
+        pass
+
+    return tensor.data_ptr() == storage_ptr(tensor) and tensor.nelement() * _get_dtype_size(
+        tensor.dtype
+    ) == get_torch_storage_size(tensor)
+
+
+def _remove_duplicate_names(
+    state_dict: Dict[str, "torch.Tensor"],
+    *,
+    preferred_names: Optional[List[str]] = None,
+    discard_names: Optional[List[str]] = None,
+) -> Dict[str, List[str]]:
+    """
+    Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L80
+    """
+    if preferred_names is None:
+        preferred_names = []
+    unique_preferred_names = set(preferred_names)
+    if discard_names is None:
+        discard_names = []
+    unique_discard_names = set(discard_names)
+
+    shareds = _find_shared_tensors(state_dict)
+    to_remove = defaultdict(list)
+    for shared in shareds:
+        complete_names = set([name for name in shared if _is_complete(state_dict[name])])
+        if not complete_names:
+            raise RuntimeError(
+                "Error while trying to find names to remove to save state dict, but found no suitable name to keep"
+                f" for saving amongst: {shared}. None is covering the entire storage. Refusing to save/load the model"
+                " since you could be storing much more memory than needed. Please refer to"
+                " https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an"
+                " issue."
+            )
+
+        keep_name = sorted(list(complete_names))[0]
+
+        # Mechanism to preferentially select keys to keep
+        # coming from the on-disk file to allow
+        # loading models saved with a different choice
+        # of keep_name
+        preferred = complete_names.difference(unique_discard_names)
+        if preferred:
+            keep_name = sorted(list(preferred))[0]
+
+        if unique_preferred_names:
+            preferred = unique_preferred_names.intersection(complete_names)
+            if preferred:
+                keep_name = sorted(list(preferred))[0]
+        for name in sorted(shared):
+            if name != keep_name:
+                to_remove[keep_name].append(name)
+    return to_remove
+
+
+@lru_cache()
+def _get_dtype_size(dtype: "torch.dtype") -> int:
+    """
+    Taken from https://github.com/huggingface/safetensors/blob/08db34094e9e59e2f9218f2df133b7b4aaff5a99/bindings/python/py_src/safetensors/torch.py#L344
+    """
+    import torch
+
+    # torch.float8 formats require 2.1; we do not support these dtypes on earlier versions
+    _float8_e4m3fn = getattr(torch, "float8_e4m3fn", None)
+    _float8_e5m2 = getattr(torch, "float8_e5m2", None)
+    _SIZE = {
+        torch.int64: 8,
+        torch.float32: 4,
+        torch.int32: 4,
+        torch.bfloat16: 2,
+        torch.float16: 2,
+        torch.int16: 2,
+        torch.uint8: 1,
+        torch.int8: 1,
+        torch.bool: 1,
+        torch.float64: 8,
+        _float8_e4m3fn: 1,
+        _float8_e5m2: 1,
+    }
+    return _SIZE[dtype]
+
+
+class _IncompatibleKeys(namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"])):
+    """
+    This is used to report missing and unexpected keys in the state dict.
+    Taken from https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py#L52.
+
+    """
+
+    def __repr__(self) -> str:
+        if not self.missing_keys and not self.unexpected_keys:
+            return "<All keys matched successfully>"
+        return super().__repr__()
+
+    __str__ = __repr__