about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/huggingface_hub/_snapshot_download.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/huggingface_hub/_snapshot_download.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/huggingface_hub/_snapshot_download.py')
-rw-r--r--.venv/lib/python3.12/site-packages/huggingface_hub/_snapshot_download.py307
1 files changed, 307 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/huggingface_hub/_snapshot_download.py b/.venv/lib/python3.12/site-packages/huggingface_hub/_snapshot_download.py
new file mode 100644
index 00000000..b928dd34
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/huggingface_hub/_snapshot_download.py
@@ -0,0 +1,307 @@
+import os
+from pathlib import Path
+from typing import Dict, List, Literal, Optional, Union
+
+import requests
+from tqdm.auto import tqdm as base_tqdm
+from tqdm.contrib.concurrent import thread_map
+
+from . import constants
+from .errors import GatedRepoError, LocalEntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
+from .file_download import REGEX_COMMIT_HASH, hf_hub_download, repo_folder_name
+from .hf_api import DatasetInfo, HfApi, ModelInfo, SpaceInfo
+from .utils import OfflineModeIsEnabled, filter_repo_objects, logging, validate_hf_hub_args
+from .utils import tqdm as hf_tqdm
+
+
+logger = logging.get_logger(__name__)
+
+
+@validate_hf_hub_args
+def snapshot_download(
+    repo_id: str,
+    *,
+    repo_type: Optional[str] = None,
+    revision: Optional[str] = None,
+    cache_dir: Union[str, Path, None] = None,
+    local_dir: Union[str, Path, None] = None,
+    library_name: Optional[str] = None,
+    library_version: Optional[str] = None,
+    user_agent: Optional[Union[Dict, str]] = None,
+    proxies: Optional[Dict] = None,
+    etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT,
+    force_download: bool = False,
+    token: Optional[Union[bool, str]] = None,
+    local_files_only: bool = False,
+    allow_patterns: Optional[Union[List[str], str]] = None,
+    ignore_patterns: Optional[Union[List[str], str]] = None,
+    max_workers: int = 8,
+    tqdm_class: Optional[base_tqdm] = None,
+    headers: Optional[Dict[str, str]] = None,
+    endpoint: Optional[str] = None,
+    # Deprecated args
+    local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
+    resume_download: Optional[bool] = None,
+) -> str:
+    """Download repo files.
+
+    Download a whole snapshot of a repo's files at the specified revision. This is useful when you want all files from
+    a repo, because you don't know which ones you will need a priori. All files are nested inside a folder in order
+    to keep their actual filename relative to that folder. You can also filter which files to download using
+    `allow_patterns` and `ignore_patterns`.
+
+    If `local_dir` is provided, the file structure from the repo will be replicated in this location. When using this
+    option, the `cache_dir` will not be used and a `.cache/huggingface/` folder will be created at the root of `local_dir`
+    to store some metadata related to the downloaded files. While this mechanism is not as robust as the main
+    cache-system, it's optimized for regularly pulling the latest version of a repository.
+
+    An alternative would be to clone the repo but this requires git and git-lfs to be installed and properly
+    configured. It is also not possible to filter which files to download when cloning a repository using git.
+
+    Args:
+        repo_id (`str`):
+            A user or an organization name and a repo name separated by a `/`.
+        repo_type (`str`, *optional*):
+            Set to `"dataset"` or `"space"` if downloading from a dataset or space,
+            `None` or `"model"` if downloading from a model. Default is `None`.
+        revision (`str`, *optional*):
+            An optional Git revision id which can be a branch name, a tag, or a
+            commit hash.
+        cache_dir (`str`, `Path`, *optional*):
+            Path to the folder where cached files are stored.
+        local_dir (`str` or `Path`, *optional*):
+            If provided, the downloaded files will be placed under this directory.
+        library_name (`str`, *optional*):
+            The name of the library to which the object corresponds.
+        library_version (`str`, *optional*):
+            The version of the library.
+        user_agent (`str`, `dict`, *optional*):
+            The user-agent info in the form of a dictionary or a string.
+        proxies (`dict`, *optional*):
+            Dictionary mapping protocol to the URL of the proxy passed to
+            `requests.request`.
+        etag_timeout (`float`, *optional*, defaults to `10`):
+            When fetching ETag, how many seconds to wait for the server to send
+            data before giving up which is passed to `requests.request`.
+        force_download (`bool`, *optional*, defaults to `False`):
+            Whether the file should be downloaded even if it already exists in the local cache.
+        token (`str`, `bool`, *optional*):
+            A token to be used for the download.
+                - If `True`, the token is read from the HuggingFace config
+                  folder.
+                - If a string, it's used as the authentication token.
+        headers (`dict`, *optional*):
+            Additional headers to include in the request. Those headers take precedence over the others.
+        local_files_only (`bool`, *optional*, defaults to `False`):
+            If `True`, avoid downloading the file and return the path to the
+            local cached file if it exists.
+        allow_patterns (`List[str]` or `str`, *optional*):
+            If provided, only files matching at least one pattern are downloaded.
+        ignore_patterns (`List[str]` or `str`, *optional*):
+            If provided, files matching any of the patterns are not downloaded.
+        max_workers (`int`, *optional*):
+            Number of concurrent threads to download files (1 thread = 1 file download).
+            Defaults to 8.
+        tqdm_class (`tqdm`, *optional*):
+            If provided, overwrites the default behavior for the progress bar. Passed
+            argument must inherit from `tqdm.auto.tqdm` or at least mimic its behavior.
+            Note that the `tqdm_class` is not passed to each individual download.
+            Defaults to the custom HF progress bar that can be disabled by setting
+            `HF_HUB_DISABLE_PROGRESS_BARS` environment variable.
+
+    Returns:
+        `str`: folder path of the repo snapshot.
+
+    Raises:
+        [`~utils.RepositoryNotFoundError`]
+            If the repository to download from cannot be found. This may be because it doesn't exist,
+            or because it is set to `private` and you do not have access.
+        [`~utils.RevisionNotFoundError`]
+            If the revision to download from cannot be found.
+        [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
+            If `token=True` and the token cannot be found.
+        [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if
+            ETag cannot be determined.
+        [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
+            if some parameter value is invalid.
+    """
+    if cache_dir is None:
+        cache_dir = constants.HF_HUB_CACHE
+    if revision is None:
+        revision = constants.DEFAULT_REVISION
+    if isinstance(cache_dir, Path):
+        cache_dir = str(cache_dir)
+
+    if repo_type is None:
+        repo_type = "model"
+    if repo_type not in constants.REPO_TYPES:
+        raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(constants.REPO_TYPES)}")
+
+    storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type))
+
+    repo_info: Union[ModelInfo, DatasetInfo, SpaceInfo, None] = None
+    api_call_error: Optional[Exception] = None
+    if not local_files_only:
+        # try/except logic to handle different errors => taken from `hf_hub_download`
+        try:
+            # if we have internet connection we want to list files to download
+            api = HfApi(
+                library_name=library_name,
+                library_version=library_version,
+                user_agent=user_agent,
+                endpoint=endpoint,
+                headers=headers,
+            )
+            repo_info = api.repo_info(repo_id=repo_id, repo_type=repo_type, revision=revision, token=token)
+        except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
+            # Actually raise for those subclasses of ConnectionError
+            raise
+        except (
+            requests.exceptions.ConnectionError,
+            requests.exceptions.Timeout,
+            OfflineModeIsEnabled,
+        ) as error:
+            # Internet connection is down
+            # => will try to use local files only
+            api_call_error = error
+            pass
+        except RevisionNotFoundError:
+            # The repo was found but the revision doesn't exist on the Hub (never existed or got deleted)
+            raise
+        except requests.HTTPError as error:
+            # Multiple reasons for an http error:
+            # - Repository is private and invalid/missing token sent
+            # - Repository is gated and invalid/missing token sent
+            # - Hub is down (error 500 or 504)
+            # => let's switch to 'local_files_only=True' to check if the files are already cached.
+            #    (if it's not the case, the error will be re-raised)
+            api_call_error = error
+            pass
+
+    # At this stage, if `repo_info` is None it means either:
+    # - internet connection is down
+    # - internet connection is deactivated (local_files_only=True or HF_HUB_OFFLINE=True)
+    # - repo is private/gated and invalid/missing token sent
+    # - Hub is down
+    # => let's look if we can find the appropriate folder in the cache:
+    #    - if the specified revision is a commit hash, look inside "snapshots".
+    #    - f the specified revision is a branch or tag, look inside "refs".
+    # => if local_dir is not None, we will return the path to the local folder if it exists.
+    if repo_info is None:
+        # Try to get which commit hash corresponds to the specified revision
+        commit_hash = None
+        if REGEX_COMMIT_HASH.match(revision):
+            commit_hash = revision
+        else:
+            ref_path = os.path.join(storage_folder, "refs", revision)
+            if os.path.exists(ref_path):
+                # retrieve commit_hash from refs file
+                with open(ref_path) as f:
+                    commit_hash = f.read()
+
+        # Try to locate snapshot folder for this commit hash
+        if commit_hash is not None:
+            snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash)
+            if os.path.exists(snapshot_folder):
+                # Snapshot folder exists => let's return it
+                # (but we can't check if all the files are actually there)
+                return snapshot_folder
+        # If local_dir is not None, return it if it exists and is not empty
+        if local_dir is not None:
+            local_dir = Path(local_dir)
+            if local_dir.is_dir() and any(local_dir.iterdir()):
+                logger.warning(
+                    f"Returning existing local_dir `{local_dir}` as remote repo cannot be accessed in `snapshot_download` ({api_call_error})."
+                )
+                return str(local_dir.resolve())
+        # If we couldn't find the appropriate folder on disk, raise an error.
+        if local_files_only:
+            raise LocalEntryNotFoundError(
+                "Cannot find an appropriate cached snapshot folder for the specified revision on the local disk and "
+                "outgoing traffic has been disabled. To enable repo look-ups and downloads online, pass "
+                "'local_files_only=False' as input."
+            )
+        elif isinstance(api_call_error, OfflineModeIsEnabled):
+            raise LocalEntryNotFoundError(
+                "Cannot find an appropriate cached snapshot folder for the specified revision on the local disk and "
+                "outgoing traffic has been disabled. To enable repo look-ups and downloads online, set "
+                "'HF_HUB_OFFLINE=0' as environment variable."
+            ) from api_call_error
+        elif isinstance(api_call_error, RepositoryNotFoundError) or isinstance(api_call_error, GatedRepoError):
+            # Repo not found => let's raise the actual error
+            raise api_call_error
+        else:
+            # Otherwise: most likely a connection issue or Hub downtime => let's warn the user
+            raise LocalEntryNotFoundError(
+                "An error happened while trying to locate the files on the Hub and we cannot find the appropriate"
+                " snapshot folder for the specified revision on the local disk. Please check your internet connection"
+                " and try again."
+            ) from api_call_error
+
+    # At this stage, internet connection is up and running
+    # => let's download the files!
+    assert repo_info.sha is not None, "Repo info returned from server must have a revision sha."
+    assert repo_info.siblings is not None, "Repo info returned from server must have a siblings list."
+    filtered_repo_files = list(
+        filter_repo_objects(
+            items=[f.rfilename for f in repo_info.siblings],
+            allow_patterns=allow_patterns,
+            ignore_patterns=ignore_patterns,
+        )
+    )
+    commit_hash = repo_info.sha
+    snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash)
+    # if passed revision is not identical to commit_hash
+    # then revision has to be a branch name or tag name.
+    # In that case store a ref.
+    if revision != commit_hash:
+        ref_path = os.path.join(storage_folder, "refs", revision)
+        try:
+            os.makedirs(os.path.dirname(ref_path), exist_ok=True)
+            with open(ref_path, "w") as f:
+                f.write(commit_hash)
+        except OSError as e:
+            logger.warning(f"Ignored error while writing commit hash to {ref_path}: {e}.")
+
+    # we pass the commit_hash to hf_hub_download
+    # so no network call happens if we already
+    # have the file locally.
+    def _inner_hf_hub_download(repo_file: str):
+        return hf_hub_download(
+            repo_id,
+            filename=repo_file,
+            repo_type=repo_type,
+            revision=commit_hash,
+            endpoint=endpoint,
+            cache_dir=cache_dir,
+            local_dir=local_dir,
+            local_dir_use_symlinks=local_dir_use_symlinks,
+            library_name=library_name,
+            library_version=library_version,
+            user_agent=user_agent,
+            proxies=proxies,
+            etag_timeout=etag_timeout,
+            resume_download=resume_download,
+            force_download=force_download,
+            token=token,
+            headers=headers,
+        )
+
+    if constants.HF_HUB_ENABLE_HF_TRANSFER:
+        # when using hf_transfer we don't want extra parallelism
+        # from the one hf_transfer provides
+        for file in filtered_repo_files:
+            _inner_hf_hub_download(file)
+    else:
+        thread_map(
+            _inner_hf_hub_download,
+            filtered_repo_files,
+            desc=f"Fetching {len(filtered_repo_files)} files",
+            max_workers=max_workers,
+            # User can use its own tqdm class or the default one from `huggingface_hub.utils`
+            tqdm_class=tqdm_class or hf_tqdm,
+        )
+
+    if local_dir is not None:
+        return str(os.path.realpath(local_dir))
+    return snapshot_folder