-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
233 lines (204 loc) · 9.54 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
import collections
from functools import partial
import sys
from typing import Callable, Tuple
from absl import app
from absl import flags
from flax import linen as nn
import jax
from jax import numpy as jnp
from jax import random
from jax import tree_util
import jaxopt
import ml_collections
from ml_collections import config_flags
import numpy as np
import optax
import torch.utils.data
from torchvision import datasets
from torchvision import transforms
import tqdm
import wandb
import models.densenet
import models.resnet_v1
import models.resnet_v2
import models.vgg
import models.wide_resnet
import util
flags.DEFINE_string('dataset_root', None, 'Path to data.', required=True)
flags.DEFINE_bool('download', False, 'Download dataset.')
flags.DEFINE_integer('eval_batch_size', 128, 'Batch size to use during evaluation.')
flags.DEFINE_integer('loader_num_workers', 4, 'num_workers for DataLoader')
flags.DEFINE_integer('loader_prefetch_factor', 2, 'prefetch_factor for DataLoader')
config_flags.DEFINE_config_file('config')
FLAGS = flags.FLAGS
Dataset = torch.utils.data.Dataset
ModuleDef = Callable[..., nn.Module]
def main(_):
config = ml_collections.ConfigDict(FLAGS.config)
wandb.init(project='flax-cifar')
wandb.config.update(config.to_dict())
num_classes, input_shape, train_dataset, val_dataset = setup_data()
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=config.train.batch_size,
shuffle=True,
pin_memory=False,
num_workers=FLAGS.loader_num_workers,
prefetch_factor=FLAGS.loader_prefetch_factor)
val_loader = torch.utils.data.DataLoader(
dataset=val_dataset,
batch_size=FLAGS.eval_batch_size,
shuffle=False,
pin_memory=False,
num_workers=FLAGS.loader_num_workers,
prefetch_factor=FLAGS.loader_prefetch_factor)
norm = nn.BatchNorm
norm_kwargs = lambda train: {'use_running_average': not train}
model = make_model(config, num_classes, input_shape, norm=norm)
rng_init, _ = random.split(random.PRNGKey(0))
init_vars = model.init(rng_init, jnp.zeros((1,) + input_shape), norm_kwargs=norm_kwargs(train=True))
params, batch_stats = init_vars['params'], init_vars['batch_stats']
print('params:')
sys.stdout.writelines(x + '\n' for x in util.dict_tree_format(util.tree_shape(params)))
print('batch_stats:')
sys.stdout.writelines(x + '\n' for x in util.dict_tree_format(util.tree_shape(batch_stats)))
sys.stdout.flush()
def filter_kernel_params(tree):
return [x for path, x in util.dict_tree_items(tree) if path[-1] == 'kernel']
print('total number of params:',
tree_util.tree_reduce(np.add, tree_util.tree_map(lambda x: np.prod(x.shape), params)))
print('number of linear layers:', sum(1 for _ in filter_kernel_params(params)))
total_steps = config.train.num_epochs * len(train_loader)
schedule = optax.cosine_decay_schedule(config.train.base_learning_rate, total_steps)
tx = optax.sgd(schedule, momentum=0.9)
opt_state = tx.init(params)
loss_with_logits = jax.vmap(jaxopt.loss.multiclass_logistic_loss)
def objective_fn(params, mutable_vars, data):
# Designed for use with jax.value_and_grad(..., has_aux=True).
# Params are a separate arg (arg 0).
# Returns scalar loss and one auxiliary output.
inputs, labels = data
model_vars = {'params': params, **mutable_vars}
outputs, mutated_vars = model.apply(
model_vars, inputs, norm_kwargs=norm_kwargs(train=True),
mutable=list(mutable_vars.keys()))
example_loss = loss_with_logits(labels, outputs)
data_loss = jnp.mean(example_loss)
if config.train.weight_decay_vars == 'all':
wd_vars = list(tree_util.tree_leaves(params))
elif config.train.weight_decay_vars == 'kernel':
wd_vars = filter_kernel_params(params)
else:
raise ValueError('unknown variable collection', config.train.weight_decay_vars)
wd_loss = 0.5 * sum(jnp.sum(jnp.square(x)) for x in wd_vars)
objective = data_loss + config.train.weight_decay * wd_loss
return objective, (outputs, mutated_vars)
@jax.jit
def train_step(opt_state, params, mutable_vars, data):
objective_and_grad_fn = jax.value_and_grad(objective_fn, has_aux=True)
(objective, aux), grads = objective_and_grad_fn(params, mutable_vars, data)
outputs, mutated_vars = aux
updates, opt_state = tx.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return opt_state, params, mutated_vars, objective, outputs
@jax.jit
def apply_model(params, batch_stats, inputs):
return model.apply(
{'params': params, 'batch_stats': batch_stats}, inputs,
norm_kwargs=norm_kwargs(train=False))
for epoch in range(config.train.num_epochs + 1):
metrics = {}
if epoch > 0:
train_outputs = collections.defaultdict(list)
for inputs, labels in tqdm.tqdm(train_loader, f'train epoch {epoch}'):
inputs, labels = jnp.asarray(inputs.numpy()), jnp.asarray(labels.numpy())
inputs = jnp.moveaxis(inputs, -3, -1)
opt_state, params, mutated_vars, objective, logits = train_step(
opt_state, params, {'batch_stats': batch_stats}, (inputs, labels))
batch_stats = mutated_vars['batch_stats']
loss = loss_with_logits(labels, logits)
pred = jnp.argmax(logits, axis=-1)
acc = (pred == labels)
train_outputs['acc'].append(acc)
train_outputs['loss'].append(loss)
train_outputs['objective'].append([objective])
train_outputs = {k: np.concatenate(v) for k, v in train_outputs.items()}
metrics.update({
'train_loss': np.mean(train_outputs['loss']),
'train_acc': np.mean(train_outputs['acc']),
'train_objective': np.mean(train_outputs['objective']),
})
val_outputs = collections.defaultdict(list)
for inputs, labels in tqdm.tqdm(val_loader, f'val epoch {epoch}'):
inputs, labels = jnp.asarray(inputs.numpy()), jnp.asarray(labels.numpy())
inputs = jnp.moveaxis(inputs, -3, -1)
logits = apply_model(params, batch_stats, inputs)
loss = loss_with_logits(labels, logits)
pred = jnp.argmax(logits, axis=-1)
acc = (pred == labels)
val_outputs['acc'].append(acc)
val_outputs['loss'].append(loss)
val_outputs = {k: np.concatenate(v) for k, v in val_outputs.items()}
metrics.update({
'val_loss': np.mean(val_outputs['loss']),
'val_acc': np.mean(val_outputs['acc']),
})
wandb.log(metrics)
if epoch == 0:
print('epoch {:d}: val_acc {:.2%}'.format(epoch, metrics['val_acc']))
else:
print('epoch {:d}: val_acc {:.2%}, train_objective {:.6g}'.format(
epoch, metrics['val_acc'], metrics['train_objective']))
def setup_data() -> Tuple[int, Tuple[int, int, int], Dataset, Dataset]:
num_classes = 10
input_shape = (32, 32, 3)
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(0.5, 1.0),
])
transform_eval = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(0.5, 1.0),
])
train_dataset = datasets.CIFAR10(
FLAGS.dataset_root, train=True, download=FLAGS.download, transform=transform_train)
val_dataset = datasets.CIFAR10(
FLAGS.dataset_root, train=False, download=FLAGS.download, transform=transform_eval)
return num_classes, input_shape, train_dataset, val_dataset
def make_model(
config: ml_collections.ConfigDict,
num_classes: int,
input_shape: Tuple[int, int, int],
norm: ModuleDef = nn.BatchNorm) -> nn.Module:
try:
model_fn = {
'resnet_v1_18': partial(models.resnet_v1.ResNet18, stem_variant='cifar'),
'resnet_v1_34': partial(models.resnet_v1.ResNet34, stem_variant='cifar'),
'resnet_v1_50': partial(models.resnet_v1.ResNet50, stem_variant='cifar'),
'resnet_v2_18': partial(models.resnet_v2.ResNet18, stem_variant='cifar'),
'resnet_v2_34': partial(models.resnet_v2.ResNet34, stem_variant='cifar'),
'resnet_v2_50': partial(models.resnet_v2.ResNet50, stem_variant='cifar'),
'wrn28_2': partial(models.wide_resnet.WideResNet, depth=28, width=2),
'wrn28_8': partial(models.wide_resnet.WideResNet, depth=28, width=8),
'densenet121_12': models.densenet.densenet_cifar,
'densenet121_32': models.densenet.DenseNet121,
'densenet169_32': models.densenet.DenseNet169,
'densenet201_32': models.densenet.DenseNet201,
'densenet161_48': models.densenet.DenseNet161,
'vgg11_backbone': models.vgg.VGG11Backbone,
'vgg13_backbone': models.vgg.VGG13Backbone,
'vgg16_backbone': models.vgg.VGG16Backbone,
'vgg19_backbone': models.vgg.VGG19Backbone,
'vgg11': models.vgg.VGG11,
'vgg13': models.vgg.VGG13,
'vgg16': models.vgg.VGG16,
'vgg19': models.vgg.VGG19,
}[config.arch]
except KeyError as ex:
raise ValueError('unknown architecture', ex)
return model_fn(num_classes=num_classes, norm=norm)
if __name__ == '__main__':
app.run(main)