Skip to content

Commit cf1218b

Browse files
committed
update CA Attention, multi GPU, heatmap, lr, count
1 parent 177d6eb commit cf1218b

File tree

6 files changed

+193
-46
lines changed

6 files changed

+193
-46
lines changed

nets/attention.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,28 @@ def eca_block(input_feature, b=1, gamma=2, name=""):
9898

9999
output = multiply([input_feature,x])
100100
return output
101+
102+
def ca_block(input_feature, ratio=16, name=""):
103+
channel = K.int_shape(input_feature)[-1]
104+
h = K.int_shape(input_feature)[1]
105+
w = K.int_shape(input_feature)[2]
106+
107+
x_h = Lambda(lambda x: K.mean(x, axis=2, keepdims=True))(input_feature)
108+
x_h = Lambda(lambda x: K.permute_dimensions(x, [0, 2, 1, 3]))(x_h)
109+
x_w = Lambda(lambda x: K.max(x, axis=1, keepdims=True))(input_feature)
110+
111+
x_cat_conv_relu = Concatenate(axis=2)([x_w, x_h])
112+
x_cat_conv_relu = Conv2D(channel // ratio, kernel_size=1, strides=1, use_bias=False, name = "ca_block_conv1_"+str(name))(x_cat_conv_relu)
113+
x_cat_conv_relu = Activation('relu')(x_cat_conv_relu)
114+
115+
x_cat_conv_split_h, x_cat_conv_split_w = Lambda(lambda x: tf.split(x, num_or_size_splits=[h, w], axis=2))(x_cat_conv_relu)
116+
x_cat_conv_split_h = Lambda(lambda x: K.permute_dimensions(x, [0, 2, 1, 3]))(x_cat_conv_split_h)
117+
x_cat_conv_split_h = Conv2D(channel, kernel_size=1, strides=1, use_bias=False, name = "ca_block_conv2_"+str(name))(x_cat_conv_split_h)
118+
x_cat_conv_split_h = Activation('sigmoid')(x_cat_conv_split_h)
119+
120+
x_cat_conv_split_w = Conv2D(channel, kernel_size=1, strides=1, use_bias=False, name = "ca_block_conv3_"+str(name))(x_cat_conv_split_w)
121+
x_cat_conv_split_w = Activation('sigmoid')(x_cat_conv_split_w)
122+
123+
output = multiply([input_feature, x_cat_conv_split_h])
124+
output = multiply([output, x_cat_conv_split_w])
125+
return output

nets/yolo.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
from tensorflow.keras.models import Model
33
from utils.utils import compose
44

5-
from nets.attention import cbam_block, eca_block, se_block
5+
from nets.attention import cbam_block, eca_block, se_block, ca_block
66
from nets.CSPdarknet53_tiny import (DarknetConv2D, DarknetConv2D_BN_Leaky,
77
darknet_body)
88
from nets.yolo_training import yolo_loss
99

10-
attention = [se_block, cbam_block, eca_block]
10+
attention = [se_block, cbam_block, eca_block, ca_block]
1111

1212
#---------------------------------------------------#
1313
# 特征层->最后的输出
@@ -20,7 +20,7 @@ def yolo_body(input_shape, anchors_mask, num_classes, phi = 0, weight_decay=5e-4
2020
# feat2的shape为13,13,512
2121
#---------------------------------------------------#
2222
feat1, feat2 = darknet_body(inputs, weight_decay=weight_decay)
23-
if phi >= 1 and phi <= 3:
23+
if phi >= 1 and phi <= 4:
2424
feat1 = attention[phi - 1](feat1, name='feat1')
2525
feat2 = attention[phi - 1](feat2, name='feat2')
2626

@@ -32,7 +32,7 @@ def yolo_body(input_shape, anchors_mask, num_classes, phi = 0, weight_decay=5e-4
3232

3333
# 13,13,256 -> 13,13,128 -> 26,26,128
3434
P5_upsample = compose(DarknetConv2D_BN_Leaky(128, (1,1), weight_decay=weight_decay), UpSampling2D(2))(P5)
35-
if phi >= 1 and phi <= 3:
35+
if phi >= 1 and phi <= 4:
3636
P5_upsample = attention[phi - 1](P5_upsample, name='P5_upsample')
3737

3838
# 26,26,256 + 26,26,128 -> 26,26,384

nets/yolo_training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def loop_body(b, ignore_mask):
291291
loss = tf.Print(loss, [loss, location_loss, confidence_loss, class_loss, tf.shape(ignore_mask)], summarize=100, message='loss: ')
292292
return loss
293293

294-
def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.1, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.3, step_num = 10):
294+
def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.05, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.05, step_num = 10):
295295
def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters):
296296
if iters <= warmup_total_iters:
297297
# lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start

predict.py

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,41 +19,55 @@
1919
yolo = YOLO()
2020
#----------------------------------------------------------------------------------------------------------#
2121
# mode用于指定测试的模式:
22-
# 'predict'表示单张图片预测,如果想对预测过程进行修改,如保存图片,截取对象等,可以先看下方详细的注释
23-
# 'video'表示视频检测,可调用摄像头或者视频进行检测,详情查看下方注释。
24-
# 'fps'表示测试fps,使用的图片是img里面的street.jpg,详情查看下方注释。
25-
# 'dir_predict'表示遍历文件夹进行检测并保存。默认遍历img文件夹,保存img_out文件夹,详情查看下方注释。
22+
# 'predict' 表示单张图片预测,如果想对预测过程进行修改,如保存图片,截取对象等,可以先看下方详细的注释
23+
# 'video' 表示视频检测,可调用摄像头或者视频进行检测,详情查看下方注释。
24+
# 'fps' 表示测试fps,使用的图片是img里面的street.jpg,详情查看下方注释。
25+
# 'dir_predict' 表示遍历文件夹进行检测并保存。默认遍历img文件夹,保存img_out文件夹,详情查看下方注释。
26+
# 'heatmap' 表示进行预测结果的热力图可视化,详情查看下方注释。
2627
#----------------------------------------------------------------------------------------------------------#
2728
mode = "predict"
2829
#-------------------------------------------------------------------------#
29-
# crop指定了是否在单张图片预测后对目标进行截取
30-
# crop仅在mode='predict'时有效
30+
# crop 指定了是否在单张图片预测后对目标进行截取
31+
# count 指定了是否进行目标的计数
32+
# crop、count仅在mode='predict'时有效
3133
#-------------------------------------------------------------------------#
3234
crop = False
35+
count = False
3336
#----------------------------------------------------------------------------------------------------------#
34-
# video_path用于指定视频的路径,当video_path=0时表示检测摄像头
35-
# 想要检测视频,则设置如video_path = "xxx.mp4"即可,代表读取出根目录下的xxx.mp4文件。
36-
# video_save_path表示视频保存的路径,当video_save_path=""时表示不保存
37-
# 想要保存视频,则设置如video_save_path = "yyy.mp4"即可,代表保存为根目录下的yyy.mp4文件。
38-
# video_fps用于保存的视频的fps
37+
# video_path 用于指定视频的路径,当video_path=0时表示检测摄像头
38+
# 想要检测视频,则设置如video_path = "xxx.mp4"即可,代表读取出根目录下的xxx.mp4文件。
39+
# video_save_path 表示视频保存的路径,当video_save_path=""时表示不保存
40+
# 想要保存视频,则设置如video_save_path = "yyy.mp4"即可,代表保存为根目录下的yyy.mp4文件。
41+
# video_fps 用于保存的视频的fps
42+
#
3943
# video_path、video_save_path和video_fps仅在mode='video'时有效
4044
# 保存视频时需要ctrl+c退出或者运行到最后一帧才会完成完整的保存步骤。
4145
#----------------------------------------------------------------------------------------------------------#
4246
video_path = 0
4347
video_save_path = ""
4448
video_fps = 25.0
45-
#-------------------------------------------------------------------------#
46-
# test_interval用于指定测量fps的时候,图片检测的次数
47-
# 理论上test_interval越大,fps越准确。
48-
#-------------------------------------------------------------------------#
49+
#----------------------------------------------------------------------------------------------------------#
50+
# test_interval 用于指定测量fps的时候,图片检测的次数。理论上test_interval越大,fps越准确。
51+
# fps_image_path 用于指定测试的fps图片
52+
#
53+
# test_interval和fps_image_path仅在mode='fps'有效
54+
#----------------------------------------------------------------------------------------------------------#
4955
test_interval = 100
56+
fps_image_path = "img/street.jpg"
5057
#-------------------------------------------------------------------------#
51-
# dir_origin_path指定了用于检测的图片的文件夹路径
52-
# dir_save_path指定了检测完图片的保存路径
58+
# dir_origin_path 指定了用于检测的图片的文件夹路径
59+
# dir_save_path 指定了检测完图片的保存路径
60+
#
5361
# dir_origin_path和dir_save_path仅在mode='dir_predict'时有效
5462
#-------------------------------------------------------------------------#
5563
dir_origin_path = "img/"
5664
dir_save_path = "img_out/"
65+
#-------------------------------------------------------------------------#
66+
# heatmap_save_path 热力图的保存路径,默认保存在model_data下
67+
#
68+
# heatmap_save_path仅在mode='heatmap'有效
69+
#-------------------------------------------------------------------------#
70+
heatmap_save_path = "model_data/heatmap_vision.png"
5771

5872
if mode == "predict":
5973
'''
@@ -72,7 +86,7 @@
7286
print('Open Error! Try again!')
7387
continue
7488
else:
75-
r_image = yolo.detect_image(image, crop = crop)
89+
r_image = yolo.detect_image(image, crop = crop, count = count)
7690
r_image.show()
7791

7892
elif mode == "video":
@@ -121,16 +135,17 @@
121135
print("Save processed video to the path :" + video_save_path)
122136
out.release()
123137
cv2.destroyAllWindows()
124-
138+
125139
elif mode == "fps":
126-
img = Image.open('img/street.jpg')
140+
img = Image.open(fps_image_path)
127141
tact_time = yolo.get_FPS(img, test_interval)
128142
print(str(tact_time) + ' seconds, ' + str(1/tact_time) + 'FPS, @batch_size 1')
129143

130144
elif mode == "dir_predict":
131145
import os
146+
132147
from tqdm import tqdm
133-
148+
134149
img_names = os.listdir(dir_origin_path)
135150
for img_name in tqdm(img_names):
136151
if img_name.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
@@ -141,5 +156,16 @@
141156
os.makedirs(dir_save_path)
142157
r_image.save(os.path.join(dir_save_path, img_name.replace(".jpg", ".png")), quality=95, subsampling=0)
143158

159+
elif mode == "heatmap":
160+
while True:
161+
img = input('Input image filename:')
162+
try:
163+
image = Image.open(img)
164+
except:
165+
print('Open Error! Try again!')
166+
continue
167+
else:
168+
yolo.detect_heatmap(image, heatmap_save_path)
169+
144170
else:
145171
raise AssertionError("Please specify the correct mode: 'predict', 'video', 'fps' or 'dir_predict'.")

train.py

Lines changed: 56 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@
4646
#----------------------------------------------------#
4747
eager = False
4848
#---------------------------------------------------------------------#
49+
# train_gpu 训练用到的GPU
50+
# 默认为第一张卡、双卡为[0, 1]、三卡为[0, 1, 2]
51+
# 在使用多GPU时,每个卡上的batch为总batch除以卡的数量。
52+
#---------------------------------------------------------------------#
53+
train_gpu = [0,]
54+
#---------------------------------------------------------------------#
4955
# classes_path 指向model_data下的txt,与自己训练的数据集相关
5056
# 训练前一定要修改classes_path,使其对应自己的数据集
5157
#---------------------------------------------------------------------#
@@ -87,6 +93,7 @@
8793
# phi = 1为SE
8894
# phi = 2为CBAM
8995
# phi = 3为ECA
96+
# phi = 4为CA
9097
#-------------------------------#
9198
phi = 0
9299
#------------------------------------------------------------------#
@@ -213,26 +220,53 @@
213220
train_annotation_path = '2007_train.txt'
214221
val_annotation_path = '2007_val.txt'
215222

223+
#------------------------------------------------------#
224+
# 设置用到的显卡
225+
#------------------------------------------------------#
226+
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in train_gpu)
227+
ngpus_per_node = len(train_gpu)
228+
229+
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
230+
for gpu in gpus:
231+
tf.config.experimental.set_memory_growth(gpu, True)
232+
233+
strategy = tf.distribute.MirroredStrategy()
234+
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
235+
216236
#----------------------------------------------------#
217237
# 获取classes和anchor
218238
#----------------------------------------------------#
219239
class_names, num_classes = get_classes(classes_path)
220240
anchors, num_anchors = get_anchors(anchors_path)
221241

222-
#------------------------------------------------------#
223-
# 创建yolo模型
224-
#------------------------------------------------------#
225-
model_body = yolo_body((None, None, 3), anchors_mask, num_classes, phi = phi, weight_decay = weight_decay)
226-
if model_path != '':
242+
if ngpus_per_node > 1:
243+
with strategy.scope():
244+
#------------------------------------------------------#
245+
# 创建yolo模型
246+
#------------------------------------------------------#
247+
model_body = yolo_body((input_shape[0], input_shape[1], 3), anchors_mask, num_classes, phi = phi, weight_decay = weight_decay)
248+
if model_path != '':
249+
#------------------------------------------------------#
250+
# 载入预训练权重
251+
#------------------------------------------------------#
252+
print('Load weights {}.'.format(model_path))
253+
model_body.load_weights(model_path, by_name=True, skip_mismatch=True)
254+
if not eager:
255+
model = get_train_model(model_body, input_shape, num_classes, anchors, anchors_mask, label_smoothing)
256+
else:
227257
#------------------------------------------------------#
228-
# 载入预训练权重
258+
# 创建yolo模型
229259
#------------------------------------------------------#
230-
print('Load weights {}.'.format(model_path))
231-
model_body.load_weights(model_path, by_name=True, skip_mismatch=True)
232-
233-
if not eager:
234-
model = get_train_model(model_body, input_shape, num_classes, anchors, anchors_mask, label_smoothing)
235-
260+
model_body = yolo_body((input_shape[0], input_shape[1], 3), anchors_mask, num_classes, phi = phi, weight_decay = weight_decay)
261+
if model_path != '':
262+
#------------------------------------------------------#
263+
# 载入预训练权重
264+
#------------------------------------------------------#
265+
print('Load weights {}.'.format(model_path))
266+
model_body.load_weights(model_path, by_name=True, skip_mismatch=True)
267+
if not eager:
268+
model = get_train_model(model_body, input_shape, num_classes, anchors, anchors_mask, label_smoothing)
269+
236270
#---------------------------#
237271
# 读取数据集对应的txt
238272
#---------------------------#
@@ -360,7 +394,11 @@
360394
start_epoch = Init_Epoch
361395
end_epoch = Freeze_Epoch if Freeze_Train else UnFreeze_Epoch
362396

363-
model.compile(optimizer = optimizer, loss={'yolo_loss': lambda y_true, y_pred: y_pred})
397+
if ngpus_per_node > 1:
398+
with strategy.scope():
399+
model.compile(optimizer = optimizer, loss={'yolo_loss': lambda y_true, y_pred: y_pred})
400+
else:
401+
model.compile(optimizer = optimizer, loss={'yolo_loss': lambda y_true, y_pred: y_pred})
364402
#-------------------------------------------------------------------------------#
365403
# 训练参数的设置
366404
# logging 用于设置tensorboard的保存地址
@@ -417,7 +455,11 @@
417455

418456
for i in range(len(model_body.layers)):
419457
model_body.layers[i].trainable = True
420-
model.compile(optimizer = optimizer, loss={'yolo_loss': lambda y_true, y_pred: y_pred})
458+
if ngpus_per_node > 1:
459+
with strategy.scope():
460+
model.compile(optimizer = optimizer, loss={'yolo_loss': lambda y_true, y_pred: y_pred})
461+
else:
462+
model.compile(optimizer = optimizer, loss={'yolo_loss': lambda y_true, y_pred: y_pred})
421463

422464
epoch_step = num_train // batch_size
423465
epoch_step_val = num_val // batch_size

0 commit comments

Comments
 (0)