Skip to content

Commit c16cd5e

Browse files
committed
update multi gpu
1 parent cf1218b commit c16cd5e

File tree

3 files changed

+28
-11
lines changed

3 files changed

+28
-11
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
9. [参考资料 Reference](#Reference)
1414

1515
## Top News
16+
**`2022-04`**:**支持多GPU训练,新增各个种类目标数量计算,新增heatmap。**
17+
1618
**`2022-03`**:**进行了大幅度的更新,修改了loss组成,使得分类、目标、回归loss的比例合适、支持step、cos学习率下降法、支持adam、sgd优化器选择、支持学习率根据batch_size自适应调整、新增图片裁剪。**
1719
BiliBili视频中的原仓库地址为:https://github.com/bubbliiiing/yolov4-tiny-tf2/tree/bilibili
1820

train.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,11 @@
230230
for gpu in gpus:
231231
tf.config.experimental.set_memory_growth(gpu, True)
232232

233-
strategy = tf.distribute.MirroredStrategy()
234-
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
233+
if ngpus_per_node > 1:
234+
strategy = tf.distribute.MirroredStrategy()
235+
else:
236+
strategy = None
237+
print('Number of devices: {}'.format(ngpus_per_node))
235238

236239
#----------------------------------------------------#
237240
# 获取classes和anchor
@@ -386,7 +389,7 @@
386389
K.set_value(optimizer.lr, lr)
387390

388391
fit_one_epoch(model_body, loss_history, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val,
389-
end_epoch, input_shape, anchors, anchors_mask, num_classes, label_smoothing, save_period, save_dir)
392+
end_epoch, input_shape, anchors, anchors_mask, num_classes, label_smoothing, save_period, save_dir, strategy)
390393

391394
train_dataloader.on_epoch_end()
392395
val_dataloader.on_epoch_end()
@@ -418,8 +421,8 @@
418421

419422
if start_epoch < end_epoch:
420423
print('Train on {} samples, val on {} samples, with batch size {}.'.format(num_train, num_val, batch_size))
421-
model.fit_generator(
422-
generator = train_dataloader,
424+
model.fit(
425+
x = train_dataloader,
423426
steps_per_epoch = epoch_step,
424427
validation_data = val_dataloader,
425428
validation_steps = epoch_step_val,
@@ -471,8 +474,8 @@
471474
val_dataloader.batch_size = Unfreeze_batch_size
472475

473476
print('Train on {} samples, val on {} samples, with batch size {}.'.format(num_train, num_val, batch_size))
474-
model.fit_generator(
475-
generator = train_dataloader,
477+
model.fit(
478+
x = train_dataloader,
476479
steps_per_epoch = epoch_step,
477480
validation_data = val_dataloader,
478481
validation_steps = epoch_step_val,

utils/utils_fit.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
#------------------------------#
99
# 防止bug
1010
#------------------------------#
11-
def get_train_step_fn(input_shape, anchors, anchors_mask, num_classes, label_smoothing):
11+
def get_train_step_fn(input_shape, anchors, anchors_mask, num_classes, label_smoothing, strategy):
1212
@tf.function
1313
def train_step(imgs, targets, net, optimizer):
1414
with tf.GradientTape() as tape:
@@ -32,11 +32,23 @@ def train_step(imgs, targets, net, optimizer):
3232
grads = tape.gradient(loss_value, net.trainable_variables)
3333
optimizer.apply_gradients(zip(grads, net.trainable_variables))
3434
return loss_value
35-
return train_step
35+
36+
if strategy == None:
37+
return train_step
38+
else:
39+
#----------------------#
40+
# 多gpu训练
41+
#----------------------#
42+
@tf.function
43+
def distributed_train_step(images, targets, net, optimizer):
44+
per_replica_losses = strategy.run(train_step, args=(images, targets, net, optimizer,))
45+
return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
46+
axis=None)
47+
return distributed_train_step
3648

3749
def fit_one_epoch(net, loss_history, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch,
38-
input_shape, anchors, anchors_mask, num_classes, label_smoothing, save_period, save_dir):
39-
train_step = get_train_step_fn(input_shape, anchors, anchors_mask, num_classes, label_smoothing)
50+
input_shape, anchors, anchors_mask, num_classes, label_smoothing, save_period, save_dir, strategy):
51+
train_step = get_train_step_fn(input_shape, anchors, anchors_mask, num_classes, label_smoothing, strategy)
4052
loss = 0
4153
val_loss = 0
4254
print('Start Train')

0 commit comments

Comments
 (0)