diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/huggingface_hub/_snapshot_download.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/huggingface_hub/_snapshot_download.py | 307 |
1 files changed, 307 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/huggingface_hub/_snapshot_download.py b/.venv/lib/python3.12/site-packages/huggingface_hub/_snapshot_download.py new file mode 100644 index 00000000..b928dd34 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/huggingface_hub/_snapshot_download.py @@ -0,0 +1,307 @@ +import os +from pathlib import Path +from typing import Dict, List, Literal, Optional, Union + +import requests +from tqdm.auto import tqdm as base_tqdm +from tqdm.contrib.concurrent import thread_map + +from . import constants +from .errors import GatedRepoError, LocalEntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError +from .file_download import REGEX_COMMIT_HASH, hf_hub_download, repo_folder_name +from .hf_api import DatasetInfo, HfApi, ModelInfo, SpaceInfo +from .utils import OfflineModeIsEnabled, filter_repo_objects, logging, validate_hf_hub_args +from .utils import tqdm as hf_tqdm + + +logger = logging.get_logger(__name__) + + +@validate_hf_hub_args +def snapshot_download( + repo_id: str, + *, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + cache_dir: Union[str, Path, None] = None, + local_dir: Union[str, Path, None] = None, + library_name: Optional[str] = None, + library_version: Optional[str] = None, + user_agent: Optional[Union[Dict, str]] = None, + proxies: Optional[Dict] = None, + etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT, + force_download: bool = False, + token: Optional[Union[bool, str]] = None, + local_files_only: bool = False, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, + max_workers: int = 8, + tqdm_class: Optional[base_tqdm] = None, + headers: Optional[Dict[str, str]] = None, + endpoint: Optional[str] = None, + # Deprecated args + local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto", + resume_download: Optional[bool] = None, +) -> str: + """Download repo files. + + Download a whole snapshot of a repo's files at the specified revision. This is useful when you want all files from + a repo, because you don't know which ones you will need a priori. All files are nested inside a folder in order + to keep their actual filename relative to that folder. You can also filter which files to download using + `allow_patterns` and `ignore_patterns`. + + If `local_dir` is provided, the file structure from the repo will be replicated in this location. When using this + option, the `cache_dir` will not be used and a `.cache/huggingface/` folder will be created at the root of `local_dir` + to store some metadata related to the downloaded files. While this mechanism is not as robust as the main + cache-system, it's optimized for regularly pulling the latest version of a repository. + + An alternative would be to clone the repo but this requires git and git-lfs to be installed and properly + configured. It is also not possible to filter which files to download when cloning a repository using git. + + Args: + repo_id (`str`): + A user or an organization name and a repo name separated by a `/`. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if downloading from a dataset or space, + `None` or `"model"` if downloading from a model. Default is `None`. + revision (`str`, *optional*): + An optional Git revision id which can be a branch name, a tag, or a + commit hash. + cache_dir (`str`, `Path`, *optional*): + Path to the folder where cached files are stored. + local_dir (`str` or `Path`, *optional*): + If provided, the downloaded files will be placed under this directory. + library_name (`str`, *optional*): + The name of the library to which the object corresponds. + library_version (`str`, *optional*): + The version of the library. + user_agent (`str`, `dict`, *optional*): + The user-agent info in the form of a dictionary or a string. + proxies (`dict`, *optional*): + Dictionary mapping protocol to the URL of the proxy passed to + `requests.request`. + etag_timeout (`float`, *optional*, defaults to `10`): + When fetching ETag, how many seconds to wait for the server to send + data before giving up which is passed to `requests.request`. + force_download (`bool`, *optional*, defaults to `False`): + Whether the file should be downloaded even if it already exists in the local cache. + token (`str`, `bool`, *optional*): + A token to be used for the download. + - If `True`, the token is read from the HuggingFace config + folder. + - If a string, it's used as the authentication token. + headers (`dict`, *optional*): + Additional headers to include in the request. Those headers take precedence over the others. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, avoid downloading the file and return the path to the + local cached file if it exists. + allow_patterns (`List[str]` or `str`, *optional*): + If provided, only files matching at least one pattern are downloaded. + ignore_patterns (`List[str]` or `str`, *optional*): + If provided, files matching any of the patterns are not downloaded. + max_workers (`int`, *optional*): + Number of concurrent threads to download files (1 thread = 1 file download). + Defaults to 8. + tqdm_class (`tqdm`, *optional*): + If provided, overwrites the default behavior for the progress bar. Passed + argument must inherit from `tqdm.auto.tqdm` or at least mimic its behavior. + Note that the `tqdm_class` is not passed to each individual download. + Defaults to the custom HF progress bar that can be disabled by setting + `HF_HUB_DISABLE_PROGRESS_BARS` environment variable. + + Returns: + `str`: folder path of the repo snapshot. + + Raises: + [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + [`~utils.RevisionNotFoundError`] + If the revision to download from cannot be found. + [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) + If `token=True` and the token cannot be found. + [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if + ETag cannot be determined. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid. + """ + if cache_dir is None: + cache_dir = constants.HF_HUB_CACHE + if revision is None: + revision = constants.DEFAULT_REVISION + if isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + if repo_type is None: + repo_type = "model" + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(constants.REPO_TYPES)}") + + storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type)) + + repo_info: Union[ModelInfo, DatasetInfo, SpaceInfo, None] = None + api_call_error: Optional[Exception] = None + if not local_files_only: + # try/except logic to handle different errors => taken from `hf_hub_download` + try: + # if we have internet connection we want to list files to download + api = HfApi( + library_name=library_name, + library_version=library_version, + user_agent=user_agent, + endpoint=endpoint, + headers=headers, + ) + repo_info = api.repo_info(repo_id=repo_id, repo_type=repo_type, revision=revision, token=token) + except (requests.exceptions.SSLError, requests.exceptions.ProxyError): + # Actually raise for those subclasses of ConnectionError + raise + except ( + requests.exceptions.ConnectionError, + requests.exceptions.Timeout, + OfflineModeIsEnabled, + ) as error: + # Internet connection is down + # => will try to use local files only + api_call_error = error + pass + except RevisionNotFoundError: + # The repo was found but the revision doesn't exist on the Hub (never existed or got deleted) + raise + except requests.HTTPError as error: + # Multiple reasons for an http error: + # - Repository is private and invalid/missing token sent + # - Repository is gated and invalid/missing token sent + # - Hub is down (error 500 or 504) + # => let's switch to 'local_files_only=True' to check if the files are already cached. + # (if it's not the case, the error will be re-raised) + api_call_error = error + pass + + # At this stage, if `repo_info` is None it means either: + # - internet connection is down + # - internet connection is deactivated (local_files_only=True or HF_HUB_OFFLINE=True) + # - repo is private/gated and invalid/missing token sent + # - Hub is down + # => let's look if we can find the appropriate folder in the cache: + # - if the specified revision is a commit hash, look inside "snapshots". + # - f the specified revision is a branch or tag, look inside "refs". + # => if local_dir is not None, we will return the path to the local folder if it exists. + if repo_info is None: + # Try to get which commit hash corresponds to the specified revision + commit_hash = None + if REGEX_COMMIT_HASH.match(revision): + commit_hash = revision + else: + ref_path = os.path.join(storage_folder, "refs", revision) + if os.path.exists(ref_path): + # retrieve commit_hash from refs file + with open(ref_path) as f: + commit_hash = f.read() + + # Try to locate snapshot folder for this commit hash + if commit_hash is not None: + snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash) + if os.path.exists(snapshot_folder): + # Snapshot folder exists => let's return it + # (but we can't check if all the files are actually there) + return snapshot_folder + # If local_dir is not None, return it if it exists and is not empty + if local_dir is not None: + local_dir = Path(local_dir) + if local_dir.is_dir() and any(local_dir.iterdir()): + logger.warning( + f"Returning existing local_dir `{local_dir}` as remote repo cannot be accessed in `snapshot_download` ({api_call_error})." + ) + return str(local_dir.resolve()) + # If we couldn't find the appropriate folder on disk, raise an error. + if local_files_only: + raise LocalEntryNotFoundError( + "Cannot find an appropriate cached snapshot folder for the specified revision on the local disk and " + "outgoing traffic has been disabled. To enable repo look-ups and downloads online, pass " + "'local_files_only=False' as input." + ) + elif isinstance(api_call_error, OfflineModeIsEnabled): + raise LocalEntryNotFoundError( + "Cannot find an appropriate cached snapshot folder for the specified revision on the local disk and " + "outgoing traffic has been disabled. To enable repo look-ups and downloads online, set " + "'HF_HUB_OFFLINE=0' as environment variable." + ) from api_call_error + elif isinstance(api_call_error, RepositoryNotFoundError) or isinstance(api_call_error, GatedRepoError): + # Repo not found => let's raise the actual error + raise api_call_error + else: + # Otherwise: most likely a connection issue or Hub downtime => let's warn the user + raise LocalEntryNotFoundError( + "An error happened while trying to locate the files on the Hub and we cannot find the appropriate" + " snapshot folder for the specified revision on the local disk. Please check your internet connection" + " and try again." + ) from api_call_error + + # At this stage, internet connection is up and running + # => let's download the files! + assert repo_info.sha is not None, "Repo info returned from server must have a revision sha." + assert repo_info.siblings is not None, "Repo info returned from server must have a siblings list." + filtered_repo_files = list( + filter_repo_objects( + items=[f.rfilename for f in repo_info.siblings], + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + ) + ) + commit_hash = repo_info.sha + snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash) + # if passed revision is not identical to commit_hash + # then revision has to be a branch name or tag name. + # In that case store a ref. + if revision != commit_hash: + ref_path = os.path.join(storage_folder, "refs", revision) + try: + os.makedirs(os.path.dirname(ref_path), exist_ok=True) + with open(ref_path, "w") as f: + f.write(commit_hash) + except OSError as e: + logger.warning(f"Ignored error while writing commit hash to {ref_path}: {e}.") + + # we pass the commit_hash to hf_hub_download + # so no network call happens if we already + # have the file locally. + def _inner_hf_hub_download(repo_file: str): + return hf_hub_download( + repo_id, + filename=repo_file, + repo_type=repo_type, + revision=commit_hash, + endpoint=endpoint, + cache_dir=cache_dir, + local_dir=local_dir, + local_dir_use_symlinks=local_dir_use_symlinks, + library_name=library_name, + library_version=library_version, + user_agent=user_agent, + proxies=proxies, + etag_timeout=etag_timeout, + resume_download=resume_download, + force_download=force_download, + token=token, + headers=headers, + ) + + if constants.HF_HUB_ENABLE_HF_TRANSFER: + # when using hf_transfer we don't want extra parallelism + # from the one hf_transfer provides + for file in filtered_repo_files: + _inner_hf_hub_download(file) + else: + thread_map( + _inner_hf_hub_download, + filtered_repo_files, + desc=f"Fetching {len(filtered_repo_files)} files", + max_workers=max_workers, + # User can use its own tqdm class or the default one from `huggingface_hub.utils` + tqdm_class=tqdm_class or hf_tqdm, + ) + + if local_dir is not None: + return str(os.path.realpath(local_dir)) + return snapshot_folder |