Skip to content

Commit 4df90eb

Browse files
committed
Fixed CTA rates update in DDP
- added main_fully_supervised.py for debugging DDP vs DP
1 parent 7dc64cc commit 4df90eb

File tree

7 files changed

+774
-48
lines changed

7 files changed

+774
-48
lines changed

README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ pip install --upgrade --pre pytorch-ignite
1616
python -u main_fixmatch.py
1717
# or python -u main_fixmatch.py --params "data_path=/path/to/cifar10"
1818
```
19+
### DDP
20+
21+
```bash
22+
python -u -m torch.distributed.launch --nproc_per_node=2 main_fixmatch.py --params="dist_backend='nccl'"
23+
```
1924

2025
## TODO
2126

@@ -26,7 +31,10 @@ BUGS:
2631
* [x] save/load CTA
2732
* [x] save ema model
2833

29-
* [ ] DDP: Synchronize CTA across processes
34+
* [x] DDP: Synchronize CTA across processes
35+
36+
* [ ] Bug: DDP performances are worse than DP on the first epochs
37+
* [ ] Increase batch_size -> batch_size * WS => LR, epoch_length
3038

3139
* [ ] Logging to online platform: NeptuneML or Trains or W&B
3240

base_train.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def run(trainer, config):
5454
unsup_criterion = nn.CrossEntropyLoss(reduction='none').to(utils.device)
5555

5656
num_epochs = config["num_epochs"]
57-
epoch_length = config["epoch_length"]
57+
epoch_length = config["epoch_length"]
5858
total_num_iters = num_epochs * epoch_length
5959
lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_num_iters, eta_min=0.0)
6060

@@ -64,15 +64,14 @@ def run(trainer, config):
6464
model=model, ema_model=ema_model, optimizer=optimizer,
6565
sup_criterion=sup_criterion, unsup_criterion=unsup_criterion,
6666
cta=cta,
67-
device=utils.device
6867
)
6968

7069
# Setup handler to prepare data batches
7170
@trainer.on(Events.ITERATION_STARTED)
7271
def prepare_batch(e):
7372
sup_batch = next(supervised_train_loader_iter)
7473
unsup_batch = next(unsupervised_train_loader_iter)
75-
cta_probe_batch = next(cta_probe_loader_iter)
74+
cta_probe_batch = next(cta_probe_loader_iter)
7675
e.state.batch = {
7776
"sup_batch": utils.sup_prepare_batch(sup_batch, utils.device, non_blocking=True),
7877
"unsup_batch": (
@@ -119,27 +118,30 @@ def update_ema_model(ema_decay):
119118
ema_param.data.mul_(ema_decay).add_(param.data, alpha=1.0 - ema_decay)
120119

121120
# Setup handlers for debugging
122-
if debug and rank == 0:
121+
if debug:
123122

124123
@trainer.on(Events.STARTED | Events.ITERATION_COMPLETED(every=100))
125124
def log_weights_norms(_):
126-
wn = []
127-
ema_wn = []
128-
for ema_param, param in zip(ema_model.parameters(), model.parameters()):
129-
wn.append(torch.mean(param.data))
130-
ema_wn.append(torch.mean(ema_param.data))
131-
132-
print("\n\nWeights norms")
133-
print("\n- Raw model: {}".format(utils.to_list_str(torch.tensor(wn[:10] + wn[-10:]))))
134-
print("- EMA model: {}\n".format(utils.to_list_str(torch.tensor(ema_wn[:10] + ema_wn[-10:]))))
135-
136-
profiler = BasicTimeProfiler()
137-
profiler.attach(trainer)
138-
139-
@trainer.on(Events.ITERATION_COMPLETED(every=200))
140-
def log_profiling(_):
141-
results = profiler.get_results()
142-
profiler.print_results(results)
125+
126+
if rank == 0:
127+
wn = []
128+
ema_wn = []
129+
for ema_param, param in zip(ema_model.parameters(), model.parameters()):
130+
wn.append(torch.mean(param.data))
131+
ema_wn.append(torch.mean(ema_param.data))
132+
133+
print("\n\nWeights norms")
134+
print("\n- Raw model: {}".format(utils.to_list_str(torch.tensor(wn[:10] + wn[-10:]))))
135+
print("- EMA model: {}\n".format(utils.to_list_str(torch.tensor(ema_wn[:10] + ema_wn[-10:]))))
136+
137+
if rank == 0:
138+
profiler = BasicTimeProfiler()
139+
profiler.attach(trainer)
140+
141+
@trainer.on(Events.ITERATION_COMPLETED(every=200))
142+
def log_profiling(_):
143+
results = profiler.get_results()
144+
profiler.print_results(results)
143145

144146
# Setup validation engine
145147
metrics = {
@@ -190,7 +192,7 @@ def run_evaluation():
190192
if config["display_iters"]:
191193
ProgressBar(persist=False, desc="Test evaluation").attach(evaluator)
192194
ProgressBar(persist=False, desc="Test EMA evaluation").attach(ema_evaluator)
193-
195+
194196
data = list(range(epoch_length))
195197

196198
resume_from = list(Path(config["output_path"]).rglob("training_checkpoint*.pt*"))
@@ -212,6 +214,8 @@ def run_evaluation():
212214
if rank == 0:
213215
tb_logger.close()
214216

217+
supervised_train_loader_iter = unsupervised_train_loader_iter = cta_probe_loader_iter = None
218+
215219

216220
def main(trainer, config):
217221
parser = argparse.ArgumentParser("Semi-Supervised Learning - FixMatch with CTA: Train WRN-28-2 on CIFAR10 dataset")
@@ -238,8 +242,8 @@ def main(trainer, config):
238242
value = eval(value)
239243
config[key] = value
240244

241-
if config["local_rank"] == 0:
242-
ds_id = "{}".format(config["num_train_samples_per_class"] * 10)
245+
ds_id = "{}".format(config["num_train_samples_per_class"] * 10)
246+
if config["local_rank"] == 0:
243247
print("SSL Training of {} on CIFAR10@{}".format(config["model"], ds_id))
244248
print("- PyTorch version: {}".format(torch.__version__))
245249
print("- Ignite version: {}".format(ignite.__version__))

ctaugment.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# https://raw.githubusercontent.com/google-research/fixmatch/master/libml/ctaugment.py
2+
#
23
# Copyright 2019 Google LLC
34
#
45
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -19,6 +20,7 @@
1920
import numpy as np
2021
from PIL import Image, ImageOps, ImageEnhance, ImageFilter
2122

23+
2224
OPS = {}
2325
OP = namedtuple('OP', ('f', 'bins'))
2426
Sample = namedtuple('Sample', ('train', 'probe'))
@@ -32,15 +34,6 @@ def wrap(f):
3234
return wrap
3335

3436

35-
def apply(x, ops):
36-
if ops is None:
37-
return x
38-
y = Image.fromarray(np.round(127.5 * (1 + x)).clip(0, 255).astype('uint8'))
39-
for op, args in ops:
40-
y = OPS[op].f(y, *args)
41-
return np.asarray(y).astype('f') / 127.5 - 1
42-
43-
4437
class CTAugment:
4538
def __init__(self, depth=2, th=0.85, decay=0.99):
4639
self.decay = decay

main_fixmatch.py

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,34 @@
1+
from collections import defaultdict
12

23
import torch
4+
import torch.distributed as dist
35

46
from ignite.engine import Events
57

68
import utils
79
from base_train import main, BaseTrainer, get_default_config
10+
from ctaugment import OPS
11+
12+
13+
sorted_op_names = sorted(list(OPS.keys()))
14+
15+
16+
def pack_as_tensor(k, bins, error, size=5, pad_value=-555.0):
17+
out = torch.empty(size).fill_(pad_value).to(error)
18+
out[0] = sorted_op_names.index(k)
19+
le = len(bins)
20+
out[1] = le
21+
out[2:2 + le] = torch.tensor(bins).to(error)
22+
out[2 + le] = error
23+
return out
24+
25+
26+
def unpack_from_tensor(t):
27+
k_index = int(t[0].item())
28+
le = int(t[1].item())
29+
bins = t[2:2 + le].tolist()
30+
error = t[2 + le].item()
31+
return sorted_op_names[k_index], bins, error
832

933

1034
class FixMatchTrainer(BaseTrainer):
@@ -55,14 +79,15 @@ def train_step(self, engine, batch):
5579
"total_loss": total_loss.item(),
5680
"sup_loss": sup_loss.item(),
5781
"unsup_loss": unsup_loss.item(),
58-
"mask": unsup_loss_mask.mean().item()
82+
"mask": unsup_loss_mask.mean().item() # this should not be averaged for DDP
5983
}
6084

6185
def setup(self, **kwargs):
6286
super(FixMatchTrainer, self).setup(**kwargs)
6387
self.confidence_threshold = self.config["confidence_threshold"]
6488
self.lambda_u = self.config["lambda_u"]
65-
self.add_event_handler(Events.ITERATION_COMPLETED, self.update_cta_rates)
89+
# self.add_event_handler(Events.ITERATION_COMPLETED, self.update_cta_rates)
90+
self.distributed = dist.is_available() and dist.is_initialized()
6691

6792
def update_cta_rates(self):
6893
x, y, policies = self.state.batch["cta_probe_batch"]
@@ -71,13 +96,32 @@ def update_cta_rates(self):
7196
y_pred = self.ema_model(x)
7297
y_probas = torch.softmax(y_pred, dim=1) # (N, C)
7398

74-
# for y_proba, t, policy_str in zip(y_probas, y, policies):
75-
for y_proba, t, policy in zip(y_probas, y, policies):
76-
error = y_proba
77-
error[t] -= 1
78-
error = torch.abs(error).sum()
79-
self.cta.update_rates(policy, 1.0 - 0.5 * error.item())
80-
99+
if not self.distributed:
100+
for y_proba, t, policy in zip(y_probas, y, policies):
101+
error = y_proba
102+
error[t] -= 1
103+
error = torch.abs(error).sum()
104+
self.cta.update_rates(policy, 1.0 - 0.5 * error.item())
105+
else:
106+
error_per_op = []
107+
for y_proba, t, policy in zip(y_probas, y, policies):
108+
error = y_proba
109+
error[t] -= 1
110+
error = torch.abs(error).sum()
111+
for k, bins in policy:
112+
error_per_op.append(pack_as_tensor(k, bins, error))
113+
error_per_op = torch.stack(error_per_op)
114+
# all gather
115+
tensor_list = [
116+
torch.empty_like(error_per_op)
117+
for _ in range(dist.get_world_size())
118+
]
119+
dist.all_gather(tensor_list, error_per_op)
120+
tensor_list = torch.cat(tensor_list, dim=0)
121+
# update cta rates
122+
for t in tensor_list:
123+
k, bins, error = unpack_from_tensor(t)
124+
self.cta.update_rates([(k, bins), ], 1.0 - 0.5 * error)
81125

82126
if __name__ == "__main__":
83127
main(FixMatchTrainer(), get_default_config())

main_fully_supervised.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from collections import defaultdict
2+
3+
import torch
4+
import torch.distributed as dist
5+
6+
from ignite.engine import Events
7+
8+
import utils
9+
from base_train import main, BaseTrainer, get_default_config
10+
11+
12+
class FullySupervisedTrainer(BaseTrainer):
13+
14+
output_names = ["sup_loss", ]
15+
16+
def train_step(self, engine, batch):
17+
self.model.train()
18+
self.optimizer.zero_grad()
19+
20+
x, y = batch["sup_batch"]
21+
22+
y_pred = self.model(x)
23+
24+
# supervised learning:
25+
sup_loss = self.sup_criterion(y_pred, y)
26+
27+
if self.config["with_amp_level"] is not None:
28+
from apex import amp
29+
with amp.scale_loss(sup_loss, self.optimizer) as scaled_loss:
30+
scaled_loss.backward()
31+
else:
32+
sup_loss.backward()
33+
34+
self.optimizer.step()
35+
36+
return {
37+
"sup_loss": sup_loss.item(),
38+
}
39+
40+
def setup(self, **kwargs):
41+
super(FullySupervisedTrainer, self).setup(**kwargs)
42+
self.distributed = dist.is_available() and dist.is_initialized()
43+
44+
if __name__ == "__main__":
45+
main(FullySupervisedTrainer(), get_default_config())

0 commit comments

Comments
 (0)