Skip to content

Commit f540560

Browse files
authored
Merge pull request #477 from ggraffieti/master
Dataset download only in ~/.avalanche/data folder
2 parents ce4ca87 + d755e2b commit f540560

12 files changed

+226
-121
lines changed

examples/confusion_matrix.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
from os.path import expanduser
21+
2022
import argparse
2123
import torch
2224
from torch.nn import CrossEntropyLoss
@@ -53,10 +55,10 @@ def main(args):
5355
# ---------
5456

5557
# --- SCENARIO CREATION
56-
mnist_train = MNIST('./data/mnist', train=True,
57-
download=True, transform=train_transform)
58-
mnist_test = MNIST('./data/mnist', train=False,
59-
download=True, transform=test_transform)
58+
mnist_train = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
59+
train=True, download=True, transform=train_transform)
60+
mnist_test = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
61+
train=False, download=True, transform=test_transform)
6062
scenario = nc_scenario(
6163
mnist_train, mnist_test, 5, task_labels=False, seed=1234)
6264
# ---------

examples/eval_plugin.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
from os.path import expanduser
21+
2022
import argparse
2123
import torch
2224
from torch.nn import CrossEntropyLoss
@@ -55,10 +57,10 @@ def main(args):
5557
# ---------
5658

5759
# --- SCENARIO CREATION
58-
mnist_train = MNIST('./data/mnist', train=True,
59-
download=True, transform=train_transform)
60-
mnist_test = MNIST('./data/mnist', train=False,
61-
download=True, transform=test_transform)
60+
mnist_train = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
61+
train=True, download=True, transform=train_transform)
62+
mnist_test = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
63+
train=False, download=True, transform=test_transform)
6264
scenario = nc_scenario(
6365
mnist_train, mnist_test, 5, task_labels=False, seed=1234)
6466
# ---------

examples/getting_started.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
from os.path import expanduser
21+
2022
import argparse
2123
import torch
2224
from torch.nn import CrossEntropyLoss
@@ -50,10 +52,10 @@ def main(args):
5052
# ---------
5153

5254
# --- SCENARIO CREATION
53-
mnist_train = MNIST('./data/mnist', train=True,
54-
download=True, transform=train_transform)
55-
mnist_test = MNIST('./data/mnist', train=False,
56-
download=True, transform=test_transform)
55+
mnist_train = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
56+
train=True, download=True, transform=train_transform)
57+
mnist_test = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
58+
train=False, download=True, transform=test_transform)
5759
scenario = nc_scenario(
5860
mnist_train, mnist_test, 5, task_labels=False, seed=1234)
5961
# ---------

examples/getting_started_no_avalanche.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21+
from os.path import expanduser
22+
2123
import argparse
2224
import torch
2325
import torch.nn as nn
@@ -95,10 +97,10 @@ def forward(self, x):
9597
new_test_transform = transforms.Compose(test_transform_list)
9698

9799
# get the datasets with the constructed transformation
98-
permuted_train = MNIST(root='./data/mnist',
100+
permuted_train = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
99101
train=True,
100102
download=True, transform=new_train_transform)
101-
permuted_test = MNIST(root='./data/mnist',
103+
permuted_test = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
102104
train=False,
103105
download=True, transform=new_test_transform)
104106

examples/replay.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
from os.path import expanduser
21+
2022
import argparse
2123
import torch
2224
from torch.nn import CrossEntropyLoss
@@ -56,10 +58,10 @@ def main(args):
5658
# ---------
5759

5860
# --- SCENARIO CREATION
59-
mnist_train = MNIST('./data/mnist', train=True,
60-
download=True, transform=train_transform)
61-
mnist_test = MNIST('./data/mnist', train=False,
62-
download=True, transform=test_transform)
61+
mnist_train = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
62+
train=True, download=True, transform=train_transform)
63+
mnist_test = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
64+
train=False, download=True, transform=test_transform)
6365
scenario = nc_scenario(
6466
mnist_train, mnist_test, n_batches, task_labels=False, seed=1234)
6567
# ---------

examples/tensorboard_logger.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21+
from os.path import expanduser
22+
2123
import argparse
2224
import torch
2325
from torch.nn import CrossEntropyLoss
@@ -56,10 +58,10 @@ def main(args):
5658
# ---------
5759

5860
# --- SCENARIO CREATION
59-
mnist_train = MNIST('./data/mnist', train=True,
60-
download=True, transform=train_transform)
61-
mnist_test = MNIST('./data/mnist', train=False,
62-
download=True, transform=test_transform)
61+
mnist_train = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
62+
train=True, download=True, transform=train_transform)
63+
mnist_test = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
64+
train=False, download=True, transform=test_transform)
6365
scenario = nc_scenario(
6466
mnist_train, mnist_test, 5, task_labels=False, seed=1234)
6567
# ---------

0 commit comments

Comments
 (0)