|
46 | 46 | #----------------------------------------------------#
|
47 | 47 | eager = False
|
48 | 48 | #---------------------------------------------------------------------#
|
| 49 | + # train_gpu 训练用到的GPU |
| 50 | + # 默认为第一张卡、双卡为[0, 1]、三卡为[0, 1, 2] |
| 51 | + # 在使用多GPU时,每个卡上的batch为总batch除以卡的数量。 |
| 52 | + #---------------------------------------------------------------------# |
| 53 | + train_gpu = [0,] |
| 54 | + #---------------------------------------------------------------------# |
49 | 55 | # classes_path 指向model_data下的txt,与自己训练的数据集相关
|
50 | 56 | # 训练前一定要修改classes_path,使其对应自己的数据集
|
51 | 57 | #---------------------------------------------------------------------#
|
|
87 | 93 | # phi = 1为SE
|
88 | 94 | # phi = 2为CBAM
|
89 | 95 | # phi = 3为ECA
|
| 96 | + # phi = 4为CA |
90 | 97 | #-------------------------------#
|
91 | 98 | phi = 0
|
92 | 99 | #------------------------------------------------------------------#
|
|
213 | 220 | train_annotation_path = '2007_train.txt'
|
214 | 221 | val_annotation_path = '2007_val.txt'
|
215 | 222 |
|
| 223 | + #------------------------------------------------------# |
| 224 | + # 设置用到的显卡 |
| 225 | + #------------------------------------------------------# |
| 226 | + os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in train_gpu) |
| 227 | + ngpus_per_node = len(train_gpu) |
| 228 | + |
| 229 | + gpus = tf.config.experimental.list_physical_devices(device_type='GPU') |
| 230 | + for gpu in gpus: |
| 231 | + tf.config.experimental.set_memory_growth(gpu, True) |
| 232 | + |
| 233 | + strategy = tf.distribute.MirroredStrategy() |
| 234 | + print('Number of devices: {}'.format(strategy.num_replicas_in_sync)) |
| 235 | + |
216 | 236 | #----------------------------------------------------#
|
217 | 237 | # 获取classes和anchor
|
218 | 238 | #----------------------------------------------------#
|
219 | 239 | class_names, num_classes = get_classes(classes_path)
|
220 | 240 | anchors, num_anchors = get_anchors(anchors_path)
|
221 | 241 |
|
222 |
| - #------------------------------------------------------# |
223 |
| - # 创建yolo模型 |
224 |
| - #------------------------------------------------------# |
225 |
| - model_body = yolo_body((None, None, 3), anchors_mask, num_classes, phi = phi, weight_decay = weight_decay) |
226 |
| - if model_path != '': |
| 242 | + if ngpus_per_node > 1: |
| 243 | + with strategy.scope(): |
| 244 | + #------------------------------------------------------# |
| 245 | + # 创建yolo模型 |
| 246 | + #------------------------------------------------------# |
| 247 | + model_body = yolo_body((input_shape[0], input_shape[1], 3), anchors_mask, num_classes, phi = phi, weight_decay = weight_decay) |
| 248 | + if model_path != '': |
| 249 | + #------------------------------------------------------# |
| 250 | + # 载入预训练权重 |
| 251 | + #------------------------------------------------------# |
| 252 | + print('Load weights {}.'.format(model_path)) |
| 253 | + model_body.load_weights(model_path, by_name=True, skip_mismatch=True) |
| 254 | + if not eager: |
| 255 | + model = get_train_model(model_body, input_shape, num_classes, anchors, anchors_mask, label_smoothing) |
| 256 | + else: |
227 | 257 | #------------------------------------------------------#
|
228 |
| - # 载入预训练权重 |
| 258 | + # 创建yolo模型 |
229 | 259 | #------------------------------------------------------#
|
230 |
| - print('Load weights {}.'.format(model_path)) |
231 |
| - model_body.load_weights(model_path, by_name=True, skip_mismatch=True) |
232 |
| - |
233 |
| - if not eager: |
234 |
| - model = get_train_model(model_body, input_shape, num_classes, anchors, anchors_mask, label_smoothing) |
235 |
| - |
| 260 | + model_body = yolo_body((input_shape[0], input_shape[1], 3), anchors_mask, num_classes, phi = phi, weight_decay = weight_decay) |
| 261 | + if model_path != '': |
| 262 | + #------------------------------------------------------# |
| 263 | + # 载入预训练权重 |
| 264 | + #------------------------------------------------------# |
| 265 | + print('Load weights {}.'.format(model_path)) |
| 266 | + model_body.load_weights(model_path, by_name=True, skip_mismatch=True) |
| 267 | + if not eager: |
| 268 | + model = get_train_model(model_body, input_shape, num_classes, anchors, anchors_mask, label_smoothing) |
| 269 | + |
236 | 270 | #---------------------------#
|
237 | 271 | # 读取数据集对应的txt
|
238 | 272 | #---------------------------#
|
|
360 | 394 | start_epoch = Init_Epoch
|
361 | 395 | end_epoch = Freeze_Epoch if Freeze_Train else UnFreeze_Epoch
|
362 | 396 |
|
363 |
| - model.compile(optimizer = optimizer, loss={'yolo_loss': lambda y_true, y_pred: y_pred}) |
| 397 | + if ngpus_per_node > 1: |
| 398 | + with strategy.scope(): |
| 399 | + model.compile(optimizer = optimizer, loss={'yolo_loss': lambda y_true, y_pred: y_pred}) |
| 400 | + else: |
| 401 | + model.compile(optimizer = optimizer, loss={'yolo_loss': lambda y_true, y_pred: y_pred}) |
364 | 402 | #-------------------------------------------------------------------------------#
|
365 | 403 | # 训练参数的设置
|
366 | 404 | # logging 用于设置tensorboard的保存地址
|
|
417 | 455 |
|
418 | 456 | for i in range(len(model_body.layers)):
|
419 | 457 | model_body.layers[i].trainable = True
|
420 |
| - model.compile(optimizer = optimizer, loss={'yolo_loss': lambda y_true, y_pred: y_pred}) |
| 458 | + if ngpus_per_node > 1: |
| 459 | + with strategy.scope(): |
| 460 | + model.compile(optimizer = optimizer, loss={'yolo_loss': lambda y_true, y_pred: y_pred}) |
| 461 | + else: |
| 462 | + model.compile(optimizer = optimizer, loss={'yolo_loss': lambda y_true, y_pred: y_pred}) |
421 | 463 |
|
422 | 464 | epoch_step = num_train // batch_size
|
423 | 465 | epoch_step_val = num_val // batch_size
|
|
0 commit comments