diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/huggingface_hub/utils/_safetensors.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/huggingface_hub/utils/_safetensors.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/huggingface_hub/utils/_safetensors.py | 111 |
1 files changed, 111 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/huggingface_hub/utils/_safetensors.py b/.venv/lib/python3.12/site-packages/huggingface_hub/utils/_safetensors.py new file mode 100644 index 00000000..38546c6d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/huggingface_hub/utils/_safetensors.py @@ -0,0 +1,111 @@ +import functools +import operator +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Dict, List, Literal, Optional, Tuple + + +FILENAME_T = str +TENSOR_NAME_T = str +DTYPE_T = Literal["F64", "F32", "F16", "BF16", "I64", "I32", "I16", "I8", "U8", "BOOL"] + + +@dataclass +class TensorInfo: + """Information about a tensor. + + For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format. + + Attributes: + dtype (`str`): + The data type of the tensor ("F64", "F32", "F16", "BF16", "I64", "I32", "I16", "I8", "U8", "BOOL"). + shape (`List[int]`): + The shape of the tensor. + data_offsets (`Tuple[int, int]`): + The offsets of the data in the file as a tuple `[BEGIN, END]`. + parameter_count (`int`): + The number of parameters in the tensor. + """ + + dtype: DTYPE_T + shape: List[int] + data_offsets: Tuple[int, int] + parameter_count: int = field(init=False) + + def __post_init__(self) -> None: + # Taken from https://stackoverflow.com/a/13840436 + try: + self.parameter_count = functools.reduce(operator.mul, self.shape) + except TypeError: + self.parameter_count = 1 # scalar value has no shape + + +@dataclass +class SafetensorsFileMetadata: + """Metadata for a Safetensors file hosted on the Hub. + + This class is returned by [`parse_safetensors_file_metadata`]. + + For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format. + + Attributes: + metadata (`Dict`): + The metadata contained in the file. + tensors (`Dict[str, TensorInfo]`): + A map of all tensors. Keys are tensor names and values are information about the corresponding tensor, as a + [`TensorInfo`] object. + parameter_count (`Dict[str, int]`): + A map of the number of parameters per data type. Keys are data types and values are the number of parameters + of that data type. + """ + + metadata: Dict[str, str] + tensors: Dict[TENSOR_NAME_T, TensorInfo] + parameter_count: Dict[DTYPE_T, int] = field(init=False) + + def __post_init__(self) -> None: + parameter_count: Dict[DTYPE_T, int] = defaultdict(int) + for tensor in self.tensors.values(): + parameter_count[tensor.dtype] += tensor.parameter_count + self.parameter_count = dict(parameter_count) + + +@dataclass +class SafetensorsRepoMetadata: + """Metadata for a Safetensors repo. + + A repo is considered to be a Safetensors repo if it contains either a 'model.safetensors' weight file (non-shared + model) or a 'model.safetensors.index.json' index file (sharded model) at its root. + + This class is returned by [`get_safetensors_metadata`]. + + For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format. + + Attributes: + metadata (`Dict`, *optional*): + The metadata contained in the 'model.safetensors.index.json' file, if it exists. Only populated for sharded + models. + sharded (`bool`): + Whether the repo contains a sharded model or not. + weight_map (`Dict[str, str]`): + A map of all weights. Keys are tensor names and values are filenames of the files containing the tensors. + files_metadata (`Dict[str, SafetensorsFileMetadata]`): + A map of all files metadata. Keys are filenames and values are the metadata of the corresponding file, as + a [`SafetensorsFileMetadata`] object. + parameter_count (`Dict[str, int]`): + A map of the number of parameters per data type. Keys are data types and values are the number of parameters + of that data type. + """ + + metadata: Optional[Dict] + sharded: bool + weight_map: Dict[TENSOR_NAME_T, FILENAME_T] # tensor name -> filename + files_metadata: Dict[FILENAME_T, SafetensorsFileMetadata] # filename -> metadata + parameter_count: Dict[DTYPE_T, int] = field(init=False) + + def __post_init__(self) -> None: + parameter_count: Dict[DTYPE_T, int] = defaultdict(int) + for file_metadata in self.files_metadata.values(): + for dtype, nb_parameters_ in file_metadata.parameter_count.items(): + parameter_count[dtype] += nb_parameters_ + self.parameter_count = dict(parameter_count) |