Skip to content

Commit 1774e42

Browse files
committed
Added fixmatch app
1 parent 6b0e1f2 commit 1774e42

File tree

6 files changed

+150
-112
lines changed

6 files changed

+150
-112
lines changed

config/fixmatch.yaml

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,43 @@
11
hydra:
22
run:
33
dir: /tmp/output-fixmatch-cifar10-hydra/fully_supervised/${now:%Y%m%d-%H%M%S}
4+
job_logging:
5+
handlers:
6+
console:
7+
level: WARN
8+
root:
9+
level: WARN
10+
11+
name: fully-supervised
412

513
seed: 543
6-
model: "resnet18"
14+
debug: false
715

16+
# model name (from torchvision) to setup model to train. For Wide-Resnet, use "WRN-28-2"
17+
model: "resnet18"
18+
num_classes: 10
819

20+
ema_decay: 0.999
921

1022
defaults:
11-
- dataset: cifar10
23+
- dataflow: cifar10
1224
- solver: default
25+
- ssl: pseudo
26+
27+
28+
solver:
29+
unsupervised_criterion:
30+
cls: torch.nn.CrossEntropyLoss
31+
params:
32+
reduction: 'none'
33+
1334

35+
distributed:
36+
# backend to use for distributed configuration. Possible values: None, "nccl", "xla-tpu", "gloo" etc. Default, None.
37+
backend: null
38+
# optional argument to setup number of processes per node. It is useful, when main python process is spawning training as child processes.
39+
nproc_per_node: null
1440

1541

42+
online_exp_tracking:
43+
wandb: false

config/ssl/pseudo.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# @package _group_
2+
3+
num_train_samples_per_class: 25
4+
5+
confidence_threshold: 0.95
6+
7+
lambda_u: 1.0
8+
9+
mu_ratio: 7
10+

dataflow/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def get_unsupervised_train_loader(dataset_name, root, cta, download=True, **data
6767

6868
strong_transforms = partial(cifar10.cta_image_transforms, cta=cta)
6969

70-
return get_unsupervised_train_loader(
70+
return cifar10.get_unsupervised_train_loader(
7171
full_train_dataset,
7272
transforms_weak=cifar10.weak_transforms,
7373
transforms_strong=strong_transforms,

main_fixmatch.py

Lines changed: 108 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import torch
2-
import torch.distributed as dist
32

3+
import ignite.distributed as idist
44
from ignite.engine import Events
5+
from ignite.utils import manual_seed, setup_logger
6+
7+
import hydra
8+
from hydra.utils import instantiate
9+
from omegaconf import DictConfig
510

611
import utils
7-
from base_train import main, BaseTrainer
8-
from configs import get_default_config
9-
from ctaugment import OPS
12+
import trainers
13+
from ctaugment import get_default_cta, OPS, interleave, deinterleave
1014

1115

1216
sorted_op_names = sorted(list(OPS.keys()))
@@ -30,22 +34,43 @@ def unpack_from_tensor(t):
3034
return sorted_op_names[k_index], bins, error
3135

3236

33-
class FixMatchTrainer(BaseTrainer):
37+
def training(local_rank, cfg, logger):
3438

35-
output_names = ["total_loss", "sup_loss", "unsup_loss", "mask"]
39+
if local_rank == 0:
40+
logger.info(cfg.pretty())
41+
42+
rank = idist.get_rank()
43+
manual_seed(cfg.seed + rank)
44+
device = idist.device()
45+
46+
model, ema_model, optimizer, sup_criterion, lr_scheduler = utils.initialize(cfg)
47+
48+
unsup_criterion = instantiate(cfg.solver.unsupervised_criterion)
49+
50+
cta = get_default_cta()
51+
52+
supervised_train_loader, test_loader, unsup_train_loader, cta_probe_loader = \
53+
utils.get_dataflow(cfg, cta=cta, with_unsup=True)
3654

37-
def train_step(self, engine, batch):
38-
self.model.train()
39-
self.optimizer.zero_grad()
55+
def train_step(engine, batch):
56+
model.train()
57+
optimizer.zero_grad()
4058

41-
x, y = batch["sup_batch"]
42-
weak_x, strong_x = batch["unsup_batch"]
59+
x, y = batch["sup_batch"]["image"], batch["sup_batch"]["target"]
60+
if x.device != device:
61+
x = x.to(device, non_blocking=True)
62+
y = y.to(device, non_blocking=True)
63+
64+
weak_x, strong_x = batch["unsup_batch"]["image"], batch["unsup_batch"]["strong_aug"]
65+
if weak_x.device != device:
66+
weak_x = weak_x.to(device, non_blocking=True)
67+
strong_x = strong_x.to(device, non_blocking=True)
4368

4469
# according to TF code: single forward pass on concat data: [x, weak_x, strong_x]
45-
le = 2 * self.config["mu_ratio"] + 1
46-
x_cat = utils.interleave(torch.cat([x, weak_x, strong_x], dim=0), le)
47-
y_pred_cat = self.model(x_cat)
48-
y_pred_cat = utils.deinterleave(y_pred_cat, le)
70+
le = 2 * engine.state.mu_ratio + 1
71+
x_cat = interleave(torch.cat([x, weak_x, strong_x], dim=0), le)
72+
y_pred_cat = model(x_cat)
73+
y_pred_cat = deinterleave(y_pred_cat, le)
4974

5075
idx1 = len(x)
5176
idx2 = idx1 + len(weak_x)
@@ -54,25 +79,20 @@ def train_step(self, engine, batch):
5479
y_strong_preds = y_pred_cat[idx2:, ...] # logits_strong
5580

5681
# supervised learning:
57-
sup_loss = self.sup_criterion(y_pred, y)
82+
sup_loss = sup_criterion(y_pred, y)
5883

5984
# unsupervised learning:
6085
y_weak_probas = torch.softmax(y_weak_preds, dim=1).detach()
6186
y_pseudo = y_weak_probas.argmax(dim=1)
6287
max_y_weak_probas, _ = y_weak_probas.max(dim=1)
63-
unsup_loss_mask = (max_y_weak_probas >= self.confidence_threshold).float()
64-
unsup_loss = (self.unsup_criterion(y_strong_preds, y_pseudo) * unsup_loss_mask).mean()
88+
unsup_loss_mask = (max_y_weak_probas >= engine.state.confidence_threshold).float()
89+
unsup_loss = (unsup_criterion(y_strong_preds, y_pseudo) * unsup_loss_mask).mean()
6590

66-
total_loss = sup_loss + self.lambda_u * unsup_loss
91+
total_loss = sup_loss + engine.state.lambda_u * unsup_loss
6792

68-
if self.config["with_nv_amp_level"] is not None:
69-
from apex import amp
70-
with amp.scale_loss(total_loss, self.optimizer) as scaled_loss:
71-
scaled_loss.backward()
72-
else:
73-
total_loss.backward()
93+
total_loss.backward()
7494

75-
self.optimizer.step()
95+
optimizer.step()
7696

7797
return {
7898
"total_loss": total_loss.item(),
@@ -81,57 +101,87 @@ def train_step(self, engine, batch):
81101
"mask": unsup_loss_mask.mean().item() # this should not be averaged for DDP
82102
}
83103

84-
def setup(self, **kwargs):
85-
super(FixMatchTrainer, self).setup(**kwargs)
86-
self.confidence_threshold = self.config["confidence_threshold"]
87-
self.lambda_u = self.config["lambda_u"]
88-
self.add_event_handler(Events.ITERATION_COMPLETED, self.update_cta_rates)
89-
self.distributed = dist.is_available() and dist.is_initialized()
104+
output_names = ["total_loss", "sup_loss", "unsup_loss", "mask"]
90105

91-
def update_cta_rates(self):
92-
x, y, policies = self.state.batch["cta_probe_batch"]
93-
self.ema_model.eval()
106+
trainer = trainers.create_trainer(
107+
train_step,
108+
output_names=output_names,
109+
model=model,
110+
ema_model=ema_model,
111+
optimizer=optimizer,
112+
lr_scheduler=lr_scheduler,
113+
supervised_train_loader=supervised_train_loader,
114+
test_loader=test_loader,
115+
cfg=cfg,
116+
logger=logger,
117+
cta=cta,
118+
unsup_train_loader=unsup_train_loader,
119+
cta_probe_loader=cta_probe_loader
120+
)
121+
122+
trainer.state.confidence_threshold = cfg.ssl.confidence_threshold
123+
trainer.state.lambda_u = cfg.ssl.lambda_u
124+
trainer.state.mu_ratio = cfg.ssl.mu_ratio
125+
126+
distributed = idist.get_world_size() > 1
127+
128+
@trainer.on(Events.ITERATION_COMPLETED)
129+
def update_cta_rates():
130+
batch = trainer.state.batch
131+
x, y = batch["cta_probe_batch"]["image"], batch["cta_probe_batch"]["target"]
132+
if x.device != device:
133+
x = x.to(device, non_blocking=True)
134+
y = y.to(device, non_blocking=True)
135+
136+
policies = batch["cta_probe_batch"]["policy"]
137+
138+
ema_model.eval()
94139
with torch.no_grad():
95-
y_pred = self.ema_model(x)
140+
y_pred = ema_model(x)
96141
y_probas = torch.softmax(y_pred, dim=1) # (N, C)
97142

98-
if not self.distributed:
99-
for y_proba, t, policy in zip(y_probas, y, policies):
143+
if distributed:
144+
for y_proba, t, policy in zip(y_probas, y, policies):
100145
error = y_proba
101146
error[t] -= 1
102147
error = torch.abs(error).sum()
103-
self.cta.update_rates(policy, 1.0 - 0.5 * error.item())
148+
cta.update_rates(policy, 1.0 - 0.5 * error.item())
104149
else:
105150
error_per_op = []
106151
for y_proba, t, policy in zip(y_probas, y, policies):
107152
error = y_proba
108153
error[t] -= 1
109154
error = torch.abs(error).sum()
110-
for k, bins in policy:
155+
for k, bins in policy:
111156
error_per_op.append(pack_as_tensor(k, bins, error))
112157
error_per_op = torch.stack(error_per_op)
113-
# all gather
114-
tensor_list = [
115-
torch.empty_like(error_per_op)
116-
for _ in range(dist.get_world_size())
117-
]
118-
dist.all_gather(tensor_list, error_per_op)
119-
tensor_list = torch.cat(tensor_list, dim=0)
158+
# all gather
159+
tensor_list = idist.all_gather(error_per_op)
120160
# update cta rates
121161
for t in tensor_list:
122-
k, bins, error = unpack_from_tensor(t)
123-
self.cta.update_rates([(k, bins), ], 1.0 - 0.5 * error)
162+
k, bins, error = unpack_from_tensor(t)
163+
cta.update_rates([(k, bins), ], 1.0 - 0.5 * error)
164+
165+
epoch_length = cfg.solver.epoch_length
166+
num_epochs = cfg.solver.num_epochs if not cfg.debug else 2
167+
try:
168+
trainer.run(supervised_train_loader, epoch_length=epoch_length, max_epochs=num_epochs)
169+
except Exception as e:
170+
import traceback
171+
172+
print(traceback.format_exc())
173+
124174

175+
@hydra.main(config_path="config", config_name="fixmatch")
176+
def main(cfg: DictConfig) -> None:
125177

126-
def get_fixmatch_config():
127-
config = get_default_config()
128-
config.update({
129-
# FixMatch settings
130-
"confidence_threshold": 0.95,
131-
"lambda_u": 1.0,
132-
})
133-
return config
178+
with idist.Parallel(backend=cfg.distributed.backend, nproc_per_node=cfg.distributed.nproc_per_node) as parallel:
179+
logger = setup_logger(
180+
"FixMatch Training",
181+
distributed_rank=idist.get_rank()
182+
)
183+
parallel.run(training, cfg, logger)
134184

135185

136186
if __name__ == "__main__":
137-
main(FixMatchTrainer(), get_fixmatch_config())
187+
main()

main_fully_supervised.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -77,44 +77,3 @@ def main(cfg: DictConfig) -> None:
7777

7878
if __name__ == "__main__":
7979
main()
80-
81-
82-
# from base_train import main, BaseTrainer
83-
# from configs import get_default_config
84-
# import dist_utils
85-
#
86-
#
87-
# class FullySupervisedTrainer(BaseTrainer):
88-
#
89-
# output_names = ["sup_loss", ]
90-
#
91-
# def train_step(self, engine, batch):
92-
# self.model.train()
93-
# self.optimizer.zero_grad()
94-
#
95-
# x, y = batch["sup_batch"]
96-
#
97-
# y_pred = self.model(x)
98-
#
99-
# # supervised learning:
100-
# sup_loss = self.sup_criterion(y_pred, y)
101-
#
102-
# if self.config["with_nv_amp_level"] is not None:
103-
# from apex import amp
104-
# with amp.scale_loss(sup_loss, self.optimizer) as scaled_loss:
105-
# scaled_loss.backward()
106-
# else:
107-
# sup_loss.backward()
108-
#
109-
# if dist_utils.is_tpu_distributed():
110-
# dist_utils.xm.optimizer_step(self.optimizer)
111-
# else:
112-
# self.optimizer.step()
113-
#
114-
# return {
115-
# "sup_loss": sup_loss.item(),
116-
# }
117-
#
118-
#
119-
# if __name__ == "__main__":
120-
# main(FullySupervisedTrainer(), get_default_config())

trainers/__init__.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -93,15 +93,6 @@ def prepare_batch(e):
9393
cta_probe_batch["policy"] = [deserialize(p) for p in cta_probe_batch["policy"]]
9494
e.state.batch["cta_probe_batch"] = cta_probe_batch
9595

96-
# "unsup_batch": (
97-
# convert_tensor(unsup_batch["image"], device, non_blocking=True),
98-
# convert_tensor(unsup_batch["strong_aug"], device, non_blocking=True)
99-
# ),
100-
# "cta_probe_batch": (
101-
# *utils.sup_prepare_batch(cta_probe_batch, device, non_blocking=True),
102-
# [utils.deserialize(p) for p in cta_probe_batch['policy']]
103-
# )
104-
10596
# Setup handler to update EMA model
10697
@trainer.on(Events.ITERATION_COMPLETED, cfg.ema_decay)
10798
def update_ema_model(ema_decay):
@@ -174,7 +165,7 @@ def log_results(epoch, max_epochs, metrics, ema_metrics):
174165
msg2 = "\n".join(["\t{:16s}: {}".format(k, to_list_str(v)) for k, v in ema_metrics.items()])
175166
logger.info("\nEpoch {}/{}\nRaw:\n{}\nEMA:\n{}\n".format(epoch, max_epochs, msg1, msg2))
176167
if cta is not None:
177-
logger.info(stats(cta))
168+
logger.info("\n" + stats(cta))
178169

179170
@trainer.on(Events.EPOCH_COMPLETED(every=cfg.solver.validate_every) | Events.STARTED | Events.COMPLETED)
180171
def run_evaluation():

0 commit comments

Comments
 (0)