Source code for cached_path._cached_path

import logging
import os
import shutil
import tarfile
import tempfile
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Tuple
from urllib.parse import urlparse
from zipfile import ZipFile, is_zipfile

from .cache_file import CacheFile
from .common import PathOrStr, get_cache_dir
from .file_lock import FileLock
from .meta import Meta
from .schemes import (
    SchemeClient,
    get_scheme_client,
    get_supported_schemes,
    hf_get_from_cache,
)
from .util import (
    _lock_file_path,
    _meta_file_path,
    check_tarfile,
    find_latest_cached,
    resource_to_filename,
)

if TYPE_CHECKING:
    from rich.progress import Progress

logger = logging.getLogger("cached_path")


[docs]def cached_path( url_or_filename: PathOrStr, cache_dir: Optional[PathOrStr] = None, extract_archive: bool = False, force_extract: bool = False, quiet: bool = False, progress: Optional["Progress"] = None, ) -> Path: """ Given something that might be a URL or local path, determine which. If it's a remote resource, download the file and cache it, and then return the path to the cached file. If it's already a local path, make sure the file exists and return the path. For URLs, the following schemes are all supported out-of-the-box: * ``http`` and ``https``, * ``s3`` for objects on `AWS S3`_, * ``gs`` for objects on `Google Cloud Storage (GCS)`_, and * ``hf`` for objects or repositories on `HuggingFace Hub`_. If you have `Beaker-py`_ installed you can also use URLs of the form: ``beaker://{user_name}/{dataset_name}/{file_path}``. You can also extend ``cached_path()`` to handle more schemes with :func:`add_scheme_client()`. .. _AWS S3: https://aws.amazon.com/s3/ .. _Google Cloud Storage (GCS): https://cloud.google.com/storage .. _HuggingFace Hub: https://huggingface.co/ .. _Beaker-py: https://github.com/allenai/beaker-py Examples -------- To download a file over ``https``:: cached_path("https://github.com/allenai/cached_path/blob/main/README.md") To download an object on GCS:: cached_path("gs://allennlp-public-models/lerc-2020-11-18.tar.gz") To download the PyTorch weights for the model `epwalsh/bert-xsmall-dummy`_ on HuggingFace, you could do:: cached_path("hf://epwalsh/bert-xsmall-dummy/pytorch_model.bin") For paths or URLs that point to a tarfile or zipfile, you can append the path to a specific file within the archive to the ``url_or_filename``, preceeded by a "!". The archive will be automatically extracted (provided you set ``extract_archive`` to ``True``), returning the local path to the specific file. For example:: cached_path("model.tar.gz!weights.th", extract_archive=True) .. _epwalsh/bert-xsmall-dummy: https://huggingface.co/epwalsh/bert-xsmall-dummy Parameters ---------- url_or_filename : A URL or path to parse and possibly download. cache_dir : The directory to cache downloads. If not specified, the global default cache directory will be used (``~/.cache/cached_path``). This can be set to something else with :func:`set_cache_dir()`. extract_archive : If ``True``, then zip or tar.gz archives will be automatically extracted. In which case the directory is returned. force_extract : If ``True`` and the file is an archive file, it will be extracted regardless of whether or not the extracted directory already exists. .. caution:: Use this flag with caution! This can lead to race conditions if used from multiple processes on the same file. quiet : If ``True``, progress displays won't be printed. progress : A custom progress display to use. If not set and ``quiet=False``, a default display from :func:`~cached_path.get_download_progress()` will be used. Returns ------- :class:`pathlib.Path` The local path to the (potentially cached) resource. Raises ------ ``FileNotFoundError`` If the resource cannot be found locally or remotely. ``ValueError`` When the URL is invalid. ``Other errors`` Other error types are possible as well depending on the client used to fetch the resource. """ if not isinstance(url_or_filename, str): url_or_filename = str(url_or_filename) file_path: Path extraction_path: Optional[Path] = None etag: Optional[str] = None # If we're using the /a/b/foo.zip!c/d/file.txt syntax, handle it here. exclamation_index = url_or_filename.find("!") if extract_archive and exclamation_index >= 0: archive_path = url_or_filename[:exclamation_index] file_name = url_or_filename[exclamation_index + 1 :] # Call 'cached_path' recursively now to get the local path to the archive itself. cached_archive_path = cached_path( archive_path, cache_dir=cache_dir, extract_archive=True, force_extract=force_extract, quiet=quiet, progress=progress, ) if not cached_archive_path.is_dir(): raise ValueError( f"{url_or_filename} uses the ! syntax, but does not specify an archive file." ) # Now return the full path to the desired file within the extracted archive, # provided it exists. file_path = cached_archive_path / file_name if not file_path.exists(): raise FileNotFoundError(f"'{file_name}' not found within '{archive_path}'") return file_path parsed = urlparse(url_or_filename) if parsed.scheme in get_supported_schemes(): # URL, so get it from the cache (downloading if necessary) file_path, etag = get_from_cache(url_or_filename, cache_dir, quiet=quiet, progress=progress) if extract_archive and (is_zipfile(file_path) or tarfile.is_tarfile(file_path)): # This is the path the file should be extracted to. # For example ~/.cached_path/cache/234234.21341 -> ~/.cached_path/cache/234234.21341-extracted extraction_path = file_path.parent / (file_path.name + "-extracted") elif parsed.scheme == "file": return cached_path(url_or_filename.replace("file://", "", 1)) else: orig_url_or_filename = url_or_filename url_or_filename = Path(url_or_filename).expanduser() cache_dir = Path(cache_dir if cache_dir else get_cache_dir()).expanduser() cache_dir.mkdir(parents=True, exist_ok=True) if url_or_filename.exists(): # File, and it exists. file_path = url_or_filename # Normalize the path. url_or_filename = url_or_filename.resolve() if ( extract_archive and file_path.is_file() and (is_zipfile(file_path) or tarfile.is_tarfile(file_path)) ): # We'll use a unique directory within the cache to root to extract the archive to. # The name of the directory is a hash of the resource file path and it's modification # time. That way, if the file changes, we'll know when to extract it again. extraction_name = ( resource_to_filename(url_or_filename, str(os.path.getmtime(file_path))) + "-extracted" ) extraction_path = cache_dir / extraction_name elif parsed.scheme == "": # File, but it doesn't exist. raise FileNotFoundError(f"file {url_or_filename} not found") else: # Something unknown raise ValueError(f"unable to parse {orig_url_or_filename} as a URL or as a local path") if extraction_path is not None: # If the extracted directory already exists (and is non-empty), then no # need to create a lock file and extract again unless `force_extract=True`. if os.path.isdir(extraction_path) and os.listdir(extraction_path) and not force_extract: return extraction_path # Extract it. with FileLock(_lock_file_path(extraction_path)): # Check again if the directory exists now that we've acquired the lock. if os.path.isdir(extraction_path) and os.listdir(extraction_path): if force_extract: logger.warning( "Extraction directory for %s (%s) already exists, " "overwriting it since 'force_extract' is 'True'", url_or_filename, extraction_path, ) else: return extraction_path logger.info("Extracting %s to %s", url_or_filename, extraction_path) shutil.rmtree(extraction_path, ignore_errors=True) # We extract first to a temporary directory in case something goes wrong # during the extraction process so we don't end up with a corrupted cache. tmp_extraction_dir = tempfile.mkdtemp(dir=os.path.split(extraction_path)[0]) try: if is_zipfile(file_path): with ZipFile(file_path, "r") as zip_file: zip_file.extractall(tmp_extraction_dir) zip_file.close() else: tar_file = tarfile.open(file_path) check_tarfile(tar_file) tar_file.extractall(tmp_extraction_dir) tar_file.close() # Extraction was successful, rename temp directory to final # cache directory and dump the meta data. os.replace(tmp_extraction_dir, extraction_path) meta = Meta.new( url_or_filename, extraction_path, etag=etag, extraction_dir=True, ) meta.to_file() finally: shutil.rmtree(tmp_extraction_dir, ignore_errors=True) return extraction_path return file_path
def get_from_cache( url: str, cache_dir: Optional[PathOrStr] = None, quiet: bool = False, progress: Optional["Progress"] = None, no_downloads: bool = False, _client: Optional[SchemeClient] = None, ) -> Tuple[Path, Optional[str]]: """ Given a URL, look for the corresponding dataset in the local cache. If it's not there, download it. Then return the path to the cached file and the ETag. """ if url.startswith("hf://"): return hf_get_from_cache(url, cache_dir), None cache_dir = Path(cache_dir if cache_dir else get_cache_dir()).expanduser() cache_dir.mkdir(parents=True, exist_ok=True) client = _client or get_scheme_client(url) # Get eTag to add to filename, if it exists. try: etag = client.get_etag() except client.recoverable_errors: # type: ignore # We might be offline, in which case we don't want to throw an error # just yet. Instead, we'll try to use the latest cached version of the # target resource, if it exists. We'll only throw an exception if we # haven't cached the resource at all yet. logger.warning( "Connection error occurred while trying to fetch ETag for %s. " "Will attempt to use latest cached version of resource", url, ) latest_cached = find_latest_cached(url, cache_dir) if latest_cached: logger.info( "ETag request failed with recoverable error, using latest cached " "version of %s: %s", url, latest_cached, ) meta = Meta.from_path(_meta_file_path(latest_cached)) return latest_cached, meta.etag else: logger.error( "ETag request failed with recoverable error, " "but no cached version of %s could be found", url, ) raise filename = resource_to_filename(url, etag) # Get cache path to put the file. cache_path = cache_dir / filename # Multiple processes may be trying to cache the same file at once, so we need # to be a little careful to avoid race conditions. We do this using a lock file. # Only one process can own this lock file at a time, and a process will block # on the call to `lock.acquire()` until the process currently holding the lock # releases it. logger.debug("waiting to acquire lock on %s", cache_path) with FileLock(_lock_file_path(cache_path), read_only_ok=True): if os.path.exists(cache_path): logger.info("cache of %s is up-to-date", url) elif no_downloads: raise FileNotFoundError(cache_path) else: size = client.get_size() with CacheFile(cache_path) as cache_file: logger.info("%s not found in cache, downloading to %s", url, cache_path) from .progress import BufferedWriterWithProgress, get_download_progress start_and_cleanup = progress is None progress = progress or get_download_progress(quiet=quiet) if start_and_cleanup: progress.start() try: display_url = url if len(url) <= 30 else f"\N{horizontal ellipsis}{url[-30:]}" task_id = progress.add_task(f"Downloading [cyan i]{display_url}[/]", total=size) writer_with_progress = BufferedWriterWithProgress(cache_file, progress, task_id) client.get_resource(writer_with_progress) progress.update( task_id, total=writer_with_progress.total_written, completed=writer_with_progress.total_written, ) finally: if start_and_cleanup: progress.stop() logger.debug("creating metadata file for %s", cache_path) meta = Meta.new( url, cache_path, etag=etag, ) meta.to_file() return cache_path, etag