Skip to content

Commit

Permalink
[feat] add default download root instead of using current working dir
Browse files Browse the repository at this point in the history
  • Loading branch information
geniuspatrick committed Mar 1, 2023
1 parent e579ea2 commit 687e5de
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 16 deletions.
20 changes: 16 additions & 4 deletions mindcv/data/dataset_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,22 @@
"""

import os
from typing import Optional

from mindcv.utils.download import DownLoad
from mindcv.utils.download import DownLoad, get_default_download_root

__all__ = [
"get_dataset_download_root",
"MnistDownload",
"Cifar10Download",
"Cifar100Download",
]


def get_dataset_download_root():
return os.path.join(get_default_download_root(), "datasets")


class MnistDownload(DownLoad):
"""Utility class for downloading Mnist dataset.
Expand All @@ -29,8 +35,10 @@ class MnistDownload(DownLoad):
("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c"),
]

def __init__(self, root: str):
def __init__(self, root: Optional[str] = None):
super().__init__()
if root is None:
root = os.path.join(get_dataset_download_root(), "mnist")
self.root = root
self.path = root

Expand Down Expand Up @@ -77,8 +85,10 @@ class Cifar10Download(DownLoad):
"batches.meta.txt",
]

def __init__(self, root: str):
def __init__(self, root: Optional[str] = None):
super().__init__()
if root is None:
root = os.path.join(get_dataset_download_root(), "cifar10")
self.root = root
self.path = os.path.join(self.root, self.base_dir)

Expand Down Expand Up @@ -118,8 +128,10 @@ class Cifar100Download(DownLoad):
"coarse_label_names.txt",
]

def __init__(self, root: str):
def __init__(self, root: Optional[str] = None):
super().__init__()
if root is None:
root = os.path.join(get_dataset_download_root(), "cifar100")
self.root = root
self.path = os.path.join(self.root, self.base_dir)

Expand Down
12 changes: 7 additions & 5 deletions mindcv/data/dataset_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import mindspore.dataset as ds
from mindspore.dataset import Cifar10Dataset, Cifar100Dataset, DistributedSampler, ImageFolderDataset, MnistDataset

from .dataset_download import Cifar10Download, Cifar100Download, MnistDownload
from .dataset_download import Cifar10Download, Cifar100Download, MnistDownload, get_dataset_download_root
from .distributed_sampler import RepeatAugSampler

__all__ = [
Expand All @@ -24,10 +24,10 @@

def create_dataset(
name: str = "",
root: str = "./",
root: Optional[str] = None,
split: str = "train",
shuffle: bool = True,
num_samples: Optional[bool] = None,
num_samples: Optional[int] = None,
num_shards: Optional[int] = None,
shard_id: Optional[int] = None,
num_parallel_workers: Optional[int] = None,
Expand All @@ -39,7 +39,7 @@ def create_dataset(
Args:
name: dataset name like MNIST, CIFAR10, ImageNeT, ''. '' means a customized dataset. Default: ''.
root: dataset root dir. Default: './'.
root: dataset root dir. Default: None.
split: data split: '' or split name string (train/val/test), if it is '', no split is used.
Otherwise, it is a subfolder of root dir, e.g., train, val, test. Default: 'train'.
shuffle: whether to shuffle the dataset. Default: True.
Expand Down Expand Up @@ -79,10 +79,12 @@ def create_dataset(
Returns:
Dataset object
"""
name = name.lower()
if root is None:
root = os.path.join(get_dataset_download_root(), name)

assert (num_samples is None) or (num_aug_repeats == 0), "num_samples and num_aug_repeats can NOT be set together."

name = name.lower()
# subset sampling
if num_samples is not None and num_samples > 0:
# TODO: rewrite ordered distributed sampler (subset sampling in distributed mode is not tested)
Expand Down
15 changes: 10 additions & 5 deletions mindcv/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@

from mindspore import load_checkpoint, load_param_into_net

from mindcv.utils.download import DownLoad
from mindcv.utils.download import DownLoad, get_default_download_root


def get_checkpoint_download_root():
return os.path.join(get_default_download_root(), "models")


class ConfigDict(dict):
Expand All @@ -20,17 +24,18 @@ class ConfigDict(dict):
__delattr__ = dict.__delitem__


def load_pretrained(model, default_cfg, path="./", num_classes=1000, in_channels=3, filter_fn=None):
def load_pretrained(model, default_cfg, num_classes=1000, in_channels=3, filter_fn=None):
"""load pretrained model depending on cfgs of model"""
if "url" not in default_cfg or not default_cfg["url"]:
logging.warning("Pretrained model URL is invalid")
return

# download files
os.makedirs(path, exist_ok=True)
DownLoad().download_url(default_cfg["url"], path=path)
download_path = get_checkpoint_download_root()
os.makedirs(download_path, exist_ok=True)
DownLoad().download_url(default_cfg["url"], path=download_path)

param_dict = load_checkpoint(os.path.join(path, os.path.basename(default_cfg["url"])))
param_dict = load_checkpoint(os.path.join(download_path, os.path.basename(default_cfg["url"])))

if in_channels == 1:
conv1_name = default_cfg["first_conv"]
Expand Down
25 changes: 23 additions & 2 deletions mindcv/utils/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,35 @@
import bz2
import gzip
import hashlib
import logging
import os
import ssl
import tarfile
import urllib
import urllib.error
import urllib.request
import zipfile
from copy import deepcopy
from typing import Optional

from tqdm import tqdm

from .path import detect_file_type

_logger = logging.getLogger(__name__)
# The default root directory where we save downloaded files.
# Use Get/Set to R/W this variable.
_DEFAULT_DOWNLOAD_ROOT = os.path.join(os.path.expanduser("~"), ".mindspore")


def get_default_download_root():
return deepcopy(_DEFAULT_DOWNLOAD_ROOT)


def set_default_download_root(path):
global _DEFAULT_DOWNLOAD_ROOT
_DEFAULT_DOWNLOAD_ROOT = path


class DownLoad:
"""Base utility class for downloading."""
Expand Down Expand Up @@ -85,6 +101,7 @@ def download_file(self, url: str, file_path: str, chunk_size: int = 1024):
# Define request headers.
headers = {"User-Agent": self.USER_AGENT}

_logger.info(f"Downloading from {url} to {file_path} ...")
with open(file_path, "wb") as f:
request = urllib.request.Request(url, headers=headers)
with urllib.request.urlopen(request) as response:
Expand All @@ -98,11 +115,13 @@ def download_file(self, url: str, file_path: str, chunk_size: int = 1024):
def download_url(
self,
url: str,
path: str = "./",
path: Optional[str] = None,
filename: Optional[str] = None,
md5: Optional[str] = None,
) -> None:
"""Download a file from a url and place it in root."""
if path is None:
path = get_default_download_root()
path = os.path.expanduser(path)
os.makedirs(path, exist_ok=True)

Expand Down Expand Up @@ -135,13 +154,15 @@ def download_url(
def download_and_extract_archive(
self,
url: str,
download_path: str,
download_path: Optional[str] = None,
extract_path: Optional[str] = None,
filename: Optional[str] = None,
md5: Optional[str] = None,
remove_finished: bool = False,
) -> None:
"""Download and extract archive."""
if download_path is None:
download_path = get_default_download_root()
download_path = os.path.expanduser(download_path)

if not filename:
Expand Down

0 comments on commit 687e5de

Please sign in to comment.