Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update flowvision download func #127

Merged
merged 6 commits into from
Jan 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
- Refator `Vision Transformer` model [#115](https://github.com/Oneflow-Inc/vision/pull/115)
- Refine `flowvision.models.ModelCreator` to support `ModelCreator.model_list` func [#123](https://github.com/Oneflow-Inc/vision/pull/123)
- Refator README [#124](https://github.com/Oneflow-Inc/vision/pull/124)
- Refine `load_state_dict_from_url` in `flowvision.models.utils` to support downloading pretrained weights to cache dir `~/.oneflow/flowvision_cache` [#127](https://github.com/Oneflow-Inc/vision/pull/127)


**Docs Update**
Expand Down
60 changes: 47 additions & 13 deletions flowvision/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,35 +8,65 @@
import zipfile
import tarfile
import warnings
import logging
from urllib.parse import urlparse
from urllib.request import Request, urlopen
from tqdm import tqdm
from typing import Optional

import oneflow as flow

HASH_REGEX = re.compile(r"([a-f0-9]*)_")


def get_cache_dir(cache_dir: Optional[str] = None) -> str:
"""
Modified from https://github.com/facebookresearch/iopath/blob/main/iopath/common/file_io.py
Returns a default directory to cache static files
(usually downloaded from Internet), if None is provided.
Args:
cache_dir (None or str): if not None, will be returned as is.
If None, returns the default cache directory as:
1) $FLOWVISION_CACHE, if set
2) otherwise ~/.oneflow/flowvision_cache
"""
if cache_dir is None:
cache_dir = os.path.expanduser(
os.getenv("FLOWVISION_CACHE", "~/.oneflow/flowvision_cache")
)
try:
os.makedirs(cache_dir, exist_ok=True)
assert os.access(cache_dir, os.W_OK)
except (OSError, AssertionError):
tmp_dir = os.path.join(tempfile.gettempdir(), "flowvision_cache")
logger = logging.getLogger(__name__)
logger.warning(f"{cache_dir} is not accessible! Using {tmp_dir} instead!")
cache_dir = tmp_dir
return cache_dir


def _is_legacy_tar_format(filename):
return tarfile.is_tarfile(filename)


def _legacy_tar_load(filename, model_dir, map_location):
def _legacy_tar_load(filename, model_dir, map_location, delete_tar_file=True):
with tarfile.open(filename) as f:
members = f.getnames()
extracted_name = members[0]
extracted_file = os.path.join(model_dir, extracted_name)
if not os.path.exists(model_dir):
os.mkdir(model_dir)
f.extractall(model_dir)
if delete_tar_file:
os.remove(filename)
return flow.load(extracted_file)


def _is_legacy_zip_format(filename):
return zipfile.is_zipfile(filename)


def _legacy_zip_load(filename, model_dir, map_location):
def _legacy_zip_load(filename, model_dir, map_location, delete_zip_file=True):
# Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand.
# We deliberately don't handle tarfile here since our legacy serialization format was in tar.
with zipfile.ZipFile(filename) as f:
Expand All @@ -46,18 +76,19 @@ def _legacy_zip_load(filename, model_dir, map_location):
if not os.path.exists(extracted_file):
os.mkdir(extracted_file)
f.extractall(model_dir)
# TODO: flow.load doesn't have map_location
# return flow.load(extracted_file, map_location=map_location)
return flow.load(extracted_file)
if delete_zip_file:
os.remove(filename)
return flow.load(extracted_file, map_location)


def load_state_dict_from_url(
url,
model_dir="./checkpoints",
model_dir=None,
map_location=None,
progress=True,
check_hash=False,
file_name=None,
delete_file=True,
):
r"""Loads the OneFlow serialized object at the given URL.

Expand All @@ -79,14 +110,11 @@ def load_state_dict_from_url(
ensure unique names and to verify the contents of the file.
Default: ``False``
file_name (string, optional): name for the downloaded file. Filename from `url` will be used if not set

delete_file (bool, optional): delete downloaded `.zip` file or `.tar.gz` file after unzipping them.
"""

if map_location is not None:
warnings.warn("Map location is not supported yet.")

try:
os.makedirs(model_dir)
model_dir = get_cache_dir(model_dir)
except OSError as e:
if e.errno == errno.EEXIST:
# Directory already exists, ignore.
Expand All @@ -99,6 +127,12 @@ def load_state_dict_from_url(
filename = os.path.basename(parts.path)
if file_name is not None:
filename = file_name
# if already download the weight, directly return loaded state_dict
pretrained_weight_dir = os.path.join(model_dir, filename.split(".")[0])
if os.path.exists(pretrained_weight_dir):
state_dict = flow.load(pretrained_weight_dir)
return state_dict

cached_file = os.path.join(model_dir, filename)
if not os.path.exists(cached_file):
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
Expand All @@ -109,9 +143,9 @@ def load_state_dict_from_url(
download_url_to_file(url, cached_file, hash_prefix, progress=progress)

if _is_legacy_zip_format(cached_file):
return _legacy_zip_load(cached_file, model_dir, map_location)
return _legacy_zip_load(cached_file, model_dir, map_location, delete_file)
elif _is_legacy_tar_format(cached_file):
return _legacy_tar_load(cached_file, model_dir, map_location)
return _legacy_tar_load(cached_file, model_dir, map_location, delete_file)
else:
state_dict = flow.load(cached_file)
return state_dict
Expand Down