about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/huggingface_hub/hub_mixin.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/huggingface_hub/hub_mixin.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/huggingface_hub/hub_mixin.py')
-rw-r--r--.venv/lib/python3.12/site-packages/huggingface_hub/hub_mixin.py836
1 files changed, 836 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/huggingface_hub/hub_mixin.py b/.venv/lib/python3.12/site-packages/huggingface_hub/hub_mixin.py
new file mode 100644
index 00000000..8bf32f12
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/huggingface_hub/hub_mixin.py
@@ -0,0 +1,836 @@
+import inspect
+import json
+import os
+from dataclasses import Field, asdict, dataclass, is_dataclass
+from pathlib import Path
+from typing import Any, Callable, ClassVar, Dict, List, Optional, Protocol, Tuple, Type, TypeVar, Union
+
+import packaging.version
+
+from . import constants
+from .errors import EntryNotFoundError, HfHubHTTPError
+from .file_download import hf_hub_download
+from .hf_api import HfApi
+from .repocard import ModelCard, ModelCardData
+from .utils import (
+    SoftTemporaryDirectory,
+    is_jsonable,
+    is_safetensors_available,
+    is_simple_optional_type,
+    is_torch_available,
+    logging,
+    unwrap_simple_optional_type,
+    validate_hf_hub_args,
+)
+
+
+if is_torch_available():
+    import torch  # type: ignore
+
+if is_safetensors_available():
+    import safetensors
+    from safetensors.torch import load_model as load_model_as_safetensor
+    from safetensors.torch import save_model as save_model_as_safetensor
+
+
+logger = logging.get_logger(__name__)
+
+
+# Type alias for dataclass instances, copied from https://github.com/python/typeshed/blob/9f28171658b9ca6c32a7cb93fbb99fc92b17858b/stdlib/_typeshed/__init__.pyi#L349
+class DataclassInstance(Protocol):
+    __dataclass_fields__: ClassVar[Dict[str, Field]]
+
+
+# Generic variable that is either ModelHubMixin or a subclass thereof
+T = TypeVar("T", bound="ModelHubMixin")
+# Generic variable to represent an args type
+ARGS_T = TypeVar("ARGS_T")
+ENCODER_T = Callable[[ARGS_T], Any]
+DECODER_T = Callable[[Any], ARGS_T]
+CODER_T = Tuple[ENCODER_T, DECODER_T]
+
+
+DEFAULT_MODEL_CARD = """
+---
+# For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
+# Doc / guide: https://huggingface.co/docs/hub/model-cards
+{{ card_data }}
+---
+
+This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
+- Library: {{ repo_url | default("[More Information Needed]", true) }}
+- Docs: {{ docs_url | default("[More Information Needed]", true) }}
+"""
+
+
+@dataclass
+class MixinInfo:
+    model_card_template: str
+    model_card_data: ModelCardData
+    repo_url: Optional[str] = None
+    docs_url: Optional[str] = None
+
+
+class ModelHubMixin:
+    """
+    A generic mixin to integrate ANY machine learning framework with the Hub.
+
+    To integrate your framework, your model class must inherit from this class. Custom logic for saving/loading models
+    have to be overwritten in  [`_from_pretrained`] and [`_save_pretrained`]. [`PyTorchModelHubMixin`] is a good example
+    of mixin integration with the Hub. Check out our [integration guide](../guides/integrations) for more instructions.
+
+    When inheriting from [`ModelHubMixin`], you can define class-level attributes. These attributes are not passed to
+    `__init__` but to the class definition itself. This is useful to define metadata about the library integrating
+    [`ModelHubMixin`].
+
+    For more details on how to integrate the mixin with your library, checkout the [integration guide](../guides/integrations).
+
+    Args:
+        repo_url (`str`, *optional*):
+            URL of the library repository. Used to generate model card.
+        docs_url (`str`, *optional*):
+            URL of the library documentation. Used to generate model card.
+        model_card_template (`str`, *optional*):
+            Template of the model card. Used to generate model card. Defaults to a generic template.
+        language (`str` or `List[str]`, *optional*):
+            Language supported by the library. Used to generate model card.
+        library_name (`str`, *optional*):
+            Name of the library integrating ModelHubMixin. Used to generate model card.
+        license (`str`, *optional*):
+            License of the library integrating ModelHubMixin. Used to generate model card.
+            E.g: "apache-2.0"
+        license_name (`str`, *optional*):
+            Name of the library integrating ModelHubMixin. Used to generate model card.
+            Only used if `license` is set to `other`.
+            E.g: "coqui-public-model-license".
+        license_link (`str`, *optional*):
+            URL to the license of the library integrating ModelHubMixin. Used to generate model card.
+            Only used if `license` is set to `other` and `license_name` is set.
+            E.g: "https://coqui.ai/cpml".
+        pipeline_tag (`str`, *optional*):
+            Tag of the pipeline. Used to generate model card. E.g. "text-classification".
+        tags (`List[str]`, *optional*):
+            Tags to be added to the model card. Used to generate model card. E.g. ["x-custom-tag", "arxiv:2304.12244"]
+        coders (`Dict[Type, Tuple[Callable, Callable]]`, *optional*):
+            Dictionary of custom types and their encoders/decoders. Used to encode/decode arguments that are not
+            jsonable by default. E.g dataclasses, argparse.Namespace, OmegaConf, etc.
+
+    Example:
+
+    ```python
+    >>> from huggingface_hub import ModelHubMixin
+
+    # Inherit from ModelHubMixin
+    >>> class MyCustomModel(
+    ...         ModelHubMixin,
+    ...         library_name="my-library",
+    ...         tags=["x-custom-tag", "arxiv:2304.12244"],
+    ...         repo_url="https://github.com/huggingface/my-cool-library",
+    ...         docs_url="https://huggingface.co/docs/my-cool-library",
+    ...         # ^ optional metadata to generate model card
+    ...     ):
+    ...     def __init__(self, size: int = 512, device: str = "cpu"):
+    ...         # define how to initialize your model
+    ...         super().__init__()
+    ...         ...
+    ...
+    ...     def _save_pretrained(self, save_directory: Path) -> None:
+    ...         # define how to serialize your model
+    ...         ...
+    ...
+    ...     @classmethod
+    ...     def from_pretrained(
+    ...         cls: Type[T],
+    ...         pretrained_model_name_or_path: Union[str, Path],
+    ...         *,
+    ...         force_download: bool = False,
+    ...         resume_download: Optional[bool] = None,
+    ...         proxies: Optional[Dict] = None,
+    ...         token: Optional[Union[str, bool]] = None,
+    ...         cache_dir: Optional[Union[str, Path]] = None,
+    ...         local_files_only: bool = False,
+    ...         revision: Optional[str] = None,
+    ...         **model_kwargs,
+    ...     ) -> T:
+    ...         # define how to deserialize your model
+    ...         ...
+
+    >>> model = MyCustomModel(size=256, device="gpu")
+
+    # Save model weights to local directory
+    >>> model.save_pretrained("my-awesome-model")
+
+    # Push model weights to the Hub
+    >>> model.push_to_hub("my-awesome-model")
+
+    # Download and initialize weights from the Hub
+    >>> reloaded_model = MyCustomModel.from_pretrained("username/my-awesome-model")
+    >>> reloaded_model.size
+    256
+
+    # Model card has been correctly populated
+    >>> from huggingface_hub import ModelCard
+    >>> card = ModelCard.load("username/my-awesome-model")
+    >>> card.data.tags
+    ["x-custom-tag", "pytorch_model_hub_mixin", "model_hub_mixin"]
+    >>> card.data.library_name
+    "my-library"
+    ```
+    """
+
+    _hub_mixin_config: Optional[Union[dict, DataclassInstance]] = None
+    # ^ optional config attribute automatically set in `from_pretrained`
+    _hub_mixin_info: MixinInfo
+    # ^ information about the library integrating ModelHubMixin (used to generate model card)
+    _hub_mixin_inject_config: bool  # whether `_from_pretrained` expects `config` or not
+    _hub_mixin_init_parameters: Dict[str, inspect.Parameter]  # __init__ parameters
+    _hub_mixin_jsonable_default_values: Dict[str, Any]  # default values for __init__ parameters
+    _hub_mixin_jsonable_custom_types: Tuple[Type, ...]  # custom types that can be encoded/decoded
+    _hub_mixin_coders: Dict[Type, CODER_T]  # encoders/decoders for custom types
+    # ^ internal values to handle config
+
+    def __init_subclass__(
+        cls,
+        *,
+        # Generic info for model card
+        repo_url: Optional[str] = None,
+        docs_url: Optional[str] = None,
+        # Model card template
+        model_card_template: str = DEFAULT_MODEL_CARD,
+        # Model card metadata
+        language: Optional[List[str]] = None,
+        library_name: Optional[str] = None,
+        license: Optional[str] = None,
+        license_name: Optional[str] = None,
+        license_link: Optional[str] = None,
+        pipeline_tag: Optional[str] = None,
+        tags: Optional[List[str]] = None,
+        # How to encode/decode arguments with custom type into a JSON config?
+        coders: Optional[
+            Dict[Type, CODER_T]
+            # Key is a type.
+            # Value is a tuple (encoder, decoder).
+            # Example: {MyCustomType: (lambda x: x.value, lambda data: MyCustomType(data))}
+        ] = None,
+    ) -> None:
+        """Inspect __init__ signature only once when subclassing + handle modelcard."""
+        super().__init_subclass__()
+
+        # Will be reused when creating modelcard
+        tags = tags or []
+        tags.append("model_hub_mixin")
+
+        # Initialize MixinInfo if not existent
+        info = MixinInfo(model_card_template=model_card_template, model_card_data=ModelCardData())
+
+        # If parent class has a MixinInfo, inherit from it as a copy
+        if hasattr(cls, "_hub_mixin_info"):
+            # Inherit model card template from parent class if not explicitly set
+            if model_card_template == DEFAULT_MODEL_CARD:
+                info.model_card_template = cls._hub_mixin_info.model_card_template
+
+            # Inherit from parent model card data
+            info.model_card_data = ModelCardData(**cls._hub_mixin_info.model_card_data.to_dict())
+
+            # Inherit other info
+            info.docs_url = cls._hub_mixin_info.docs_url
+            info.repo_url = cls._hub_mixin_info.repo_url
+        cls._hub_mixin_info = info
+
+        # Update MixinInfo with metadata
+        if model_card_template is not None and model_card_template != DEFAULT_MODEL_CARD:
+            info.model_card_template = model_card_template
+        if repo_url is not None:
+            info.repo_url = repo_url
+        if docs_url is not None:
+            info.docs_url = docs_url
+        if language is not None:
+            info.model_card_data.language = language
+        if library_name is not None:
+            info.model_card_data.library_name = library_name
+        if license is not None:
+            info.model_card_data.license = license
+        if license_name is not None:
+            info.model_card_data.license_name = license_name
+        if license_link is not None:
+            info.model_card_data.license_link = license_link
+        if pipeline_tag is not None:
+            info.model_card_data.pipeline_tag = pipeline_tag
+        if tags is not None:
+            if info.model_card_data.tags is not None:
+                info.model_card_data.tags.extend(tags)
+            else:
+                info.model_card_data.tags = tags
+
+        info.model_card_data.tags = sorted(set(info.model_card_data.tags))
+
+        # Handle encoders/decoders for args
+        cls._hub_mixin_coders = coders or {}
+        cls._hub_mixin_jsonable_custom_types = tuple(cls._hub_mixin_coders.keys())
+
+        # Inspect __init__ signature to handle config
+        cls._hub_mixin_init_parameters = dict(inspect.signature(cls.__init__).parameters)
+        cls._hub_mixin_jsonable_default_values = {
+            param.name: cls._encode_arg(param.default)
+            for param in cls._hub_mixin_init_parameters.values()
+            if param.default is not inspect.Parameter.empty and cls._is_jsonable(param.default)
+        }
+        cls._hub_mixin_inject_config = "config" in inspect.signature(cls._from_pretrained).parameters
+
+    def __new__(cls: Type[T], *args, **kwargs) -> T:
+        """Create a new instance of the class and handle config.
+
+        3 cases:
+        - If `self._hub_mixin_config` is already set, do nothing.
+        - If `config` is passed as a dataclass, set it as `self._hub_mixin_config`.
+        - Otherwise, build `self._hub_mixin_config` from default values and passed values.
+        """
+        instance = super().__new__(cls)
+
+        # If `config` is already set, return early
+        if instance._hub_mixin_config is not None:
+            return instance
+
+        # Infer passed values
+        passed_values = {
+            **{
+                key: value
+                for key, value in zip(
+                    # [1:] to skip `self` parameter
+                    list(cls._hub_mixin_init_parameters)[1:],
+                    args,
+                )
+            },
+            **kwargs,
+        }
+
+        # If config passed as dataclass => set it and return early
+        if is_dataclass(passed_values.get("config")):
+            instance._hub_mixin_config = passed_values["config"]
+            return instance
+
+        # Otherwise, build config from default + passed values
+        init_config = {
+            # default values
+            **cls._hub_mixin_jsonable_default_values,
+            # passed values
+            **{
+                key: cls._encode_arg(value)  # Encode custom types as jsonable value
+                for key, value in passed_values.items()
+                if instance._is_jsonable(value)  # Only if jsonable or we have a custom encoder
+            },
+        }
+        passed_config = init_config.pop("config", {})
+
+        # Populate `init_config` with provided config
+        if isinstance(passed_config, dict):
+            init_config.update(passed_config)
+
+        # Set `config` attribute and return
+        if init_config != {}:
+            instance._hub_mixin_config = init_config
+        return instance
+
+    @classmethod
+    def _is_jsonable(cls, value: Any) -> bool:
+        """Check if a value is JSON serializable."""
+        if isinstance(value, cls._hub_mixin_jsonable_custom_types):
+            return True
+        return is_jsonable(value)
+
+    @classmethod
+    def _encode_arg(cls, arg: Any) -> Any:
+        """Encode an argument into a JSON serializable format."""
+        for type_, (encoder, _) in cls._hub_mixin_coders.items():
+            if isinstance(arg, type_):
+                if arg is None:
+                    return None
+                return encoder(arg)
+        return arg
+
+    @classmethod
+    def _decode_arg(cls, expected_type: Type[ARGS_T], value: Any) -> Optional[ARGS_T]:
+        """Decode a JSON serializable value into an argument."""
+        if is_simple_optional_type(expected_type):
+            if value is None:
+                return None
+            expected_type = unwrap_simple_optional_type(expected_type)
+        # Dataclass => handle it
+        if is_dataclass(expected_type):
+            return _load_dataclass(expected_type, value)  # type: ignore[return-value]
+        # Otherwise => check custom decoders
+        for type_, (_, decoder) in cls._hub_mixin_coders.items():
+            if inspect.isclass(expected_type) and issubclass(expected_type, type_):
+                return decoder(value)
+        # Otherwise => don't decode
+        return value
+
+    def save_pretrained(
+        self,
+        save_directory: Union[str, Path],
+        *,
+        config: Optional[Union[dict, DataclassInstance]] = None,
+        repo_id: Optional[str] = None,
+        push_to_hub: bool = False,
+        model_card_kwargs: Optional[Dict[str, Any]] = None,
+        **push_to_hub_kwargs,
+    ) -> Optional[str]:
+        """
+        Save weights in local directory.
+
+        Args:
+            save_directory (`str` or `Path`):
+                Path to directory in which the model weights and configuration will be saved.
+            config (`dict` or `DataclassInstance`, *optional*):
+                Model configuration specified as a key/value dictionary or a dataclass instance.
+            push_to_hub (`bool`, *optional*, defaults to `False`):
+                Whether or not to push your model to the Huggingface Hub after saving it.
+            repo_id (`str`, *optional*):
+                ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
+                not provided.
+            model_card_kwargs (`Dict[str, Any]`, *optional*):
+                Additional arguments passed to the model card template to customize the model card.
+            push_to_hub_kwargs:
+                Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method.
+        Returns:
+            `str` or `None`: url of the commit on the Hub if `push_to_hub=True`, `None` otherwise.
+        """
+        save_directory = Path(save_directory)
+        save_directory.mkdir(parents=True, exist_ok=True)
+
+        # Remove config.json if already exists. After `_save_pretrained` we don't want to overwrite config.json
+        # as it might have been saved by the custom `_save_pretrained` already. However we do want to overwrite
+        # an existing config.json if it was not saved by `_save_pretrained`.
+        config_path = save_directory / constants.CONFIG_NAME
+        config_path.unlink(missing_ok=True)
+
+        # save model weights/files (framework-specific)
+        self._save_pretrained(save_directory)
+
+        # save config (if provided and if not serialized yet in `_save_pretrained`)
+        if config is None:
+            config = self._hub_mixin_config
+        if config is not None:
+            if is_dataclass(config):
+                config = asdict(config)  # type: ignore[arg-type]
+            if not config_path.exists():
+                config_str = json.dumps(config, sort_keys=True, indent=2)
+                config_path.write_text(config_str)
+
+        # save model card
+        model_card_path = save_directory / "README.md"
+        model_card_kwargs = model_card_kwargs if model_card_kwargs is not None else {}
+        if not model_card_path.exists():  # do not overwrite if already exists
+            self.generate_model_card(**model_card_kwargs).save(save_directory / "README.md")
+
+        # push to the Hub if required
+        if push_to_hub:
+            kwargs = push_to_hub_kwargs.copy()  # soft-copy to avoid mutating input
+            if config is not None:  # kwarg for `push_to_hub`
+                kwargs["config"] = config
+            if repo_id is None:
+                repo_id = save_directory.name  # Defaults to `save_directory` name
+            return self.push_to_hub(repo_id=repo_id, model_card_kwargs=model_card_kwargs, **kwargs)
+        return None
+
+    def _save_pretrained(self, save_directory: Path) -> None:
+        """
+        Overwrite this method in subclass to define how to save your model.
+        Check out our [integration guide](../guides/integrations) for instructions.
+
+        Args:
+            save_directory (`str` or `Path`):
+                Path to directory in which the model weights and configuration will be saved.
+        """
+        raise NotImplementedError
+
+    @classmethod
+    @validate_hf_hub_args
+    def from_pretrained(
+        cls: Type[T],
+        pretrained_model_name_or_path: Union[str, Path],
+        *,
+        force_download: bool = False,
+        resume_download: Optional[bool] = None,
+        proxies: Optional[Dict] = None,
+        token: Optional[Union[str, bool]] = None,
+        cache_dir: Optional[Union[str, Path]] = None,
+        local_files_only: bool = False,
+        revision: Optional[str] = None,
+        **model_kwargs,
+    ) -> T:
+        """
+        Download a model from the Huggingface Hub and instantiate it.
+
+        Args:
+            pretrained_model_name_or_path (`str`, `Path`):
+                - Either the `model_id` (string) of a model hosted on the Hub, e.g. `bigscience/bloom`.
+                - Or a path to a `directory` containing model weights saved using
+                    [`~transformers.PreTrainedModel.save_pretrained`], e.g., `../path/to/my_model_directory/`.
+            revision (`str`, *optional*):
+                Revision of the model on the Hub. Can be a branch name, a git tag or any commit id.
+                Defaults to the latest commit on `main` branch.
+            force_download (`bool`, *optional*, defaults to `False`):
+                Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
+                the existing cache.
+            proxies (`Dict[str, str]`, *optional*):
+                A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+                'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request.
+            token (`str` or `bool`, *optional*):
+                The token to use as HTTP bearer authorization for remote files. By default, it will use the token
+                cached when running `huggingface-cli login`.
+            cache_dir (`str`, `Path`, *optional*):
+                Path to the folder where cached files are stored.
+            local_files_only (`bool`, *optional*, defaults to `False`):
+                If `True`, avoid downloading the file and return the path to the local cached file if it exists.
+            model_kwargs (`Dict`, *optional*):
+                Additional kwargs to pass to the model during initialization.
+        """
+        model_id = str(pretrained_model_name_or_path)
+        config_file: Optional[str] = None
+        if os.path.isdir(model_id):
+            if constants.CONFIG_NAME in os.listdir(model_id):
+                config_file = os.path.join(model_id, constants.CONFIG_NAME)
+            else:
+                logger.warning(f"{constants.CONFIG_NAME} not found in {Path(model_id).resolve()}")
+        else:
+            try:
+                config_file = hf_hub_download(
+                    repo_id=model_id,
+                    filename=constants.CONFIG_NAME,
+                    revision=revision,
+                    cache_dir=cache_dir,
+                    force_download=force_download,
+                    proxies=proxies,
+                    resume_download=resume_download,
+                    token=token,
+                    local_files_only=local_files_only,
+                )
+            except HfHubHTTPError as e:
+                logger.info(f"{constants.CONFIG_NAME} not found on the HuggingFace Hub: {str(e)}")
+
+        # Read config
+        config = None
+        if config_file is not None:
+            with open(config_file, "r", encoding="utf-8") as f:
+                config = json.load(f)
+
+            # Decode custom types in config
+            for key, value in config.items():
+                if key in cls._hub_mixin_init_parameters:
+                    expected_type = cls._hub_mixin_init_parameters[key].annotation
+                    if expected_type is not inspect.Parameter.empty:
+                        config[key] = cls._decode_arg(expected_type, value)
+
+            # Populate model_kwargs from config
+            for param in cls._hub_mixin_init_parameters.values():
+                if param.name not in model_kwargs and param.name in config:
+                    model_kwargs[param.name] = config[param.name]
+
+            # Check if `config` argument was passed at init
+            if "config" in cls._hub_mixin_init_parameters and "config" not in model_kwargs:
+                # Decode `config` argument if it was passed
+                config_annotation = cls._hub_mixin_init_parameters["config"].annotation
+                config = cls._decode_arg(config_annotation, config)
+
+                # Forward config to model initialization
+                model_kwargs["config"] = config
+
+            # Inject config if `**kwargs` are expected
+            if is_dataclass(cls):
+                for key in cls.__dataclass_fields__:
+                    if key not in model_kwargs and key in config:
+                        model_kwargs[key] = config[key]
+            elif any(param.kind == inspect.Parameter.VAR_KEYWORD for param in cls._hub_mixin_init_parameters.values()):
+                for key, value in config.items():
+                    if key not in model_kwargs:
+                        model_kwargs[key] = value
+
+            # Finally, also inject if `_from_pretrained` expects it
+            if cls._hub_mixin_inject_config and "config" not in model_kwargs:
+                model_kwargs["config"] = config
+
+        instance = cls._from_pretrained(
+            model_id=str(model_id),
+            revision=revision,
+            cache_dir=cache_dir,
+            force_download=force_download,
+            proxies=proxies,
+            resume_download=resume_download,
+            local_files_only=local_files_only,
+            token=token,
+            **model_kwargs,
+        )
+
+        # Implicitly set the config as instance attribute if not already set by the class
+        # This way `config` will be available when calling `save_pretrained` or `push_to_hub`.
+        if config is not None and (getattr(instance, "_hub_mixin_config", None) in (None, {})):
+            instance._hub_mixin_config = config
+
+        return instance
+
+    @classmethod
+    def _from_pretrained(
+        cls: Type[T],
+        *,
+        model_id: str,
+        revision: Optional[str],
+        cache_dir: Optional[Union[str, Path]],
+        force_download: bool,
+        proxies: Optional[Dict],
+        resume_download: Optional[bool],
+        local_files_only: bool,
+        token: Optional[Union[str, bool]],
+        **model_kwargs,
+    ) -> T:
+        """Overwrite this method in subclass to define how to load your model from pretrained.
+
+        Use [`hf_hub_download`] or [`snapshot_download`] to download files from the Hub before loading them. Most
+        args taken as input can be directly passed to those 2 methods. If needed, you can add more arguments to this
+        method using "model_kwargs". For example [`PyTorchModelHubMixin._from_pretrained`] takes as input a `map_location`
+        parameter to set on which device the model should be loaded.
+
+        Check out our [integration guide](../guides/integrations) for more instructions.
+
+        Args:
+            model_id (`str`):
+                ID of the model to load from the Huggingface Hub (e.g. `bigscience/bloom`).
+            revision (`str`, *optional*):
+                Revision of the model on the Hub. Can be a branch name, a git tag or any commit id. Defaults to the
+                latest commit on `main` branch.
+            force_download (`bool`, *optional*, defaults to `False`):
+                Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
+                the existing cache.
+            proxies (`Dict[str, str]`, *optional*):
+                A dictionary of proxy servers to use by protocol or endpoint (e.g., `{'http': 'foo.bar:3128',
+                'http://hostname': 'foo.bar:4012'}`).
+            token (`str` or `bool`, *optional*):
+                The token to use as HTTP bearer authorization for remote files. By default, it will use the token
+                cached when running `huggingface-cli login`.
+            cache_dir (`str`, `Path`, *optional*):
+                Path to the folder where cached files are stored.
+            local_files_only (`bool`, *optional*, defaults to `False`):
+                If `True`, avoid downloading the file and return the path to the local cached file if it exists.
+            model_kwargs:
+                Additional keyword arguments passed along to the [`~ModelHubMixin._from_pretrained`] method.
+        """
+        raise NotImplementedError
+
+    @validate_hf_hub_args
+    def push_to_hub(
+        self,
+        repo_id: str,
+        *,
+        config: Optional[Union[dict, DataclassInstance]] = None,
+        commit_message: str = "Push model using huggingface_hub.",
+        private: Optional[bool] = None,
+        token: Optional[str] = None,
+        branch: Optional[str] = None,
+        create_pr: Optional[bool] = None,
+        allow_patterns: Optional[Union[List[str], str]] = None,
+        ignore_patterns: Optional[Union[List[str], str]] = None,
+        delete_patterns: Optional[Union[List[str], str]] = None,
+        model_card_kwargs: Optional[Dict[str, Any]] = None,
+    ) -> str:
+        """
+        Upload model checkpoint to the Hub.
+
+        Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use
+        `delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more
+        details.
+
+        Args:
+            repo_id (`str`):
+                ID of the repository to push to (example: `"username/my-model"`).
+            config (`dict` or `DataclassInstance`, *optional*):
+                Model configuration specified as a key/value dictionary or a dataclass instance.
+            commit_message (`str`, *optional*):
+                Message to commit while pushing.
+            private (`bool`, *optional*):
+                Whether the repository created should be private.
+                If `None` (default), the repo will be public unless the organization's default is private.
+            token (`str`, *optional*):
+                The token to use as HTTP bearer authorization for remote files. By default, it will use the token
+                cached when running `huggingface-cli login`.
+            branch (`str`, *optional*):
+                The git branch on which to push the model. This defaults to `"main"`.
+            create_pr (`boolean`, *optional*):
+                Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`.
+            allow_patterns (`List[str]` or `str`, *optional*):
+                If provided, only files matching at least one pattern are pushed.
+            ignore_patterns (`List[str]` or `str`, *optional*):
+                If provided, files matching any of the patterns are not pushed.
+            delete_patterns (`List[str]` or `str`, *optional*):
+                If provided, remote files matching any of the patterns will be deleted from the repo.
+            model_card_kwargs (`Dict[str, Any]`, *optional*):
+                Additional arguments passed to the model card template to customize the model card.
+
+        Returns:
+            The url of the commit of your model in the given repository.
+        """
+        api = HfApi(token=token)
+        repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id
+
+        # Push the files to the repo in a single commit
+        with SoftTemporaryDirectory() as tmp:
+            saved_path = Path(tmp) / repo_id
+            self.save_pretrained(saved_path, config=config, model_card_kwargs=model_card_kwargs)
+            return api.upload_folder(
+                repo_id=repo_id,
+                repo_type="model",
+                folder_path=saved_path,
+                commit_message=commit_message,
+                revision=branch,
+                create_pr=create_pr,
+                allow_patterns=allow_patterns,
+                ignore_patterns=ignore_patterns,
+                delete_patterns=delete_patterns,
+            )
+
+    def generate_model_card(self, *args, **kwargs) -> ModelCard:
+        card = ModelCard.from_template(
+            card_data=self._hub_mixin_info.model_card_data,
+            template_str=self._hub_mixin_info.model_card_template,
+            repo_url=self._hub_mixin_info.repo_url,
+            docs_url=self._hub_mixin_info.docs_url,
+            **kwargs,
+        )
+        return card
+
+
+class PyTorchModelHubMixin(ModelHubMixin):
+    """
+    Implementation of [`ModelHubMixin`] to provide model Hub upload/download capabilities to PyTorch models. The model
+    is set in evaluation mode by default using `model.eval()` (dropout modules are deactivated). To train the model,
+    you should first set it back in training mode with `model.train()`.
+
+    See [`ModelHubMixin`] for more details on how to use the mixin.
+
+    Example:
+
+    ```python
+    >>> import torch
+    >>> import torch.nn as nn
+    >>> from huggingface_hub import PyTorchModelHubMixin
+
+    >>> class MyModel(
+    ...         nn.Module,
+    ...         PyTorchModelHubMixin,
+    ...         library_name="keras-nlp",
+    ...         repo_url="https://github.com/keras-team/keras-nlp",
+    ...         docs_url="https://keras.io/keras_nlp/",
+    ...         # ^ optional metadata to generate model card
+    ...     ):
+    ...     def __init__(self, hidden_size: int = 512, vocab_size: int = 30000, output_size: int = 4):
+    ...         super().__init__()
+    ...         self.param = nn.Parameter(torch.rand(hidden_size, vocab_size))
+    ...         self.linear = nn.Linear(output_size, vocab_size)
+
+    ...     def forward(self, x):
+    ...         return self.linear(x + self.param)
+    >>> model = MyModel(hidden_size=256)
+
+    # Save model weights to local directory
+    >>> model.save_pretrained("my-awesome-model")
+
+    # Push model weights to the Hub
+    >>> model.push_to_hub("my-awesome-model")
+
+    # Download and initialize weights from the Hub
+    >>> model = MyModel.from_pretrained("username/my-awesome-model")
+    >>> model.hidden_size
+    256
+    ```
+    """
+
+    def __init_subclass__(cls, *args, tags: Optional[List[str]] = None, **kwargs) -> None:
+        tags = tags or []
+        tags.append("pytorch_model_hub_mixin")
+        kwargs["tags"] = tags
+        return super().__init_subclass__(*args, **kwargs)
+
+    def _save_pretrained(self, save_directory: Path) -> None:
+        """Save weights from a Pytorch model to a local directory."""
+        model_to_save = self.module if hasattr(self, "module") else self  # type: ignore
+        save_model_as_safetensor(model_to_save, str(save_directory / constants.SAFETENSORS_SINGLE_FILE))
+
+    @classmethod
+    def _from_pretrained(
+        cls,
+        *,
+        model_id: str,
+        revision: Optional[str],
+        cache_dir: Optional[Union[str, Path]],
+        force_download: bool,
+        proxies: Optional[Dict],
+        resume_download: Optional[bool],
+        local_files_only: bool,
+        token: Union[str, bool, None],
+        map_location: str = "cpu",
+        strict: bool = False,
+        **model_kwargs,
+    ):
+        """Load Pytorch pretrained weights and return the loaded model."""
+        model = cls(**model_kwargs)
+        if os.path.isdir(model_id):
+            print("Loading weights from local directory")
+            model_file = os.path.join(model_id, constants.SAFETENSORS_SINGLE_FILE)
+            return cls._load_as_safetensor(model, model_file, map_location, strict)
+        else:
+            try:
+                model_file = hf_hub_download(
+                    repo_id=model_id,
+                    filename=constants.SAFETENSORS_SINGLE_FILE,
+                    revision=revision,
+                    cache_dir=cache_dir,
+                    force_download=force_download,
+                    proxies=proxies,
+                    resume_download=resume_download,
+                    token=token,
+                    local_files_only=local_files_only,
+                )
+                return cls._load_as_safetensor(model, model_file, map_location, strict)
+            except EntryNotFoundError:
+                model_file = hf_hub_download(
+                    repo_id=model_id,
+                    filename=constants.PYTORCH_WEIGHTS_NAME,
+                    revision=revision,
+                    cache_dir=cache_dir,
+                    force_download=force_download,
+                    proxies=proxies,
+                    resume_download=resume_download,
+                    token=token,
+                    local_files_only=local_files_only,
+                )
+                return cls._load_as_pickle(model, model_file, map_location, strict)
+
+    @classmethod
+    def _load_as_pickle(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
+        state_dict = torch.load(model_file, map_location=torch.device(map_location), weights_only=True)
+        model.load_state_dict(state_dict, strict=strict)  # type: ignore
+        model.eval()  # type: ignore
+        return model
+
+    @classmethod
+    def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
+        if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"):  # type: ignore [attr-defined]
+            load_model_as_safetensor(model, model_file, strict=strict)  # type: ignore [arg-type]
+            if map_location != "cpu":
+                logger.warning(
+                    "Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors."
+                    " This means that the model is loaded on 'cpu' first and then copied to the device."
+                    " This leads to a slower loading time."
+                    " Please update safetensors to version 0.4.3 or above for improved performance."
+                )
+                model.to(map_location)  # type: ignore [attr-defined]
+        else:
+            safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)  # type: ignore [arg-type]
+        return model
+
+
+def _load_dataclass(datacls: Type[DataclassInstance], data: dict) -> DataclassInstance:
+    """Load a dataclass instance from a dictionary.
+
+    Fields not expected by the dataclass are ignored.
+    """
+    return datacls(**{k: v for k, v in data.items() if k in datacls.__dataclass_fields__})