-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathmnist_rot.yaml
60 lines (54 loc) · 1.58 KB
/
mnist_rot.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
dataset: 'mnist'
seed: 1337
exp_name: 'mnist_rotation'
variant_name: 'exp_1'
# When you want to have many variants
# of the same experiment
config: 'data/generator/config/mnist.json'
# config corresponding to the prior
logdir: 'logs/'
device: 'cuda'
# training
print_freq: 40
batch_size: 50
# need a high batch size for a good estimate of mmd
num_real_images: 200
max_epochs: 2000
epoch_length: 1000 #number of samples that constitute one epoch
train_reconstruction: true
freeze_encoder: true
reconstruction_epochs: 7
use_dist_loss: true
use_task_loss: true
moving_avg_alpha: 0.7 # moving_avg_alpha for baseline
# MMD
mmd_dims: [64, 192]
# sizes of layers of inception to use for MMD. Check
# the inception file for possible values
mmd_resize_input: false
optim:
lr: 0.001
lr_decay: 200 # number of epochs to decay after
lr_decay_gamma: 0.5 # gamma to decay
weight_decay: 0.00001
weight:
class: 0.1 # weight for class during reconstruction training
dist_mmd: 100.0 # multiplier for mmd
task:
val_root: 'data/datagen/mnist/target_rot/'
# data corresponding to the target configuration
# usually you would generate one small version
# and one large version of the target
# Use the small version while training and
# the large version to report final results
# this is not included in this code for simplicity
# but is easy to add by editing the test
# function in the task network to report two
# accuracies, out of which one would be used to train
device: 'cuda'
print_freq: 100
freeze_encoder: true
batch_size: 8
input_dim: [32, 32]
epochs: 2
dataset: 'mnist'