Skip to content

Commit 6896e2a

Browse files
committed
add wandb
1 parent 592dedb commit 6896e2a

File tree

11 files changed

+145
-50
lines changed

11 files changed

+145
-50
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
*.ckpt
22
*.pt
33
*.tar
4+
*.zip
45
temp.ipynb
56

67
checkpoints
78
logs
9+
wandb
810
VOC2012
911
VOCdevkit
1012
__pycache__

config/config.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
# DATA
1515
cfg.data = CN()
1616
cfg.data.type = 'voc2012_aug'
17-
cfg.data_root = './data'
17+
cfg.data.data_dir = './data'
18+
cfg.data.crop_size = 512
1819
cfg.data.num_classes = 21
1920
cfg.data.batch_size = 1
2021
cfg.data.num_workers = 0
@@ -50,5 +51,15 @@
5051
# LOGGING
5152
cfg.train.logger = CN()
5253
cfg.train.logger.log_dir = './logs'
53-
cfg.train.logger.tensorboard = True
54+
55+
# tensorboard setting
56+
cfg.train.logger.use_tensorboard = True
57+
cfg.train.logger.tensorboard = CN()
58+
59+
# wandb setting
60+
cfg.train.logger.use_wandb = False
61+
cfg.train.logger.wandb = CN()
62+
cfg.train.logger.wandb.project = 'UNet3Plus'
63+
cfg.train.logger.wandb.run_id = ''
64+
5465

config/test_voc.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,5 @@ train:
2323
loss_type: u3p
2424

2525
logger:
26-
tensorboard: True
26+
use_tensorboard: True
2727
log_dir: ./logs/

config/test_voc_cpu.yaml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
model:
2+
encoder: resnet18
3+
skip_ch: 16
4+
aux_losses: -1
5+
pretrained: True
6+
7+
data:
8+
type: voc2012_aug
9+
num_classes: 21
10+
num_workers: 2
11+
batch_size: 1
12+
max_training_samples: 10
13+
14+
train:
15+
seed: 42
16+
num_epochs: 20
17+
lr: 0.001
18+
weight_decay: 0.0001
19+
optimizer: adamw
20+
accum_steps: 2
21+
resume: ''
22+
device: cpu
23+
loss_type: focal
24+
25+
logger:
26+
use_tensorboard: False
27+
use_wandb: False
28+
log_dir: ./logs/

config/u3p_resnet18_voc.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ train:
2525
warmup_iters: 1
2626

2727
logger:
28-
tensorboard: True
28+
use_tensorboard: True
2929
log_dir: ./logs
3030

3131

config/u3p_resnet34_voc.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ train:
2525
warmup_iters: 1000
2626

2727
logger:
28-
tensorboard: True
28+
use_tensorboard: True
2929
log_dir: ./logs
3030

3131

datasets/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ def get_voc(data_root='./data', crop_size=SIZE, crop_val=SIZE, year='2012_aug',
4040

4141
return train_dst, val_dst
4242

43-
def build_data_loader(batch_size=1, num_workers=0, max_training_samples=-1) -> Tuple[DataLoader, DataLoader]:
44-
train_dataset, val_dataset = get_voc()
43+
def build_data_loader(data_root='./data', batch_size=1, num_workers=0, max_training_samples=-1, crop_size=512) -> Tuple[DataLoader, DataLoader]:
44+
train_dataset, val_dataset = get_voc(data_root, crop_size=crop_size, crop_val=crop_size)
4545
if max_training_samples > 0: # for testing
4646
num_samples = len(train_dataset)
4747
train_dataset.image_set

model/unet3plus.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def en2dec_layer(in_ch, out_ch, scale):
2626
def dec2dec_layer(in_ch, out_ch, scale, efficient=False):
2727
up = [nn.Upsample(scale_factor=scale, mode='bilinear', align_corners=True) if scale != 1 else nn.Identity()]
2828
m = [u3pblock(in_ch, out_ch, num_block=1)]
29+
efficient = True
2930
if efficient:
3031
m = m + up
3132
else:

train.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from datasets import build_data_loader
1717
from config.config import cfg
1818
from utils.loss import build_u3p_loss
19-
from utils.log import AverageMeter
19+
from utils.logging import AverageMeter, SummaryLogger
2020
from utils.metrics import StreamSegMetrics
2121

2222
def one_cycle(y1=0.0, y2=1.0, steps=100):
@@ -50,7 +50,7 @@ def __init__(self, cfg, model, train_loader, val_loader):
5050

5151
# build loss
5252
self.criterion = build_u3p_loss(cfg.loss_type, cfg.aux_weight)
53-
self.scaler = amp.GradScaler(enabled=True) # mixed precision training
53+
self.scaler = amp.GradScaler(enabled=cfg.device == 'cuda') # mixed precision training
5454

5555
# build optimizer
5656
if cfg.optimizer == 'sgd':
@@ -71,10 +71,7 @@ def __init__(self, cfg, model, train_loader, val_loader):
7171
# build scheduler
7272
self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lr_func)
7373

74-
if cfg.logger.tensorboard:
75-
self.writer = SummaryWriter(log_dir=cfg.logger.log_dir)
76-
else:
77-
self.writer = None
74+
self.logger = SummaryLogger(self.cfg_all)
7875

7976
self.model.to(cfg.device)
8077
if cfg.resume:
@@ -174,22 +171,27 @@ def update_loss_dict(self, loss_dict, batch_loss_dict=None):
174171
loss_dict[k].update(v)
175172

176173
def log_results(self):
177-
if self.writer is not None:
178-
for k, v in self.loss_dict.items():
179-
self.writer.add_scalars('Train_metrics/' + k, {"Train": v.avg}, self.global_iter)
180-
self.update_loss_dict(self.loss_dict, None) # clean loss meters
181-
lr = self.optimizer.param_groups[0]['lr']
182-
self.writer.add_scalars('Train_metrics/lr', {"lr": lr}, self.global_iter)
183-
184-
for k, v in self.val_loss_dict.items():
185-
self.writer.add_scalars('Val_metrics/' + k, {"Val": v.avg}, self.global_iter)
186-
self.update_loss_dict(self.val_loss_dict, None)
187-
188-
for k, v in self.val_score_dict.items():
189-
if k == 'Class IoU':
190-
continue
191-
self.writer.add_scalars('Val_metrics/' + k, {"Val": v}, self.global_iter)
192-
self.writer.flush()
174+
log_dict = {
175+
'Train': {},
176+
'Val': {}
177+
}
178+
179+
for k, v in self.loss_dict.items():
180+
log_dict['Train'][k] = v.avg
181+
self.update_loss_dict(self.loss_dict, None)
182+
log_dict['Train']['lr'] = self.optimizer.param_groups[0]['lr']
183+
184+
for k, v in self.val_loss_dict.items():
185+
log_dict['Val'][k] = v.avg
186+
self.update_loss_dict(self.val_loss_dict, None)
187+
188+
for k, v in self.val_score_dict.items():
189+
if k == 'Class IoU':
190+
print(v)
191+
# self.logger.cmd_logger.info(v)
192+
continue
193+
log_dict['Val'][k] = v
194+
self.logger.summary(log_dict, self.global_iter)
193195

194196

195197
def validate(self):
@@ -226,14 +228,15 @@ def main(args):
226228
cfg.train.seed = int(args.seed)
227229
if args.resume:
228230
cfg.train.resume = args.resume
231+
cfg.data.data_dir = args.data_dir
229232

230233
cfg.freeze()
231234
print(cfg)
232235
model, data = cfg.model, cfg.data
233236
model = build_unet3plus(data.num_classes, model.encoder, model.skip_ch, model.aux_losses, model.use_cgm, model.pretrained)
234237
# model = UNet_3Plus_DeepSup()
235238
if data.type in ['voc2012', 'voc2012_aug']:
236-
train_loader, val_loader = build_data_loader(data.batch_size, data.num_workers, data.max_training_samples)
239+
train_loader, val_loader = build_data_loader(data.data_dir, data.batch_size, data.num_workers, data.max_training_samples)
237240
else:
238241
raise NotImplementedError
239242

@@ -254,6 +257,9 @@ def main(args):
254257
help='resume from checkpoint',
255258
default='',
256259
type=str)
260+
parser.add_argument('--data_dir',
261+
default="./data",
262+
type=str)
257263

258264
args = parser.parse_args()
259265
main(args)

utils/log.py

Lines changed: 0 additions & 20 deletions
This file was deleted.

0 commit comments

Comments
 (0)